使用KV数据库来管理IP名单

This commit is contained in:
GoEdgeLab
2024-03-31 10:08:53 +08:00
parent 93e7c6bb48
commit 0f22e87711
21 changed files with 1184 additions and 368 deletions

View File

@@ -97,7 +97,7 @@ func (this *KVListFileStore) ExistItem(hash string) (bool, error) {
item, err := this.itemsTable.Get(hash)
if err != nil {
if kvstore.IsKeyNotFound(err) {
if kvstore.IsNotFound(err) {
return false, nil
}
return false, err

View File

@@ -1,305 +1,13 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
"github.com/iwind/TeaGo/Tea"
"os"
"path/filepath"
"time"
)
import "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
type IPListDB struct {
db *dbs.DB
itemTableName string
versionTableName string
deleteExpiredItemsStmt *dbs.Stmt
deleteItemStmt *dbs.Stmt
insertItemStmt *dbs.Stmt
selectItemsStmt *dbs.Stmt
selectMaxItemVersionStmt *dbs.Stmt
selectVersionStmt *dbs.Stmt
updateVersionStmt *dbs.Stmt
cleanTicker *time.Ticker
dir string
isClosed bool
}
func NewIPListDB() (*IPListDB, error) {
var db = &IPListDB{
itemTableName: "ipItems",
versionTableName: "versions",
dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour),
}
err := db.init()
return db, err
}
func (this *IPListDB) init() error {
// 检查目录是否存在
_, err := os.Stat(this.dir)
if err != nil {
err = os.MkdirAll(this.dir, 0777)
if err != nil {
return err
}
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
}
var path = this.dir + "/ip_list.db"
db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
//_, err = db.Exec("VACUUM")
//if err != nil {
// return err
//}
this.db = db
// 恢复数据库
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 {
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
}
}
// 初始化数据库
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"listId" integer DEFAULT 0,
"listType" varchar(32),
"isGlobal" integer(1) DEFAULT 0,
"type" varchar(16),
"itemId" integer DEFAULT 0,
"ipFrom" varchar(64) DEFAULT 0,
"ipTo" varchar(64) DEFAULT 0,
"expiredAt" integer DEFAULT 0,
"eventLevel" varchar(32),
"isDeleted" integer(1) DEFAULT 0,
"version" integer DEFAULT 0,
"nodeId" integer DEFAULT 0,
"serverId" integer DEFAULT 0
);
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
ON "` + this.itemTableName + `" (
"itemId" ASC
);
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
ON "` + this.itemTableName + `" (
"expiredAt" ASC
);
`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"version" integer DEFAULT 0
);
`)
if err != nil {
return err
}
// 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil {
return err
}
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
if err != nil {
return err
}
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
if err != nil {
return err
}
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
if err != nil {
return err
}
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
if err != nil {
return err
}
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
if err != nil {
return err
}
this.db = db
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
err := this.DeleteExpiredItems()
if err != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
}
}
})
return nil
}
// DeleteExpiredItems 删除过期的条目
func (this *IPListDB) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
return err
}
func (this *IPListDB) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
_, err := this.deleteItemStmt.Exec(item.Id)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
if this.isClosed {
return
}
rows, err := this.selectItemsStmt.Query(offset, size)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
var pbItem = &pb.IPItem{}
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
if err != nil {
return nil, err
}
items = append(items, pbItem)
}
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *IPListDB) ReadMaxVersion() int64 {
if this.isClosed {
return 0
}
// from version table
{
var row = this.selectVersionStmt.QueryRow()
if row == nil {
return 0
}
var version int64
err := row.Scan(&version)
if err == nil {
return version
}
}
// from items table
{
var row = this.selectMaxItemVersionStmt.QueryRow()
if row == nil {
return 0
}
var version int64
err := row.Scan(&version)
if err != nil {
return 0
}
return version
}
}
// UpdateMaxVersion 修改版本号
func (this *IPListDB) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
_, err := this.updateVersionStmt.Exec(version)
return err
}
func (this *IPListDB) Close() error {
this.isClosed = true
if this.db != nil {
for _, stmt := range []*dbs.Stmt{
this.deleteExpiredItemsStmt,
this.deleteItemStmt,
this.insertItemStmt,
this.selectItemsStmt,
this.selectMaxItemVersionStmt, // ipItems table
this.selectVersionStmt, // versions table
this.updateVersionStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
return this.db.Close()
}
return nil
type IPListDB interface {
Name() string
DeleteExpiredItems() error
ReadMaxVersion() (int64, error)
ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error)
AddItem(item *pb.IPItem) error
}

View File

@@ -0,0 +1,229 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
"testing"
"time"
)
type IPListKV struct {
ipTable *kvstore.Table[*pb.IPItem]
versionsTable *kvstore.Table[int64]
encoder *IPItemEncoder[*pb.IPItem]
cleanTicker *time.Ticker
isClosed bool
offsetItemKey string
}
func NewIPListKV() (*IPListKV, error) {
var db = &IPListKV{
cleanTicker: time.NewTicker(24 * time.Hour),
encoder: &IPItemEncoder[*pb.IPItem]{},
}
err := db.init()
return db, err
}
func (this *IPListKV) init() error {
store, storeErr := kvstore.DefaultStore()
if storeErr != nil {
return storeErr
}
db, dbErr := store.NewDB("ip_list")
if dbErr != nil {
return dbErr
}
{
table, err := kvstore.NewTable[*pb.IPItem]("ip_items", this.encoder)
if err != nil {
return err
}
this.ipTable = table
err = table.AddFields("expiresAt")
if err != nil {
return err
}
db.AddTable(table)
}
{
table, err := kvstore.NewTable[int64]("versions", kvstore.NewIntValueEncoder[int64]())
if err != nil {
return err
}
this.versionsTable = table
db.AddTable(table)
}
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
deleteErr := this.DeleteExpiredItems()
if deleteErr != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+deleteErr.Error())
}
}
})
return nil
}
// Name 数据库名称代号
func (this *IPListKV) Name() string {
return "kvstore"
}
// DeleteExpiredItems 删除过期的条目
func (this *IPListKV) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
for {
var found bool
var currentTime = fasttime.Now().Unix()
err := this.ipTable.
Query().
FieldAsc("expiresAt").
ForUpdate().
Limit(1000).
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
if !item.Value.IsDeleted && item.Value.ExpiredAt == 0 { // never expires
return kvstore.Skip()
}
if item.Value.ExpiredAt < currentTime-7*86400 /** keep for 7 days **/ {
err = tx.Delete(item.Key)
if err != nil {
return false, err
}
found = true
return true, nil
}
found = false
return false, nil
})
if err != nil {
return err
}
if !found {
break
}
}
return nil
}
func (this *IPListKV) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
// 先删除
var key = this.encoder.EncodeKey(item)
err := this.ipTable.Delete(key)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
err = this.ipTable.Set(key, item)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *IPListKV) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNextLoop bool, err error) {
if this.isClosed {
return
}
err = this.ipTable.
Query().
Offset(this.offsetItemKey).
Limit(int(size)).
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
this.offsetItemKey = item.Key
goNextLoop = true
if !item.Value.IsDeleted {
items = append(items, item.Value)
}
return true, nil
})
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *IPListKV) ReadMaxVersion() (int64, error) {
if this.isClosed {
return 0, errors.New("database has been closed")
}
version, err := this.versionsTable.Get("version")
if err != nil {
if kvstore.IsNotFound(err) {
return 0, nil
}
return 0, err
}
return version, nil
}
// UpdateMaxVersion 修改版本号
func (this *IPListKV) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
return this.versionsTable.SetSync("version", version)
}
func (this *IPListKV) TestInspect(t *testing.T) error {
return this.ipTable.
Query().
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
if len(item.Key) != 8 {
return false, errors.New("invalid key '" + item.Key + "'")
}
t.Log(binary.BigEndian.Uint64([]byte(item.Key)), "=>", item.Value)
return true, nil
})
}
// Flush to disk
func (this *IPListKV) Flush() error {
return this.ipTable.DB().Store().Flush()
}
func (this *IPListKV) Close() error {
this.isClosed = true
return nil
}

View File

@@ -0,0 +1,55 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"google.golang.org/protobuf/proto"
"math"
)
type IPItemEncoder[T interface{ *pb.IPItem }] struct {
}
func NewIPItemEncoder[T interface{ *pb.IPItem }]() *IPItemEncoder[T] {
return &IPItemEncoder[T]{}
}
func (this *IPItemEncoder[T]) Encode(value T) ([]byte, error) {
return proto.Marshal(any(value).(*pb.IPItem))
}
func (this *IPItemEncoder[T]) EncodeField(value T, fieldName string) ([]byte, error) {
switch fieldName {
case "expiresAt":
var expiresAt = any(value).(*pb.IPItem).ExpiredAt
if expiresAt < 0 || expiresAt > int64(math.MaxUint32) {
expiresAt = 0
}
var b = make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(expiresAt))
return b, nil
}
return nil, errors.New("field '" + fieldName + "' not found")
}
func (this *IPItemEncoder[T]) Decode(valueBytes []byte) (value T, err error) {
var item = &pb.IPItem{}
err = proto.Unmarshal(valueBytes, item)
value = item
return
}
// EncodeKey generate key for ip item
func (this *IPItemEncoder[T]) EncodeKey(item *pb.IPItem) string {
var b = make([]byte, 8)
if item.Id < 0 {
item.Id = 0
}
binary.BigEndian.PutUint64(b, uint64(item.Id))
return string(b)
}

View File

@@ -0,0 +1,221 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"testing"
"time"
)
func TestIPListKV_AddItem(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
{
err = kv.AddItem(&pb.IPItem{
Id: 1,
IpFrom: "192.168.1.101",
IpTo: "",
Version: 1,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
{
err = kv.AddItem(&pb.IPItem{
Id: 2,
IpFrom: "192.168.1.102",
IpTo: "",
Version: 2,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
{
err = kv.AddItem(&pb.IPItem{
Id: 3,
IpFrom: "192.168.1.103",
IpTo: "",
Version: 3,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
}
func TestIPListKV_AddItems_Many(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
var count = 2
var from = 1
if testutils.IsSingleTesting() {
count = 2_000_000
}
var before = time.Now()
defer func() {
t.Logf("cost: %.2f s", time.Since(before).Seconds())
}()
for i := from; i <= from+count; i++ {
err = kv.AddItem(&pb.IPItem{
Id: int64(i),
IpFrom: testutils.RandIP(),
IpTo: "",
Version: int64(i),
ExpiredAt: fasttime.NewFastTime().Unix() + 86400,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
}
func TestIPListKV_DeleteExpiredItems(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
err = kv.DeleteExpiredItems()
if err != nil {
t.Fatal(err)
}
}
func TestIPListKV_UpdateMaxVersion(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
err = kv.UpdateMaxVersion(101)
if err != nil {
t.Fatal(err)
}
maxVersion, err := kv.ReadMaxVersion()
if err != nil {
t.Fatal(err)
}
t.Log("version:", maxVersion)
}
func TestIPListKV_ReadMaxVersion(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
maxVersion, err := kv.ReadMaxVersion()
if err != nil {
t.Fatal(err)
}
t.Log("version:", maxVersion)
}
func TestIPListKV_ReadItems(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
for {
items, goNext, readErr := kv.ReadItems(0, 2)
if readErr != nil {
t.Fatal(readErr)
}
t.Log("====")
for _, item := range items {
t.Log(item.Id)
}
if !goNext {
break
}
}
}
func TestIPListKV_CountItems(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
var count int
var m = map[int64]zero.Zero{}
for {
items, goNext, readErr := kv.ReadItems(0, 1000)
if readErr != nil {
t.Fatal(readErr)
}
for _, item := range items {
count++
m[item.Id] = zero.Zero{}
}
if !goNext {
break
}
}
t.Log("count:", count, "len:", len(m))
}
func TestIPListKV_Inspect(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
err = kv.TestInspect(t)
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,312 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
"github.com/iwind/TeaGo/Tea"
"os"
"path/filepath"
"time"
)
type IPListSQLite struct {
db *dbs.DB
itemTableName string
versionTableName string
deleteExpiredItemsStmt *dbs.Stmt
deleteItemStmt *dbs.Stmt
insertItemStmt *dbs.Stmt
selectItemsStmt *dbs.Stmt
selectMaxItemVersionStmt *dbs.Stmt
selectVersionStmt *dbs.Stmt
updateVersionStmt *dbs.Stmt
cleanTicker *time.Ticker
dir string
isClosed bool
}
func NewIPListSqlite() (*IPListSQLite, error) {
var db = &IPListSQLite{
itemTableName: "ipItems",
versionTableName: "versions",
dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour),
}
err := db.init()
return db, err
}
func (this *IPListSQLite) init() error {
// 检查目录是否存在
_, err := os.Stat(this.dir)
if err != nil {
err = os.MkdirAll(this.dir, 0777)
if err != nil {
return err
}
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
}
var path = this.dir + "/ip_list.db"
db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
//_, err = db.Exec("VACUUM")
//if err != nil {
// return err
//}
this.db = db
// 恢复数据库
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 {
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
}
}
// 初始化数据库
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"listId" integer DEFAULT 0,
"listType" varchar(32),
"isGlobal" integer(1) DEFAULT 0,
"type" varchar(16),
"itemId" integer DEFAULT 0,
"ipFrom" varchar(64) DEFAULT 0,
"ipTo" varchar(64) DEFAULT 0,
"expiredAt" integer DEFAULT 0,
"eventLevel" varchar(32),
"isDeleted" integer(1) DEFAULT 0,
"version" integer DEFAULT 0,
"nodeId" integer DEFAULT 0,
"serverId" integer DEFAULT 0
);
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
ON "` + this.itemTableName + `" (
"itemId" ASC
);
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
ON "` + this.itemTableName + `" (
"expiredAt" ASC
);
`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"version" integer DEFAULT 0
);
`)
if err != nil {
return err
}
// 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil {
return err
}
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
if err != nil {
return err
}
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
if err != nil {
return err
}
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
if err != nil {
return err
}
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
if err != nil {
return err
}
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
if err != nil {
return err
}
this.db = db
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
err := this.DeleteExpiredItems()
if err != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
}
}
})
return nil
}
// Name 数据库名称代号
func (this *IPListSQLite) Name() string {
return "sqlite"
}
// DeleteExpiredItems 删除过期的条目
func (this *IPListSQLite) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
return err
}
func (this *IPListSQLite) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
_, err := this.deleteItemStmt.Exec(item.Id)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *IPListSQLite) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error) {
if this.isClosed {
return
}
rows, err := this.selectItemsStmt.Query(offset, size)
if err != nil {
return nil, false, err
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
var pbItem = &pb.IPItem{}
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
if err != nil {
return nil, false, err
}
items = append(items, pbItem)
}
goNext = int64(len(items)) == size
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *IPListSQLite) ReadMaxVersion() (int64, error) {
if this.isClosed {
return 0, nil
}
// from version table
{
var row = this.selectVersionStmt.QueryRow()
if row == nil {
return 0, nil
}
var version int64
err := row.Scan(&version)
if err == nil {
return version, nil
}
}
// from items table
{
var row = this.selectMaxItemVersionStmt.QueryRow()
if row == nil {
return 0, nil
}
var version int64
err := row.Scan(&version)
if err != nil {
return 0, nil
}
return version, nil
}
}
// UpdateMaxVersion 修改版本号
func (this *IPListSQLite) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
_, err := this.updateVersionStmt.Exec(version)
return err
}
func (this *IPListSQLite) Close() error {
this.isClosed = true
if this.db != nil {
for _, stmt := range []*dbs.Stmt{
this.deleteExpiredItemsStmt,
this.deleteItemStmt,
this.insertItemStmt,
this.selectItemsStmt,
this.selectMaxItemVersionStmt, // ipItems table
this.selectVersionStmt, // versions table
this.updateVersionStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
return this.db.Close()
}
return nil
}

View File

@@ -12,7 +12,7 @@ import (
)
func TestIPListDB_AddItem(t *testing.T) {
db, err := iplibrary.NewIPListDB()
db, err := iplibrary.NewIPListSqlite()
if err != nil {
t.Fatal(err)
}
@@ -59,7 +59,7 @@ func TestIPListDB_AddItem(t *testing.T) {
}
func TestIPListDB_ReadItems(t *testing.T) {
db, err := iplibrary.NewIPListDB()
db, err := iplibrary.NewIPListSqlite()
if err != nil {
t.Fatal(err)
}
@@ -71,15 +71,16 @@ func TestIPListDB_ReadItems(t *testing.T) {
_ = db.Close()
}()
items, err := db.ReadItems(0, 2)
items, goNext, err := db.ReadItems(0, 2)
if err != nil {
t.Fatal(err)
}
t.Log("goNext:", goNext)
logs.PrintAsJSON(items, t)
}
func TestIPListDB_ReadMaxVersion(t *testing.T) {
db, err := iplibrary.NewIPListDB()
db, err := iplibrary.NewIPListSqlite()
if err != nil {
t.Fatal(err)
}
@@ -90,7 +91,7 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
}
func TestIPListDB_UpdateMaxVersion(t *testing.T) {
db, err := iplibrary.NewIPListDB()
db, err := iplibrary.NewIPListSqlite()
if err != nil {
t.Fatal(err)
}

View File

@@ -8,10 +8,13 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"os"
"sync"
"time"
)
@@ -45,7 +48,7 @@ func init() {
type IPListManager struct {
ticker *time.Ticker
db *IPListDB
db IPListDB
lastVersion int64
fetchPageSize int64
@@ -83,7 +86,7 @@ func (this *IPListManager) Start() {
case <-this.ticker.C:
case <-IPListUpdateNotify:
}
err := this.loop()
err = this.loop()
if err != nil {
countErrors++
@@ -110,7 +113,18 @@ func (this *IPListManager) Stop() {
func (this *IPListManager) init() {
// 从数据库中当中读取数据
db, err := NewIPListDB()
// 检查sqlite文件是否存在以便决定使用sqlite还是kv
var sqlitePath = Tea.Root + "/data/ip_list.db"
_, sqliteErr := os.Stat(sqlitePath)
var db IPListDB
var err error
if sqliteErr == nil {
db, err = NewIPListSqlite()
} else {
db, err = NewIPListKV()
}
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error())
} else {
@@ -120,24 +134,30 @@ func (this *IPListManager) init() {
_ = db.DeleteExpiredItems()
// 本地数据库中最大版本号
this.lastVersion = db.ReadMaxVersion()
this.lastVersion, err = db.ReadMaxVersion()
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "find max version failed: "+err.Error())
this.lastVersion = 0
}
remotelogs.Println("IP_LIST_MANAGER", "starting from '"+db.Name()+"' version '"+types.String(this.lastVersion)+"' ...")
// 从本地数据库中加载
var offset int64 = 0
var size int64 = 2_000
var tr = trackers.Begin("IP_LIST_MANAGER:load")
defer tr.End()
for {
items, err := db.ReadItems(offset, size)
items, goNext, readErr := db.ReadItems(offset, size)
var l = len(items)
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+err.Error())
if readErr != nil {
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+readErr.Error())
} else {
if l == 0 {
if !goNext {
break
}
this.processItems(items, false)
if int64(l) < size {
break
}
}
offset += int64(l)
}
@@ -310,9 +330,14 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
// 调试IP信息
func (this *IPListManager) debugItem(item *pb.IPItem) {
var ipRange = item.IpFrom
if len(item.IpTo) > 0 {
ipRange += " - " + item.IpTo
}
if item.IsDeleted {
remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+item.IpFrom+"'")
remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+ipRange+"'")
} else {
remotelogs.Debug("IP_ITEM_DEBUG", "add '"+item.IpFrom+"'")
remotelogs.Debug("IP_ITEM_DEBUG", "add '"+ipRange+"'")
}
}

View File

@@ -0,0 +1,28 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package byteutils
// Copy bytes
func Copy(b []byte) []byte {
var l = len(b)
if l == 0 {
return []byte{}
}
var d = make([]byte, l)
copy(d, b)
return d
}
// Append bytes
func Append(b []byte, b2 ...byte) []byte {
return append(Copy(b), b2...)
}
// Contact bytes
func Contact(b []byte, b2 ...[]byte) []byte {
b = Copy(b)
for _, b3 := range b2 {
b = append(b, b3...)
}
return b
}

View File

@@ -0,0 +1,56 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package byteutils_test
import (
"bytes"
byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestCopy(t *testing.T) {
var a = assert.NewAssertion(t)
var prefix []byte
prefix = append(prefix, 1, 2, 3)
t.Log(prefix, byteutils.Copy(prefix))
a.IsTrue(bytes.Equal(byteutils.Copy(prefix), []byte{1, 2, 3}))
}
func TestAppend(t *testing.T) {
var as = assert.NewAssertion(t)
var prefix []byte
prefix = append(prefix, 1, 2, 3)
// [1 2 3 4 5 6] [1 2 3 7]
var a = byteutils.Append(prefix, 4, 5, 6)
var b = byteutils.Append(prefix, 7)
t.Log(a, b)
as.IsTrue(bytes.Equal(a, []byte{1, 2, 3, 4, 5, 6}))
as.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 7}))
}
func TestConcat(t *testing.T) {
var a = assert.NewAssertion(t)
var prefix []byte
prefix = append(prefix, 1, 2, 3)
var b = byteutils.Contact(prefix, []byte{4, 5, 6}, []byte{7})
t.Log(b)
a.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 4, 5, 6, 7}))
}
func TestAppend_Raw(t *testing.T) {
var prefix []byte
prefix = append(prefix, 1, 2, 3)
// [1 2 3 7 5 6] [1 2 3 7]
var a = append(prefix, 4, 5, 6)
var b = append(prefix, 7)
t.Log(a, b)
}

View File

@@ -9,10 +9,16 @@ import (
var ErrTableNotFound = errors.New("table not found")
var ErrKeyTooLong = errors.New("too long key")
var ErrSkip= errors.New("skip") // skip count in iterator
func IsKeyNotFound(err error) bool {
if err == nil {
return false
}
return errors.Is(err, pebble.ErrNotFound)
func IsNotFound(err error) bool {
return err != nil && errors.Is(err, pebble.ErrNotFound)
}
func IsSkipError(err error) bool {
return err != nil && errors.Is(err, ErrSkip)
}
func Skip() (bool, error) {
return true, ErrSkip
}

View File

@@ -7,3 +7,7 @@ import "github.com/cockroachdb/pebble"
var DefaultWriteOptions = &pebble.WriteOptions{
Sync: false,
}
var DefaultWriteSyncOptions = &pebble.WriteOptions{
Sync: true,
}

View File

@@ -6,6 +6,7 @@ import (
"bytes"
"errors"
"fmt"
byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
)
type DataType = int
@@ -222,11 +223,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
var prefix []byte
switch this.dataType {
case DataTypeKey:
prefix = append(this.table.Namespace(), KeyPrefix...)
prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
case DataTypeField:
prefix = append(this.table.Namespace(), FieldPrefix...)
prefix = byteutils.Append(this.table.Namespace(), []byte(FieldPrefix)...)
default:
prefix = append(this.table.Namespace(), KeyPrefix...)
prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
}
var prefixLen = len(prefix)
@@ -238,21 +239,21 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
var offsetKey []byte
if this.reverse {
if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...)
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
} else {
offsetKey = append(prefix, 0xFF)
offsetKey = byteutils.Append(prefix, 0xFF)
}
opt.LowerBound = prefix
opt.UpperBound = offsetKey
} else {
if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...)
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
} else {
offsetKey = prefix
}
opt.LowerBound = offsetKey
opt.UpperBound = append(offsetKey, 0xFF)
opt.UpperBound = byteutils.Append(prefix, 0xFF)
}
var hasOffsetKey = len(this.offsetKey) > 0
@@ -267,7 +268,7 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
var count int
var itemFn = func() (goNext bool, err error) {
var itemFn = func() (goNextItem bool, err error) {
var keyBytes = it.Key()
// skip first offset key
@@ -297,7 +298,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
Value: value,
})
if callbackErr != nil {
return false, callbackErr
if IsSkipError(callbackErr) {
return true, nil
} else {
return false, callbackErr
}
}
if !goNext {
return false, nil
@@ -361,9 +366,9 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
if len(this.fieldOffsetKey) > 0 {
offsetKey = this.fieldOffsetKey
} else if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...)
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
} else {
offsetKey = append(prefix, 0xFF)
offsetKey = byteutils.Append(prefix, 0xFF)
}
opt.LowerBound = prefix
opt.UpperBound = offsetKey
@@ -371,14 +376,14 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
if len(this.fieldOffsetKey) > 0 {
offsetKey = this.fieldOffsetKey
} else if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...)
offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
offsetKey = append(offsetKey, 0)
} else {
offsetKey = prefix
}
opt.LowerBound = offsetKey
opt.UpperBound = append(prefix, 0xFF)
opt.UpperBound = byteutils.Append(prefix, 0xFF)
}
it, itErr := this.tx.NewIterator(opt)
@@ -391,7 +396,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
var count int
var itemFn = func() (goNext bool, err error) {
var itemFn = func() (goNextItem bool, err error) {
var fieldKeyBytes = it.Key()
fieldValueBytes, keyBytes, decodeKeyErr := this.table.DecodeFieldKey(this.fieldName, fieldKeyBytes)
@@ -423,7 +428,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
if !this.keysOnly {
value, getErr := this.table.getWithKeyBytes(this.tx, this.table.FullKeyBytes(keyBytes))
if getErr != nil {
if IsKeyNotFound(getErr) {
if IsNotFound(getErr) {
return true, nil
}
return false, getErr
@@ -432,11 +437,15 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
resultItem.Value = value
}
goNext, err = fn(this.tx, resultItem)
goNextItem, err = fn(this.tx, resultItem)
if err != nil {
return
if IsSkipError(err) {
return true, nil
} else {
return false, err
}
}
if !goNext {
if !goNextItem {
return false, nil
}

View File

@@ -138,6 +138,26 @@ func TestQuery_FindAll_Offset(t *testing.T) {
}
}
func TestQuery_FindAll_Skip(t *testing.T) {
var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})
{
err := table.Query().
Offset("a3").
Limit(10).
FindAll(func(tx *kvstore.Tx[*testCachedItem], item kvstore.Item[*testCachedItem]) (goNext bool, err error) {
if item.Key == "a30" || item.Key == "a3000005" {
return kvstore.Skip()
}
t.Log("key:", item.Key, "value:", item.Value)
return true, nil
})
if err != nil {
t.Fatal(err)
}
}
}
func TestQuery_FindAll_Count(t *testing.T) {
var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
memutils "github.com/TeaOSLab/EdgeNode/internal/utils/mem"
"github.com/cockroachdb/pebble"
"github.com/iwind/TeaGo/Tea"
@@ -85,6 +86,31 @@ func OpenStoreDir(dir string, storeName string) (*Store, error) {
return store, nil
}
var storeOnce = &sync.Once{}
var defaultSore *Store
func DefaultStore() (*Store, error) {
if defaultSore != nil {
return defaultSore, nil
}
storeOnce.Do(func() {
store, err := NewStore("default")
if err != nil {
remotelogs.Error("KV", "create default store failed: "+err.Error())
return
}
err = store.Open()
if err != nil {
remotelogs.Error("KV", "open default store failed: "+err.Error())
return
}
defaultSore = store
})
return defaultSore, nil
}
func (this *Store) Open() error {
var opt = &pebble.Options{
Logger: NewLogger(),
@@ -144,6 +170,10 @@ func (this *Store) RawDB() *pebble.DB {
return this.rawDB
}
func (this *Store) Flush() error {
return this.rawDB.Flush()
}
func (this *Store) Close() error {
if this.isClosed {
return nil

View File

@@ -6,7 +6,9 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
"github.com/cockroachdb/pebble"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/assert"
_ "github.com/iwind/TeaGo/bootstrap"
"sync"
"testing"
"time"
)
@@ -19,6 +21,44 @@ func TestMain(m *testing.M) {
}
}
func TestStore_Default(t *testing.T) {
var a = assert.NewAssertion(t)
store, err := kvstore.DefaultStore()
if err != nil {
t.Fatal(err)
}
a.IsTrue(store != nil)
}
func TestStore_Default_Concurrent(t *testing.T) {
var lastStore *kvstore.Store
const threads = 32
var wg = &sync.WaitGroup{}
wg.Add(threads)
for i := 0; i < threads; i++ {
go func() {
defer wg.Done()
store, err := kvstore.DefaultStore()
if err != nil {
t.Log("ERROR", err)
t.Fail()
}
if lastStore != nil && lastStore != store {
t.Log("ERROR", "should be single instance")
t.Fail()
}
lastStore = store
}()
}
wg.Wait()
}
func TestStore_Open(t *testing.T) {
store, err := kvstore.OpenStore("test")
if err != nil {
@@ -29,6 +69,7 @@ func TestStore_Open(t *testing.T) {
}()
t.Log("opened")
_ = store
}
func TestStore_RawDB(t *testing.T) {

View File

@@ -64,6 +64,10 @@ func (this *Table[T]) DB() *DB {
return this.db
}
func (this *Table[T]) Encoder() ValueEncoder[T] {
return this.encoder
}
func (this *Table[T]) Set(key string, value T) error {
if len(key) > KeyMaxLength {
return ErrKeyTooLong
@@ -75,7 +79,22 @@ func (this *Table[T]) Set(key string, value T) error {
}
return this.WriteTx(func(tx *Tx[T]) error {
return this.set(tx, key, valueBytes, value, false)
return this.set(tx, key, valueBytes, value, false, false)
})
}
func (this *Table[T]) SetSync(key string, value T) error {
if len(key) > KeyMaxLength {
return ErrKeyTooLong
}
valueBytes, err := this.encoder.Encode(value)
if err != nil {
return err
}
return this.WriteTxSync(func(tx *Tx[T]) error {
return this.set(tx, key, valueBytes, value, false, true)
})
}
@@ -90,7 +109,7 @@ func (this *Table[T]) Insert(key string, value T) error {
}
return this.WriteTx(func(tx *Tx[T]) error {
return this.set(tx, key, valueBytes, value, true)
return this.set(tx, key, valueBytes, value, true, false)
})
}
@@ -111,7 +130,7 @@ func (this *Table[T]) ComposeFieldKey(keyBytes []byte, fieldName string, fieldVa
func (this *Table[T]) Exist(key string) (found bool, err error) {
_, closer, err := this.db.store.rawDB.Get(this.FullKey(key))
if err != nil {
if IsKeyNotFound(err) {
if IsNotFound(err) {
return false, nil
}
return false, err
@@ -173,6 +192,20 @@ func (this *Table[T]) WriteTx(fn func(tx *Tx[T]) error) error {
return tx.Commit()
}
func (this *Table[T]) WriteTxSync(fn func(tx *Tx[T]) error) error {
var tx = NewTx[T](this, false)
defer func() {
_ = tx.Close()
}()
err := fn(tx)
if err != nil {
return err
}
return tx.CommitSync()
}
func (this *Table[T]) Truncate() error {
this.mu.Lock()
defer this.mu.Unlock()
@@ -256,7 +289,7 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
if len(this.fieldNames) > 0 {
valueBytes, closer, getErr := batch.Get(keyBytes)
if getErr != nil {
if IsKeyNotFound(getErr) {
if IsNotFound(getErr) {
return nil
}
return getErr
@@ -298,8 +331,12 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
return nil
}
func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool) error {
func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool, syncMode bool) error {
var keyBytes = this.FullKey(key)
var writeOptions = DefaultWriteOptions
if syncMode {
writeOptions = DefaultWriteSyncOptions
}
var batch = tx.batch
@@ -312,7 +349,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
if countFields > 0 {
oldValueBytes, closer, getErr := batch.Get(keyBytes)
if getErr != nil {
if !IsKeyNotFound(getErr) {
if !IsNotFound(getErr) {
return getErr
}
} else {
@@ -330,7 +367,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
}
}
setErr := batch.Set(keyBytes, valueBytes, DefaultWriteOptions)
setErr := batch.Set(keyBytes, valueBytes, writeOptions)
if setErr != nil {
return setErr
}
@@ -362,14 +399,14 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
// skip the field
continue
}
deleteFieldErr := batch.Delete(oldFieldKeyBytes, DefaultWriteOptions)
deleteFieldErr := batch.Delete(oldFieldKeyBytes, writeOptions)
if deleteFieldErr != nil {
return deleteFieldErr
}
}
// set new field key
setFieldErr := batch.Set(newFieldKeyBytes, nil, DefaultWriteOptions)
setFieldErr := batch.Set(newFieldKeyBytes, nil, writeOptions)
if setFieldErr != nil {
return setFieldErr
}

View File

@@ -21,7 +21,7 @@ func (this *CounterTable[T]) Increase(key string, delta T) (newValue T, err erro
err = this.Table.WriteTx(func(tx *Tx[T]) error {
value, getErr := tx.Get(key)
if getErr != nil {
if !IsKeyNotFound(getErr) {
if !IsNotFound(getErr) {
return getErr
}
}

View File

@@ -45,7 +45,7 @@ func TestTable_Set(t *testing.T) {
value, err := table.Get("a")
if err != nil {
if kvstore.IsKeyNotFound(err) {
if kvstore.IsNotFound(err) {
t.Log("not found key")
return
}
@@ -81,7 +81,7 @@ func TestTable_Get(t *testing.T) {
for _, key := range []string{"a", "b", "c"} {
value, getErr := table.Get(key)
if getErr != nil {
if kvstore.IsKeyNotFound(getErr) {
if kvstore.IsNotFound(getErr) {
t.Log("not found key", key)
continue
}
@@ -146,7 +146,7 @@ func TestTable_Delete(t *testing.T) {
value, err := table.Get("a123")
if err != nil {
if !kvstore.IsKeyNotFound(err) {
if !kvstore.IsNotFound(err) {
t.Fatal(err)
}
} else {
@@ -173,7 +173,7 @@ func TestTable_Delete(t *testing.T) {
{
_, err = table.Get("a123")
a.IsTrue(kvstore.IsKeyNotFound(err))
a.IsTrue(kvstore.IsNotFound(err))
}
}
@@ -357,7 +357,7 @@ func BenchmarkTable_Get(b *testing.B) {
for pb.Next() {
_, putErr := table.Get(types.String(rand.Int()))
if putErr != nil {
if kvstore.IsKeyNotFound(putErr) {
if kvstore.IsNotFound(putErr) {
continue
}
b.Fatal(putErr)

View File

@@ -37,7 +37,24 @@ func (this *Tx[T]) Set(key string, value T) error {
return err
}
return this.table.set(this, key, valueBytes, value, false)
return this.table.set(this, key, valueBytes, value, false, false)
}
func (this *Tx[T]) SetSync(key string, value T) error {
if this.readOnly {
return errors.New("can not set value in readonly transaction")
}
if len(key) > KeyMaxLength {
return ErrKeyTooLong
}
valueBytes, err := this.table.encoder.Encode(value)
if err != nil {
return err
}
return this.table.set(this, key, valueBytes, value, false, true)
}
func (this *Tx[T]) Insert(key string, value T) error {
@@ -54,7 +71,7 @@ func (this *Tx[T]) Insert(key string, value T) error {
return err
}
return this.table.set(this, key, valueBytes, value, true)
return this.table.set(this, key, valueBytes, value, true, false)
}
func (this *Tx[T]) Get(key string) (value T, err error) {
@@ -78,6 +95,20 @@ func (this *Tx[T]) Close() error {
}
func (this *Tx[T]) Commit() (err error) {
return this.commit(DefaultWriteOptions)
}
func (this *Tx[T]) CommitSync() (err error) {
return this.commit(DefaultWriteSyncOptions)
}
func (this *Tx[T]) Query() *Query[T] {
var query = NewQuery[T]()
query.SetTx(this)
return query
}
func (this *Tx[T]) commit(opt *pebble.WriteOptions) (err error) {
defer func() {
var panicErr = recover()
if panicErr != nil {
@@ -88,11 +119,5 @@ func (this *Tx[T]) Commit() (err error) {
}
}()
return this.batch.Commit(DefaultWriteOptions)
}
func (this *Tx[T]) Query() *Query[T] {
var query = NewQuery[T]()
query.SetTx(this)
return query
return this.batch.Commit(opt)
}

View File

@@ -2,7 +2,11 @@
package testutils
import "os"
import (
"fmt"
"math/rand"
"os"
)
// IsSingleTesting 判断当前测试环境是否为单个函数测试
func IsSingleTesting() bool {
@@ -12,4 +16,9 @@ func IsSingleTesting() bool {
}
}
return false
}
}
// RandIP 生成一个随机IP用于测试
func RandIP() string {
return fmt.Sprintf("%d.%d.%d.%d", rand.Int()%255, rand.Int()%255, rand.Int()%255, rand.Int()%255)
}