diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 518c288..6a1a6f9 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -3,12 +3,18 @@ package waf import ( + "encoding/json" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/conns" + teaconst "github.com/TeaOSLab/EdgeNode/internal/const" + "github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/firewalls" "github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" + "os" "sync" "sync/atomic" ) @@ -25,11 +31,30 @@ const ( const IPTypeAll = "*" +func init() { + if !teaconst.IsMain { + return + } + + var cacheFile = Tea.Root + "/data/waf_white_list.cache" + + // save + events.On(events.EventTerminated, func() { + _ = SharedIPWhiteList.Save(cacheFile) + }) + + // load + go func() { + _ = SharedIPWhiteList.Load(cacheFile) + _ = os.Remove(cacheFile) + }() +} + // IPList IP列表管理 type IPList struct { expireList *expires.List - ipMap map[string]uint64 // ip => id - idMap map[uint64]string // id => ip + ipMap map[string]uint64 // ip info => id + idMap map[uint64]string // id => ip info listType IPListType id uint64 @@ -47,7 +72,7 @@ func NewIPList(listType IPListType) *IPList { listType: listType, } - e := expires.NewList() + var e = expires.NewList() list.expireList = e e.OnGC(func(itemId uint64) { @@ -206,6 +231,85 @@ func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) { } } +// Save to local file +func (this *IPList) Save(path string) error { + var itemMaps = []maps.Map{} // [ {ip info, expiresAt }, ... ] + this.locker.Lock() + defer this.locker.Unlock() + + // prevent too many items + if len(this.ipMap) > 100_000 { + return nil + } + + for ipInfo, id := range this.ipMap { + var expiresAt = this.expireList.ExpiresAt(id) + if expiresAt <= 0 { + continue + } + itemMaps = append(itemMaps, maps.Map{ + "ip": ipInfo, + "expiresAt": expiresAt, + }) + } + + itemMapsJSON, err := json.Marshal(itemMaps) + if err != nil { + return err + } + return os.WriteFile(path, itemMapsJSON, 0666) +} + +// Load from local file +func (this *IPList) Load(path string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + if len(data) == 0 { + return nil + } + + var itemMaps = []maps.Map{} + err = json.Unmarshal(data, &itemMaps) + if err != nil { + return err + } + + this.locker.Lock() + defer this.locker.Unlock() + + for _, itemMap := range itemMaps { + var ip = itemMap.GetString("ip") + var expiresAt = itemMap.GetInt64("expiresAt") + if len(ip) == 0 || expiresAt < fasttime.Now().Unix()+10 /** seconds **/ { + continue + } + + var id = this.nextId() + this.expireList.Add(id, expiresAt) + + this.ipMap[ip] = id + this.idMap[id] = ip + } + + return nil +} + +// IPMap get ipMap +func (this *IPList) IPMap() map[string]uint64 { + this.locker.RLock() + defer this.locker.RUnlock() + return this.ipMap +} + +// IdMap get idMap +func (this *IPList) IdMap() map[uint64]string { + this.locker.RLock() + defer this.locker.RUnlock() + return this.idMap +} + func (this *IPList) remove(id uint64) { this.locker.Lock() ip, ok := this.idMap[id] diff --git a/internal/waf/ip_list_test.go b/internal/waf/ip_list_test.go index 7133a82..8df5f54 100644 --- a/internal/waf/ip_list_test.go +++ b/internal/waf/ip_list_test.go @@ -1,12 +1,16 @@ // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. -package waf +package waf_test import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeNode/internal/waf" + "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/assert" + _ "github.com/iwind/TeaGo/bootstrap" "github.com/iwind/TeaGo/logs" timeutil "github.com/iwind/TeaGo/utils/time" + "os" "runtime" "strconv" "testing" @@ -14,35 +18,33 @@ import ( ) func TestNewIPList(t *testing.T) { - var list = NewIPList(IPListTypeDeny) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10) + var list = waf.NewIPList(waf.IPListTypeDeny) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10) list.RemoveIP("127.0.0.1", 1, false) - logs.PrintAsJSON(list.ipMap, t) - logs.PrintAsJSON(list.idMap, t) + logs.PrintAsJSON(list.IPMap(), t) + logs.PrintAsJSON(list.IdMap(), t) } func TestIPList_Expire(t *testing.T) { - var list = NewIPList(IPListTypeDeny) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6) + var list = waf.NewIPList(waf.IPListTypeDeny) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6) var ticker = time.NewTicker(1 * time.Second) for range ticker.C { t.Log("====") - list.locker.Lock() - logs.PrintAsJSON(list.ipMap, t) - logs.PrintAsJSON(list.idMap, t) - list.locker.Unlock() - if len(list.idMap) == 0 { + logs.PrintAsJSON(list.IPMap(), t) + logs.PrintAsJSON(list.IdMap(), t) + if len(list.IdMap()) == 0 { break } } @@ -51,54 +53,78 @@ func TestIPList_Expire(t *testing.T) { func TestIPList_Contains(t *testing.T) { var a = assert.NewAssertion(t) - var list = NewIPList(IPListTypeDeny) + var list = waf.NewIPList(waf.IPListTypeDeny) for i := 0; i < 1_0000; i++ { - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } //list.RemoveIP("192.168.1.100") { - a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")) + a.IsTrue(list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")) } { - a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100")) + a.IsFalse(list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100")) } } func TestIPList_ContainsExpires(t *testing.T) { - var list = NewIPList(IPListTypeDeny) + var list = waf.NewIPList(waf.IPListTypeDeny) for i := 0; i < 1_0000; i++ { - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } // list.RemoveIP("192.168.1.100", 1, false) for _, ip := range []string{"192.168.1.100", "192.168.2.100"} { - expiresAt, ok := list.ContainsExpires(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip) + expiresAt, ok := list.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip) t.Log(ok, expiresAt, timeutil.FormatTime("Y-m-d H:i:s", expiresAt)) } } +func TestIPList_Save(t *testing.T) { + var a = assert.NewAssertion(t) + + var list = waf.NewIPList(waf.IPListTypeAllow) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100", time.Now().Unix()+3600) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 2, "192.168.1.101", time.Now().Unix()+3600) + + var file = Tea.Root + "/data/waf.iplist.json" + err := list.Save(file) + if err != nil { + t.Fatal(err) + } + + var newList = waf.NewIPList(waf.IPListTypeAllow) + err = newList.Load(file) + if err != nil { + t.Fatal(err) + } + + a.IsTrue(newList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")) + + _ = os.Remove(file) +} + func BenchmarkIPList_Add(b *testing.B) { runtime.GOMAXPROCS(1) - var list = NewIPList(IPListTypeDeny) + var list = waf.NewIPList(waf.IPListTypeDeny) for i := 0; i < b.N; i++ { - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } - b.Log(len(list.ipMap)) + b.Log(len(list.IPMap())) } func BenchmarkIPList_Has(b *testing.B) { runtime.GOMAXPROCS(1) - var list = NewIPList(IPListTypeDeny) + var list = waf.NewIPList(waf.IPListTypeDeny) b.ResetTimer() for i := 0; i < 1_0000; i++ { - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } for i := 0; i < b.N; i++ { - list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100") + list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100") } }