优化IP名单同步速度

This commit is contained in:
刘祥超
2023-04-19 12:01:02 +08:00
parent 4dc25fb71e
commit 37ddff86f1
20 changed files with 334 additions and 56 deletions

View File

@@ -88,11 +88,15 @@ type blockIPItem struct {
} }
func NewNFTablesFirewall() (*NFTablesFirewall, error) { func NewNFTablesFirewall() (*NFTablesFirewall, error) {
conn, err := nftables.NewConn()
if err != nil {
return nil, err
}
var firewall = &NFTablesFirewall{ var firewall = &NFTablesFirewall{
conn: nftables.NewConn(), conn: conn,
dropIPQueue: make(chan *blockIPItem, 4096), dropIPQueue: make(chan *blockIPItem, 4096),
} }
err := firewall.init() err = firewall.init()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,6 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux //go:build linux
// +build linux
package nftables package nftables

View File

@@ -1,4 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables package nftables

View File

@@ -1,6 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux //go:build linux
// +build linux
package nftables_test package nftables_test
@@ -11,7 +10,10 @@ import (
) )
func getIPv4Chain(t *testing.T) *nftables.Chain { func getIPv4Chain(t *testing.T) *nftables.Chain {
var conn = nftables.NewConn() conn, err := nftables.NewConn()
if err != nil {
t.Fatal(err)
}
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4) table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil { if err != nil {
if err == nftables.ErrTableNotFound { if err == nftables.ErrTableNotFound {

View File

@@ -15,10 +15,14 @@ type Conn struct {
rawConn *nft.Conn rawConn *nft.Conn
} }
func NewConn() *Conn { func NewConn() (*Conn, error) {
return &Conn{ conn, err := nft.New()
rawConn: &nft.Conn{}, if err != nil {
return nil, err
} }
return &Conn{
rawConn: conn,
}, nil
} }
func (this *Conn) Raw() *nft.Conn { func (this *Conn) Raw() *nft.Conn {

View File

@@ -4,7 +4,10 @@
package nftables package nftables
import "errors" import (
"errors"
"strings"
)
var ErrTableNotFound = errors.New("table not found") var ErrTableNotFound = errors.New("table not found")
var ErrChainNotFound = errors.New("chain not found") var ErrChainNotFound = errors.New("chain not found")
@@ -15,5 +18,5 @@ func IsNotFound(err error) bool {
if err == nil { if err == nil {
return false return false
} }
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound || strings.Contains(err.Error(), "no such file or directory")
} }

View File

@@ -0,0 +1,62 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nftables
import (
"sync"
"time"
)
type Expiration struct {
m map[string]time.Time // key => expires time
lastGCAt int64
locker sync.RWMutex
}
func NewExpiration() *Expiration {
return &Expiration{
m: map[string]time.Time{},
}
}
func (this *Expiration) AddUnsafe(key []byte, expires time.Time) {
this.m[string(key)] = expires
}
func (this *Expiration) Add(key []byte, expires time.Time) {
this.locker.Lock()
this.m[string(key)] = expires
this.gc()
this.locker.Unlock()
}
func (this *Expiration) Remove(key []byte) {
this.locker.Lock()
delete(this.m, string(key))
this.locker.Unlock()
}
func (this *Expiration) Contains(key []byte) bool {
this.locker.RLock()
_, ok := this.m[string(key)]
this.locker.RUnlock()
return ok
}
func (this *Expiration) gc() {
// we won't gc too frequently
var currentTime = time.Now().Unix()
if this.lastGCAt >= currentTime {
return
}
this.lastGCAt = currentTime
var now = time.Now().Add(-10 * time.Second) // gc elements expired before 10 seconds ago
for key, expires := range this.m {
if expires.Year() >= 2000 && now.After(expires) {
delete(this.m, key)
}
}
}

View File

@@ -0,0 +1,51 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nftables_test
import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"net"
"testing"
"time"
)
func TestExpiration_Add(t *testing.T) {
var expiration = nftables.NewExpiration()
{
expiration.Add([]byte{'a', 'b', 'c'}, time.Now())
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
}
{
expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(-1*time.Second))
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
}
{
expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(-10*time.Second))
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
}
{
expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(1*time.Second))
expiration.Remove([]byte{'a', 'b', 'c'})
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
}
{
expiration.Add(net.ParseIP("10.254.0.75").To4(), time.Now())
t.Log(expiration.Contains(net.ParseIP("10.254.0.75").To4()))
}
}
func BenchmarkNewExpiration(b *testing.B) {
var expiration = nftables.NewExpiration()
for i := 0; i < 10_000; i++ {
expiration.Add([]byte(types.String(types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)))), time.Now().Add(3600*time.Second))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
expiration.Add([]byte(types.String(types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)))), time.Now().Add(3600*time.Second))
}
})
}

View File

@@ -1,4 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables package nftables

View File

@@ -1,4 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build linux
package nftables package nftables

View File

@@ -1,4 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables package nftables

View File

@@ -1,6 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux //go:build linux
// +build linux
package nftables package nftables
@@ -35,17 +34,25 @@ type Set struct {
conn *Conn conn *Conn
rawSet *nft.Set rawSet *nft.Set
batch *SetBatch batch *SetBatch
expiration *Expiration
} }
func NewSet(conn *Conn, rawSet *nft.Set) *Set { func NewSet(conn *Conn, rawSet *nft.Set) *Set {
return &Set{ var set = &Set{
conn: conn, conn: conn,
rawSet: rawSet, rawSet: rawSet,
expiration: nil,
batch: &SetBatch{ batch: &SetBatch{
conn: conn, conn: conn,
rawSet: rawSet, rawSet: rawSet,
}, },
} }
// retrieve set elements to improve "delete" speed
set.initElements()
return set
} }
func (this *Set) Raw() *nft.Set { func (this *Set) Raw() *nft.Set {
@@ -57,11 +64,21 @@ func (this *Set) Name() string {
} }
func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error { func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error {
// check if already exists
if this.expiration != nil && !overwrite && this.expiration.Contains(key) {
return nil
}
var expiresTime = time.Time{}
var rawElement = nft.SetElement{ var rawElement = nft.SetElement{
Key: key, Key: key,
} }
if options != nil { if options != nil {
rawElement.Timeout = options.Timeout rawElement.Timeout = options.Timeout
if options.Timeout > 0 {
expiresTime = time.UnixMilli(time.Now().UnixMilli() + options.Timeout.Milliseconds())
}
} }
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{ err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement, rawElement,
@@ -71,9 +88,19 @@ func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool)
} }
err = this.conn.Commit() err = this.conn.Commit()
if err != nil { if err == nil {
if this.expiration != nil {
this.expiration.Add(key, expiresTime)
}
} else {
var isFileExistsErr = strings.Contains(err.Error(), "file exists")
if !overwrite && isFileExistsErr {
// ignore file exists error
return nil
}
// retry if exists // retry if exists
if overwrite && strings.Contains(err.Error(), "file exists") { if overwrite && isFileExistsErr {
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{ {
Key: key, Key: key,
@@ -85,6 +112,11 @@ func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool)
}) })
if err == nil { if err == nil {
err = this.conn.Commit() err = this.conn.Commit()
if err == nil {
if this.expiration != nil {
this.expiration.Add(key, expiresTime)
}
}
} }
} }
} }
@@ -107,6 +139,11 @@ func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool
} }
func (this *Set) DeleteElement(key []byte) error { func (this *Set) DeleteElement(key []byte) error {
// if set element does not exist, we return immediately
if this.expiration != nil && !this.expiration.Contains(key) {
return nil
}
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{ {
Key: key, Key: key,
@@ -116,9 +153,17 @@ func (this *Set) DeleteElement(key []byte) error {
return err return err
} }
err = this.conn.Commit() err = this.conn.Commit()
if err != nil { if err == nil {
if this.expiration != nil {
this.expiration.Remove(key)
}
} else {
if strings.Contains(err.Error(), "no such file or directory") { if strings.Contains(err.Error(), "no such file or directory") {
err = nil err = nil
if this.expiration != nil {
this.expiration.Remove(key)
}
} }
} }
return err return err

View File

@@ -1,4 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables package nftables

View File

@@ -0,0 +1,8 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build linux && !plus
package nftables
func (this *Set) initElements() {
// NOT IMPLEMENTED
}

View File

@@ -34,7 +34,7 @@ func getIPv4Set(t *testing.T) *nftables.Set {
func TestSet_AddElement(t *testing.T) { func TestSet_AddElement(t *testing.T) {
var set = getIPv4Set(t) var set = getIPv4Set(t)
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second}) err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second}, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1,6 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux //go:build linux
// +build linux
package nftables package nftables

View File

@@ -1,6 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux //go:build linux
// +build linux
package nftables_test package nftables_test
@@ -10,7 +9,10 @@ import (
) )
func getIPv4Table(t *testing.T) *nftables.Table { func getIPv4Table(t *testing.T) *nftables.Table {
var conn = nftables.NewConn() conn, err := nftables.NewConn()
if err != nil {
t.Fatal(err)
}
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4) table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil { if err != nil {
if err == nftables.ErrTableNotFound { if err == nftables.ErrTableNotFound {

View File

@@ -18,12 +18,16 @@ type IPListDB struct {
db *dbs.DB db *dbs.DB
itemTableName string itemTableName string
versionTableName string
deleteExpiredItemsStmt *dbs.Stmt deleteExpiredItemsStmt *dbs.Stmt
deleteItemStmt *dbs.Stmt deleteItemStmt *dbs.Stmt
insertItemStmt *dbs.Stmt insertItemStmt *dbs.Stmt
selectItemsStmt *dbs.Stmt selectItemsStmt *dbs.Stmt
selectMaxVersionStmt *dbs.Stmt selectMaxItemVersionStmt *dbs.Stmt
selectVersionStmt *dbs.Stmt
updateVersionStmt *dbs.Stmt
cleanTicker *time.Ticker cleanTicker *time.Ticker
@@ -35,6 +39,7 @@ type IPListDB struct {
func NewIPListDB() (*IPListDB, error) { func NewIPListDB() (*IPListDB, error) {
var db = &IPListDB{ var db = &IPListDB{
itemTableName: "ipItems", itemTableName: "ipItems",
versionTableName: "versions",
dir: filepath.Clean(Tea.Root + "/data"), dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour), cleanTicker: time.NewTicker(24 * time.Hour),
} }
@@ -108,6 +113,15 @@ ON "` + this.itemTableName + `" (
return err return err
} }
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"version" integer DEFAULT 0
);
`)
if err != nil {
return err
}
// 初始化SQL语句 // 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`) this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil { if err != nil {
@@ -129,7 +143,20 @@ ON "` + this.itemTableName + `" (
return err return err
} }
this.selectMaxVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`) this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
if err != nil {
return err
}
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
if err != nil {
return err
}
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
if err != nil {
return err
}
this.db = db this.db = db
@@ -172,11 +199,15 @@ func (this *IPListDB) AddItem(item *pb.IPItem) error {
// 如果是删除,则不再创建新记录 // 如果是删除,则不再创建新记录
if item.IsDeleted { if item.IsDeleted {
return nil return this.UpdateMaxVersion(item.Version)
} }
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId) _, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
if err != nil {
return err return err
}
return this.UpdateMaxVersion(item.Version)
} }
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) { func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
@@ -210,7 +241,22 @@ func (this *IPListDB) ReadMaxVersion() int64 {
return 0 return 0
} }
var row = this.selectMaxVersionStmt.QueryRow() // from version table
{
var row = this.selectVersionStmt.QueryRow()
if row == nil {
return 0
}
var version int64
err := row.Scan(&version)
if err == nil {
return version
}
}
// from items table
{
var row = this.selectMaxItemVersionStmt.QueryRow()
if row == nil { if row == nil {
return 0 return 0
} }
@@ -219,18 +265,39 @@ func (this *IPListDB) ReadMaxVersion() int64 {
if err != nil { if err != nil {
return 0 return 0
} }
return version return version
}
}
// UpdateMaxVersion 修改版本号
func (this *IPListDB) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
_, err := this.updateVersionStmt.Exec(version)
return err
} }
func (this *IPListDB) Close() error { func (this *IPListDB) Close() error {
this.isClosed = true this.isClosed = true
if this.db != nil { if this.db != nil {
_ = this.deleteExpiredItemsStmt.Close() for _, stmt := range []*dbs.Stmt{
_ = this.deleteItemStmt.Close() this.deleteExpiredItemsStmt,
_ = this.insertItemStmt.Close() this.deleteItemStmt,
_ = this.selectItemsStmt.Close() this.insertItemStmt,
_ = this.selectMaxVersionStmt.Close() this.selectItemsStmt,
this.selectMaxItemVersionStmt, // ipItems table
this.selectVersionStmt, // versions table
this.updateVersionStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
return this.db.Close() return this.db.Close()
} }

View File

@@ -79,3 +79,15 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
} }
t.Log(db.ReadMaxVersion()) t.Log(db.ReadMaxVersion())
} }
func TestIPListDB_UpdateMaxVersion(t *testing.T) {
db, err := iplibrary.NewIPListDB()
if err != nil {
t.Fatal(err)
}
err = db.UpdateMaxVersion(1027)
if err != nil {
t.Fatal(err)
}
t.Log(db.ReadMaxVersion())
}

View File

@@ -47,17 +47,20 @@ type IPListManager struct {
db *IPListDB db *IPListDB
version int64 lastVersion int64
pageSize int64 fetchPageSize int64
listMap map[int64]*IPList listMap map[int64]*IPList
locker sync.Mutex locker sync.Mutex
isFirstTime bool
} }
func NewIPListManager() *IPListManager { func NewIPListManager() *IPListManager {
return &IPListManager{ return &IPListManager{
pageSize: 1000, fetchPageSize: 5_000,
listMap: map[int64]*IPList{}, listMap: map[int64]*IPList{},
isFirstTime: true,
} }
} }
@@ -117,11 +120,11 @@ func (this *IPListManager) init() {
_ = db.DeleteExpiredItems() _ = db.DeleteExpiredItems()
// 本地数据库中最大版本号 // 本地数据库中最大版本号
this.version = db.ReadMaxVersion() this.lastVersion = db.ReadMaxVersion()
// 从本地数据库中加载 // 从本地数据库中加载
var offset int64 = 0 var offset int64 = 0
var size int64 = 1000 var size int64 = 2_000
for { for {
items, err := db.ReadItems(offset, size) items, err := db.ReadItems(offset, size)
var l = len(items) var l = len(items)
@@ -148,6 +151,11 @@ func (this *IPListManager) loop() error {
return nil return nil
} }
// 第一次同步则打印信息
if this.isFirstTime {
remotelogs.Println("IP_LIST_MANAGER", "initializing ip items ...")
}
for { for {
hasNext, err := this.fetch() hasNext, err := this.fetch()
if err != nil { if err != nil {
@@ -159,6 +167,12 @@ func (this *IPListManager) loop() error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
// 第一次同步则打印信息
if this.isFirstTime {
this.isFirstTime = false
remotelogs.Println("IP_LIST_MANAGER", "finished initializing ip items")
}
return nil return nil
} }
@@ -168,8 +182,8 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
return false, err return false, err
} }
itemsResp, err := rpcClient.IPItemRPC.ListIPItemsAfterVersion(rpcClient.Context(), &pb.ListIPItemsAfterVersionRequest{ itemsResp, err := rpcClient.IPItemRPC.ListIPItemsAfterVersion(rpcClient.Context(), &pb.ListIPItemsAfterVersionRequest{
Version: this.version, Version: this.lastVersion,
Size: this.pageSize, Size: this.fetchPageSize,
}) })
if err != nil { if err != nil {
if rpc.IsConnError(err) { if rpc.IsConnError(err) {
@@ -211,6 +225,7 @@ func (this *IPListManager) DeleteExpiredItems() {
} }
} }
// 处理IP条目
func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) { func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
var changedLists = map[*IPList]zero.Zero{} var changedLists = map[*IPList]zero.Zero{}
for _, item := range items { for _, item := range items {
@@ -280,8 +295,8 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
if fromRemote { if fromRemote {
var latestVersion = items[len(items)-1].Version var latestVersion = items[len(items)-1].Version
if latestVersion > this.version { if latestVersion > this.lastVersion {
this.version = latestVersion this.lastVersion = latestVersion
} }
} }
} }