mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 07:40:56 +08:00 
			
		
		
		
	使用KV数据库来管理IP名单
This commit is contained in:
		@@ -97,7 +97,7 @@ func (this *KVListFileStore) ExistItem(hash string) (bool, error) {
 | 
			
		||||
 | 
			
		||||
	item, err := this.itemsTable.Get(hash)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if kvstore.IsKeyNotFound(err) {
 | 
			
		||||
		if kvstore.IsNotFound(err) {
 | 
			
		||||
			return false, nil
 | 
			
		||||
		}
 | 
			
		||||
		return false, err
 | 
			
		||||
 
 | 
			
		||||
@@ -1,305 +1,13 @@
 | 
			
		||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
			
		||||
 | 
			
		||||
package iplibrary
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
import "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
 | 
			
		||||
type IPListDB struct {
 | 
			
		||||
	db *dbs.DB
 | 
			
		||||
 | 
			
		||||
	itemTableName    string
 | 
			
		||||
	versionTableName string
 | 
			
		||||
 | 
			
		||||
	deleteExpiredItemsStmt   *dbs.Stmt
 | 
			
		||||
	deleteItemStmt           *dbs.Stmt
 | 
			
		||||
	insertItemStmt           *dbs.Stmt
 | 
			
		||||
	selectItemsStmt          *dbs.Stmt
 | 
			
		||||
	selectMaxItemVersionStmt *dbs.Stmt
 | 
			
		||||
 | 
			
		||||
	selectVersionStmt *dbs.Stmt
 | 
			
		||||
	updateVersionStmt *dbs.Stmt
 | 
			
		||||
 | 
			
		||||
	cleanTicker *time.Ticker
 | 
			
		||||
 | 
			
		||||
	dir string
 | 
			
		||||
 | 
			
		||||
	isClosed bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewIPListDB() (*IPListDB, error) {
 | 
			
		||||
	var db = &IPListDB{
 | 
			
		||||
		itemTableName:    "ipItems",
 | 
			
		||||
		versionTableName: "versions",
 | 
			
		||||
		dir:              filepath.Clean(Tea.Root + "/data"),
 | 
			
		||||
		cleanTicker:      time.NewTicker(24 * time.Hour),
 | 
			
		||||
	}
 | 
			
		||||
	err := db.init()
 | 
			
		||||
	return db, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListDB) init() error {
 | 
			
		||||
	// 检查目录是否存在
 | 
			
		||||
	_, err := os.Stat(this.dir)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = os.MkdirAll(this.dir, 0777)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var path = this.dir + "/ip_list.db"
 | 
			
		||||
 | 
			
		||||
	db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	db.SetMaxOpenConns(1)
 | 
			
		||||
 | 
			
		||||
	//_, err = db.Exec("VACUUM")
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	return err
 | 
			
		||||
	//}
 | 
			
		||||
 | 
			
		||||
	this.db = db
 | 
			
		||||
 | 
			
		||||
	// 恢复数据库
 | 
			
		||||
	var recoverEnv, _ = os.LookupEnv("EdgeRecover")
 | 
			
		||||
	if len(recoverEnv) > 0 {
 | 
			
		||||
		for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
 | 
			
		||||
			_, _ = db.Exec(`REINDEX "` + indexName + `"`)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 初始化数据库
 | 
			
		||||
	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
 | 
			
		||||
  "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  "listId" integer DEFAULT 0,
 | 
			
		||||
  "listType" varchar(32),
 | 
			
		||||
  "isGlobal" integer(1) DEFAULT 0,
 | 
			
		||||
  "type" varchar(16),
 | 
			
		||||
  "itemId" integer DEFAULT 0,
 | 
			
		||||
  "ipFrom" varchar(64) DEFAULT 0,
 | 
			
		||||
  "ipTo" varchar(64) DEFAULT 0,
 | 
			
		||||
  "expiredAt" integer DEFAULT 0,
 | 
			
		||||
  "eventLevel" varchar(32),
 | 
			
		||||
  "isDeleted" integer(1) DEFAULT 0,
 | 
			
		||||
  "version" integer DEFAULT 0,
 | 
			
		||||
  "nodeId" integer DEFAULT 0,
 | 
			
		||||
  "serverId" integer DEFAULT 0
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
 | 
			
		||||
ON "` + this.itemTableName + `" (
 | 
			
		||||
  "itemId" ASC
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
 | 
			
		||||
ON "` + this.itemTableName + `" (
 | 
			
		||||
  "expiredAt" ASC
 | 
			
		||||
);
 | 
			
		||||
`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
 | 
			
		||||
  "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  "version" integer DEFAULT 0
 | 
			
		||||
);
 | 
			
		||||
`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 初始化SQL语句
 | 
			
		||||
	this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE  "expiredAt">0 AND "expiredAt"<?`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.db = db
 | 
			
		||||
 | 
			
		||||
	goman.New(func() {
 | 
			
		||||
		events.OnClose(func() {
 | 
			
		||||
			_ = this.Close()
 | 
			
		||||
			this.cleanTicker.Stop()
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		for range this.cleanTicker.C {
 | 
			
		||||
			err := this.DeleteExpiredItems()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteExpiredItems 删除过期的条目
 | 
			
		||||
func (this *IPListDB) DeleteExpiredItems() error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListDB) AddItem(item *pb.IPItem) error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := this.deleteItemStmt.Exec(item.Id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果是删除,则不再创建新记录
 | 
			
		||||
	if item.IsDeleted {
 | 
			
		||||
		return this.UpdateMaxVersion(item.Version)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.UpdateMaxVersion(item.Version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := this.selectItemsStmt.Query(offset, size)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = rows.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		//  "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
 | 
			
		||||
		var pbItem = &pb.IPItem{}
 | 
			
		||||
		err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		items = append(items, pbItem)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadMaxVersion 读取当前最大版本号
 | 
			
		||||
func (this *IPListDB) ReadMaxVersion() int64 {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// from version table
 | 
			
		||||
	{
 | 
			
		||||
		var row = this.selectVersionStmt.QueryRow()
 | 
			
		||||
		if row == nil {
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
		var version int64
 | 
			
		||||
		err := row.Scan(&version)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return version
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// from items table
 | 
			
		||||
	{
 | 
			
		||||
		var row = this.selectMaxItemVersionStmt.QueryRow()
 | 
			
		||||
		if row == nil {
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
		var version int64
 | 
			
		||||
		err := row.Scan(&version)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return version
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateMaxVersion 修改版本号
 | 
			
		||||
func (this *IPListDB) UpdateMaxVersion(version int64) error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := this.updateVersionStmt.Exec(version)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListDB) Close() error {
 | 
			
		||||
	this.isClosed = true
 | 
			
		||||
 | 
			
		||||
	if this.db != nil {
 | 
			
		||||
		for _, stmt := range []*dbs.Stmt{
 | 
			
		||||
			this.deleteExpiredItemsStmt,
 | 
			
		||||
			this.deleteItemStmt,
 | 
			
		||||
			this.insertItemStmt,
 | 
			
		||||
			this.selectItemsStmt,
 | 
			
		||||
			this.selectMaxItemVersionStmt, // ipItems table
 | 
			
		||||
 | 
			
		||||
			this.selectVersionStmt, // versions table
 | 
			
		||||
			this.updateVersionStmt,
 | 
			
		||||
		} {
 | 
			
		||||
			if stmt != nil {
 | 
			
		||||
				_ = stmt.Close()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return this.db.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
type IPListDB interface {
 | 
			
		||||
	Name() string
 | 
			
		||||
	DeleteExpiredItems() error
 | 
			
		||||
	ReadMaxVersion() (int64, error)
 | 
			
		||||
	ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error)
 | 
			
		||||
	AddItem(item *pb.IPItem) error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										229
									
								
								internal/iplibrary/ip_list_kv.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										229
									
								
								internal/iplibrary/ip_list_kv.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,229 @@
 | 
			
		||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
			
		||||
 | 
			
		||||
package iplibrary
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type IPListKV struct {
 | 
			
		||||
	ipTable       *kvstore.Table[*pb.IPItem]
 | 
			
		||||
	versionsTable *kvstore.Table[int64]
 | 
			
		||||
 | 
			
		||||
	encoder *IPItemEncoder[*pb.IPItem]
 | 
			
		||||
 | 
			
		||||
	cleanTicker *time.Ticker
 | 
			
		||||
 | 
			
		||||
	isClosed bool
 | 
			
		||||
 | 
			
		||||
	offsetItemKey string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewIPListKV() (*IPListKV, error) {
 | 
			
		||||
	var db = &IPListKV{
 | 
			
		||||
		cleanTicker: time.NewTicker(24 * time.Hour),
 | 
			
		||||
		encoder:     &IPItemEncoder[*pb.IPItem]{},
 | 
			
		||||
	}
 | 
			
		||||
	err := db.init()
 | 
			
		||||
	return db, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListKV) init() error {
 | 
			
		||||
	store, storeErr := kvstore.DefaultStore()
 | 
			
		||||
	if storeErr != nil {
 | 
			
		||||
		return storeErr
 | 
			
		||||
	}
 | 
			
		||||
	db, dbErr := store.NewDB("ip_list")
 | 
			
		||||
	if dbErr != nil {
 | 
			
		||||
		return dbErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		table, err := kvstore.NewTable[*pb.IPItem]("ip_items", this.encoder)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		this.ipTable = table
 | 
			
		||||
 | 
			
		||||
		err = table.AddFields("expiresAt")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		db.AddTable(table)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		table, err := kvstore.NewTable[int64]("versions", kvstore.NewIntValueEncoder[int64]())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		this.versionsTable = table
 | 
			
		||||
		db.AddTable(table)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	goman.New(func() {
 | 
			
		||||
		events.OnClose(func() {
 | 
			
		||||
			_ = this.Close()
 | 
			
		||||
			this.cleanTicker.Stop()
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		for range this.cleanTicker.C {
 | 
			
		||||
			deleteErr := this.DeleteExpiredItems()
 | 
			
		||||
			if deleteErr != nil {
 | 
			
		||||
				remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+deleteErr.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Name 数据库名称代号
 | 
			
		||||
func (this *IPListKV) Name() string {
 | 
			
		||||
	return "kvstore"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteExpiredItems 删除过期的条目
 | 
			
		||||
func (this *IPListKV) DeleteExpiredItems() error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		var found bool
 | 
			
		||||
		var currentTime = fasttime.Now().Unix()
 | 
			
		||||
		err := this.ipTable.
 | 
			
		||||
			Query().
 | 
			
		||||
			FieldAsc("expiresAt").
 | 
			
		||||
			ForUpdate().
 | 
			
		||||
			Limit(1000).
 | 
			
		||||
			FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
 | 
			
		||||
				if !item.Value.IsDeleted && item.Value.ExpiredAt == 0 { // never expires
 | 
			
		||||
					return kvstore.Skip()
 | 
			
		||||
				}
 | 
			
		||||
				if item.Value.ExpiredAt < currentTime-7*86400 /** keep for 7 days **/ {
 | 
			
		||||
					err = tx.Delete(item.Key)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return false, err
 | 
			
		||||
					}
 | 
			
		||||
					found = true
 | 
			
		||||
					return true, nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				found = false
 | 
			
		||||
				return false, nil
 | 
			
		||||
			})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if !found {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListKV) AddItem(item *pb.IPItem) error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 先删除
 | 
			
		||||
	var key = this.encoder.EncodeKey(item)
 | 
			
		||||
	err := this.ipTable.Delete(key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果是删除,则不再创建新记录
 | 
			
		||||
	if item.IsDeleted {
 | 
			
		||||
		return this.UpdateMaxVersion(item.Version)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = this.ipTable.Set(key, item)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.UpdateMaxVersion(item.Version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListKV) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNextLoop bool, err error) {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = this.ipTable.
 | 
			
		||||
		Query().
 | 
			
		||||
		Offset(this.offsetItemKey).
 | 
			
		||||
		Limit(int(size)).
 | 
			
		||||
		FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
 | 
			
		||||
			this.offsetItemKey = item.Key
 | 
			
		||||
			goNextLoop = true
 | 
			
		||||
 | 
			
		||||
			if !item.Value.IsDeleted {
 | 
			
		||||
				items = append(items, item.Value)
 | 
			
		||||
			}
 | 
			
		||||
			return true, nil
 | 
			
		||||
		})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadMaxVersion 读取当前最大版本号
 | 
			
		||||
func (this *IPListKV) ReadMaxVersion() (int64, error) {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return 0, errors.New("database has been closed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	version, err := this.versionsTable.Get("version")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if kvstore.IsNotFound(err) {
 | 
			
		||||
			return 0, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	return version, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateMaxVersion 修改版本号
 | 
			
		||||
func (this *IPListKV) UpdateMaxVersion(version int64) error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.versionsTable.SetSync("version", version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListKV) TestInspect(t *testing.T) error {
 | 
			
		||||
	return this.ipTable.
 | 
			
		||||
		Query().
 | 
			
		||||
		FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
 | 
			
		||||
			if len(item.Key) != 8 {
 | 
			
		||||
				return false, errors.New("invalid key '" + item.Key + "'")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			t.Log(binary.BigEndian.Uint64([]byte(item.Key)), "=>", item.Value)
 | 
			
		||||
			return true, nil
 | 
			
		||||
		})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Flush to disk
 | 
			
		||||
func (this *IPListKV) Flush() error {
 | 
			
		||||
	return this.ipTable.DB().Store().Flush()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListKV) Close() error {
 | 
			
		||||
	this.isClosed = true
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										55
									
								
								internal/iplibrary/ip_list_kv_objects.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								internal/iplibrary/ip_list_kv_objects.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
			
		||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
			
		||||
 | 
			
		||||
package iplibrary
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"google.golang.org/protobuf/proto"
 | 
			
		||||
	"math"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type IPItemEncoder[T interface{ *pb.IPItem }] struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewIPItemEncoder[T interface{ *pb.IPItem }]() *IPItemEncoder[T] {
 | 
			
		||||
	return &IPItemEncoder[T]{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPItemEncoder[T]) Encode(value T) ([]byte, error) {
 | 
			
		||||
	return proto.Marshal(any(value).(*pb.IPItem))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPItemEncoder[T]) EncodeField(value T, fieldName string) ([]byte, error) {
 | 
			
		||||
	switch fieldName {
 | 
			
		||||
	case "expiresAt":
 | 
			
		||||
		var expiresAt = any(value).(*pb.IPItem).ExpiredAt
 | 
			
		||||
		if expiresAt < 0 || expiresAt > int64(math.MaxUint32) {
 | 
			
		||||
			expiresAt = 0
 | 
			
		||||
		}
 | 
			
		||||
		var b = make([]byte, 4)
 | 
			
		||||
		binary.BigEndian.PutUint32(b, uint32(expiresAt))
 | 
			
		||||
		return b, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, errors.New("field '" + fieldName + "' not found")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPItemEncoder[T]) Decode(valueBytes []byte) (value T, err error) {
 | 
			
		||||
	var item = &pb.IPItem{}
 | 
			
		||||
	err = proto.Unmarshal(valueBytes, item)
 | 
			
		||||
	value = item
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// EncodeKey generate key for ip item
 | 
			
		||||
func (this *IPItemEncoder[T]) EncodeKey(item *pb.IPItem) string {
 | 
			
		||||
	var b = make([]byte, 8)
 | 
			
		||||
	if item.Id < 0 {
 | 
			
		||||
		item.Id = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	binary.BigEndian.PutUint64(b, uint64(item.Id))
 | 
			
		||||
	return string(b)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										221
									
								
								internal/iplibrary/ip_list_kv_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								internal/iplibrary/ip_list_kv_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,221 @@
 | 
			
		||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
			
		||||
 | 
			
		||||
package iplibrary_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/zero"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_AddItem(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = kv.Flush()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		err = kv.AddItem(&pb.IPItem{
 | 
			
		||||
			Id:        1,
 | 
			
		||||
			IpFrom:    "192.168.1.101",
 | 
			
		||||
			IpTo:      "",
 | 
			
		||||
			Version:   1,
 | 
			
		||||
			ExpiredAt: fasttime.NewFastTime().Unix() + 60,
 | 
			
		||||
			ListId:    1,
 | 
			
		||||
			IsDeleted: false,
 | 
			
		||||
			ListType:  "white",
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		err = kv.AddItem(&pb.IPItem{
 | 
			
		||||
			Id:        2,
 | 
			
		||||
			IpFrom:    "192.168.1.102",
 | 
			
		||||
			IpTo:      "",
 | 
			
		||||
			Version:   2,
 | 
			
		||||
			ExpiredAt: fasttime.NewFastTime().Unix() + 60,
 | 
			
		||||
			ListId:    1,
 | 
			
		||||
			IsDeleted: false,
 | 
			
		||||
			ListType:  "white",
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		err = kv.AddItem(&pb.IPItem{
 | 
			
		||||
			Id:        3,
 | 
			
		||||
			IpFrom:    "192.168.1.103",
 | 
			
		||||
			IpTo:      "",
 | 
			
		||||
			Version:   3,
 | 
			
		||||
			ExpiredAt: fasttime.NewFastTime().Unix() + 60,
 | 
			
		||||
			ListId:    1,
 | 
			
		||||
			IsDeleted: false,
 | 
			
		||||
			ListType:  "white",
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_AddItems_Many(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = kv.Flush()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var count = 2
 | 
			
		||||
	var from = 1
 | 
			
		||||
	if testutils.IsSingleTesting() {
 | 
			
		||||
		count = 2_000_000
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var before = time.Now()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		t.Logf("cost: %.2f s", time.Since(before).Seconds())
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for i := from; i <= from+count; i++ {
 | 
			
		||||
		err = kv.AddItem(&pb.IPItem{
 | 
			
		||||
			Id:        int64(i),
 | 
			
		||||
			IpFrom:    testutils.RandIP(),
 | 
			
		||||
			IpTo:      "",
 | 
			
		||||
			Version:   int64(i),
 | 
			
		||||
			ExpiredAt: fasttime.NewFastTime().Unix() + 86400,
 | 
			
		||||
			ListId:    1,
 | 
			
		||||
			IsDeleted: false,
 | 
			
		||||
			ListType:  "white",
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_DeleteExpiredItems(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = kv.Flush()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	err = kv.DeleteExpiredItems()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_UpdateMaxVersion(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = kv.Flush()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	err = kv.UpdateMaxVersion(101)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	maxVersion, err := kv.ReadMaxVersion()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Log("version:", maxVersion)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_ReadMaxVersion(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	maxVersion, err := kv.ReadMaxVersion()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Log("version:", maxVersion)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_ReadItems(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		items, goNext, readErr := kv.ReadItems(0, 2)
 | 
			
		||||
		if readErr != nil {
 | 
			
		||||
			t.Fatal(readErr)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log("====")
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			t.Log(item.Id)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !goNext {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_CountItems(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var count int
 | 
			
		||||
	var m = map[int64]zero.Zero{}
 | 
			
		||||
	for {
 | 
			
		||||
		items, goNext, readErr := kv.ReadItems(0, 1000)
 | 
			
		||||
		if readErr != nil {
 | 
			
		||||
			t.Fatal(readErr)
 | 
			
		||||
		}
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			count++
 | 
			
		||||
			m[item.Id] = zero.Zero{}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !goNext {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("count:", count, "len:", len(m))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListKV_Inspect(t *testing.T) {
 | 
			
		||||
	kv, err := iplibrary.NewIPListKV()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	err = kv.TestInspect(t)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										312
									
								
								internal/iplibrary/ip_list_sqlite.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										312
									
								
								internal/iplibrary/ip_list_sqlite.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,312 @@
 | 
			
		||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package iplibrary
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type IPListSQLite struct {
 | 
			
		||||
	db *dbs.DB
 | 
			
		||||
 | 
			
		||||
	itemTableName    string
 | 
			
		||||
	versionTableName string
 | 
			
		||||
 | 
			
		||||
	deleteExpiredItemsStmt   *dbs.Stmt
 | 
			
		||||
	deleteItemStmt           *dbs.Stmt
 | 
			
		||||
	insertItemStmt           *dbs.Stmt
 | 
			
		||||
	selectItemsStmt          *dbs.Stmt
 | 
			
		||||
	selectMaxItemVersionStmt *dbs.Stmt
 | 
			
		||||
 | 
			
		||||
	selectVersionStmt *dbs.Stmt
 | 
			
		||||
	updateVersionStmt *dbs.Stmt
 | 
			
		||||
 | 
			
		||||
	cleanTicker *time.Ticker
 | 
			
		||||
 | 
			
		||||
	dir string
 | 
			
		||||
 | 
			
		||||
	isClosed bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewIPListSqlite() (*IPListSQLite, error) {
 | 
			
		||||
	var db = &IPListSQLite{
 | 
			
		||||
		itemTableName:    "ipItems",
 | 
			
		||||
		versionTableName: "versions",
 | 
			
		||||
		dir:              filepath.Clean(Tea.Root + "/data"),
 | 
			
		||||
		cleanTicker:      time.NewTicker(24 * time.Hour),
 | 
			
		||||
	}
 | 
			
		||||
	err := db.init()
 | 
			
		||||
	return db, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListSQLite) init() error {
 | 
			
		||||
	// 检查目录是否存在
 | 
			
		||||
	_, err := os.Stat(this.dir)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = os.MkdirAll(this.dir, 0777)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var path = this.dir + "/ip_list.db"
 | 
			
		||||
 | 
			
		||||
	db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	db.SetMaxOpenConns(1)
 | 
			
		||||
 | 
			
		||||
	//_, err = db.Exec("VACUUM")
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	return err
 | 
			
		||||
	//}
 | 
			
		||||
 | 
			
		||||
	this.db = db
 | 
			
		||||
 | 
			
		||||
	// 恢复数据库
 | 
			
		||||
	var recoverEnv, _ = os.LookupEnv("EdgeRecover")
 | 
			
		||||
	if len(recoverEnv) > 0 {
 | 
			
		||||
		for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
 | 
			
		||||
			_, _ = db.Exec(`REINDEX "` + indexName + `"`)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 初始化数据库
 | 
			
		||||
	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
 | 
			
		||||
  "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  "listId" integer DEFAULT 0,
 | 
			
		||||
  "listType" varchar(32),
 | 
			
		||||
  "isGlobal" integer(1) DEFAULT 0,
 | 
			
		||||
  "type" varchar(16),
 | 
			
		||||
  "itemId" integer DEFAULT 0,
 | 
			
		||||
  "ipFrom" varchar(64) DEFAULT 0,
 | 
			
		||||
  "ipTo" varchar(64) DEFAULT 0,
 | 
			
		||||
  "expiredAt" integer DEFAULT 0,
 | 
			
		||||
  "eventLevel" varchar(32),
 | 
			
		||||
  "isDeleted" integer(1) DEFAULT 0,
 | 
			
		||||
  "version" integer DEFAULT 0,
 | 
			
		||||
  "nodeId" integer DEFAULT 0,
 | 
			
		||||
  "serverId" integer DEFAULT 0
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
 | 
			
		||||
ON "` + this.itemTableName + `" (
 | 
			
		||||
  "itemId" ASC
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
 | 
			
		||||
ON "` + this.itemTableName + `" (
 | 
			
		||||
  "expiredAt" ASC
 | 
			
		||||
);
 | 
			
		||||
`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
 | 
			
		||||
  "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  "version" integer DEFAULT 0
 | 
			
		||||
);
 | 
			
		||||
`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 初始化SQL语句
 | 
			
		||||
	this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE  "expiredAt">0 AND "expiredAt"<?`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.db = db
 | 
			
		||||
 | 
			
		||||
	goman.New(func() {
 | 
			
		||||
		events.OnClose(func() {
 | 
			
		||||
			_ = this.Close()
 | 
			
		||||
			this.cleanTicker.Stop()
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		for range this.cleanTicker.C {
 | 
			
		||||
			err := this.DeleteExpiredItems()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Name 数据库名称代号
 | 
			
		||||
func (this *IPListSQLite) Name() string {
 | 
			
		||||
	return "sqlite"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteExpiredItems 删除过期的条目
 | 
			
		||||
func (this *IPListSQLite) DeleteExpiredItems() error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListSQLite) AddItem(item *pb.IPItem) error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := this.deleteItemStmt.Exec(item.Id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果是删除,则不再创建新记录
 | 
			
		||||
	if item.IsDeleted {
 | 
			
		||||
		return this.UpdateMaxVersion(item.Version)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.UpdateMaxVersion(item.Version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListSQLite) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error) {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := this.selectItemsStmt.Query(offset, size)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, false, err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = rows.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		//  "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
 | 
			
		||||
		var pbItem = &pb.IPItem{}
 | 
			
		||||
		err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, false, err
 | 
			
		||||
		}
 | 
			
		||||
		items = append(items, pbItem)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	goNext = int64(len(items)) == size
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadMaxVersion 读取当前最大版本号
 | 
			
		||||
func (this *IPListSQLite) ReadMaxVersion() (int64, error) {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return 0, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// from version table
 | 
			
		||||
	{
 | 
			
		||||
		var row = this.selectVersionStmt.QueryRow()
 | 
			
		||||
		if row == nil {
 | 
			
		||||
			return 0, nil
 | 
			
		||||
		}
 | 
			
		||||
		var version int64
 | 
			
		||||
		err := row.Scan(&version)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return version, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// from items table
 | 
			
		||||
	{
 | 
			
		||||
		var row = this.selectMaxItemVersionStmt.QueryRow()
 | 
			
		||||
		if row == nil {
 | 
			
		||||
			return 0, nil
 | 
			
		||||
		}
 | 
			
		||||
		var version int64
 | 
			
		||||
		err := row.Scan(&version)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return 0, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return version, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateMaxVersion 修改版本号
 | 
			
		||||
func (this *IPListSQLite) UpdateMaxVersion(version int64) error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := this.updateVersionStmt.Exec(version)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *IPListSQLite) Close() error {
 | 
			
		||||
	this.isClosed = true
 | 
			
		||||
 | 
			
		||||
	if this.db != nil {
 | 
			
		||||
		for _, stmt := range []*dbs.Stmt{
 | 
			
		||||
			this.deleteExpiredItemsStmt,
 | 
			
		||||
			this.deleteItemStmt,
 | 
			
		||||
			this.insertItemStmt,
 | 
			
		||||
			this.selectItemsStmt,
 | 
			
		||||
			this.selectMaxItemVersionStmt, // ipItems table
 | 
			
		||||
 | 
			
		||||
			this.selectVersionStmt, // versions table
 | 
			
		||||
			this.updateVersionStmt,
 | 
			
		||||
		} {
 | 
			
		||||
			if stmt != nil {
 | 
			
		||||
				_ = stmt.Close()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return this.db.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -12,7 +12,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestIPListDB_AddItem(t *testing.T) {
 | 
			
		||||
	db, err := iplibrary.NewIPListDB()
 | 
			
		||||
	db, err := iplibrary.NewIPListSqlite()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -59,7 +59,7 @@ func TestIPListDB_AddItem(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListDB_ReadItems(t *testing.T) {
 | 
			
		||||
	db, err := iplibrary.NewIPListDB()
 | 
			
		||||
	db, err := iplibrary.NewIPListSqlite()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -71,15 +71,16 @@ func TestIPListDB_ReadItems(t *testing.T) {
 | 
			
		||||
		_ = db.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	items, err := db.ReadItems(0, 2)
 | 
			
		||||
	items, goNext, err := db.ReadItems(0, 2)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("goNext:", goNext)
 | 
			
		||||
	logs.PrintAsJSON(items, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListDB_ReadMaxVersion(t *testing.T) {
 | 
			
		||||
	db, err := iplibrary.NewIPListDB()
 | 
			
		||||
	db, err := iplibrary.NewIPListSqlite()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -90,7 +91,7 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListDB_UpdateMaxVersion(t *testing.T) {
 | 
			
		||||
	db, err := iplibrary.NewIPListDB()
 | 
			
		||||
	db, err := iplibrary.NewIPListSqlite()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -8,10 +8,13 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/trackers"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/zero"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"os"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@@ -45,7 +48,7 @@ func init() {
 | 
			
		||||
type IPListManager struct {
 | 
			
		||||
	ticker *time.Ticker
 | 
			
		||||
 | 
			
		||||
	db *IPListDB
 | 
			
		||||
	db IPListDB
 | 
			
		||||
 | 
			
		||||
	lastVersion   int64
 | 
			
		||||
	fetchPageSize int64
 | 
			
		||||
@@ -83,7 +86,7 @@ func (this *IPListManager) Start() {
 | 
			
		||||
		case <-this.ticker.C:
 | 
			
		||||
		case <-IPListUpdateNotify:
 | 
			
		||||
		}
 | 
			
		||||
		err := this.loop()
 | 
			
		||||
		err = this.loop()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			countErrors++
 | 
			
		||||
 | 
			
		||||
@@ -110,7 +113,18 @@ func (this *IPListManager) Stop() {
 | 
			
		||||
 | 
			
		||||
func (this *IPListManager) init() {
 | 
			
		||||
	// 从数据库中当中读取数据
 | 
			
		||||
	db, err := NewIPListDB()
 | 
			
		||||
	// 检查sqlite文件是否存在,以便决定使用sqlite还是kv
 | 
			
		||||
	var sqlitePath = Tea.Root + "/data/ip_list.db"
 | 
			
		||||
	_, sqliteErr := os.Stat(sqlitePath)
 | 
			
		||||
 | 
			
		||||
	var db IPListDB
 | 
			
		||||
	var err error
 | 
			
		||||
	if sqliteErr == nil {
 | 
			
		||||
		db, err = NewIPListSqlite()
 | 
			
		||||
	} else {
 | 
			
		||||
		db, err = NewIPListKV()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error())
 | 
			
		||||
	} else {
 | 
			
		||||
@@ -120,24 +134,30 @@ func (this *IPListManager) init() {
 | 
			
		||||
		_ = db.DeleteExpiredItems()
 | 
			
		||||
 | 
			
		||||
		// 本地数据库中最大版本号
 | 
			
		||||
		this.lastVersion = db.ReadMaxVersion()
 | 
			
		||||
		this.lastVersion, err = db.ReadMaxVersion()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			remotelogs.Error("IP_LIST_MANAGER", "find max version failed: "+err.Error())
 | 
			
		||||
			this.lastVersion = 0
 | 
			
		||||
		}
 | 
			
		||||
		remotelogs.Println("IP_LIST_MANAGER", "starting from '"+db.Name()+"' version '"+types.String(this.lastVersion)+"' ...")
 | 
			
		||||
 | 
			
		||||
		// 从本地数据库中加载
 | 
			
		||||
		var offset int64 = 0
 | 
			
		||||
		var size int64 = 2_000
 | 
			
		||||
 | 
			
		||||
		var tr = trackers.Begin("IP_LIST_MANAGER:load")
 | 
			
		||||
		defer tr.End()
 | 
			
		||||
 | 
			
		||||
		for {
 | 
			
		||||
			items, err := db.ReadItems(offset, size)
 | 
			
		||||
			items, goNext, readErr := db.ReadItems(offset, size)
 | 
			
		||||
			var l = len(items)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+err.Error())
 | 
			
		||||
			if readErr != nil {
 | 
			
		||||
				remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+readErr.Error())
 | 
			
		||||
			} else {
 | 
			
		||||
				if l == 0 {
 | 
			
		||||
				if !goNext {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				this.processItems(items, false)
 | 
			
		||||
				if int64(l) < size {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			offset += int64(l)
 | 
			
		||||
		}
 | 
			
		||||
@@ -310,9 +330,14 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
 | 
			
		||||
 | 
			
		||||
// 调试IP信息
 | 
			
		||||
func (this *IPListManager) debugItem(item *pb.IPItem) {
 | 
			
		||||
	var ipRange = item.IpFrom
 | 
			
		||||
	if len(item.IpTo) > 0 {
 | 
			
		||||
		ipRange += " - " + item.IpTo
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if item.IsDeleted {
 | 
			
		||||
		remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+item.IpFrom+"'")
 | 
			
		||||
		remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+ipRange+"'")
 | 
			
		||||
	} else {
 | 
			
		||||
		remotelogs.Debug("IP_ITEM_DEBUG", "add '"+item.IpFrom+"'")
 | 
			
		||||
		remotelogs.Debug("IP_ITEM_DEBUG", "add '"+ipRange+"'")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										28
									
								
								internal/utils/byte/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								internal/utils/byte/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
			
		||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
			
		||||
 | 
			
		||||
package byteutils
 | 
			
		||||
 | 
			
		||||
// Copy bytes
 | 
			
		||||
func Copy(b []byte) []byte {
 | 
			
		||||
	var l = len(b)
 | 
			
		||||
	if l == 0 {
 | 
			
		||||
		return []byte{}
 | 
			
		||||
	}
 | 
			
		||||
	var d = make([]byte, l)
 | 
			
		||||
	copy(d, b)
 | 
			
		||||
	return d
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Append bytes
 | 
			
		||||
func Append(b []byte, b2 ...byte) []byte {
 | 
			
		||||
	return append(Copy(b), b2...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contact bytes
 | 
			
		||||
func Contact(b []byte, b2 ...[]byte) []byte {
 | 
			
		||||
	b = Copy(b)
 | 
			
		||||
	for _, b3 := range b2 {
 | 
			
		||||
		b = append(b, b3...)
 | 
			
		||||
	}
 | 
			
		||||
	return b
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										56
									
								
								internal/utils/byte/utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								internal/utils/byte/utils_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,56 @@
 | 
			
		||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
			
		||||
 | 
			
		||||
package byteutils_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
 | 
			
		||||
	"github.com/iwind/TeaGo/assert"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCopy(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	var prefix []byte
 | 
			
		||||
	prefix = append(prefix, 1, 2, 3)
 | 
			
		||||
	t.Log(prefix, byteutils.Copy(prefix))
 | 
			
		||||
	a.IsTrue(bytes.Equal(byteutils.Copy(prefix), []byte{1, 2, 3}))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAppend(t *testing.T) {
 | 
			
		||||
	var as = assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	var prefix []byte
 | 
			
		||||
	prefix = append(prefix, 1, 2, 3)
 | 
			
		||||
 | 
			
		||||
	// [1 2 3 4 5 6] [1 2 3 7]
 | 
			
		||||
	var a = byteutils.Append(prefix, 4, 5, 6)
 | 
			
		||||
	var b = byteutils.Append(prefix, 7)
 | 
			
		||||
	t.Log(a, b)
 | 
			
		||||
 | 
			
		||||
	as.IsTrue(bytes.Equal(a, []byte{1, 2, 3, 4, 5, 6}))
 | 
			
		||||
	as.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 7}))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConcat(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	var prefix []byte
 | 
			
		||||
	prefix = append(prefix, 1, 2, 3)
 | 
			
		||||
 | 
			
		||||
	var b = byteutils.Contact(prefix, []byte{4, 5, 6}, []byte{7})
 | 
			
		||||
	t.Log(b)
 | 
			
		||||
 | 
			
		||||
	a.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 4, 5, 6, 7}))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAppend_Raw(t *testing.T) {
 | 
			
		||||
	var prefix []byte
 | 
			
		||||
	prefix = append(prefix, 1, 2, 3)
 | 
			
		||||
 | 
			
		||||
	// [1 2 3 7 5 6] [1 2 3 7]
 | 
			
		||||
	var a = append(prefix, 4, 5, 6)
 | 
			
		||||
	var b = append(prefix, 7)
 | 
			
		||||
	t.Log(a, b)
 | 
			
		||||
}
 | 
			
		||||
@@ -9,10 +9,16 @@ import (
 | 
			
		||||
 | 
			
		||||
var ErrTableNotFound = errors.New("table not found")
 | 
			
		||||
var ErrKeyTooLong = errors.New("too long key")
 | 
			
		||||
var ErrSkip= errors.New("skip") // skip count in iterator
 | 
			
		||||
 | 
			
		||||
func IsKeyNotFound(err error) bool {
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return errors.Is(err, pebble.ErrNotFound)
 | 
			
		||||
func IsNotFound(err error) bool {
 | 
			
		||||
	return err != nil && errors.Is(err, pebble.ErrNotFound)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsSkipError(err error) bool {
 | 
			
		||||
	return err != nil && errors.Is(err, ErrSkip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Skip() (bool, error) {
 | 
			
		||||
	return true, ErrSkip
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,3 +7,7 @@ import "github.com/cockroachdb/pebble"
 | 
			
		||||
var DefaultWriteOptions = &pebble.WriteOptions{
 | 
			
		||||
	Sync: false,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DefaultWriteSyncOptions = &pebble.WriteOptions{
 | 
			
		||||
	Sync: true,
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DataType = int
 | 
			
		||||
@@ -222,11 +223,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
 | 
			
		||||
	var prefix []byte
 | 
			
		||||
	switch this.dataType {
 | 
			
		||||
	case DataTypeKey:
 | 
			
		||||
		prefix = append(this.table.Namespace(), KeyPrefix...)
 | 
			
		||||
		prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
 | 
			
		||||
	case DataTypeField:
 | 
			
		||||
		prefix = append(this.table.Namespace(), FieldPrefix...)
 | 
			
		||||
		prefix = byteutils.Append(this.table.Namespace(), []byte(FieldPrefix)...)
 | 
			
		||||
	default:
 | 
			
		||||
		prefix = append(this.table.Namespace(), KeyPrefix...)
 | 
			
		||||
		prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var prefixLen = len(prefix)
 | 
			
		||||
@@ -238,21 +239,21 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
 | 
			
		||||
	var offsetKey []byte
 | 
			
		||||
	if this.reverse {
 | 
			
		||||
		if len(this.offsetKey) > 0 {
 | 
			
		||||
			offsetKey = append(prefix, this.offsetKey...)
 | 
			
		||||
			offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
 | 
			
		||||
		} else {
 | 
			
		||||
			offsetKey = append(prefix, 0xFF)
 | 
			
		||||
			offsetKey = byteutils.Append(prefix, 0xFF)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		opt.LowerBound = prefix
 | 
			
		||||
		opt.UpperBound = offsetKey
 | 
			
		||||
	} else {
 | 
			
		||||
		if len(this.offsetKey) > 0 {
 | 
			
		||||
			offsetKey = append(prefix, this.offsetKey...)
 | 
			
		||||
			offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
 | 
			
		||||
		} else {
 | 
			
		||||
			offsetKey = prefix
 | 
			
		||||
		}
 | 
			
		||||
		opt.LowerBound = offsetKey
 | 
			
		||||
		opt.UpperBound = append(offsetKey, 0xFF)
 | 
			
		||||
		opt.UpperBound = byteutils.Append(prefix, 0xFF)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var hasOffsetKey = len(this.offsetKey) > 0
 | 
			
		||||
@@ -267,7 +268,7 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
 | 
			
		||||
 | 
			
		||||
	var count int
 | 
			
		||||
 | 
			
		||||
	var itemFn = func() (goNext bool, err error) {
 | 
			
		||||
	var itemFn = func() (goNextItem bool, err error) {
 | 
			
		||||
		var keyBytes = it.Key()
 | 
			
		||||
 | 
			
		||||
		// skip first offset key
 | 
			
		||||
@@ -297,7 +298,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
 | 
			
		||||
			Value: value,
 | 
			
		||||
		})
 | 
			
		||||
		if callbackErr != nil {
 | 
			
		||||
			return false, callbackErr
 | 
			
		||||
			if IsSkipError(callbackErr) {
 | 
			
		||||
				return true, nil
 | 
			
		||||
			} else {
 | 
			
		||||
				return false, callbackErr
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if !goNext {
 | 
			
		||||
			return false, nil
 | 
			
		||||
@@ -361,9 +366,9 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
 | 
			
		||||
		if len(this.fieldOffsetKey) > 0 {
 | 
			
		||||
			offsetKey = this.fieldOffsetKey
 | 
			
		||||
		} else if len(this.offsetKey) > 0 {
 | 
			
		||||
			offsetKey = append(prefix, this.offsetKey...)
 | 
			
		||||
			offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
 | 
			
		||||
		} else {
 | 
			
		||||
			offsetKey = append(prefix, 0xFF)
 | 
			
		||||
			offsetKey = byteutils.Append(prefix, 0xFF)
 | 
			
		||||
		}
 | 
			
		||||
		opt.LowerBound = prefix
 | 
			
		||||
		opt.UpperBound = offsetKey
 | 
			
		||||
@@ -371,14 +376,14 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
 | 
			
		||||
		if len(this.fieldOffsetKey) > 0 {
 | 
			
		||||
			offsetKey = this.fieldOffsetKey
 | 
			
		||||
		} else if len(this.offsetKey) > 0 {
 | 
			
		||||
			offsetKey = append(prefix, this.offsetKey...)
 | 
			
		||||
			offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
 | 
			
		||||
			offsetKey = append(offsetKey, 0)
 | 
			
		||||
		} else {
 | 
			
		||||
			offsetKey = prefix
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		opt.LowerBound = offsetKey
 | 
			
		||||
		opt.UpperBound = append(prefix, 0xFF)
 | 
			
		||||
		opt.UpperBound = byteutils.Append(prefix, 0xFF)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	it, itErr := this.tx.NewIterator(opt)
 | 
			
		||||
@@ -391,7 +396,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
 | 
			
		||||
 | 
			
		||||
	var count int
 | 
			
		||||
 | 
			
		||||
	var itemFn = func() (goNext bool, err error) {
 | 
			
		||||
	var itemFn = func() (goNextItem bool, err error) {
 | 
			
		||||
		var fieldKeyBytes = it.Key()
 | 
			
		||||
 | 
			
		||||
		fieldValueBytes, keyBytes, decodeKeyErr := this.table.DecodeFieldKey(this.fieldName, fieldKeyBytes)
 | 
			
		||||
@@ -423,7 +428,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
 | 
			
		||||
		if !this.keysOnly {
 | 
			
		||||
			value, getErr := this.table.getWithKeyBytes(this.tx, this.table.FullKeyBytes(keyBytes))
 | 
			
		||||
			if getErr != nil {
 | 
			
		||||
				if IsKeyNotFound(getErr) {
 | 
			
		||||
				if IsNotFound(getErr) {
 | 
			
		||||
					return true, nil
 | 
			
		||||
				}
 | 
			
		||||
				return false, getErr
 | 
			
		||||
@@ -432,11 +437,15 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
 | 
			
		||||
			resultItem.Value = value
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		goNext, err = fn(this.tx, resultItem)
 | 
			
		||||
		goNextItem, err = fn(this.tx, resultItem)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return
 | 
			
		||||
			if IsSkipError(err) {
 | 
			
		||||
				return true, nil
 | 
			
		||||
			} else {
 | 
			
		||||
				return false, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if !goNext {
 | 
			
		||||
		if !goNextItem {
 | 
			
		||||
			return false, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -138,6 +138,26 @@ func TestQuery_FindAll_Offset(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQuery_FindAll_Skip(t *testing.T) {
 | 
			
		||||
	var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		err := table.Query().
 | 
			
		||||
			Offset("a3").
 | 
			
		||||
			Limit(10).
 | 
			
		||||
			FindAll(func(tx *kvstore.Tx[*testCachedItem], item kvstore.Item[*testCachedItem]) (goNext bool, err error) {
 | 
			
		||||
				if item.Key == "a30" || item.Key == "a3000005" {
 | 
			
		||||
					return kvstore.Skip()
 | 
			
		||||
				}
 | 
			
		||||
				t.Log("key:", item.Key, "value:", item.Value)
 | 
			
		||||
				return true, nil
 | 
			
		||||
			})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQuery_FindAll_Count(t *testing.T) {
 | 
			
		||||
	var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	memutils "github.com/TeaOSLab/EdgeNode/internal/utils/mem"
 | 
			
		||||
	"github.com/cockroachdb/pebble"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
@@ -85,6 +86,31 @@ func OpenStoreDir(dir string, storeName string) (*Store, error) {
 | 
			
		||||
	return store, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var storeOnce = &sync.Once{}
 | 
			
		||||
var defaultSore *Store
 | 
			
		||||
 | 
			
		||||
func DefaultStore() (*Store, error) {
 | 
			
		||||
	if defaultSore != nil {
 | 
			
		||||
		return defaultSore, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	storeOnce.Do(func() {
 | 
			
		||||
		store, err := NewStore("default")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			remotelogs.Error("KV", "create default store failed: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = store.Open()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			remotelogs.Error("KV", "open default store failed: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defaultSore = store
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return defaultSore, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Store) Open() error {
 | 
			
		||||
	var opt = &pebble.Options{
 | 
			
		||||
		Logger: NewLogger(),
 | 
			
		||||
@@ -144,6 +170,10 @@ func (this *Store) RawDB() *pebble.DB {
 | 
			
		||||
	return this.rawDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Store) Flush() error {
 | 
			
		||||
	return this.rawDB.Flush()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Store) Close() error {
 | 
			
		||||
	if this.isClosed {
 | 
			
		||||
		return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,9 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
 | 
			
		||||
	"github.com/cockroachdb/pebble"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/assert"
 | 
			
		||||
	_ "github.com/iwind/TeaGo/bootstrap"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@@ -19,6 +21,44 @@ func TestMain(m *testing.M) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStore_Default(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	store, err := kvstore.DefaultStore()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	a.IsTrue(store != nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStore_Default_Concurrent(t *testing.T) {
 | 
			
		||||
	var lastStore *kvstore.Store
 | 
			
		||||
 | 
			
		||||
	const threads = 32
 | 
			
		||||
 | 
			
		||||
	var wg = &sync.WaitGroup{}
 | 
			
		||||
	wg.Add(threads)
 | 
			
		||||
	for i := 0; i < threads; i++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
 | 
			
		||||
			store, err := kvstore.DefaultStore()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Log("ERROR", err)
 | 
			
		||||
				t.Fail()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if lastStore != nil && lastStore != store {
 | 
			
		||||
				t.Log("ERROR", "should be single instance")
 | 
			
		||||
				t.Fail()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			lastStore = store
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStore_Open(t *testing.T) {
 | 
			
		||||
	store, err := kvstore.OpenStore("test")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -29,6 +69,7 @@ func TestStore_Open(t *testing.T) {
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	t.Log("opened")
 | 
			
		||||
	_ = store
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStore_RawDB(t *testing.T) {
 | 
			
		||||
 
 | 
			
		||||
@@ -64,6 +64,10 @@ func (this *Table[T]) DB() *DB {
 | 
			
		||||
	return this.db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table[T]) Encoder() ValueEncoder[T] {
 | 
			
		||||
	return this.encoder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table[T]) Set(key string, value T) error {
 | 
			
		||||
	if len(key) > KeyMaxLength {
 | 
			
		||||
		return ErrKeyTooLong
 | 
			
		||||
@@ -75,7 +79,22 @@ func (this *Table[T]) Set(key string, value T) error {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.WriteTx(func(tx *Tx[T]) error {
 | 
			
		||||
		return this.set(tx, key, valueBytes, value, false)
 | 
			
		||||
		return this.set(tx, key, valueBytes, value, false, false)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table[T]) SetSync(key string, value T) error {
 | 
			
		||||
	if len(key) > KeyMaxLength {
 | 
			
		||||
		return ErrKeyTooLong
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	valueBytes, err := this.encoder.Encode(value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.WriteTxSync(func(tx *Tx[T]) error {
 | 
			
		||||
		return this.set(tx, key, valueBytes, value, false, true)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -90,7 +109,7 @@ func (this *Table[T]) Insert(key string, value T) error {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.WriteTx(func(tx *Tx[T]) error {
 | 
			
		||||
		return this.set(tx, key, valueBytes, value, true)
 | 
			
		||||
		return this.set(tx, key, valueBytes, value, true, false)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -111,7 +130,7 @@ func (this *Table[T]) ComposeFieldKey(keyBytes []byte, fieldName string, fieldVa
 | 
			
		||||
func (this *Table[T]) Exist(key string) (found bool, err error) {
 | 
			
		||||
	_, closer, err := this.db.store.rawDB.Get(this.FullKey(key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if IsKeyNotFound(err) {
 | 
			
		||||
		if IsNotFound(err) {
 | 
			
		||||
			return false, nil
 | 
			
		||||
		}
 | 
			
		||||
		return false, err
 | 
			
		||||
@@ -173,6 +192,20 @@ func (this *Table[T]) WriteTx(fn func(tx *Tx[T]) error) error {
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table[T]) WriteTxSync(fn func(tx *Tx[T]) error) error {
 | 
			
		||||
	var tx = NewTx[T](this, false)
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = tx.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	err := fn(tx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tx.CommitSync()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table[T]) Truncate() error {
 | 
			
		||||
	this.mu.Lock()
 | 
			
		||||
	defer this.mu.Unlock()
 | 
			
		||||
@@ -256,7 +289,7 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
 | 
			
		||||
			if len(this.fieldNames) > 0 {
 | 
			
		||||
				valueBytes, closer, getErr := batch.Get(keyBytes)
 | 
			
		||||
				if getErr != nil {
 | 
			
		||||
					if IsKeyNotFound(getErr) {
 | 
			
		||||
					if IsNotFound(getErr) {
 | 
			
		||||
						return nil
 | 
			
		||||
					}
 | 
			
		||||
					return getErr
 | 
			
		||||
@@ -298,8 +331,12 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool) error {
 | 
			
		||||
func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool, syncMode bool) error {
 | 
			
		||||
	var keyBytes = this.FullKey(key)
 | 
			
		||||
	var writeOptions = DefaultWriteOptions
 | 
			
		||||
	if syncMode {
 | 
			
		||||
		writeOptions = DefaultWriteSyncOptions
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var batch = tx.batch
 | 
			
		||||
 | 
			
		||||
@@ -312,7 +349,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
 | 
			
		||||
		if countFields > 0 {
 | 
			
		||||
			oldValueBytes, closer, getErr := batch.Get(keyBytes)
 | 
			
		||||
			if getErr != nil {
 | 
			
		||||
				if !IsKeyNotFound(getErr) {
 | 
			
		||||
				if !IsNotFound(getErr) {
 | 
			
		||||
					return getErr
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
@@ -330,7 +367,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	setErr := batch.Set(keyBytes, valueBytes, DefaultWriteOptions)
 | 
			
		||||
	setErr := batch.Set(keyBytes, valueBytes, writeOptions)
 | 
			
		||||
	if setErr != nil {
 | 
			
		||||
		return setErr
 | 
			
		||||
	}
 | 
			
		||||
@@ -362,14 +399,14 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
 | 
			
		||||
					// skip the field
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				deleteFieldErr := batch.Delete(oldFieldKeyBytes, DefaultWriteOptions)
 | 
			
		||||
				deleteFieldErr := batch.Delete(oldFieldKeyBytes, writeOptions)
 | 
			
		||||
				if deleteFieldErr != nil {
 | 
			
		||||
					return deleteFieldErr
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// set new field key
 | 
			
		||||
			setFieldErr := batch.Set(newFieldKeyBytes, nil, DefaultWriteOptions)
 | 
			
		||||
			setFieldErr := batch.Set(newFieldKeyBytes, nil, writeOptions)
 | 
			
		||||
			if setFieldErr != nil {
 | 
			
		||||
				return setFieldErr
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,7 +21,7 @@ func (this *CounterTable[T]) Increase(key string, delta T) (newValue T, err erro
 | 
			
		||||
	err = this.Table.WriteTx(func(tx *Tx[T]) error {
 | 
			
		||||
		value, getErr := tx.Get(key)
 | 
			
		||||
		if getErr != nil {
 | 
			
		||||
			if !IsKeyNotFound(getErr) {
 | 
			
		||||
			if !IsNotFound(getErr) {
 | 
			
		||||
				return getErr
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -45,7 +45,7 @@ func TestTable_Set(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	value, err := table.Get("a")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if kvstore.IsKeyNotFound(err) {
 | 
			
		||||
		if kvstore.IsNotFound(err) {
 | 
			
		||||
			t.Log("not found key")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
@@ -81,7 +81,7 @@ func TestTable_Get(t *testing.T) {
 | 
			
		||||
	for _, key := range []string{"a", "b", "c"} {
 | 
			
		||||
		value, getErr := table.Get(key)
 | 
			
		||||
		if getErr != nil {
 | 
			
		||||
			if kvstore.IsKeyNotFound(getErr) {
 | 
			
		||||
			if kvstore.IsNotFound(getErr) {
 | 
			
		||||
				t.Log("not found key", key)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
@@ -146,7 +146,7 @@ func TestTable_Delete(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	value, err := table.Get("a123")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if !kvstore.IsKeyNotFound(err) {
 | 
			
		||||
		if !kvstore.IsNotFound(err) {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
@@ -173,7 +173,7 @@ func TestTable_Delete(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		_, err = table.Get("a123")
 | 
			
		||||
		a.IsTrue(kvstore.IsKeyNotFound(err))
 | 
			
		||||
		a.IsTrue(kvstore.IsNotFound(err))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -357,7 +357,7 @@ func BenchmarkTable_Get(b *testing.B) {
 | 
			
		||||
		for pb.Next() {
 | 
			
		||||
			_, putErr := table.Get(types.String(rand.Int()))
 | 
			
		||||
			if putErr != nil {
 | 
			
		||||
				if kvstore.IsKeyNotFound(putErr) {
 | 
			
		||||
				if kvstore.IsNotFound(putErr) {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				b.Fatal(putErr)
 | 
			
		||||
 
 | 
			
		||||
@@ -37,7 +37,24 @@ func (this *Tx[T]) Set(key string, value T) error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.table.set(this, key, valueBytes, value, false)
 | 
			
		||||
	return this.table.set(this, key, valueBytes, value, false, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) SetSync(key string, value T) error {
 | 
			
		||||
	if this.readOnly {
 | 
			
		||||
		return errors.New("can not set value in readonly transaction")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(key) > KeyMaxLength {
 | 
			
		||||
		return ErrKeyTooLong
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	valueBytes, err := this.table.encoder.Encode(value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.table.set(this, key, valueBytes, value, false, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) Insert(key string, value T) error {
 | 
			
		||||
@@ -54,7 +71,7 @@ func (this *Tx[T]) Insert(key string, value T) error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.table.set(this, key, valueBytes, value, true)
 | 
			
		||||
	return this.table.set(this, key, valueBytes, value, true, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) Get(key string) (value T, err error) {
 | 
			
		||||
@@ -78,6 +95,20 @@ func (this *Tx[T]) Close() error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) Commit() (err error) {
 | 
			
		||||
	return this.commit(DefaultWriteOptions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) CommitSync() (err error) {
 | 
			
		||||
	return this.commit(DefaultWriteSyncOptions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) Query() *Query[T] {
 | 
			
		||||
	var query = NewQuery[T]()
 | 
			
		||||
	query.SetTx(this)
 | 
			
		||||
	return query
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) commit(opt *pebble.WriteOptions) (err error) {
 | 
			
		||||
	defer func() {
 | 
			
		||||
		var panicErr = recover()
 | 
			
		||||
		if panicErr != nil {
 | 
			
		||||
@@ -88,11 +119,5 @@ func (this *Tx[T]) Commit() (err error) {
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return this.batch.Commit(DefaultWriteOptions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Tx[T]) Query() *Query[T] {
 | 
			
		||||
	var query = NewQuery[T]()
 | 
			
		||||
	query.SetTx(this)
 | 
			
		||||
	return query
 | 
			
		||||
	return this.batch.Commit(opt)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,11 @@
 | 
			
		||||
 | 
			
		||||
package testutils
 | 
			
		||||
 | 
			
		||||
import "os"
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"os"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// IsSingleTesting 判断当前测试环境是否为单个函数测试
 | 
			
		||||
func IsSingleTesting() bool {
 | 
			
		||||
@@ -12,4 +16,9 @@ func IsSingleTesting() bool {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RandIP 生成一个随机IP用于测试
 | 
			
		||||
func RandIP() string {
 | 
			
		||||
	return fmt.Sprintf("%d.%d.%d.%d", rand.Int()%255, rand.Int()%255, rand.Int()%255, rand.Int()%255)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user