mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 06:40:25 +08:00
使用KV数据库来管理IP名单
This commit is contained in:
@@ -97,7 +97,7 @@ func (this *KVListFileStore) ExistItem(hash string) (bool, error) {
|
||||
|
||||
item, err := this.itemsTable.Get(hash)
|
||||
if err != nil {
|
||||
if kvstore.IsKeyNotFound(err) {
|
||||
if kvstore.IsNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
|
||||
@@ -1,305 +1,13 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
|
||||
type IPListDB struct {
|
||||
db *dbs.DB
|
||||
|
||||
itemTableName string
|
||||
versionTableName string
|
||||
|
||||
deleteExpiredItemsStmt *dbs.Stmt
|
||||
deleteItemStmt *dbs.Stmt
|
||||
insertItemStmt *dbs.Stmt
|
||||
selectItemsStmt *dbs.Stmt
|
||||
selectMaxItemVersionStmt *dbs.Stmt
|
||||
|
||||
selectVersionStmt *dbs.Stmt
|
||||
updateVersionStmt *dbs.Stmt
|
||||
|
||||
cleanTicker *time.Ticker
|
||||
|
||||
dir string
|
||||
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewIPListDB() (*IPListDB, error) {
|
||||
var db = &IPListDB{
|
||||
itemTableName: "ipItems",
|
||||
versionTableName: "versions",
|
||||
dir: filepath.Clean(Tea.Root + "/data"),
|
||||
cleanTicker: time.NewTicker(24 * time.Hour),
|
||||
}
|
||||
err := db.init()
|
||||
return db, err
|
||||
}
|
||||
|
||||
func (this *IPListDB) init() error {
|
||||
// 检查目录是否存在
|
||||
_, err := os.Stat(this.dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(this.dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
|
||||
}
|
||||
|
||||
var path = this.dir + "/ip_list.db"
|
||||
|
||||
db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
//_, err = db.Exec("VACUUM")
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
|
||||
this.db = db
|
||||
|
||||
// 恢复数据库
|
||||
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
|
||||
if len(recoverEnv) > 0 {
|
||||
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
|
||||
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"listId" integer DEFAULT 0,
|
||||
"listType" varchar(32),
|
||||
"isGlobal" integer(1) DEFAULT 0,
|
||||
"type" varchar(16),
|
||||
"itemId" integer DEFAULT 0,
|
||||
"ipFrom" varchar(64) DEFAULT 0,
|
||||
"ipTo" varchar(64) DEFAULT 0,
|
||||
"expiredAt" integer DEFAULT 0,
|
||||
"eventLevel" varchar(32),
|
||||
"isDeleted" integer(1) DEFAULT 0,
|
||||
"version" integer DEFAULT 0,
|
||||
"nodeId" integer DEFAULT 0,
|
||||
"serverId" integer DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"itemId" ASC
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"expiredAt" ASC
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"version" integer DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化SQL语句
|
||||
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.db = db
|
||||
|
||||
goman.New(func() {
|
||||
events.OnClose(func() {
|
||||
_ = this.Close()
|
||||
this.cleanTicker.Stop()
|
||||
})
|
||||
|
||||
for range this.cleanTicker.C {
|
||||
err := this.DeleteExpiredItems()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpiredItems 删除过期的条目
|
||||
func (this *IPListDB) DeleteExpiredItems() error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListDB) AddItem(item *pb.IPItem) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.deleteItemStmt.Exec(item.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是删除,则不再创建新记录
|
||||
if item.IsDeleted {
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
|
||||
if this.isClosed {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := this.selectItemsStmt.Query(offset, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
|
||||
var pbItem = &pb.IPItem{}
|
||||
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, pbItem)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ReadMaxVersion 读取当前最大版本号
|
||||
func (this *IPListDB) ReadMaxVersion() int64 {
|
||||
if this.isClosed {
|
||||
return 0
|
||||
}
|
||||
|
||||
// from version table
|
||||
{
|
||||
var row = this.selectVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0
|
||||
}
|
||||
var version int64
|
||||
err := row.Scan(&version)
|
||||
if err == nil {
|
||||
return version
|
||||
}
|
||||
}
|
||||
|
||||
// from items table
|
||||
{
|
||||
var row = this.selectMaxItemVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0
|
||||
}
|
||||
var version int64
|
||||
err := row.Scan(&version)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return version
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMaxVersion 修改版本号
|
||||
func (this *IPListDB) UpdateMaxVersion(version int64) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.updateVersionStmt.Exec(version)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListDB) Close() error {
|
||||
this.isClosed = true
|
||||
|
||||
if this.db != nil {
|
||||
for _, stmt := range []*dbs.Stmt{
|
||||
this.deleteExpiredItemsStmt,
|
||||
this.deleteItemStmt,
|
||||
this.insertItemStmt,
|
||||
this.selectItemsStmt,
|
||||
this.selectMaxItemVersionStmt, // ipItems table
|
||||
|
||||
this.selectVersionStmt, // versions table
|
||||
this.updateVersionStmt,
|
||||
} {
|
||||
if stmt != nil {
|
||||
_ = stmt.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return this.db.Close()
|
||||
}
|
||||
return nil
|
||||
type IPListDB interface {
|
||||
Name() string
|
||||
DeleteExpiredItems() error
|
||||
ReadMaxVersion() (int64, error)
|
||||
ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error)
|
||||
AddItem(item *pb.IPItem) error
|
||||
}
|
||||
|
||||
229
internal/iplibrary/ip_list_kv.go
Normal file
229
internal/iplibrary/ip_list_kv.go
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IPListKV struct {
|
||||
ipTable *kvstore.Table[*pb.IPItem]
|
||||
versionsTable *kvstore.Table[int64]
|
||||
|
||||
encoder *IPItemEncoder[*pb.IPItem]
|
||||
|
||||
cleanTicker *time.Ticker
|
||||
|
||||
isClosed bool
|
||||
|
||||
offsetItemKey string
|
||||
}
|
||||
|
||||
func NewIPListKV() (*IPListKV, error) {
|
||||
var db = &IPListKV{
|
||||
cleanTicker: time.NewTicker(24 * time.Hour),
|
||||
encoder: &IPItemEncoder[*pb.IPItem]{},
|
||||
}
|
||||
err := db.init()
|
||||
return db, err
|
||||
}
|
||||
|
||||
func (this *IPListKV) init() error {
|
||||
store, storeErr := kvstore.DefaultStore()
|
||||
if storeErr != nil {
|
||||
return storeErr
|
||||
}
|
||||
db, dbErr := store.NewDB("ip_list")
|
||||
if dbErr != nil {
|
||||
return dbErr
|
||||
}
|
||||
|
||||
{
|
||||
table, err := kvstore.NewTable[*pb.IPItem]("ip_items", this.encoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.ipTable = table
|
||||
|
||||
err = table.AddFields("expiresAt")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.AddTable(table)
|
||||
}
|
||||
|
||||
{
|
||||
table, err := kvstore.NewTable[int64]("versions", kvstore.NewIntValueEncoder[int64]())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.versionsTable = table
|
||||
db.AddTable(table)
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
events.OnClose(func() {
|
||||
_ = this.Close()
|
||||
this.cleanTicker.Stop()
|
||||
})
|
||||
|
||||
for range this.cleanTicker.C {
|
||||
deleteErr := this.DeleteExpiredItems()
|
||||
if deleteErr != nil {
|
||||
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+deleteErr.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name 数据库名称代号
|
||||
func (this *IPListKV) Name() string {
|
||||
return "kvstore"
|
||||
}
|
||||
|
||||
// DeleteExpiredItems 删除过期的条目
|
||||
func (this *IPListKV) DeleteExpiredItems() error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
var found bool
|
||||
var currentTime = fasttime.Now().Unix()
|
||||
err := this.ipTable.
|
||||
Query().
|
||||
FieldAsc("expiresAt").
|
||||
ForUpdate().
|
||||
Limit(1000).
|
||||
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
|
||||
if !item.Value.IsDeleted && item.Value.ExpiredAt == 0 { // never expires
|
||||
return kvstore.Skip()
|
||||
}
|
||||
if item.Value.ExpiredAt < currentTime-7*86400 /** keep for 7 days **/ {
|
||||
err = tx.Delete(item.Key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
found = true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
found = false
|
||||
return false, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *IPListKV) AddItem(item *pb.IPItem) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先删除
|
||||
var key = this.encoder.EncodeKey(item)
|
||||
err := this.ipTable.Delete(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是删除,则不再创建新记录
|
||||
if item.IsDeleted {
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
err = this.ipTable.Set(key, item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
func (this *IPListKV) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNextLoop bool, err error) {
|
||||
if this.isClosed {
|
||||
return
|
||||
}
|
||||
|
||||
err = this.ipTable.
|
||||
Query().
|
||||
Offset(this.offsetItemKey).
|
||||
Limit(int(size)).
|
||||
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
|
||||
this.offsetItemKey = item.Key
|
||||
goNextLoop = true
|
||||
|
||||
if !item.Value.IsDeleted {
|
||||
items = append(items, item.Value)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ReadMaxVersion 读取当前最大版本号
|
||||
func (this *IPListKV) ReadMaxVersion() (int64, error) {
|
||||
if this.isClosed {
|
||||
return 0, errors.New("database has been closed")
|
||||
}
|
||||
|
||||
version, err := this.versionsTable.Get("version")
|
||||
if err != nil {
|
||||
if kvstore.IsNotFound(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// UpdateMaxVersion 修改版本号
|
||||
func (this *IPListKV) UpdateMaxVersion(version int64) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return this.versionsTable.SetSync("version", version)
|
||||
}
|
||||
|
||||
func (this *IPListKV) TestInspect(t *testing.T) error {
|
||||
return this.ipTable.
|
||||
Query().
|
||||
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
|
||||
if len(item.Key) != 8 {
|
||||
return false, errors.New("invalid key '" + item.Key + "'")
|
||||
}
|
||||
|
||||
t.Log(binary.BigEndian.Uint64([]byte(item.Key)), "=>", item.Value)
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Flush to disk
|
||||
func (this *IPListKV) Flush() error {
|
||||
return this.ipTable.DB().Store().Flush()
|
||||
}
|
||||
|
||||
func (this *IPListKV) Close() error {
|
||||
this.isClosed = true
|
||||
return nil
|
||||
}
|
||||
55
internal/iplibrary/ip_list_kv_objects.go
Normal file
55
internal/iplibrary/ip_list_kv_objects.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"math"
|
||||
)
|
||||
|
||||
type IPItemEncoder[T interface{ *pb.IPItem }] struct {
|
||||
}
|
||||
|
||||
func NewIPItemEncoder[T interface{ *pb.IPItem }]() *IPItemEncoder[T] {
|
||||
return &IPItemEncoder[T]{}
|
||||
}
|
||||
|
||||
func (this *IPItemEncoder[T]) Encode(value T) ([]byte, error) {
|
||||
return proto.Marshal(any(value).(*pb.IPItem))
|
||||
}
|
||||
|
||||
func (this *IPItemEncoder[T]) EncodeField(value T, fieldName string) ([]byte, error) {
|
||||
switch fieldName {
|
||||
case "expiresAt":
|
||||
var expiresAt = any(value).(*pb.IPItem).ExpiredAt
|
||||
if expiresAt < 0 || expiresAt > int64(math.MaxUint32) {
|
||||
expiresAt = 0
|
||||
}
|
||||
var b = make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, uint32(expiresAt))
|
||||
return b, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("field '" + fieldName + "' not found")
|
||||
}
|
||||
|
||||
func (this *IPItemEncoder[T]) Decode(valueBytes []byte) (value T, err error) {
|
||||
var item = &pb.IPItem{}
|
||||
err = proto.Unmarshal(valueBytes, item)
|
||||
value = item
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeKey generate key for ip item
|
||||
func (this *IPItemEncoder[T]) EncodeKey(item *pb.IPItem) string {
|
||||
var b = make([]byte, 8)
|
||||
if item.Id < 0 {
|
||||
item.Id = 0
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint64(b, uint64(item.Id))
|
||||
return string(b)
|
||||
}
|
||||
221
internal/iplibrary/ip_list_kv_test.go
Normal file
221
internal/iplibrary/ip_list_kv_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package iplibrary_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPListKV_AddItem(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = kv.Flush()
|
||||
}()
|
||||
|
||||
{
|
||||
err = kv.AddItem(&pb.IPItem{
|
||||
Id: 1,
|
||||
IpFrom: "192.168.1.101",
|
||||
IpTo: "",
|
||||
Version: 1,
|
||||
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
|
||||
ListId: 1,
|
||||
IsDeleted: false,
|
||||
ListType: "white",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
err = kv.AddItem(&pb.IPItem{
|
||||
Id: 2,
|
||||
IpFrom: "192.168.1.102",
|
||||
IpTo: "",
|
||||
Version: 2,
|
||||
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
|
||||
ListId: 1,
|
||||
IsDeleted: false,
|
||||
ListType: "white",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
err = kv.AddItem(&pb.IPItem{
|
||||
Id: 3,
|
||||
IpFrom: "192.168.1.103",
|
||||
IpTo: "",
|
||||
Version: 3,
|
||||
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
|
||||
ListId: 1,
|
||||
IsDeleted: false,
|
||||
ListType: "white",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPListKV_AddItems_Many(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = kv.Flush()
|
||||
}()
|
||||
|
||||
var count = 2
|
||||
var from = 1
|
||||
if testutils.IsSingleTesting() {
|
||||
count = 2_000_000
|
||||
}
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Logf("cost: %.2f s", time.Since(before).Seconds())
|
||||
}()
|
||||
|
||||
for i := from; i <= from+count; i++ {
|
||||
err = kv.AddItem(&pb.IPItem{
|
||||
Id: int64(i),
|
||||
IpFrom: testutils.RandIP(),
|
||||
IpTo: "",
|
||||
Version: int64(i),
|
||||
ExpiredAt: fasttime.NewFastTime().Unix() + 86400,
|
||||
ListId: 1,
|
||||
IsDeleted: false,
|
||||
ListType: "white",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPListKV_DeleteExpiredItems(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = kv.Flush()
|
||||
}()
|
||||
|
||||
err = kv.DeleteExpiredItems()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPListKV_UpdateMaxVersion(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = kv.Flush()
|
||||
}()
|
||||
|
||||
err = kv.UpdateMaxVersion(101)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
maxVersion, err := kv.ReadMaxVersion()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log("version:", maxVersion)
|
||||
}
|
||||
|
||||
func TestIPListKV_ReadMaxVersion(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
maxVersion, err := kv.ReadMaxVersion()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log("version:", maxVersion)
|
||||
}
|
||||
|
||||
func TestIPListKV_ReadItems(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for {
|
||||
items, goNext, readErr := kv.ReadItems(0, 2)
|
||||
if readErr != nil {
|
||||
t.Fatal(readErr)
|
||||
}
|
||||
t.Log("====")
|
||||
for _, item := range items {
|
||||
t.Log(item.Id)
|
||||
}
|
||||
|
||||
if !goNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPListKV_CountItems(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var count int
|
||||
var m = map[int64]zero.Zero{}
|
||||
for {
|
||||
items, goNext, readErr := kv.ReadItems(0, 1000)
|
||||
if readErr != nil {
|
||||
t.Fatal(readErr)
|
||||
}
|
||||
for _, item := range items {
|
||||
count++
|
||||
m[item.Id] = zero.Zero{}
|
||||
}
|
||||
|
||||
if !goNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Log("count:", count, "len:", len(m))
|
||||
}
|
||||
|
||||
func TestIPListKV_Inspect(t *testing.T) {
|
||||
kv, err := iplibrary.NewIPListKV()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = kv.TestInspect(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
312
internal/iplibrary/ip_list_sqlite.go
Normal file
312
internal/iplibrary/ip_list_sqlite.go
Normal file
@@ -0,0 +1,312 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IPListSQLite struct {
|
||||
db *dbs.DB
|
||||
|
||||
itemTableName string
|
||||
versionTableName string
|
||||
|
||||
deleteExpiredItemsStmt *dbs.Stmt
|
||||
deleteItemStmt *dbs.Stmt
|
||||
insertItemStmt *dbs.Stmt
|
||||
selectItemsStmt *dbs.Stmt
|
||||
selectMaxItemVersionStmt *dbs.Stmt
|
||||
|
||||
selectVersionStmt *dbs.Stmt
|
||||
updateVersionStmt *dbs.Stmt
|
||||
|
||||
cleanTicker *time.Ticker
|
||||
|
||||
dir string
|
||||
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewIPListSqlite() (*IPListSQLite, error) {
|
||||
var db = &IPListSQLite{
|
||||
itemTableName: "ipItems",
|
||||
versionTableName: "versions",
|
||||
dir: filepath.Clean(Tea.Root + "/data"),
|
||||
cleanTicker: time.NewTicker(24 * time.Hour),
|
||||
}
|
||||
err := db.init()
|
||||
return db, err
|
||||
}
|
||||
|
||||
func (this *IPListSQLite) init() error {
|
||||
// 检查目录是否存在
|
||||
_, err := os.Stat(this.dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(this.dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
|
||||
}
|
||||
|
||||
var path = this.dir + "/ip_list.db"
|
||||
|
||||
db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
//_, err = db.Exec("VACUUM")
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
|
||||
this.db = db
|
||||
|
||||
// 恢复数据库
|
||||
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
|
||||
if len(recoverEnv) > 0 {
|
||||
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
|
||||
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"listId" integer DEFAULT 0,
|
||||
"listType" varchar(32),
|
||||
"isGlobal" integer(1) DEFAULT 0,
|
||||
"type" varchar(16),
|
||||
"itemId" integer DEFAULT 0,
|
||||
"ipFrom" varchar(64) DEFAULT 0,
|
||||
"ipTo" varchar(64) DEFAULT 0,
|
||||
"expiredAt" integer DEFAULT 0,
|
||||
"eventLevel" varchar(32),
|
||||
"isDeleted" integer(1) DEFAULT 0,
|
||||
"version" integer DEFAULT 0,
|
||||
"nodeId" integer DEFAULT 0,
|
||||
"serverId" integer DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"itemId" ASC
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"expiredAt" ASC
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"version" integer DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化SQL语句
|
||||
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.db = db
|
||||
|
||||
goman.New(func() {
|
||||
events.OnClose(func() {
|
||||
_ = this.Close()
|
||||
this.cleanTicker.Stop()
|
||||
})
|
||||
|
||||
for range this.cleanTicker.C {
|
||||
err := this.DeleteExpiredItems()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name 数据库名称代号
|
||||
func (this *IPListSQLite) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
// DeleteExpiredItems 删除过期的条目
|
||||
func (this *IPListSQLite) DeleteExpiredItems() error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListSQLite) AddItem(item *pb.IPItem) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.deleteItemStmt.Exec(item.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是删除,则不再创建新记录
|
||||
if item.IsDeleted {
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
func (this *IPListSQLite) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error) {
|
||||
if this.isClosed {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := this.selectItemsStmt.Query(offset, size)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
|
||||
var pbItem = &pb.IPItem{}
|
||||
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
items = append(items, pbItem)
|
||||
}
|
||||
|
||||
goNext = int64(len(items)) == size
|
||||
return
|
||||
}
|
||||
|
||||
// ReadMaxVersion 读取当前最大版本号
|
||||
func (this *IPListSQLite) ReadMaxVersion() (int64, error) {
|
||||
if this.isClosed {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// from version table
|
||||
{
|
||||
var row = this.selectVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0, nil
|
||||
}
|
||||
var version int64
|
||||
err := row.Scan(&version)
|
||||
if err == nil {
|
||||
return version, nil
|
||||
}
|
||||
}
|
||||
|
||||
// from items table
|
||||
{
|
||||
var row = this.selectMaxItemVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0, nil
|
||||
}
|
||||
var version int64
|
||||
err := row.Scan(&version)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMaxVersion 修改版本号
|
||||
func (this *IPListSQLite) UpdateMaxVersion(version int64) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.updateVersionStmt.Exec(version)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListSQLite) Close() error {
|
||||
this.isClosed = true
|
||||
|
||||
if this.db != nil {
|
||||
for _, stmt := range []*dbs.Stmt{
|
||||
this.deleteExpiredItemsStmt,
|
||||
this.deleteItemStmt,
|
||||
this.insertItemStmt,
|
||||
this.selectItemsStmt,
|
||||
this.selectMaxItemVersionStmt, // ipItems table
|
||||
|
||||
this.selectVersionStmt, // versions table
|
||||
this.updateVersionStmt,
|
||||
} {
|
||||
if stmt != nil {
|
||||
_ = stmt.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return this.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestIPListDB_AddItem(t *testing.T) {
|
||||
db, err := iplibrary.NewIPListDB()
|
||||
db, err := iplibrary.NewIPListSqlite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func TestIPListDB_AddItem(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPListDB_ReadItems(t *testing.T) {
|
||||
db, err := iplibrary.NewIPListDB()
|
||||
db, err := iplibrary.NewIPListSqlite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -71,15 +71,16 @@ func TestIPListDB_ReadItems(t *testing.T) {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
items, err := db.ReadItems(0, 2)
|
||||
items, goNext, err := db.ReadItems(0, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("goNext:", goNext)
|
||||
logs.PrintAsJSON(items, t)
|
||||
}
|
||||
|
||||
func TestIPListDB_ReadMaxVersion(t *testing.T) {
|
||||
db, err := iplibrary.NewIPListDB()
|
||||
db, err := iplibrary.NewIPListSqlite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -90,7 +91,7 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPListDB_UpdateMaxVersion(t *testing.T) {
|
||||
db, err := iplibrary.NewIPListDB()
|
||||
db, err := iplibrary.NewIPListSqlite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -8,10 +8,13 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -45,7 +48,7 @@ func init() {
|
||||
type IPListManager struct {
|
||||
ticker *time.Ticker
|
||||
|
||||
db *IPListDB
|
||||
db IPListDB
|
||||
|
||||
lastVersion int64
|
||||
fetchPageSize int64
|
||||
@@ -83,7 +86,7 @@ func (this *IPListManager) Start() {
|
||||
case <-this.ticker.C:
|
||||
case <-IPListUpdateNotify:
|
||||
}
|
||||
err := this.loop()
|
||||
err = this.loop()
|
||||
if err != nil {
|
||||
countErrors++
|
||||
|
||||
@@ -110,7 +113,18 @@ func (this *IPListManager) Stop() {
|
||||
|
||||
func (this *IPListManager) init() {
|
||||
// 从数据库中当中读取数据
|
||||
db, err := NewIPListDB()
|
||||
// 检查sqlite文件是否存在,以便决定使用sqlite还是kv
|
||||
var sqlitePath = Tea.Root + "/data/ip_list.db"
|
||||
_, sqliteErr := os.Stat(sqlitePath)
|
||||
|
||||
var db IPListDB
|
||||
var err error
|
||||
if sqliteErr == nil {
|
||||
db, err = NewIPListSqlite()
|
||||
} else {
|
||||
db, err = NewIPListKV()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error())
|
||||
} else {
|
||||
@@ -120,24 +134,30 @@ func (this *IPListManager) init() {
|
||||
_ = db.DeleteExpiredItems()
|
||||
|
||||
// 本地数据库中最大版本号
|
||||
this.lastVersion = db.ReadMaxVersion()
|
||||
this.lastVersion, err = db.ReadMaxVersion()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "find max version failed: "+err.Error())
|
||||
this.lastVersion = 0
|
||||
}
|
||||
remotelogs.Println("IP_LIST_MANAGER", "starting from '"+db.Name()+"' version '"+types.String(this.lastVersion)+"' ...")
|
||||
|
||||
// 从本地数据库中加载
|
||||
var offset int64 = 0
|
||||
var size int64 = 2_000
|
||||
|
||||
var tr = trackers.Begin("IP_LIST_MANAGER:load")
|
||||
defer tr.End()
|
||||
|
||||
for {
|
||||
items, err := db.ReadItems(offset, size)
|
||||
items, goNext, readErr := db.ReadItems(offset, size)
|
||||
var l = len(items)
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+err.Error())
|
||||
if readErr != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+readErr.Error())
|
||||
} else {
|
||||
if l == 0 {
|
||||
if !goNext {
|
||||
break
|
||||
}
|
||||
this.processItems(items, false)
|
||||
if int64(l) < size {
|
||||
break
|
||||
}
|
||||
}
|
||||
offset += int64(l)
|
||||
}
|
||||
@@ -310,9 +330,14 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
|
||||
|
||||
// 调试IP信息
|
||||
func (this *IPListManager) debugItem(item *pb.IPItem) {
|
||||
var ipRange = item.IpFrom
|
||||
if len(item.IpTo) > 0 {
|
||||
ipRange += " - " + item.IpTo
|
||||
}
|
||||
|
||||
if item.IsDeleted {
|
||||
remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+item.IpFrom+"'")
|
||||
remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+ipRange+"'")
|
||||
} else {
|
||||
remotelogs.Debug("IP_ITEM_DEBUG", "add '"+item.IpFrom+"'")
|
||||
remotelogs.Debug("IP_ITEM_DEBUG", "add '"+ipRange+"'")
|
||||
}
|
||||
}
|
||||
|
||||
28
internal/utils/byte/utils.go
Normal file
28
internal/utils/byte/utils.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package byteutils
|
||||
|
||||
// Copy bytes
|
||||
func Copy(b []byte) []byte {
|
||||
var l = len(b)
|
||||
if l == 0 {
|
||||
return []byte{}
|
||||
}
|
||||
var d = make([]byte, l)
|
||||
copy(d, b)
|
||||
return d
|
||||
}
|
||||
|
||||
// Append bytes
|
||||
func Append(b []byte, b2 ...byte) []byte {
|
||||
return append(Copy(b), b2...)
|
||||
}
|
||||
|
||||
// Contact bytes
|
||||
func Contact(b []byte, b2 ...[]byte) []byte {
|
||||
b = Copy(b)
|
||||
for _, b3 := range b2 {
|
||||
b = append(b, b3...)
|
||||
}
|
||||
return b
|
||||
}
|
||||
56
internal/utils/byte/utils_test.go
Normal file
56
internal/utils/byte/utils_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package byteutils_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
var prefix []byte
|
||||
prefix = append(prefix, 1, 2, 3)
|
||||
t.Log(prefix, byteutils.Copy(prefix))
|
||||
a.IsTrue(bytes.Equal(byteutils.Copy(prefix), []byte{1, 2, 3}))
|
||||
}
|
||||
|
||||
func TestAppend(t *testing.T) {
|
||||
var as = assert.NewAssertion(t)
|
||||
|
||||
var prefix []byte
|
||||
prefix = append(prefix, 1, 2, 3)
|
||||
|
||||
// [1 2 3 4 5 6] [1 2 3 7]
|
||||
var a = byteutils.Append(prefix, 4, 5, 6)
|
||||
var b = byteutils.Append(prefix, 7)
|
||||
t.Log(a, b)
|
||||
|
||||
as.IsTrue(bytes.Equal(a, []byte{1, 2, 3, 4, 5, 6}))
|
||||
as.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 7}))
|
||||
}
|
||||
|
||||
func TestConcat(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
var prefix []byte
|
||||
prefix = append(prefix, 1, 2, 3)
|
||||
|
||||
var b = byteutils.Contact(prefix, []byte{4, 5, 6}, []byte{7})
|
||||
t.Log(b)
|
||||
|
||||
a.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 4, 5, 6, 7}))
|
||||
}
|
||||
|
||||
func TestAppend_Raw(t *testing.T) {
|
||||
var prefix []byte
|
||||
prefix = append(prefix, 1, 2, 3)
|
||||
|
||||
// [1 2 3 7 5 6] [1 2 3 7]
|
||||
var a = append(prefix, 4, 5, 6)
|
||||
var b = append(prefix, 7)
|
||||
t.Log(a, b)
|
||||
}
|
||||
@@ -9,10 +9,16 @@ import (
|
||||
|
||||
var ErrTableNotFound = errors.New("table not found")
|
||||
var ErrKeyTooLong = errors.New("too long key")
|
||||
var ErrSkip= errors.New("skip") // skip count in iterator
|
||||
|
||||
func IsKeyNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, pebble.ErrNotFound)
|
||||
func IsNotFound(err error) bool {
|
||||
return err != nil && errors.Is(err, pebble.ErrNotFound)
|
||||
}
|
||||
|
||||
func IsSkipError(err error) bool {
|
||||
return err != nil && errors.Is(err, ErrSkip)
|
||||
}
|
||||
|
||||
func Skip() (bool, error) {
|
||||
return true, ErrSkip
|
||||
}
|
||||
|
||||
@@ -7,3 +7,7 @@ import "github.com/cockroachdb/pebble"
|
||||
var DefaultWriteOptions = &pebble.WriteOptions{
|
||||
Sync: false,
|
||||
}
|
||||
|
||||
var DefaultWriteSyncOptions = &pebble.WriteOptions{
|
||||
Sync: true,
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
|
||||
)
|
||||
|
||||
type DataType = int
|
||||
@@ -222,11 +223,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
|
||||
var prefix []byte
|
||||
switch this.dataType {
|
||||
case DataTypeKey:
|
||||
prefix = append(this.table.Namespace(), KeyPrefix...)
|
||||
prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
|
||||
case DataTypeField:
|
||||
prefix = append(this.table.Namespace(), FieldPrefix...)
|
||||
prefix = byteutils.Append(this.table.Namespace(), []byte(FieldPrefix)...)
|
||||
default:
|
||||
prefix = append(this.table.Namespace(), KeyPrefix...)
|
||||
prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
|
||||
}
|
||||
|
||||
var prefixLen = len(prefix)
|
||||
@@ -238,21 +239,21 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
|
||||
var offsetKey []byte
|
||||
if this.reverse {
|
||||
if len(this.offsetKey) > 0 {
|
||||
offsetKey = append(prefix, this.offsetKey...)
|
||||
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
|
||||
} else {
|
||||
offsetKey = append(prefix, 0xFF)
|
||||
offsetKey = byteutils.Append(prefix, 0xFF)
|
||||
}
|
||||
|
||||
opt.LowerBound = prefix
|
||||
opt.UpperBound = offsetKey
|
||||
} else {
|
||||
if len(this.offsetKey) > 0 {
|
||||
offsetKey = append(prefix, this.offsetKey...)
|
||||
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
|
||||
} else {
|
||||
offsetKey = prefix
|
||||
}
|
||||
opt.LowerBound = offsetKey
|
||||
opt.UpperBound = append(offsetKey, 0xFF)
|
||||
opt.UpperBound = byteutils.Append(prefix, 0xFF)
|
||||
}
|
||||
|
||||
var hasOffsetKey = len(this.offsetKey) > 0
|
||||
@@ -267,7 +268,7 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
|
||||
|
||||
var count int
|
||||
|
||||
var itemFn = func() (goNext bool, err error) {
|
||||
var itemFn = func() (goNextItem bool, err error) {
|
||||
var keyBytes = it.Key()
|
||||
|
||||
// skip first offset key
|
||||
@@ -297,7 +298,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
|
||||
Value: value,
|
||||
})
|
||||
if callbackErr != nil {
|
||||
return false, callbackErr
|
||||
if IsSkipError(callbackErr) {
|
||||
return true, nil
|
||||
} else {
|
||||
return false, callbackErr
|
||||
}
|
||||
}
|
||||
if !goNext {
|
||||
return false, nil
|
||||
@@ -361,9 +366,9 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
|
||||
if len(this.fieldOffsetKey) > 0 {
|
||||
offsetKey = this.fieldOffsetKey
|
||||
} else if len(this.offsetKey) > 0 {
|
||||
offsetKey = append(prefix, this.offsetKey...)
|
||||
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
|
||||
} else {
|
||||
offsetKey = append(prefix, 0xFF)
|
||||
offsetKey = byteutils.Append(prefix, 0xFF)
|
||||
}
|
||||
opt.LowerBound = prefix
|
||||
opt.UpperBound = offsetKey
|
||||
@@ -371,14 +376,14 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
|
||||
if len(this.fieldOffsetKey) > 0 {
|
||||
offsetKey = this.fieldOffsetKey
|
||||
} else if len(this.offsetKey) > 0 {
|
||||
offsetKey = append(prefix, this.offsetKey...)
|
||||
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
|
||||
offsetKey = append(offsetKey, 0)
|
||||
} else {
|
||||
offsetKey = prefix
|
||||
}
|
||||
|
||||
opt.LowerBound = offsetKey
|
||||
opt.UpperBound = append(prefix, 0xFF)
|
||||
opt.UpperBound = byteutils.Append(prefix, 0xFF)
|
||||
}
|
||||
|
||||
it, itErr := this.tx.NewIterator(opt)
|
||||
@@ -391,7 +396,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
|
||||
|
||||
var count int
|
||||
|
||||
var itemFn = func() (goNext bool, err error) {
|
||||
var itemFn = func() (goNextItem bool, err error) {
|
||||
var fieldKeyBytes = it.Key()
|
||||
|
||||
fieldValueBytes, keyBytes, decodeKeyErr := this.table.DecodeFieldKey(this.fieldName, fieldKeyBytes)
|
||||
@@ -423,7 +428,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
|
||||
if !this.keysOnly {
|
||||
value, getErr := this.table.getWithKeyBytes(this.tx, this.table.FullKeyBytes(keyBytes))
|
||||
if getErr != nil {
|
||||
if IsKeyNotFound(getErr) {
|
||||
if IsNotFound(getErr) {
|
||||
return true, nil
|
||||
}
|
||||
return false, getErr
|
||||
@@ -432,11 +437,15 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
|
||||
resultItem.Value = value
|
||||
}
|
||||
|
||||
goNext, err = fn(this.tx, resultItem)
|
||||
goNextItem, err = fn(this.tx, resultItem)
|
||||
if err != nil {
|
||||
return
|
||||
if IsSkipError(err) {
|
||||
return true, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
if !goNext {
|
||||
if !goNextItem {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -138,6 +138,26 @@ func TestQuery_FindAll_Offset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery_FindAll_Skip(t *testing.T) {
|
||||
var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})
|
||||
|
||||
{
|
||||
err := table.Query().
|
||||
Offset("a3").
|
||||
Limit(10).
|
||||
FindAll(func(tx *kvstore.Tx[*testCachedItem], item kvstore.Item[*testCachedItem]) (goNext bool, err error) {
|
||||
if item.Key == "a30" || item.Key == "a3000005" {
|
||||
return kvstore.Skip()
|
||||
}
|
||||
t.Log("key:", item.Key, "value:", item.Value)
|
||||
return true, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery_FindAll_Count(t *testing.T) {
|
||||
var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
memutils "github.com/TeaOSLab/EdgeNode/internal/utils/mem"
|
||||
"github.com/cockroachdb/pebble"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
@@ -85,6 +86,31 @@ func OpenStoreDir(dir string, storeName string) (*Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
var storeOnce = &sync.Once{}
|
||||
var defaultSore *Store
|
||||
|
||||
func DefaultStore() (*Store, error) {
|
||||
if defaultSore != nil {
|
||||
return defaultSore, nil
|
||||
}
|
||||
|
||||
storeOnce.Do(func() {
|
||||
store, err := NewStore("default")
|
||||
if err != nil {
|
||||
remotelogs.Error("KV", "create default store failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
err = store.Open()
|
||||
if err != nil {
|
||||
remotelogs.Error("KV", "open default store failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
defaultSore = store
|
||||
})
|
||||
|
||||
return defaultSore, nil
|
||||
}
|
||||
|
||||
func (this *Store) Open() error {
|
||||
var opt = &pebble.Options{
|
||||
Logger: NewLogger(),
|
||||
@@ -144,6 +170,10 @@ func (this *Store) RawDB() *pebble.DB {
|
||||
return this.rawDB
|
||||
}
|
||||
|
||||
func (this *Store) Flush() error {
|
||||
return this.rawDB.Flush()
|
||||
}
|
||||
|
||||
func (this *Store) Close() error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
|
||||
"github.com/cockroachdb/pebble"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -19,6 +21,44 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Default(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
store, err := kvstore.DefaultStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.IsTrue(store != nil)
|
||||
}
|
||||
|
||||
func TestStore_Default_Concurrent(t *testing.T) {
|
||||
var lastStore *kvstore.Store
|
||||
|
||||
const threads = 32
|
||||
|
||||
var wg = &sync.WaitGroup{}
|
||||
wg.Add(threads)
|
||||
for i := 0; i < threads; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
store, err := kvstore.DefaultStore()
|
||||
if err != nil {
|
||||
t.Log("ERROR", err)
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
if lastStore != nil && lastStore != store {
|
||||
t.Log("ERROR", "should be single instance")
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
lastStore = store
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestStore_Open(t *testing.T) {
|
||||
store, err := kvstore.OpenStore("test")
|
||||
if err != nil {
|
||||
@@ -29,6 +69,7 @@ func TestStore_Open(t *testing.T) {
|
||||
}()
|
||||
|
||||
t.Log("opened")
|
||||
_ = store
|
||||
}
|
||||
|
||||
func TestStore_RawDB(t *testing.T) {
|
||||
|
||||
@@ -64,6 +64,10 @@ func (this *Table[T]) DB() *DB {
|
||||
return this.db
|
||||
}
|
||||
|
||||
func (this *Table[T]) Encoder() ValueEncoder[T] {
|
||||
return this.encoder
|
||||
}
|
||||
|
||||
func (this *Table[T]) Set(key string, value T) error {
|
||||
if len(key) > KeyMaxLength {
|
||||
return ErrKeyTooLong
|
||||
@@ -75,7 +79,22 @@ func (this *Table[T]) Set(key string, value T) error {
|
||||
}
|
||||
|
||||
return this.WriteTx(func(tx *Tx[T]) error {
|
||||
return this.set(tx, key, valueBytes, value, false)
|
||||
return this.set(tx, key, valueBytes, value, false, false)
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Table[T]) SetSync(key string, value T) error {
|
||||
if len(key) > KeyMaxLength {
|
||||
return ErrKeyTooLong
|
||||
}
|
||||
|
||||
valueBytes, err := this.encoder.Encode(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.WriteTxSync(func(tx *Tx[T]) error {
|
||||
return this.set(tx, key, valueBytes, value, false, true)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -90,7 +109,7 @@ func (this *Table[T]) Insert(key string, value T) error {
|
||||
}
|
||||
|
||||
return this.WriteTx(func(tx *Tx[T]) error {
|
||||
return this.set(tx, key, valueBytes, value, true)
|
||||
return this.set(tx, key, valueBytes, value, true, false)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -111,7 +130,7 @@ func (this *Table[T]) ComposeFieldKey(keyBytes []byte, fieldName string, fieldVa
|
||||
func (this *Table[T]) Exist(key string) (found bool, err error) {
|
||||
_, closer, err := this.db.store.rawDB.Get(this.FullKey(key))
|
||||
if err != nil {
|
||||
if IsKeyNotFound(err) {
|
||||
if IsNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
@@ -173,6 +192,20 @@ func (this *Table[T]) WriteTx(fn func(tx *Tx[T]) error) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (this *Table[T]) WriteTxSync(fn func(tx *Tx[T]) error) error {
|
||||
var tx = NewTx[T](this, false)
|
||||
defer func() {
|
||||
_ = tx.Close()
|
||||
}()
|
||||
|
||||
err := fn(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.CommitSync()
|
||||
}
|
||||
|
||||
func (this *Table[T]) Truncate() error {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
@@ -256,7 +289,7 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
|
||||
if len(this.fieldNames) > 0 {
|
||||
valueBytes, closer, getErr := batch.Get(keyBytes)
|
||||
if getErr != nil {
|
||||
if IsKeyNotFound(getErr) {
|
||||
if IsNotFound(getErr) {
|
||||
return nil
|
||||
}
|
||||
return getErr
|
||||
@@ -298,8 +331,12 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool) error {
|
||||
func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool, syncMode bool) error {
|
||||
var keyBytes = this.FullKey(key)
|
||||
var writeOptions = DefaultWriteOptions
|
||||
if syncMode {
|
||||
writeOptions = DefaultWriteSyncOptions
|
||||
}
|
||||
|
||||
var batch = tx.batch
|
||||
|
||||
@@ -312,7 +349,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
|
||||
if countFields > 0 {
|
||||
oldValueBytes, closer, getErr := batch.Get(keyBytes)
|
||||
if getErr != nil {
|
||||
if !IsKeyNotFound(getErr) {
|
||||
if !IsNotFound(getErr) {
|
||||
return getErr
|
||||
}
|
||||
} else {
|
||||
@@ -330,7 +367,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
|
||||
}
|
||||
}
|
||||
|
||||
setErr := batch.Set(keyBytes, valueBytes, DefaultWriteOptions)
|
||||
setErr := batch.Set(keyBytes, valueBytes, writeOptions)
|
||||
if setErr != nil {
|
||||
return setErr
|
||||
}
|
||||
@@ -362,14 +399,14 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
|
||||
// skip the field
|
||||
continue
|
||||
}
|
||||
deleteFieldErr := batch.Delete(oldFieldKeyBytes, DefaultWriteOptions)
|
||||
deleteFieldErr := batch.Delete(oldFieldKeyBytes, writeOptions)
|
||||
if deleteFieldErr != nil {
|
||||
return deleteFieldErr
|
||||
}
|
||||
}
|
||||
|
||||
// set new field key
|
||||
setFieldErr := batch.Set(newFieldKeyBytes, nil, DefaultWriteOptions)
|
||||
setFieldErr := batch.Set(newFieldKeyBytes, nil, writeOptions)
|
||||
if setFieldErr != nil {
|
||||
return setFieldErr
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func (this *CounterTable[T]) Increase(key string, delta T) (newValue T, err erro
|
||||
err = this.Table.WriteTx(func(tx *Tx[T]) error {
|
||||
value, getErr := tx.Get(key)
|
||||
if getErr != nil {
|
||||
if !IsKeyNotFound(getErr) {
|
||||
if !IsNotFound(getErr) {
|
||||
return getErr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestTable_Set(t *testing.T) {
|
||||
|
||||
value, err := table.Get("a")
|
||||
if err != nil {
|
||||
if kvstore.IsKeyNotFound(err) {
|
||||
if kvstore.IsNotFound(err) {
|
||||
t.Log("not found key")
|
||||
return
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func TestTable_Get(t *testing.T) {
|
||||
for _, key := range []string{"a", "b", "c"} {
|
||||
value, getErr := table.Get(key)
|
||||
if getErr != nil {
|
||||
if kvstore.IsKeyNotFound(getErr) {
|
||||
if kvstore.IsNotFound(getErr) {
|
||||
t.Log("not found key", key)
|
||||
continue
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func TestTable_Delete(t *testing.T) {
|
||||
|
||||
value, err := table.Get("a123")
|
||||
if err != nil {
|
||||
if !kvstore.IsKeyNotFound(err) {
|
||||
if !kvstore.IsNotFound(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
@@ -173,7 +173,7 @@ func TestTable_Delete(t *testing.T) {
|
||||
|
||||
{
|
||||
_, err = table.Get("a123")
|
||||
a.IsTrue(kvstore.IsKeyNotFound(err))
|
||||
a.IsTrue(kvstore.IsNotFound(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,7 +357,7 @@ func BenchmarkTable_Get(b *testing.B) {
|
||||
for pb.Next() {
|
||||
_, putErr := table.Get(types.String(rand.Int()))
|
||||
if putErr != nil {
|
||||
if kvstore.IsKeyNotFound(putErr) {
|
||||
if kvstore.IsNotFound(putErr) {
|
||||
continue
|
||||
}
|
||||
b.Fatal(putErr)
|
||||
|
||||
@@ -37,7 +37,24 @@ func (this *Tx[T]) Set(key string, value T) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.table.set(this, key, valueBytes, value, false)
|
||||
return this.table.set(this, key, valueBytes, value, false, false)
|
||||
}
|
||||
|
||||
func (this *Tx[T]) SetSync(key string, value T) error {
|
||||
if this.readOnly {
|
||||
return errors.New("can not set value in readonly transaction")
|
||||
}
|
||||
|
||||
if len(key) > KeyMaxLength {
|
||||
return ErrKeyTooLong
|
||||
}
|
||||
|
||||
valueBytes, err := this.table.encoder.Encode(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.table.set(this, key, valueBytes, value, false, true)
|
||||
}
|
||||
|
||||
func (this *Tx[T]) Insert(key string, value T) error {
|
||||
@@ -54,7 +71,7 @@ func (this *Tx[T]) Insert(key string, value T) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.table.set(this, key, valueBytes, value, true)
|
||||
return this.table.set(this, key, valueBytes, value, true, false)
|
||||
}
|
||||
|
||||
func (this *Tx[T]) Get(key string) (value T, err error) {
|
||||
@@ -78,6 +95,20 @@ func (this *Tx[T]) Close() error {
|
||||
}
|
||||
|
||||
func (this *Tx[T]) Commit() (err error) {
|
||||
return this.commit(DefaultWriteOptions)
|
||||
}
|
||||
|
||||
func (this *Tx[T]) CommitSync() (err error) {
|
||||
return this.commit(DefaultWriteSyncOptions)
|
||||
}
|
||||
|
||||
func (this *Tx[T]) Query() *Query[T] {
|
||||
var query = NewQuery[T]()
|
||||
query.SetTx(this)
|
||||
return query
|
||||
}
|
||||
|
||||
func (this *Tx[T]) commit(opt *pebble.WriteOptions) (err error) {
|
||||
defer func() {
|
||||
var panicErr = recover()
|
||||
if panicErr != nil {
|
||||
@@ -88,11 +119,5 @@ func (this *Tx[T]) Commit() (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return this.batch.Commit(DefaultWriteOptions)
|
||||
}
|
||||
|
||||
func (this *Tx[T]) Query() *Query[T] {
|
||||
var query = NewQuery[T]()
|
||||
query.SetTx(this)
|
||||
return query
|
||||
return this.batch.Commit(opt)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
|
||||
package testutils
|
||||
|
||||
import "os"
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
)
|
||||
|
||||
// IsSingleTesting 判断当前测试环境是否为单个函数测试
|
||||
func IsSingleTesting() bool {
|
||||
@@ -12,4 +16,9 @@ func IsSingleTesting() bool {
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RandIP 生成一个随机IP用于测试
|
||||
func RandIP() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", rand.Int()%255, rand.Int()%255, rand.Int()%255, rand.Int()%255)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user