IP库加密/其他对IP库的改进

This commit is contained in:
刘祥超
2023-03-30 20:02:46 +08:00
parent 67941b5379
commit cfc1a7461c
21 changed files with 566 additions and 363 deletions

View File

@@ -7,26 +7,48 @@ import (
"compress/gzip"
_ "embed"
"net"
"sync"
)
//go:embed internal-ip-library.db
var ipLibraryData []byte
var defaultLibrary = NewIPLibrary()
var commonLibrary *IPLibrary
var libraryLocker = &sync.Mutex{} // 为了保持加载顺序性
func DefaultIPLibraryData() []byte {
return ipLibraryData
}
// InitDefault 加载默认的IP库
func InitDefault() error {
defaultLibrary.reader = nil
return defaultLibrary.InitFromData(ipLibraryData)
libraryLocker.Lock()
defer libraryLocker.Unlock()
if commonLibrary != nil {
defaultLibrary = commonLibrary
return nil
}
var library = NewIPLibrary()
err := library.InitFromData(ipLibraryData, "")
if err != nil {
return err
}
commonLibrary = library
defaultLibrary = commonLibrary
return nil
}
// Lookup 查询IP信息
func Lookup(ip net.IP) *QueryResult {
return defaultLibrary.Lookup(ip)
}
// LookupIP 查询IP信息
func LookupIP(ip string) *QueryResult {
return defaultLibrary.LookupIP(ip)
}
@@ -43,10 +65,19 @@ func NewIPLibraryWithReader(reader *Reader) *IPLibrary {
return &IPLibrary{reader: reader}
}
func (this *IPLibrary) InitFromData(data []byte) error {
func (this *IPLibrary) InitFromData(data []byte, password string) error {
if len(data) == 0 || this.reader != nil {
return nil
}
if len(password) > 0 {
srcData, err := NewEncrypt().Decode(data, password)
if err != nil {
return err
}
data = srcData
}
var reader = bytes.NewReader(data)
gzipReader, err := gzip.NewReader(reader)
if err != nil {

View File

@@ -15,12 +15,21 @@ import (
func TestIPLibrary_Init(t *testing.T) {
var lib = iplibrary.NewIPLibrary()
err := lib.InitFromData(iplibrary.DefaultIPLibraryData())
err := lib.InitFromData(iplibrary.DefaultIPLibraryData(), "")
if err != nil {
t.Fatal(err)
}
}
func TestIPLibrary_Load(t *testing.T) {
for i := 0; i < 10; i++ {
err := iplibrary.InitDefault()
if err != nil {
t.Fatal(err)
}
}
}
func TestIPLibrary_Lookup(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
@@ -29,7 +38,7 @@ func TestIPLibrary_Lookup(t *testing.T) {
var before = time.Now()
err := lib.InitFromData(iplibrary.DefaultIPLibraryData())
err := lib.InitFromData(iplibrary.DefaultIPLibraryData(), "")
if err != nil {
t.Fatal(err)
}
@@ -48,6 +57,7 @@ func TestIPLibrary_Lookup(t *testing.T) {
"8.8.8.8",
"4.4.4.4",
"202.96.0.20",
"111.197.165.199",
"66.249.66.69",
"2222", // wrong ip
"2406:8c00:0:3401:133:18:168:70", // ipv6
@@ -59,7 +69,7 @@ func TestIPLibrary_Lookup(t *testing.T) {
func TestIPLibrary_LookupIP(t *testing.T) {
var lib = iplibrary.NewIPLibrary()
err := lib.InitFromData(iplibrary.DefaultIPLibraryData())
err := lib.InitFromData(iplibrary.DefaultIPLibraryData(), "")
if err != nil {
t.Fatal(err)
}
@@ -78,7 +88,7 @@ func TestIPLibrary_LookupIP(t *testing.T) {
func BenchmarkIPLibrary_Lookup(b *testing.B) {
var lib = iplibrary.NewIPLibrary()
err := lib.InitFromData(iplibrary.DefaultIPLibraryData())
err := lib.InitFromData(iplibrary.DefaultIPLibraryData(), "")
if err != nil {
b.Fatal(err)
}

32
pkg/iplibrary/encrypt.go Normal file
View File

@@ -0,0 +1,32 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import "github.com/TeaOSLab/EdgeCommon/pkg/nodeutils"
type Encrypt struct {
}
func NewEncrypt() *Encrypt {
return &Encrypt{}
}
func (this *Encrypt) Encode(srcData []byte, password string) ([]byte, error) {
var method = nodeutils.AES256CFBMethod{}
err := method.Init([]byte(password), []byte(password))
if err != nil {
return nil, err
}
return method.Encrypt(srcData)
}
func (this *Encrypt) Decode(encodedData []byte, password string) ([]byte, error) {
var method = nodeutils.AES256CFBMethod{}
err := method.Init([]byte(password), []byte(password))
if err != nil {
return nil, err
}
return method.Decrypt(encodedData)
}

Binary file not shown.

View File

@@ -3,12 +3,17 @@
package iplibrary
import (
"bytes"
"encoding/binary"
"github.com/iwind/TeaGo/types"
)
type ipItem struct {
type ipv4Item struct {
IPFrom uint32
IPTo uint32
Region *ipRegion
}
type ipv6Item struct {
IPFrom uint64
IPTo uint64
@@ -16,23 +21,14 @@ type ipItem struct {
}
type ipRegion struct {
CountryId uint32
ProvinceId uint32
CountryId uint16
ProvinceId uint16
CityId uint32
TownId uint32
ProviderId uint32
ProviderId uint16
}
func (this *ipItem) AsBinary() ([]byte, error) {
var buf = &bytes.Buffer{}
err := binary.Write(buf, binary.BigEndian, this)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func HashRegion(countryId uint32, provinceId uint32, cityId uint32, townId uint32, providerId uint32) string {
func HashRegion(countryId uint16, provinceId uint16, cityId uint32, townId uint32, providerId uint16) string {
var providerHash = ""
if providerId > 0 {
providerHash = "_" + types.String(providerId)

View File

@@ -3,13 +3,13 @@
package iplibrary
type Country struct {
Id uint32 `json:"id"`
Id uint16 `json:"id"`
Name string `json:"name"`
Codes []string `json:"codes"`
}
type Province struct {
Id uint32 `json:"id"`
Id uint16 `json:"id"`
Name string `json:"name"`
Codes []string `json:"codes"`
}
@@ -27,7 +27,7 @@ type Town struct {
}
type Provider struct {
Id uint32 `json:"id"`
Id uint16 `json:"id"`
Name string `json:"name"`
Codes []string `json:"codes"`
}
@@ -43,19 +43,19 @@ type Meta struct {
Providers []*Provider `json:"providers"`
CreatedAt int64 `json:"createdAt"`
countryMap map[uint32]*Country // id => *Country
provinceMap map[uint32]*Province // id => *Province
countryMap map[uint16]*Country // id => *Country
provinceMap map[uint16]*Province // id => *Province
cityMap map[uint32]*City // id => *City
townMap map[uint32]*Town // id => *Town
providerMap map[uint32]*Provider // id => *Provider
providerMap map[uint16]*Provider // id => *Provider
}
func (this *Meta) Init() {
this.countryMap = map[uint32]*Country{}
this.provinceMap = map[uint32]*Province{}
this.countryMap = map[uint16]*Country{}
this.provinceMap = map[uint16]*Province{}
this.cityMap = map[uint32]*City{}
this.townMap = map[uint32]*Town{}
this.providerMap = map[uint32]*Provider{}
this.providerMap = map[uint16]*Provider{}
for _, country := range this.Countries {
this.countryMap[country.Id] = country
@@ -74,11 +74,11 @@ func (this *Meta) Init() {
}
}
func (this *Meta) CountryWithId(countryId uint32) *Country {
func (this *Meta) CountryWithId(countryId uint16) *Country {
return this.countryMap[countryId]
}
func (this *Meta) ProvinceWithId(provinceId uint32) *Province {
func (this *Meta) ProvinceWithId(provinceId uint16) *Province {
return this.provinceMap[provinceId]
}
@@ -90,6 +90,6 @@ func (this *Meta) TownWithId(townId uint32) *Town {
return this.townMap[townId]
}
func (this *Meta) ProviderWithId(providerId uint32) *Provider {
func (this *Meta) ProviderWithId(providerId uint16) *Provider {
return this.providerMap[providerId]
}

View File

@@ -7,10 +7,10 @@ import (
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/iwind/TeaGo/types"
"io"
"net"
"sort"
"strconv"
"strings"
)
@@ -18,17 +18,17 @@ import (
type Reader struct {
meta *Meta
regionMap map[string]*ipRegion
regionMap map[string]*ipRegion // 缓存重复的区域用来节约内存
ipV4Items []*ipItem
ipV6Items []*ipItem
ipV4Items []*ipv4Item
ipV6Items []*ipv6Item
lastIPFrom uint64
lastCountryId uint32
lastProvinceId uint32
lastCountryId uint16
lastProvinceId uint16
lastCityId uint32
lastTownId uint32
lastProviderId uint32
lastProviderId uint16
}
// NewReader 创建新Reader对象
@@ -112,6 +112,9 @@ func (this *Reader) load(reader io.Reader) error {
return from0 < from1
})
// 清理内存
this.regionMap = nil
return nil
}
@@ -122,12 +125,12 @@ func (this *Reader) Lookup(ip net.IP) *QueryResult {
var ipLong = configutils.IP2Long(ip)
var isV4 = configutils.IsIPv4(ip)
var resultItem *ipItem
var resultItem any
if isV4 {
sort.Search(len(this.ipV4Items), func(i int) bool {
var item = this.ipV4Items[i]
if item.IPFrom <= ipLong {
if item.IPTo >= ipLong {
if item.IPFrom <= uint32(ipLong) {
if item.IPTo >= uint32(ipLong) {
resultItem = item
return false
}
@@ -159,11 +162,11 @@ func (this *Reader) Meta() *Meta {
return this.meta
}
func (this *Reader) IPv4Items() []*ipItem {
func (this *Reader) IPv4Items() []*ipv4Item {
return this.ipV4Items
}
func (this *Reader) IPv6Items() []*ipItem {
func (this *Reader) IPv6Items() []*ipv6Item {
return this.ipV6Items
}
@@ -216,31 +219,31 @@ func (this *Reader) parseLine(line []byte) error {
var ipFrom uint64
var ipTo uint64
if strings.HasPrefix(pieces[1], "+") {
ipFrom = this.lastIPFrom + types.Uint64(pieces[1][1:])
ipFrom = this.lastIPFrom + this.decodeUint64(pieces[1][1:])
} else {
ipFrom = types.Uint64(pieces[1])
ipFrom = this.decodeUint64(pieces[1])
}
if len(pieces[2]) == 0 {
ipTo = ipFrom
} else {
ipTo = types.Uint64(pieces[2]) + ipFrom
ipTo = this.decodeUint64(pieces[2]) + ipFrom
}
this.lastIPFrom = ipFrom
// country
var countryId uint32
var countryId uint16
if pieces[3] == "+" {
countryId = this.lastCountryId
} else {
countryId = types.Uint32(pieces[3])
countryId = uint16(this.decodeUint64(pieces[3]))
}
this.lastCountryId = countryId
var provinceId uint32
var provinceId uint16
if pieces[4] == "+" {
provinceId = this.lastProvinceId
} else {
provinceId = types.Uint32(pieces[4])
provinceId = uint16(this.decodeUint64(pieces[4]))
}
this.lastProvinceId = provinceId
@@ -249,7 +252,7 @@ func (this *Reader) parseLine(line []byte) error {
if pieces[5] == "+" {
cityId = this.lastCityId
} else {
cityId = types.Uint32(pieces[5])
cityId = uint32(this.decodeUint64(pieces[5]))
}
this.lastCityId = cityId
@@ -258,16 +261,16 @@ func (this *Reader) parseLine(line []byte) error {
if pieces[6] == "+" {
townId = this.lastTownId
} else {
townId = types.Uint32(pieces[6])
townId = uint32(this.decodeUint64(pieces[6]))
}
this.lastTownId = townId
// provider
var providerId uint32
var providerId uint16
if pieces[7] == "+" {
providerId = this.lastProviderId
} else {
providerId = types.Uint32(pieces[7])
providerId = uint16(this.decodeUint64(pieces[7]))
}
this.lastProviderId = providerId
@@ -286,13 +289,13 @@ func (this *Reader) parseLine(line []byte) error {
}
if version == "4" {
this.ipV4Items = append(this.ipV4Items, &ipItem{
IPFrom: ipFrom,
IPTo: ipTo,
this.ipV4Items = append(this.ipV4Items, &ipv4Item{
IPFrom: uint32(ipFrom),
IPTo: uint32(ipTo),
Region: region,
})
} else {
this.ipV6Items = append(this.ipV6Items, &ipItem{
this.ipV6Items = append(this.ipV6Items, &ipv6Item{
IPFrom: ipFrom,
IPTo: ipTo,
Region: region,
@@ -301,3 +304,12 @@ func (this *Reader) parseLine(line []byte) error {
return nil
}
func (this *Reader) decodeUint64(s string) uint64 {
if this.meta != nil && this.meta.Version == Version2 {
i, _ := strconv.ParseUint(s, 32, 64)
return i
}
i, _ := strconv.ParseUint(s, 10, 64)
return i
}

View File

@@ -3,6 +3,7 @@
package iplibrary
import (
"bytes"
"compress/gzip"
"errors"
"io"
@@ -12,9 +13,10 @@ import (
type FileReader struct {
rawReader *Reader
password string
}
func NewFileReader(path string) (*FileReader, error) {
func NewFileReader(path string, password string) (*FileReader, error) {
fp, err := os.Open(path)
if err != nil {
return nil, err
@@ -23,10 +25,24 @@ func NewFileReader(path string) (*FileReader, error) {
_ = fp.Close()
}()
return NewFileDataReader(fp)
return NewFileDataReader(fp, password)
}
func NewFileDataReader(dataReader io.Reader) (*FileReader, error) {
func NewFileDataReader(dataReader io.Reader, password string) (*FileReader, error) {
if len(password) > 0 {
data, err := io.ReadAll(dataReader)
if err != nil {
return nil, err
}
sourceData, err := NewEncrypt().Decode(data, password)
if err != nil {
return nil, err
}
dataReader = bytes.NewReader(sourceData)
}
gzReader, err := gzip.NewReader(dataReader)
if err != nil {
return nil, errors.New("create gzip reader failed: " + err.Error())

View File

@@ -11,7 +11,7 @@ import (
)
func TestNewFileReader(t *testing.T) {
reader, err := iplibrary.NewFileReader("./ip")
reader, err := iplibrary.NewFileReader("./ip-20c1461c.db", "123456")
if err != nil {
t.Fatal(err)
}

View File

@@ -8,7 +8,7 @@ import (
)
type QueryResult struct {
item *ipItem
item any
meta *Meta
}
@@ -17,18 +17,16 @@ func (this *QueryResult) IsOk() bool {
}
func (this *QueryResult) CountryId() int64 {
if this.item != nil {
return int64(this.item.Region.CountryId)
}
return 0
return int64(this.realCountryId())
}
func (this *QueryResult) CountryName() string {
if this.item == nil {
return ""
}
if this.item.Region.CountryId > 0 {
var country = this.meta.CountryWithId(this.item.Region.CountryId)
var countryId = this.realCountryId()
if countryId > 0 {
var country = this.meta.CountryWithId(countryId)
if country != nil {
return country.Name
}
@@ -40,8 +38,9 @@ func (this *QueryResult) CountryCodes() []string {
if this.item == nil {
return nil
}
if this.item.Region.CountryId > 0 {
var country = this.meta.CountryWithId(this.item.Region.CountryId)
var countryId = this.realCountryId()
if countryId > 0 {
var country = this.meta.CountryWithId(countryId)
if country != nil {
return country.Codes
}
@@ -50,18 +49,16 @@ func (this *QueryResult) CountryCodes() []string {
}
func (this *QueryResult) ProvinceId() int64 {
if this.item != nil {
return int64(this.item.Region.ProvinceId)
}
return 0
return int64(this.realProvinceId())
}
func (this *QueryResult) ProvinceName() string {
if this.item == nil {
return ""
}
if this.item.Region.ProvinceId > 0 {
var province = this.meta.ProvinceWithId(this.item.Region.ProvinceId)
var provinceId = this.realProvinceId()
if provinceId > 0 {
var province = this.meta.ProvinceWithId(provinceId)
if province != nil {
return province.Name
}
@@ -73,8 +70,9 @@ func (this *QueryResult) ProvinceCodes() []string {
if this.item == nil {
return nil
}
if this.item.Region.ProvinceId > 0 {
var province = this.meta.ProvinceWithId(this.item.Region.ProvinceId)
var provinceId = this.realProvinceId()
if provinceId > 0 {
var province = this.meta.ProvinceWithId(provinceId)
if province != nil {
return province.Codes
}
@@ -83,18 +81,16 @@ func (this *QueryResult) ProvinceCodes() []string {
}
func (this *QueryResult) CityId() int64 {
if this.item != nil {
return int64(this.item.Region.CityId)
}
return 0
return int64(this.realCityId())
}
func (this *QueryResult) CityName() string {
if this.item == nil {
return ""
}
if this.item.Region.CityId > 0 {
var city = this.meta.CityWithId(this.item.Region.CityId)
var cityId = this.realCityId()
if cityId > 0 {
var city = this.meta.CityWithId(cityId)
if city != nil {
return city.Name
}
@@ -103,18 +99,16 @@ func (this *QueryResult) CityName() string {
}
func (this *QueryResult) TownId() int64 {
if this.item != nil {
return int64(this.item.Region.TownId)
}
return 0
return int64(this.realTownId())
}
func (this *QueryResult) TownName() string {
if this.item == nil {
return ""
}
if this.item.Region.TownId > 0 {
var town = this.meta.TownWithId(this.item.Region.TownId)
var townId = this.realTownId()
if townId > 0 {
var town = this.meta.TownWithId(townId)
if town != nil {
return town.Name
}
@@ -123,18 +117,16 @@ func (this *QueryResult) TownName() string {
}
func (this *QueryResult) ProviderId() int64 {
if this.item != nil {
return int64(this.item.Region.ProviderId)
}
return 0
return int64(this.realProviderId())
}
func (this *QueryResult) ProviderName() string {
if this.item == nil {
return ""
}
if this.item.Region.ProviderId > 0 {
var provider = this.meta.ProviderWithId(this.item.Region.ProviderId)
var providerId = this.realProviderId()
if providerId > 0 {
var provider = this.meta.ProviderWithId(providerId)
if provider != nil {
return provider.Name
}
@@ -146,8 +138,9 @@ func (this *QueryResult) ProviderCodes() []string {
if this.item == nil {
return nil
}
if this.item.Region.ProviderId > 0 {
var provider = this.meta.ProviderWithId(this.item.Region.ProviderId)
var providerId = this.realProviderId()
if providerId > 0 {
var provider = this.meta.ProviderWithId(providerId)
if provider != nil {
return provider.Codes
}
@@ -189,3 +182,68 @@ func (this *QueryResult) Summary() string {
return strings.Join(pieces, " ")
}
func (this *QueryResult) realCountryId() uint16 {
if this.item != nil {
switch item := this.item.(type) {
case *ipv4Item:
return item.Region.CountryId
case *ipv6Item:
return item.Region.CountryId
}
}
return 0
}
func (this *QueryResult) realProvinceId() uint16 {
if this.item != nil {
switch item := this.item.(type) {
case *ipv4Item:
return item.Region.ProvinceId
case *ipv6Item:
return item.Region.ProvinceId
}
}
return 0
}
func (this *QueryResult) realCityId() uint32 {
if this.item != nil {
switch item := this.item.(type) {
case *ipv4Item:
return item.Region.CityId
case *ipv6Item:
return item.Region.CityId
}
}
return 0
}
func (this *QueryResult) realTownId() uint32 {
if this.item != nil {
switch item := this.item.(type) {
case *ipv4Item:
return item.Region.TownId
case *ipv6Item:
return item.Region.TownId
}
}
return 0
}
func (this *QueryResult) realProviderId() uint16 {
if this.item != nil {
switch item := this.item.(type) {
case *ipv4Item:
return item.Region.ProviderId
case *ipv6Item:
return item.Region.ProviderId
}
}
return 0
}

View File

@@ -213,7 +213,7 @@ func (this *Updater) Loop() error {
func (this *Updater) loadFile(fp *os.File) error {
this.source.LogInfo("load ip library from '" + fp.Name() + "' ...")
fileReader, err := NewFileDataReader(fp)
fileReader, err := NewFileDataReader(fp, "")
if err != nil {
return errors.New("load ip library from reader failed: " + err.Error())
}

View File

@@ -6,4 +6,5 @@ type Version = int
const (
Version1 Version = 1
Version2 Version = 2 // 主要变更为数字使用32进制
)

View File

@@ -8,10 +8,10 @@ import (
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/iwind/TeaGo/types"
"hash"
"io"
"net"
"strconv"
"strings"
"time"
)
@@ -54,7 +54,7 @@ func NewWriter(writer io.Writer, meta *Meta) *Writer {
if meta == nil {
meta = &Meta{}
}
meta.Version = Version1
meta.Version = Version2
meta.CreatedAt = time.Now().Unix()
var libWriter = &Writer{
@@ -111,9 +111,9 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
}
if this.lastIPFrom > 0 && fromIPLong > this.lastIPFrom {
pieces = append(pieces, "+"+types.String(fromIPLong-this.lastIPFrom))
pieces = append(pieces, "+"+this.formatUint64(fromIPLong-this.lastIPFrom))
} else {
pieces = append(pieces, types.String(fromIPLong))
pieces = append(pieces, this.formatUint64(fromIPLong))
}
this.lastIPFrom = fromIPLong
if ipFrom == ipTo {
@@ -121,7 +121,7 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
pieces = append(pieces, "")
} else {
// 2
pieces = append(pieces, types.String(toIPLong-fromIPLong))
pieces = append(pieces, this.formatUint64(toIPLong-fromIPLong))
}
// 3
@@ -129,7 +129,7 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
if countryId == this.lastCountryId {
pieces = append(pieces, "+")
} else {
pieces = append(pieces, types.String(countryId))
pieces = append(pieces, this.formatUint64(uint64(countryId)))
}
} else {
pieces = append(pieces, "")
@@ -141,7 +141,7 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
if provinceId == this.lastProvinceId {
pieces = append(pieces, "+")
} else {
pieces = append(pieces, types.String(provinceId))
pieces = append(pieces, this.formatUint64(uint64(provinceId)))
}
} else {
pieces = append(pieces, "")
@@ -153,7 +153,7 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
if cityId == this.lastCityId {
pieces = append(pieces, "+")
} else {
pieces = append(pieces, types.String(cityId))
pieces = append(pieces, this.formatUint64(uint64(cityId)))
}
} else {
pieces = append(pieces, "")
@@ -165,7 +165,7 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
if townId == this.lastTownId {
pieces = append(pieces, "+")
} else {
pieces = append(pieces, types.String(townId))
pieces = append(pieces, this.formatUint64(uint64(townId)))
}
} else {
pieces = append(pieces, "")
@@ -177,7 +177,7 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
if providerId == this.lastProviderId {
pieces = append(pieces, "+")
} else {
pieces = append(pieces, types.String(providerId))
pieces = append(pieces, this.formatUint64(uint64(providerId)))
}
} else {
pieces = append(pieces, "")
@@ -196,3 +196,7 @@ func (this *Writer) Write(ipFrom string, ipTo string, countryId int64, provinceI
func (this *Writer) Sum() string {
return this.writer.Sum()
}
func (this *Writer) formatUint64(i uint64) string {
return strconv.FormatUint(i, 32)
}

View File

@@ -10,11 +10,12 @@ import (
type FileWriter struct {
fp *os.File
gzWriter *gzip.Writer
password string
rawWriter *Writer
}
func NewFileWriter(path string, meta *Meta) (*FileWriter, error) {
func NewFileWriter(path string, meta *Meta, password string) (*FileWriter, error) {
fp, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
if err != nil {
return nil, err
@@ -29,6 +30,7 @@ func NewFileWriter(path string, meta *Meta) (*FileWriter, error) {
fp: fp,
gzWriter: gzWriter,
rawWriter: NewWriter(gzWriter, meta),
password: password,
}
return writer, nil
}
@@ -54,5 +56,25 @@ func (this *FileWriter) Close() error {
if err2 != nil {
return err2
}
// 加密内容
if len(this.password) > 0 {
var filePath = this.fp.Name()
data, err := os.ReadFile(filePath)
if err != nil {
return err
}
if len(data) > 0 {
encodedData, err := NewEncrypt().Encode(data, this.password)
if err != nil {
return err
}
err = os.WriteFile(filePath, encodedData, 0666)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -10,9 +10,9 @@ import (
)
func TestNewFileWriter(t *testing.T) {
writer, err := iplibrary.NewFileWriter("./internal-ip-library.db", &iplibrary.Meta{
writer, err := iplibrary.NewFileWriter("./internal-ip-library-test.db", &iplibrary.Meta{
Author: "GoEdge",
})
}, "")
if err != nil {
t.Fatal(err)
}
@@ -41,7 +41,7 @@ func TestNewFileWriter(t *testing.T) {
return types.String(rands.Int(0, 255))
}
for i := 0; i < 1_000_000; i++ {
for i := 0; i < 1; i++ {
err = writer.Write(n()+"."+n()+"."+n()+"."+n(), n()+"."+n()+"."+n()+"."+n(), int64(i)+100, 201, 301, 401, 501)
if err != nil {
t.Fatal(err)