使用KV数据库来管理IP名单

This commit is contained in:
刘祥超
2024-03-31 10:08:53 +08:00
parent b4995868c9
commit d2e9c8c10f
21 changed files with 1184 additions and 368 deletions

View File

@@ -97,7 +97,7 @@ func (this *KVListFileStore) ExistItem(hash string) (bool, error) {
item, err := this.itemsTable.Get(hash) item, err := this.itemsTable.Get(hash)
if err != nil { if err != nil {
if kvstore.IsKeyNotFound(err) { if kvstore.IsNotFound(err) {
return false, nil return false, nil
} }
return false, err return false, err

View File

@@ -1,305 +1,13 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary package iplibrary
import ( import "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
"github.com/iwind/TeaGo/Tea"
"os"
"path/filepath"
"time"
)
type IPListDB struct { type IPListDB interface {
db *dbs.DB Name() string
DeleteExpiredItems() error
itemTableName string ReadMaxVersion() (int64, error)
versionTableName string ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error)
AddItem(item *pb.IPItem) error
deleteExpiredItemsStmt *dbs.Stmt
deleteItemStmt *dbs.Stmt
insertItemStmt *dbs.Stmt
selectItemsStmt *dbs.Stmt
selectMaxItemVersionStmt *dbs.Stmt
selectVersionStmt *dbs.Stmt
updateVersionStmt *dbs.Stmt
cleanTicker *time.Ticker
dir string
isClosed bool
}
func NewIPListDB() (*IPListDB, error) {
var db = &IPListDB{
itemTableName: "ipItems",
versionTableName: "versions",
dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour),
}
err := db.init()
return db, err
}
func (this *IPListDB) init() error {
// 检查目录是否存在
_, err := os.Stat(this.dir)
if err != nil {
err = os.MkdirAll(this.dir, 0777)
if err != nil {
return err
}
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
}
var path = this.dir + "/ip_list.db"
db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
//_, err = db.Exec("VACUUM")
//if err != nil {
// return err
//}
this.db = db
// 恢复数据库
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 {
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
}
}
// 初始化数据库
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"listId" integer DEFAULT 0,
"listType" varchar(32),
"isGlobal" integer(1) DEFAULT 0,
"type" varchar(16),
"itemId" integer DEFAULT 0,
"ipFrom" varchar(64) DEFAULT 0,
"ipTo" varchar(64) DEFAULT 0,
"expiredAt" integer DEFAULT 0,
"eventLevel" varchar(32),
"isDeleted" integer(1) DEFAULT 0,
"version" integer DEFAULT 0,
"nodeId" integer DEFAULT 0,
"serverId" integer DEFAULT 0
);
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
ON "` + this.itemTableName + `" (
"itemId" ASC
);
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
ON "` + this.itemTableName + `" (
"expiredAt" ASC
);
`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"version" integer DEFAULT 0
);
`)
if err != nil {
return err
}
// 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil {
return err
}
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
if err != nil {
return err
}
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
if err != nil {
return err
}
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
if err != nil {
return err
}
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
if err != nil {
return err
}
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
if err != nil {
return err
}
this.db = db
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
err := this.DeleteExpiredItems()
if err != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
}
}
})
return nil
}
// DeleteExpiredItems 删除过期的条目
func (this *IPListDB) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
return err
}
func (this *IPListDB) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
_, err := this.deleteItemStmt.Exec(item.Id)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
if this.isClosed {
return
}
rows, err := this.selectItemsStmt.Query(offset, size)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
var pbItem = &pb.IPItem{}
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
if err != nil {
return nil, err
}
items = append(items, pbItem)
}
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *IPListDB) ReadMaxVersion() int64 {
if this.isClosed {
return 0
}
// from version table
{
var row = this.selectVersionStmt.QueryRow()
if row == nil {
return 0
}
var version int64
err := row.Scan(&version)
if err == nil {
return version
}
}
// from items table
{
var row = this.selectMaxItemVersionStmt.QueryRow()
if row == nil {
return 0
}
var version int64
err := row.Scan(&version)
if err != nil {
return 0
}
return version
}
}
// UpdateMaxVersion 修改版本号
func (this *IPListDB) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
_, err := this.updateVersionStmt.Exec(version)
return err
}
func (this *IPListDB) Close() error {
this.isClosed = true
if this.db != nil {
for _, stmt := range []*dbs.Stmt{
this.deleteExpiredItemsStmt,
this.deleteItemStmt,
this.insertItemStmt,
this.selectItemsStmt,
this.selectMaxItemVersionStmt, // ipItems table
this.selectVersionStmt, // versions table
this.updateVersionStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
return this.db.Close()
}
return nil
} }

View File

@@ -0,0 +1,229 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
"testing"
"time"
)
type IPListKV struct {
ipTable *kvstore.Table[*pb.IPItem]
versionsTable *kvstore.Table[int64]
encoder *IPItemEncoder[*pb.IPItem]
cleanTicker *time.Ticker
isClosed bool
offsetItemKey string
}
func NewIPListKV() (*IPListKV, error) {
var db = &IPListKV{
cleanTicker: time.NewTicker(24 * time.Hour),
encoder: &IPItemEncoder[*pb.IPItem]{},
}
err := db.init()
return db, err
}
func (this *IPListKV) init() error {
store, storeErr := kvstore.DefaultStore()
if storeErr != nil {
return storeErr
}
db, dbErr := store.NewDB("ip_list")
if dbErr != nil {
return dbErr
}
{
table, err := kvstore.NewTable[*pb.IPItem]("ip_items", this.encoder)
if err != nil {
return err
}
this.ipTable = table
err = table.AddFields("expiresAt")
if err != nil {
return err
}
db.AddTable(table)
}
{
table, err := kvstore.NewTable[int64]("versions", kvstore.NewIntValueEncoder[int64]())
if err != nil {
return err
}
this.versionsTable = table
db.AddTable(table)
}
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
deleteErr := this.DeleteExpiredItems()
if deleteErr != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+deleteErr.Error())
}
}
})
return nil
}
// Name 数据库名称代号
func (this *IPListKV) Name() string {
return "kvstore"
}
// DeleteExpiredItems 删除过期的条目
func (this *IPListKV) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
for {
var found bool
var currentTime = fasttime.Now().Unix()
err := this.ipTable.
Query().
FieldAsc("expiresAt").
ForUpdate().
Limit(1000).
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
if !item.Value.IsDeleted && item.Value.ExpiredAt == 0 { // never expires
return kvstore.Skip()
}
if item.Value.ExpiredAt < currentTime-7*86400 /** keep for 7 days **/ {
err = tx.Delete(item.Key)
if err != nil {
return false, err
}
found = true
return true, nil
}
found = false
return false, nil
})
if err != nil {
return err
}
if !found {
break
}
}
return nil
}
func (this *IPListKV) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
// 先删除
var key = this.encoder.EncodeKey(item)
err := this.ipTable.Delete(key)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
err = this.ipTable.Set(key, item)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *IPListKV) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNextLoop bool, err error) {
if this.isClosed {
return
}
err = this.ipTable.
Query().
Offset(this.offsetItemKey).
Limit(int(size)).
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
this.offsetItemKey = item.Key
goNextLoop = true
if !item.Value.IsDeleted {
items = append(items, item.Value)
}
return true, nil
})
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *IPListKV) ReadMaxVersion() (int64, error) {
if this.isClosed {
return 0, errors.New("database has been closed")
}
version, err := this.versionsTable.Get("version")
if err != nil {
if kvstore.IsNotFound(err) {
return 0, nil
}
return 0, err
}
return version, nil
}
// UpdateMaxVersion 修改版本号
func (this *IPListKV) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
return this.versionsTable.SetSync("version", version)
}
func (this *IPListKV) TestInspect(t *testing.T) error {
return this.ipTable.
Query().
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
if len(item.Key) != 8 {
return false, errors.New("invalid key '" + item.Key + "'")
}
t.Log(binary.BigEndian.Uint64([]byte(item.Key)), "=>", item.Value)
return true, nil
})
}
// Flush to disk
func (this *IPListKV) Flush() error {
return this.ipTable.DB().Store().Flush()
}
func (this *IPListKV) Close() error {
this.isClosed = true
return nil
}

View File

@@ -0,0 +1,55 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"google.golang.org/protobuf/proto"
"math"
)
type IPItemEncoder[T interface{ *pb.IPItem }] struct {
}
func NewIPItemEncoder[T interface{ *pb.IPItem }]() *IPItemEncoder[T] {
return &IPItemEncoder[T]{}
}
func (this *IPItemEncoder[T]) Encode(value T) ([]byte, error) {
return proto.Marshal(any(value).(*pb.IPItem))
}
func (this *IPItemEncoder[T]) EncodeField(value T, fieldName string) ([]byte, error) {
switch fieldName {
case "expiresAt":
var expiresAt = any(value).(*pb.IPItem).ExpiredAt
if expiresAt < 0 || expiresAt > int64(math.MaxUint32) {
expiresAt = 0
}
var b = make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(expiresAt))
return b, nil
}
return nil, errors.New("field '" + fieldName + "' not found")
}
func (this *IPItemEncoder[T]) Decode(valueBytes []byte) (value T, err error) {
var item = &pb.IPItem{}
err = proto.Unmarshal(valueBytes, item)
value = item
return
}
// EncodeKey generate key for ip item
func (this *IPItemEncoder[T]) EncodeKey(item *pb.IPItem) string {
var b = make([]byte, 8)
if item.Id < 0 {
item.Id = 0
}
binary.BigEndian.PutUint64(b, uint64(item.Id))
return string(b)
}

View File

@@ -0,0 +1,221 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"testing"
"time"
)
func TestIPListKV_AddItem(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
{
err = kv.AddItem(&pb.IPItem{
Id: 1,
IpFrom: "192.168.1.101",
IpTo: "",
Version: 1,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
{
err = kv.AddItem(&pb.IPItem{
Id: 2,
IpFrom: "192.168.1.102",
IpTo: "",
Version: 2,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
{
err = kv.AddItem(&pb.IPItem{
Id: 3,
IpFrom: "192.168.1.103",
IpTo: "",
Version: 3,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
}
func TestIPListKV_AddItems_Many(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
var count = 2
var from = 1
if testutils.IsSingleTesting() {
count = 2_000_000
}
var before = time.Now()
defer func() {
t.Logf("cost: %.2f s", time.Since(before).Seconds())
}()
for i := from; i <= from+count; i++ {
err = kv.AddItem(&pb.IPItem{
Id: int64(i),
IpFrom: testutils.RandIP(),
IpTo: "",
Version: int64(i),
ExpiredAt: fasttime.NewFastTime().Unix() + 86400,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
}
func TestIPListKV_DeleteExpiredItems(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
err = kv.DeleteExpiredItems()
if err != nil {
t.Fatal(err)
}
}
func TestIPListKV_UpdateMaxVersion(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
err = kv.UpdateMaxVersion(101)
if err != nil {
t.Fatal(err)
}
maxVersion, err := kv.ReadMaxVersion()
if err != nil {
t.Fatal(err)
}
t.Log("version:", maxVersion)
}
func TestIPListKV_ReadMaxVersion(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
maxVersion, err := kv.ReadMaxVersion()
if err != nil {
t.Fatal(err)
}
t.Log("version:", maxVersion)
}
func TestIPListKV_ReadItems(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
for {
items, goNext, readErr := kv.ReadItems(0, 2)
if readErr != nil {
t.Fatal(readErr)
}
t.Log("====")
for _, item := range items {
t.Log(item.Id)
}
if !goNext {
break
}
}
}
func TestIPListKV_CountItems(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
var count int
var m = map[int64]zero.Zero{}
for {
items, goNext, readErr := kv.ReadItems(0, 1000)
if readErr != nil {
t.Fatal(readErr)
}
for _, item := range items {
count++
m[item.Id] = zero.Zero{}
}
if !goNext {
break
}
}
t.Log("count:", count, "len:", len(m))
}
func TestIPListKV_Inspect(t *testing.T) {
kv, err := iplibrary.NewIPListKV()
if err != nil {
t.Fatal(err)
}
err = kv.TestInspect(t)
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,312 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
"github.com/iwind/TeaGo/Tea"
"os"
"path/filepath"
"time"
)
type IPListSQLite struct {
db *dbs.DB
itemTableName string
versionTableName string
deleteExpiredItemsStmt *dbs.Stmt
deleteItemStmt *dbs.Stmt
insertItemStmt *dbs.Stmt
selectItemsStmt *dbs.Stmt
selectMaxItemVersionStmt *dbs.Stmt
selectVersionStmt *dbs.Stmt
updateVersionStmt *dbs.Stmt
cleanTicker *time.Ticker
dir string
isClosed bool
}
func NewIPListSqlite() (*IPListSQLite, error) {
var db = &IPListSQLite{
itemTableName: "ipItems",
versionTableName: "versions",
dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour),
}
err := db.init()
return db, err
}
func (this *IPListSQLite) init() error {
// 检查目录是否存在
_, err := os.Stat(this.dir)
if err != nil {
err = os.MkdirAll(this.dir, 0777)
if err != nil {
return err
}
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
}
var path = this.dir + "/ip_list.db"
db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
//_, err = db.Exec("VACUUM")
//if err != nil {
// return err
//}
this.db = db
// 恢复数据库
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 {
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
}
}
// 初始化数据库
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"listId" integer DEFAULT 0,
"listType" varchar(32),
"isGlobal" integer(1) DEFAULT 0,
"type" varchar(16),
"itemId" integer DEFAULT 0,
"ipFrom" varchar(64) DEFAULT 0,
"ipTo" varchar(64) DEFAULT 0,
"expiredAt" integer DEFAULT 0,
"eventLevel" varchar(32),
"isDeleted" integer(1) DEFAULT 0,
"version" integer DEFAULT 0,
"nodeId" integer DEFAULT 0,
"serverId" integer DEFAULT 0
);
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
ON "` + this.itemTableName + `" (
"itemId" ASC
);
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
ON "` + this.itemTableName + `" (
"expiredAt" ASC
);
`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"version" integer DEFAULT 0
);
`)
if err != nil {
return err
}
// 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil {
return err
}
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
if err != nil {
return err
}
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
if err != nil {
return err
}
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
if err != nil {
return err
}
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
if err != nil {
return err
}
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
if err != nil {
return err
}
this.db = db
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
err := this.DeleteExpiredItems()
if err != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
}
}
})
return nil
}
// Name 数据库名称代号
func (this *IPListSQLite) Name() string {
return "sqlite"
}
// DeleteExpiredItems 删除过期的条目
func (this *IPListSQLite) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
return err
}
func (this *IPListSQLite) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
_, err := this.deleteItemStmt.Exec(item.Id)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *IPListSQLite) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error) {
if this.isClosed {
return
}
rows, err := this.selectItemsStmt.Query(offset, size)
if err != nil {
return nil, false, err
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
var pbItem = &pb.IPItem{}
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
if err != nil {
return nil, false, err
}
items = append(items, pbItem)
}
goNext = int64(len(items)) == size
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *IPListSQLite) ReadMaxVersion() (int64, error) {
if this.isClosed {
return 0, nil
}
// from version table
{
var row = this.selectVersionStmt.QueryRow()
if row == nil {
return 0, nil
}
var version int64
err := row.Scan(&version)
if err == nil {
return version, nil
}
}
// from items table
{
var row = this.selectMaxItemVersionStmt.QueryRow()
if row == nil {
return 0, nil
}
var version int64
err := row.Scan(&version)
if err != nil {
return 0, nil
}
return version, nil
}
}
// UpdateMaxVersion 修改版本号
func (this *IPListSQLite) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
_, err := this.updateVersionStmt.Exec(version)
return err
}
func (this *IPListSQLite) Close() error {
this.isClosed = true
if this.db != nil {
for _, stmt := range []*dbs.Stmt{
this.deleteExpiredItemsStmt,
this.deleteItemStmt,
this.insertItemStmt,
this.selectItemsStmt,
this.selectMaxItemVersionStmt, // ipItems table
this.selectVersionStmt, // versions table
this.updateVersionStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
return this.db.Close()
}
return nil
}

View File

@@ -12,7 +12,7 @@ import (
) )
func TestIPListDB_AddItem(t *testing.T) { func TestIPListDB_AddItem(t *testing.T) {
db, err := iplibrary.NewIPListDB() db, err := iplibrary.NewIPListSqlite()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -59,7 +59,7 @@ func TestIPListDB_AddItem(t *testing.T) {
} }
func TestIPListDB_ReadItems(t *testing.T) { func TestIPListDB_ReadItems(t *testing.T) {
db, err := iplibrary.NewIPListDB() db, err := iplibrary.NewIPListSqlite()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -71,15 +71,16 @@ func TestIPListDB_ReadItems(t *testing.T) {
_ = db.Close() _ = db.Close()
}() }()
items, err := db.ReadItems(0, 2) items, goNext, err := db.ReadItems(0, 2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log("goNext:", goNext)
logs.PrintAsJSON(items, t) logs.PrintAsJSON(items, t)
} }
func TestIPListDB_ReadMaxVersion(t *testing.T) { func TestIPListDB_ReadMaxVersion(t *testing.T) {
db, err := iplibrary.NewIPListDB() db, err := iplibrary.NewIPListSqlite()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -90,7 +91,7 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
} }
func TestIPListDB_UpdateMaxVersion(t *testing.T) { func TestIPListDB_UpdateMaxVersion(t *testing.T) {
db, err := iplibrary.NewIPListDB() db, err := iplibrary.NewIPListSqlite()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -8,10 +8,13 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/zero" "github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"os"
"sync" "sync"
"time" "time"
) )
@@ -45,7 +48,7 @@ func init() {
type IPListManager struct { type IPListManager struct {
ticker *time.Ticker ticker *time.Ticker
db *IPListDB db IPListDB
lastVersion int64 lastVersion int64
fetchPageSize int64 fetchPageSize int64
@@ -83,7 +86,7 @@ func (this *IPListManager) Start() {
case <-this.ticker.C: case <-this.ticker.C:
case <-IPListUpdateNotify: case <-IPListUpdateNotify:
} }
err := this.loop() err = this.loop()
if err != nil { if err != nil {
countErrors++ countErrors++
@@ -110,7 +113,18 @@ func (this *IPListManager) Stop() {
func (this *IPListManager) init() { func (this *IPListManager) init() {
// 从数据库中当中读取数据 // 从数据库中当中读取数据
db, err := NewIPListDB() // 检查sqlite文件是否存在以便决定使用sqlite还是kv
var sqlitePath = Tea.Root + "/data/ip_list.db"
_, sqliteErr := os.Stat(sqlitePath)
var db IPListDB
var err error
if sqliteErr == nil {
db, err = NewIPListSqlite()
} else {
db, err = NewIPListKV()
}
if err != nil { if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error()) remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error())
} else { } else {
@@ -120,24 +134,30 @@ func (this *IPListManager) init() {
_ = db.DeleteExpiredItems() _ = db.DeleteExpiredItems()
// 本地数据库中最大版本号 // 本地数据库中最大版本号
this.lastVersion = db.ReadMaxVersion() this.lastVersion, err = db.ReadMaxVersion()
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "find max version failed: "+err.Error())
this.lastVersion = 0
}
remotelogs.Println("IP_LIST_MANAGER", "starting from '"+db.Name()+"' version '"+types.String(this.lastVersion)+"' ...")
// 从本地数据库中加载 // 从本地数据库中加载
var offset int64 = 0 var offset int64 = 0
var size int64 = 2_000 var size int64 = 2_000
var tr = trackers.Begin("IP_LIST_MANAGER:load")
defer tr.End()
for { for {
items, err := db.ReadItems(offset, size) items, goNext, readErr := db.ReadItems(offset, size)
var l = len(items) var l = len(items)
if err != nil { if readErr != nil {
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+err.Error()) remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+readErr.Error())
} else { } else {
if l == 0 { if !goNext {
break break
} }
this.processItems(items, false) this.processItems(items, false)
if int64(l) < size {
break
}
} }
offset += int64(l) offset += int64(l)
} }
@@ -310,9 +330,14 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
// 调试IP信息 // 调试IP信息
func (this *IPListManager) debugItem(item *pb.IPItem) { func (this *IPListManager) debugItem(item *pb.IPItem) {
var ipRange = item.IpFrom
if len(item.IpTo) > 0 {
ipRange += " - " + item.IpTo
}
if item.IsDeleted { if item.IsDeleted {
remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+item.IpFrom+"'") remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+ipRange+"'")
} else { } else {
remotelogs.Debug("IP_ITEM_DEBUG", "add '"+item.IpFrom+"'") remotelogs.Debug("IP_ITEM_DEBUG", "add '"+ipRange+"'")
} }
} }

View File

@@ -0,0 +1,28 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package byteutils
// Copy bytes
func Copy(b []byte) []byte {
var l = len(b)
if l == 0 {
return []byte{}
}
var d = make([]byte, l)
copy(d, b)
return d
}
// Append bytes
func Append(b []byte, b2 ...byte) []byte {
return append(Copy(b), b2...)
}
// Contact bytes
func Contact(b []byte, b2 ...[]byte) []byte {
b = Copy(b)
for _, b3 := range b2 {
b = append(b, b3...)
}
return b
}

View File

@@ -0,0 +1,56 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package byteutils_test
import (
"bytes"
byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestCopy(t *testing.T) {
var a = assert.NewAssertion(t)
var prefix []byte
prefix = append(prefix, 1, 2, 3)
t.Log(prefix, byteutils.Copy(prefix))
a.IsTrue(bytes.Equal(byteutils.Copy(prefix), []byte{1, 2, 3}))
}
func TestAppend(t *testing.T) {
var as = assert.NewAssertion(t)
var prefix []byte
prefix = append(prefix, 1, 2, 3)
// [1 2 3 4 5 6] [1 2 3 7]
var a = byteutils.Append(prefix, 4, 5, 6)
var b = byteutils.Append(prefix, 7)
t.Log(a, b)
as.IsTrue(bytes.Equal(a, []byte{1, 2, 3, 4, 5, 6}))
as.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 7}))
}
func TestConcat(t *testing.T) {
var a = assert.NewAssertion(t)
var prefix []byte
prefix = append(prefix, 1, 2, 3)
var b = byteutils.Contact(prefix, []byte{4, 5, 6}, []byte{7})
t.Log(b)
a.IsTrue(bytes.Equal(b, []byte{1, 2, 3, 4, 5, 6, 7}))
}
func TestAppend_Raw(t *testing.T) {
var prefix []byte
prefix = append(prefix, 1, 2, 3)
// [1 2 3 7 5 6] [1 2 3 7]
var a = append(prefix, 4, 5, 6)
var b = append(prefix, 7)
t.Log(a, b)
}

View File

@@ -9,10 +9,16 @@ import (
var ErrTableNotFound = errors.New("table not found") var ErrTableNotFound = errors.New("table not found")
var ErrKeyTooLong = errors.New("too long key") var ErrKeyTooLong = errors.New("too long key")
var ErrSkip= errors.New("skip") // skip count in iterator
func IsKeyNotFound(err error) bool { func IsNotFound(err error) bool {
if err == nil { return err != nil && errors.Is(err, pebble.ErrNotFound)
return false }
}
return errors.Is(err, pebble.ErrNotFound) func IsSkipError(err error) bool {
return err != nil && errors.Is(err, ErrSkip)
}
func Skip() (bool, error) {
return true, ErrSkip
} }

View File

@@ -7,3 +7,7 @@ import "github.com/cockroachdb/pebble"
var DefaultWriteOptions = &pebble.WriteOptions{ var DefaultWriteOptions = &pebble.WriteOptions{
Sync: false, Sync: false,
} }
var DefaultWriteSyncOptions = &pebble.WriteOptions{
Sync: true,
}

View File

@@ -6,6 +6,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
) )
type DataType = int type DataType = int
@@ -222,11 +223,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
var prefix []byte var prefix []byte
switch this.dataType { switch this.dataType {
case DataTypeKey: case DataTypeKey:
prefix = append(this.table.Namespace(), KeyPrefix...) prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
case DataTypeField: case DataTypeField:
prefix = append(this.table.Namespace(), FieldPrefix...) prefix = byteutils.Append(this.table.Namespace(), []byte(FieldPrefix)...)
default: default:
prefix = append(this.table.Namespace(), KeyPrefix...) prefix = byteutils.Append(this.table.Namespace(), []byte(KeyPrefix)...)
} }
var prefixLen = len(prefix) var prefixLen = len(prefix)
@@ -238,21 +239,21 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
var offsetKey []byte var offsetKey []byte
if this.reverse { if this.reverse {
if len(this.offsetKey) > 0 { if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...) offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
} else { } else {
offsetKey = append(prefix, 0xFF) offsetKey = byteutils.Append(prefix, 0xFF)
} }
opt.LowerBound = prefix opt.LowerBound = prefix
opt.UpperBound = offsetKey opt.UpperBound = offsetKey
} else { } else {
if len(this.offsetKey) > 0 { if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...) offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
} else { } else {
offsetKey = prefix offsetKey = prefix
} }
opt.LowerBound = offsetKey opt.LowerBound = offsetKey
opt.UpperBound = append(offsetKey, 0xFF) opt.UpperBound = byteutils.Append(prefix, 0xFF)
} }
var hasOffsetKey = len(this.offsetKey) > 0 var hasOffsetKey = len(this.offsetKey) > 0
@@ -267,7 +268,7 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
var count int var count int
var itemFn = func() (goNext bool, err error) { var itemFn = func() (goNextItem bool, err error) {
var keyBytes = it.Key() var keyBytes = it.Key()
// skip first offset key // skip first offset key
@@ -297,7 +298,11 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
Value: value, Value: value,
}) })
if callbackErr != nil { if callbackErr != nil {
return false, callbackErr if IsSkipError(callbackErr) {
return true, nil
} else {
return false, callbackErr
}
} }
if !goNext { if !goNext {
return false, nil return false, nil
@@ -361,9 +366,9 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
if len(this.fieldOffsetKey) > 0 { if len(this.fieldOffsetKey) > 0 {
offsetKey = this.fieldOffsetKey offsetKey = this.fieldOffsetKey
} else if len(this.offsetKey) > 0 { } else if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...) offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
} else { } else {
offsetKey = append(prefix, 0xFF) offsetKey = byteutils.Append(prefix, 0xFF)
} }
opt.LowerBound = prefix opt.LowerBound = prefix
opt.UpperBound = offsetKey opt.UpperBound = offsetKey
@@ -371,14 +376,14 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
if len(this.fieldOffsetKey) > 0 { if len(this.fieldOffsetKey) > 0 {
offsetKey = this.fieldOffsetKey offsetKey = this.fieldOffsetKey
} else if len(this.offsetKey) > 0 { } else if len(this.offsetKey) > 0 {
offsetKey = append(prefix, this.offsetKey...) offsetKey = byteutils.Append(prefix, []byte(this.offsetKey)...)
offsetKey = append(offsetKey, 0) offsetKey = append(offsetKey, 0)
} else { } else {
offsetKey = prefix offsetKey = prefix
} }
opt.LowerBound = offsetKey opt.LowerBound = offsetKey
opt.UpperBound = append(prefix, 0xFF) opt.UpperBound = byteutils.Append(prefix, 0xFF)
} }
it, itErr := this.tx.NewIterator(opt) it, itErr := this.tx.NewIterator(opt)
@@ -391,7 +396,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
var count int var count int
var itemFn = func() (goNext bool, err error) { var itemFn = func() (goNextItem bool, err error) {
var fieldKeyBytes = it.Key() var fieldKeyBytes = it.Key()
fieldValueBytes, keyBytes, decodeKeyErr := this.table.DecodeFieldKey(this.fieldName, fieldKeyBytes) fieldValueBytes, keyBytes, decodeKeyErr := this.table.DecodeFieldKey(this.fieldName, fieldKeyBytes)
@@ -423,7 +428,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
if !this.keysOnly { if !this.keysOnly {
value, getErr := this.table.getWithKeyBytes(this.tx, this.table.FullKeyBytes(keyBytes)) value, getErr := this.table.getWithKeyBytes(this.tx, this.table.FullKeyBytes(keyBytes))
if getErr != nil { if getErr != nil {
if IsKeyNotFound(getErr) { if IsNotFound(getErr) {
return true, nil return true, nil
} }
return false, getErr return false, getErr
@@ -432,11 +437,15 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
resultItem.Value = value resultItem.Value = value
} }
goNext, err = fn(this.tx, resultItem) goNextItem, err = fn(this.tx, resultItem)
if err != nil { if err != nil {
return if IsSkipError(err) {
return true, nil
} else {
return false, err
}
} }
if !goNext { if !goNextItem {
return false, nil return false, nil
} }

View File

@@ -138,6 +138,26 @@ func TestQuery_FindAll_Offset(t *testing.T) {
} }
} }
func TestQuery_FindAll_Skip(t *testing.T) {
var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})
{
err := table.Query().
Offset("a3").
Limit(10).
FindAll(func(tx *kvstore.Tx[*testCachedItem], item kvstore.Item[*testCachedItem]) (goNext bool, err error) {
if item.Key == "a30" || item.Key == "a3000005" {
return kvstore.Skip()
}
t.Log("key:", item.Key, "value:", item.Value)
return true, nil
})
if err != nil {
t.Fatal(err)
}
}
}
func TestQuery_FindAll_Count(t *testing.T) { func TestQuery_FindAll_Count(t *testing.T) {
var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{}) var table = testOpenStoreTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
memutils "github.com/TeaOSLab/EdgeNode/internal/utils/mem" memutils "github.com/TeaOSLab/EdgeNode/internal/utils/mem"
"github.com/cockroachdb/pebble" "github.com/cockroachdb/pebble"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
@@ -85,6 +86,31 @@ func OpenStoreDir(dir string, storeName string) (*Store, error) {
return store, nil return store, nil
} }
var storeOnce = &sync.Once{}
var defaultSore *Store
func DefaultStore() (*Store, error) {
if defaultSore != nil {
return defaultSore, nil
}
storeOnce.Do(func() {
store, err := NewStore("default")
if err != nil {
remotelogs.Error("KV", "create default store failed: "+err.Error())
return
}
err = store.Open()
if err != nil {
remotelogs.Error("KV", "open default store failed: "+err.Error())
return
}
defaultSore = store
})
return defaultSore, nil
}
func (this *Store) Open() error { func (this *Store) Open() error {
var opt = &pebble.Options{ var opt = &pebble.Options{
Logger: NewLogger(), Logger: NewLogger(),
@@ -144,6 +170,10 @@ func (this *Store) RawDB() *pebble.DB {
return this.rawDB return this.rawDB
} }
func (this *Store) Flush() error {
return this.rawDB.Flush()
}
func (this *Store) Close() error { func (this *Store) Close() error {
if this.isClosed { if this.isClosed {
return nil return nil

View File

@@ -6,7 +6,9 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore" "github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
"github.com/cockroachdb/pebble" "github.com/cockroachdb/pebble"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/assert"
_ "github.com/iwind/TeaGo/bootstrap" _ "github.com/iwind/TeaGo/bootstrap"
"sync"
"testing" "testing"
"time" "time"
) )
@@ -19,6 +21,44 @@ func TestMain(m *testing.M) {
} }
} }
func TestStore_Default(t *testing.T) {
var a = assert.NewAssertion(t)
store, err := kvstore.DefaultStore()
if err != nil {
t.Fatal(err)
}
a.IsTrue(store != nil)
}
func TestStore_Default_Concurrent(t *testing.T) {
var lastStore *kvstore.Store
const threads = 32
var wg = &sync.WaitGroup{}
wg.Add(threads)
for i := 0; i < threads; i++ {
go func() {
defer wg.Done()
store, err := kvstore.DefaultStore()
if err != nil {
t.Log("ERROR", err)
t.Fail()
}
if lastStore != nil && lastStore != store {
t.Log("ERROR", "should be single instance")
t.Fail()
}
lastStore = store
}()
}
wg.Wait()
}
func TestStore_Open(t *testing.T) { func TestStore_Open(t *testing.T) {
store, err := kvstore.OpenStore("test") store, err := kvstore.OpenStore("test")
if err != nil { if err != nil {
@@ -29,6 +69,7 @@ func TestStore_Open(t *testing.T) {
}() }()
t.Log("opened") t.Log("opened")
_ = store
} }
func TestStore_RawDB(t *testing.T) { func TestStore_RawDB(t *testing.T) {

View File

@@ -64,6 +64,10 @@ func (this *Table[T]) DB() *DB {
return this.db return this.db
} }
func (this *Table[T]) Encoder() ValueEncoder[T] {
return this.encoder
}
func (this *Table[T]) Set(key string, value T) error { func (this *Table[T]) Set(key string, value T) error {
if len(key) > KeyMaxLength { if len(key) > KeyMaxLength {
return ErrKeyTooLong return ErrKeyTooLong
@@ -75,7 +79,22 @@ func (this *Table[T]) Set(key string, value T) error {
} }
return this.WriteTx(func(tx *Tx[T]) error { return this.WriteTx(func(tx *Tx[T]) error {
return this.set(tx, key, valueBytes, value, false) return this.set(tx, key, valueBytes, value, false, false)
})
}
func (this *Table[T]) SetSync(key string, value T) error {
if len(key) > KeyMaxLength {
return ErrKeyTooLong
}
valueBytes, err := this.encoder.Encode(value)
if err != nil {
return err
}
return this.WriteTxSync(func(tx *Tx[T]) error {
return this.set(tx, key, valueBytes, value, false, true)
}) })
} }
@@ -90,7 +109,7 @@ func (this *Table[T]) Insert(key string, value T) error {
} }
return this.WriteTx(func(tx *Tx[T]) error { return this.WriteTx(func(tx *Tx[T]) error {
return this.set(tx, key, valueBytes, value, true) return this.set(tx, key, valueBytes, value, true, false)
}) })
} }
@@ -111,7 +130,7 @@ func (this *Table[T]) ComposeFieldKey(keyBytes []byte, fieldName string, fieldVa
func (this *Table[T]) Exist(key string) (found bool, err error) { func (this *Table[T]) Exist(key string) (found bool, err error) {
_, closer, err := this.db.store.rawDB.Get(this.FullKey(key)) _, closer, err := this.db.store.rawDB.Get(this.FullKey(key))
if err != nil { if err != nil {
if IsKeyNotFound(err) { if IsNotFound(err) {
return false, nil return false, nil
} }
return false, err return false, err
@@ -173,6 +192,20 @@ func (this *Table[T]) WriteTx(fn func(tx *Tx[T]) error) error {
return tx.Commit() return tx.Commit()
} }
func (this *Table[T]) WriteTxSync(fn func(tx *Tx[T]) error) error {
var tx = NewTx[T](this, false)
defer func() {
_ = tx.Close()
}()
err := fn(tx)
if err != nil {
return err
}
return tx.CommitSync()
}
func (this *Table[T]) Truncate() error { func (this *Table[T]) Truncate() error {
this.mu.Lock() this.mu.Lock()
defer this.mu.Unlock() defer this.mu.Unlock()
@@ -256,7 +289,7 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
if len(this.fieldNames) > 0 { if len(this.fieldNames) > 0 {
valueBytes, closer, getErr := batch.Get(keyBytes) valueBytes, closer, getErr := batch.Get(keyBytes)
if getErr != nil { if getErr != nil {
if IsKeyNotFound(getErr) { if IsNotFound(getErr) {
return nil return nil
} }
return getErr return getErr
@@ -298,8 +331,12 @@ func (this *Table[T]) deleteKeys(tx *Tx[T], key ...string) error {
return nil return nil
} }
func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool) error { func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, insertOnly bool, syncMode bool) error {
var keyBytes = this.FullKey(key) var keyBytes = this.FullKey(key)
var writeOptions = DefaultWriteOptions
if syncMode {
writeOptions = DefaultWriteSyncOptions
}
var batch = tx.batch var batch = tx.batch
@@ -312,7 +349,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
if countFields > 0 { if countFields > 0 {
oldValueBytes, closer, getErr := batch.Get(keyBytes) oldValueBytes, closer, getErr := batch.Get(keyBytes)
if getErr != nil { if getErr != nil {
if !IsKeyNotFound(getErr) { if !IsNotFound(getErr) {
return getErr return getErr
} }
} else { } else {
@@ -330,7 +367,7 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
} }
} }
setErr := batch.Set(keyBytes, valueBytes, DefaultWriteOptions) setErr := batch.Set(keyBytes, valueBytes, writeOptions)
if setErr != nil { if setErr != nil {
return setErr return setErr
} }
@@ -362,14 +399,14 @@ func (this *Table[T]) set(tx *Tx[T], key string, valueBytes []byte, value T, ins
// skip the field // skip the field
continue continue
} }
deleteFieldErr := batch.Delete(oldFieldKeyBytes, DefaultWriteOptions) deleteFieldErr := batch.Delete(oldFieldKeyBytes, writeOptions)
if deleteFieldErr != nil { if deleteFieldErr != nil {
return deleteFieldErr return deleteFieldErr
} }
} }
// set new field key // set new field key
setFieldErr := batch.Set(newFieldKeyBytes, nil, DefaultWriteOptions) setFieldErr := batch.Set(newFieldKeyBytes, nil, writeOptions)
if setFieldErr != nil { if setFieldErr != nil {
return setFieldErr return setFieldErr
} }

View File

@@ -21,7 +21,7 @@ func (this *CounterTable[T]) Increase(key string, delta T) (newValue T, err erro
err = this.Table.WriteTx(func(tx *Tx[T]) error { err = this.Table.WriteTx(func(tx *Tx[T]) error {
value, getErr := tx.Get(key) value, getErr := tx.Get(key)
if getErr != nil { if getErr != nil {
if !IsKeyNotFound(getErr) { if !IsNotFound(getErr) {
return getErr return getErr
} }
} }

View File

@@ -45,7 +45,7 @@ func TestTable_Set(t *testing.T) {
value, err := table.Get("a") value, err := table.Get("a")
if err != nil { if err != nil {
if kvstore.IsKeyNotFound(err) { if kvstore.IsNotFound(err) {
t.Log("not found key") t.Log("not found key")
return return
} }
@@ -81,7 +81,7 @@ func TestTable_Get(t *testing.T) {
for _, key := range []string{"a", "b", "c"} { for _, key := range []string{"a", "b", "c"} {
value, getErr := table.Get(key) value, getErr := table.Get(key)
if getErr != nil { if getErr != nil {
if kvstore.IsKeyNotFound(getErr) { if kvstore.IsNotFound(getErr) {
t.Log("not found key", key) t.Log("not found key", key)
continue continue
} }
@@ -146,7 +146,7 @@ func TestTable_Delete(t *testing.T) {
value, err := table.Get("a123") value, err := table.Get("a123")
if err != nil { if err != nil {
if !kvstore.IsKeyNotFound(err) { if !kvstore.IsNotFound(err) {
t.Fatal(err) t.Fatal(err)
} }
} else { } else {
@@ -173,7 +173,7 @@ func TestTable_Delete(t *testing.T) {
{ {
_, err = table.Get("a123") _, err = table.Get("a123")
a.IsTrue(kvstore.IsKeyNotFound(err)) a.IsTrue(kvstore.IsNotFound(err))
} }
} }
@@ -357,7 +357,7 @@ func BenchmarkTable_Get(b *testing.B) {
for pb.Next() { for pb.Next() {
_, putErr := table.Get(types.String(rand.Int())) _, putErr := table.Get(types.String(rand.Int()))
if putErr != nil { if putErr != nil {
if kvstore.IsKeyNotFound(putErr) { if kvstore.IsNotFound(putErr) {
continue continue
} }
b.Fatal(putErr) b.Fatal(putErr)

View File

@@ -37,7 +37,24 @@ func (this *Tx[T]) Set(key string, value T) error {
return err return err
} }
return this.table.set(this, key, valueBytes, value, false) return this.table.set(this, key, valueBytes, value, false, false)
}
func (this *Tx[T]) SetSync(key string, value T) error {
if this.readOnly {
return errors.New("can not set value in readonly transaction")
}
if len(key) > KeyMaxLength {
return ErrKeyTooLong
}
valueBytes, err := this.table.encoder.Encode(value)
if err != nil {
return err
}
return this.table.set(this, key, valueBytes, value, false, true)
} }
func (this *Tx[T]) Insert(key string, value T) error { func (this *Tx[T]) Insert(key string, value T) error {
@@ -54,7 +71,7 @@ func (this *Tx[T]) Insert(key string, value T) error {
return err return err
} }
return this.table.set(this, key, valueBytes, value, true) return this.table.set(this, key, valueBytes, value, true, false)
} }
func (this *Tx[T]) Get(key string) (value T, err error) { func (this *Tx[T]) Get(key string) (value T, err error) {
@@ -78,6 +95,20 @@ func (this *Tx[T]) Close() error {
} }
func (this *Tx[T]) Commit() (err error) { func (this *Tx[T]) Commit() (err error) {
return this.commit(DefaultWriteOptions)
}
func (this *Tx[T]) CommitSync() (err error) {
return this.commit(DefaultWriteSyncOptions)
}
func (this *Tx[T]) Query() *Query[T] {
var query = NewQuery[T]()
query.SetTx(this)
return query
}
func (this *Tx[T]) commit(opt *pebble.WriteOptions) (err error) {
defer func() { defer func() {
var panicErr = recover() var panicErr = recover()
if panicErr != nil { if panicErr != nil {
@@ -88,11 +119,5 @@ func (this *Tx[T]) Commit() (err error) {
} }
}() }()
return this.batch.Commit(DefaultWriteOptions) return this.batch.Commit(opt)
}
func (this *Tx[T]) Query() *Query[T] {
var query = NewQuery[T]()
query.SetTx(this)
return query
} }

View File

@@ -2,7 +2,11 @@
package testutils package testutils
import "os" import (
"fmt"
"math/rand"
"os"
)
// IsSingleTesting 判断当前测试环境是否为单个函数测试 // IsSingleTesting 判断当前测试环境是否为单个函数测试
func IsSingleTesting() bool { func IsSingleTesting() bool {
@@ -13,3 +17,8 @@ func IsSingleTesting() bool {
} }
return false return false
} }
// RandIP 生成一个随机IP用于测试
func RandIP() string {
return fmt.Sprintf("%d.%d.%d.%d", rand.Int()%255, rand.Int()%255, rand.Int()%255, rand.Int()%255)
}