diff --git a/internal/firewalls/firewall_nftables.go b/internal/firewalls/firewall_nftables.go index 96cb1b9..8061b55 100644 --- a/internal/firewalls/firewall_nftables.go +++ b/internal/firewalls/firewall_nftables.go @@ -88,11 +88,15 @@ type blockIPItem struct { } func NewNFTablesFirewall() (*NFTablesFirewall, error) { + conn, err := nftables.NewConn() + if err != nil { + return nil, err + } var firewall = &NFTablesFirewall{ - conn: nftables.NewConn(), + conn: conn, dropIPQueue: make(chan *blockIPItem, 4096), } - err := firewall.init() + err = firewall.init() if err != nil { return nil, err } diff --git a/internal/firewalls/nftables/chain.go b/internal/firewalls/nftables/chain.go index 005b585..5f0ac6d 100644 --- a/internal/firewalls/nftables/chain.go +++ b/internal/firewalls/nftables/chain.go @@ -1,6 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build linux -// +build linux package nftables diff --git a/internal/firewalls/nftables/chain_policy.go b/internal/firewalls/nftables/chain_policy.go index 677c573..86a0854 100644 --- a/internal/firewalls/nftables/chain_policy.go +++ b/internal/firewalls/nftables/chain_policy.go @@ -1,4 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux package nftables diff --git a/internal/firewalls/nftables/chain_test.go b/internal/firewalls/nftables/chain_test.go index b75341f..8a658da 100644 --- a/internal/firewalls/nftables/chain_test.go +++ b/internal/firewalls/nftables/chain_test.go @@ -1,6 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build linux -// +build linux package nftables_test @@ -11,7 +10,10 @@ import ( ) func getIPv4Chain(t *testing.T) *nftables.Chain { - var conn = nftables.NewConn() + conn, err := nftables.NewConn() + if err != nil { + t.Fatal(err) + } table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4) if err != nil { if err == nftables.ErrTableNotFound { diff --git a/internal/firewalls/nftables/conn.go b/internal/firewalls/nftables/conn.go index d2e63e7..013890a 100644 --- a/internal/firewalls/nftables/conn.go +++ b/internal/firewalls/nftables/conn.go @@ -15,10 +15,14 @@ type Conn struct { rawConn *nft.Conn } -func NewConn() *Conn { - return &Conn{ - rawConn: &nft.Conn{}, +func NewConn() (*Conn, error) { + conn, err := nft.New() + if err != nil { + return nil, err } + return &Conn{ + rawConn: conn, + }, nil } func (this *Conn) Raw() *nft.Conn { diff --git a/internal/firewalls/nftables/errors.go b/internal/firewalls/nftables/errors.go index 7eba54c..8c6e80a 100644 --- a/internal/firewalls/nftables/errors.go +++ b/internal/firewalls/nftables/errors.go @@ -4,7 +4,10 @@ package nftables -import "errors" +import ( + "errors" + "strings" +) var ErrTableNotFound = errors.New("table not found") var ErrChainNotFound = errors.New("chain not found") @@ -15,5 +18,5 @@ func IsNotFound(err error) bool { if err == nil { return false } - return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound + return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound || strings.Contains(err.Error(), "no such file or directory") } diff --git a/internal/firewalls/nftables/expration.go b/internal/firewalls/nftables/expration.go new file mode 100644 index 0000000..e85ec62 --- /dev/null +++ b/internal/firewalls/nftables/expration.go @@ -0,0 +1,62 @@ +// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package nftables + +import ( + "sync" + "time" +) + +type Expiration struct { + m map[string]time.Time // key => expires time + + lastGCAt int64 + + locker sync.RWMutex +} + +func NewExpiration() *Expiration { + return &Expiration{ + m: map[string]time.Time{}, + } +} + +func (this *Expiration) AddUnsafe(key []byte, expires time.Time) { + this.m[string(key)] = expires +} + +func (this *Expiration) Add(key []byte, expires time.Time) { + this.locker.Lock() + this.m[string(key)] = expires + this.gc() + this.locker.Unlock() +} + +func (this *Expiration) Remove(key []byte) { + this.locker.Lock() + delete(this.m, string(key)) + this.locker.Unlock() +} + +func (this *Expiration) Contains(key []byte) bool { + this.locker.RLock() + _, ok := this.m[string(key)] + this.locker.RUnlock() + return ok +} + +func (this *Expiration) gc() { + // we won't gc too frequently + var currentTime = time.Now().Unix() + if this.lastGCAt >= currentTime { + return + } + this.lastGCAt = currentTime + + var now = time.Now().Add(-10 * time.Second) // gc elements expired before 10 seconds ago + for key, expires := range this.m { + if expires.Year() >= 2000 && now.After(expires) { + delete(this.m, key) + } + } +} diff --git a/internal/firewalls/nftables/expration_test.go b/internal/firewalls/nftables/expration_test.go new file mode 100644 index 0000000..cd894a4 --- /dev/null +++ b/internal/firewalls/nftables/expration_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package nftables_test + +import ( + "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" + "github.com/iwind/TeaGo/rands" + "github.com/iwind/TeaGo/types" + "net" + "testing" + "time" +) + +func TestExpiration_Add(t *testing.T) { + var expiration = nftables.NewExpiration() + { + expiration.Add([]byte{'a', 'b', 'c'}, time.Now()) + t.Log(expiration.Contains([]byte{'a', 'b', 'c'})) + } + { + expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(-1*time.Second)) + t.Log(expiration.Contains([]byte{'a', 'b', 'c'})) + } + { + expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(-10*time.Second)) + t.Log(expiration.Contains([]byte{'a', 'b', 'c'})) + } + { + expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(1*time.Second)) + expiration.Remove([]byte{'a', 'b', 'c'}) + t.Log(expiration.Contains([]byte{'a', 'b', 'c'})) + } + { + expiration.Add(net.ParseIP("10.254.0.75").To4(), time.Now()) + t.Log(expiration.Contains(net.ParseIP("10.254.0.75").To4())) + } +} + +func BenchmarkNewExpiration(b *testing.B) { + var expiration = nftables.NewExpiration() + for i := 0; i < 10_000; i++ { + expiration.Add([]byte(types.String(types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)))), time.Now().Add(3600*time.Second)) + } + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + expiration.Add([]byte(types.String(types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)))), time.Now().Add(3600*time.Second)) + } + }) +} diff --git a/internal/firewalls/nftables/family.go b/internal/firewalls/nftables/family.go index 42fa71a..1643a13 100644 --- a/internal/firewalls/nftables/family.go +++ b/internal/firewalls/nftables/family.go @@ -1,4 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux package nftables diff --git a/internal/firewalls/nftables/installer.go b/internal/firewalls/nftables/installer.go index 4c1e636..eb93775 100644 --- a/internal/firewalls/nftables/installer.go +++ b/internal/firewalls/nftables/installer.go @@ -1,4 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . +//go:build linux package nftables diff --git a/internal/firewalls/nftables/rule.go b/internal/firewalls/nftables/rule.go index 75b3874..77a7c69 100644 --- a/internal/firewalls/nftables/rule.go +++ b/internal/firewalls/nftables/rule.go @@ -1,4 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux package nftables diff --git a/internal/firewalls/nftables/set.go b/internal/firewalls/nftables/set.go index 2204c53..6c25a8c 100644 --- a/internal/firewalls/nftables/set.go +++ b/internal/firewalls/nftables/set.go @@ -1,6 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build linux -// +build linux package nftables @@ -35,17 +34,25 @@ type Set struct { conn *Conn rawSet *nft.Set batch *SetBatch + + expiration *Expiration } func NewSet(conn *Conn, rawSet *nft.Set) *Set { - return &Set{ - conn: conn, - rawSet: rawSet, + var set = &Set{ + conn: conn, + rawSet: rawSet, + expiration: nil, batch: &SetBatch{ conn: conn, rawSet: rawSet, }, } + + // retrieve set elements to improve "delete" speed + set.initElements() + + return set } func (this *Set) Raw() *nft.Set { @@ -57,11 +64,21 @@ func (this *Set) Name() string { } func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error { + // check if already exists + if this.expiration != nil && !overwrite && this.expiration.Contains(key) { + return nil + } + + var expiresTime = time.Time{} var rawElement = nft.SetElement{ Key: key, } if options != nil { rawElement.Timeout = options.Timeout + + if options.Timeout > 0 { + expiresTime = time.UnixMilli(time.Now().UnixMilli() + options.Timeout.Milliseconds()) + } } err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{ rawElement, @@ -71,9 +88,19 @@ func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) } err = this.conn.Commit() - if err != nil { + if err == nil { + if this.expiration != nil { + this.expiration.Add(key, expiresTime) + } + } else { + var isFileExistsErr = strings.Contains(err.Error(), "file exists") + if !overwrite && isFileExistsErr { + // ignore file exists error + return nil + } + // retry if exists - if overwrite && strings.Contains(err.Error(), "file exists") { + if overwrite && isFileExistsErr { deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ { Key: key, @@ -85,6 +112,11 @@ func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) }) if err == nil { err = this.conn.Commit() + if err == nil { + if this.expiration != nil { + this.expiration.Add(key, expiresTime) + } + } } } } @@ -107,6 +139,11 @@ func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool } func (this *Set) DeleteElement(key []byte) error { + // if set element does not exist, we return immediately + if this.expiration != nil && !this.expiration.Contains(key) { + return nil + } + err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ { Key: key, @@ -116,9 +153,17 @@ func (this *Set) DeleteElement(key []byte) error { return err } err = this.conn.Commit() - if err != nil { + if err == nil { + if this.expiration != nil { + this.expiration.Remove(key) + } + } else { if strings.Contains(err.Error(), "no such file or directory") { err = nil + + if this.expiration != nil { + this.expiration.Remove(key) + } } } return err diff --git a/internal/firewalls/nftables/set_data_type.go b/internal/firewalls/nftables/set_data_type.go index 6fb80c8..8ef764c 100644 --- a/internal/firewalls/nftables/set_data_type.go +++ b/internal/firewalls/nftables/set_data_type.go @@ -1,4 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux package nftables diff --git a/internal/firewalls/nftables/set_ext.go b/internal/firewalls/nftables/set_ext.go new file mode 100644 index 0000000..b12deca --- /dev/null +++ b/internal/firewalls/nftables/set_ext.go @@ -0,0 +1,8 @@ +// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . +//go:build linux && !plus + +package nftables + +func (this *Set) initElements() { + // NOT IMPLEMENTED +} diff --git a/internal/firewalls/nftables/set_test.go b/internal/firewalls/nftables/set_test.go index 47b969d..568709f 100644 --- a/internal/firewalls/nftables/set_test.go +++ b/internal/firewalls/nftables/set_test.go @@ -34,7 +34,7 @@ func getIPv4Set(t *testing.T) *nftables.Set { func TestSet_AddElement(t *testing.T) { var set = getIPv4Set(t) - err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second}) + err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second}, false) if err != nil { t.Fatal(err) } diff --git a/internal/firewalls/nftables/table.go b/internal/firewalls/nftables/table.go index cfbe766..6a8e44e 100644 --- a/internal/firewalls/nftables/table.go +++ b/internal/firewalls/nftables/table.go @@ -1,6 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build linux -// +build linux package nftables diff --git a/internal/firewalls/nftables/table_test.go b/internal/firewalls/nftables/table_test.go index e414a29..74b0546 100644 --- a/internal/firewalls/nftables/table_test.go +++ b/internal/firewalls/nftables/table_test.go @@ -1,6 +1,5 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build linux -// +build linux package nftables_test @@ -10,7 +9,10 @@ import ( ) func getIPv4Table(t *testing.T) *nftables.Table { - var conn = nftables.NewConn() + conn, err := nftables.NewConn() + if err != nil { + t.Fatal(err) + } table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4) if err != nil { if err == nftables.ErrTableNotFound { diff --git a/internal/iplibrary/ip_list_db.go b/internal/iplibrary/ip_list_db.go index 4278b85..30d486f 100644 --- a/internal/iplibrary/ip_list_db.go +++ b/internal/iplibrary/ip_list_db.go @@ -17,13 +17,17 @@ import ( type IPListDB struct { db *dbs.DB - itemTableName string + itemTableName string + versionTableName string - deleteExpiredItemsStmt *dbs.Stmt - deleteItemStmt *dbs.Stmt - insertItemStmt *dbs.Stmt - selectItemsStmt *dbs.Stmt - selectMaxVersionStmt *dbs.Stmt + deleteExpiredItemsStmt *dbs.Stmt + deleteItemStmt *dbs.Stmt + insertItemStmt *dbs.Stmt + selectItemsStmt *dbs.Stmt + selectMaxItemVersionStmt *dbs.Stmt + + selectVersionStmt *dbs.Stmt + updateVersionStmt *dbs.Stmt cleanTicker *time.Ticker @@ -34,9 +38,10 @@ type IPListDB struct { func NewIPListDB() (*IPListDB, error) { var db = &IPListDB{ - itemTableName: "ipItems", - dir: filepath.Clean(Tea.Root + "/data"), - cleanTicker: time.NewTicker(24 * time.Hour), + itemTableName: "ipItems", + versionTableName: "versions", + dir: filepath.Clean(Tea.Root + "/data"), + cleanTicker: time.NewTicker(24 * time.Hour), } err := db.init() return db, err @@ -108,6 +113,15 @@ ON "` + this.itemTableName + `" ( return err } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "version" integer DEFAULT 0 +); +`) + if err != nil { + return err + } + // 初始化SQL语句 this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt" this.version { - this.version = latestVersion + if latestVersion > this.lastVersion { + this.lastVersion = latestVersion } } }