diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 6716a91..9c2120b 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -10,8 +10,15 @@ import ( "sync/atomic" ) -var SharedIPWhiteList = NewIPList() -var SharedIPBlackList = NewIPList() +var SharedIPWhiteList = NewIPList(IPListTypeAllow) +var SharedIPBlackList = NewIPList(IPListTypeDeny) + +type IPListType = string + +const ( + IPListTypeAllow IPListType = "allow" + IPListTypeDeny IPListType = "deny" +) const IPTypeAll = "*" @@ -20,16 +27,18 @@ type IPList struct { expireList *expires.List ipMap map[string]int64 // ip => id idMap map[int64]string // id => ip + listType IPListType id int64 locker sync.RWMutex } // NewIPList 获取新对象 -func NewIPList() *IPList { +func NewIPList(listType IPListType) *IPList { var list = &IPList{ - ipMap: map[string]int64{}, - idMap: map[int64]string{}, + ipMap: map[string]int64{}, + idMap: map[int64]string{}, + listType: listType, } e := expires.NewList() @@ -67,20 +76,22 @@ func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serv func (this *IPList) RecordIP(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64, policyId int64, groupId int64, setId int64) { this.Add(ipType, scope, serverId, ip, expiresAt) - select { - case recordIPTaskChan <- &recordIPTask{ - ip: ip, - listId: firewallconfigs.GlobalListId, - expiredAt: expiresAt, - level: firewallconfigs.DefaultEventLevel, - serverId: serverId, - sourceServerId: serverId, - sourceHTTPFirewallPolicyId: policyId, - sourceHTTPFirewallRuleGroupId: groupId, - sourceHTTPFirewallRuleSetId: setId, - }: - default: + if this.listType == IPListTypeDeny { + select { + case recordIPTaskChan <- &recordIPTask{ + ip: ip, + listId: firewallconfigs.GlobalListId, + expiredAt: expiresAt, + level: firewallconfigs.DefaultEventLevel, + serverId: serverId, + sourceServerId: serverId, + sourceHTTPFirewallPolicyId: policyId, + sourceHTTPFirewallRuleGroupId: groupId, + sourceHTTPFirewallRuleSetId: setId, + }: + default: + } } }