diff --git a/internal/firewalls/firewall_nftables.go b/internal/firewalls/firewall_nftables.go index 59a88f5..90d0d56 100644 --- a/internal/firewalls/firewall_nftables.go +++ b/internal/firewalls/firewall_nftables.go @@ -5,6 +5,7 @@ package firewalls import ( "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeNode/internal/conns" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/events" @@ -130,8 +131,8 @@ type NFTablesFirewall struct { allowIPv4Set *nftables.Set allowIPv6Set *nftables.Set - denyIPv4Set *nftables.Set - denyIPv6Set *nftables.Set + denyIPv4Sets []*nftables.Set + denyIPv6Sets []*nftables.Set firewalld *Firewalld @@ -206,7 +207,7 @@ func (this *NFTablesFirewall) init() error { // allow set // "allow" should be always first - for _, setAction := range []string{"allow", "deny"} { + for _, setAction := range []string{"allow", "deny", "deny1", "deny2", "deny3", "deny4"} { var setName = setAction + "_set" set, err := table.GetSet(setName) @@ -236,13 +237,13 @@ func (this *NFTablesFirewall) init() error { if setAction == "allow" { this.allowIPv4Set = set } else { - this.denyIPv4Set = set + this.denyIPv4Sets = append(this.denyIPv4Sets, set) } } else if tableDef.IsIPv6 { if setAction == "allow" { this.allowIPv6Set = set } else { - this.denyIPv6Set = set + this.denyIPv6Sets = append(this.denyIPv6Sets, set) } } @@ -401,20 +402,21 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async // 再次尝试关闭连接 defer conns.SharedMap.CloseIPConns(ip) + var ipLong = configutils.IPString2Long(ip) if strings.Contains(ip, ":") { // ipv6 - if this.denyIPv6Set == nil { - return errors.New("ipv6 ip set is nil") + if len(this.denyIPv6Sets) == 0 { + return errors.New("ipv6 ip set not found") } - return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{ + return this.denyIPv6Sets[ipLong%uint64(len(this.denyIPv6Sets))].AddElement(data.To16(), &nftables.ElementOptions{ Timeout: time.Duration(timeoutSeconds) * time.Second, }, false) } // ipv4 - if this.denyIPv4Set == nil { - return errors.New("ipv4 ip set is nil") + if len(this.denyIPv4Sets) == 0 { + return errors.New("ipv4 ip set not found") } - return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{ + return this.denyIPv4Sets[ipLong%uint64(len(this.denyIPv4Sets))].AddElement(data.To4(), &nftables.ElementOptions{ Timeout: time.Duration(timeoutSeconds) * time.Second, }, false) } @@ -426,9 +428,10 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error { return errors.New("invalid ip '" + ip + "'") } + var ipLong = configutils.IPString2Long(ip) if strings.Contains(ip, ":") { // ipv6 - if this.denyIPv6Set != nil { - err := this.denyIPv6Set.DeleteElement(data.To16()) + if len(this.denyIPv6Sets) > 0 { + err := this.denyIPv6Sets[ipLong%uint64(len(this.denyIPv6Sets))].DeleteElement(data.To16()) if err != nil { return err } @@ -445,13 +448,14 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error { } // ipv4 - if this.allowIPv4Set != nil { - err := this.denyIPv4Set.DeleteElement(data.To4()) + if len(this.denyIPv4Sets) > 0 { + err := this.denyIPv4Sets[ipLong%uint64(len(this.denyIPv4Sets))].DeleteElement(data.To4()) if err != nil { return err } - - err = this.allowIPv4Set.DeleteElement(data.To4()) + } + if this.allowIPv4Set != nil { + err := this.allowIPv4Set.DeleteElement(data.To4()) if err != nil { return err }