// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. package waf import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/conns" "github.com/TeaOSLab/EdgeNode/internal/firewalls" "github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/iwind/TeaGo/types" "sync" "sync/atomic" "time" ) var SharedIPWhiteList = NewIPList(IPListTypeAllow) var SharedIPBlackList = NewIPList(IPListTypeDeny) type IPListType = string const ( IPListTypeAllow IPListType = "allow" IPListTypeDeny IPListType = "deny" ) const IPTypeAll = "*" // IPList IP列表管理 type IPList struct { expireList *expires.List ipMap map[string]uint64 // ip => id idMap map[uint64]string // id => ip listType IPListType id uint64 locker sync.RWMutex } // NewIPList 获取新对象 func NewIPList(listType IPListType) *IPList { var list = &IPList{ ipMap: map[string]uint64{}, idMap: map[uint64]string{}, listType: listType, } e := expires.NewList() list.expireList = e e.OnGC(func(itemId uint64) { list.remove(itemId) // TODO 使用异步,防止阻塞GC }) return list } // Add 添加IP func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) { switch scope { case firewallconfigs.FirewallScopeGlobal: ip = "*@" + ip + "@" + ipType case firewallconfigs.FirewallScopeService: ip = types.String(serverId) + "@" + ip + "@" + ipType default: ip = "*@" + ip + "@" + ipType } var id = this.nextId() this.expireList.Add(id, expiresAt) this.locker.Lock() // 删除以前 oldId, ok := this.ipMap[ip] if ok { delete(this.idMap, oldId) this.expireList.Remove(oldId) } this.ipMap[ip] = id this.idMap[id] = ip this.locker.Unlock() } // RecordIP 记录IP func (this *IPList) RecordIP(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64, policyId int64, useLocalFirewall bool, groupId int64, setId int64, reason string) { this.Add(ipType, scope, serverId, ip, expiresAt) if this.listType == IPListTypeDeny { // 加入队列等待上传 select { case recordIPTaskChan <- &recordIPTask{ ip: ip, listId: firewallconfigs.GlobalListId, expiresAt: expiresAt, level: firewallconfigs.DefaultEventLevel, serverId: serverId, sourceServerId: serverId, sourceHTTPFirewallPolicyId: policyId, sourceHTTPFirewallRuleGroupId: groupId, sourceHTTPFirewallRuleSetId: setId, reason: reason, }: default: } // 使用本地防火墙 if useLocalFirewall { var seconds = expiresAt - time.Now().Unix() if seconds > 0 { // 最大3600,防止误封时间过长 if seconds > 3600 { seconds = 3600 } _ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true) } } // 关闭此IP相关连接 conns.SharedMap.CloseIPConns(ip) } } // Contains 判断是否有某个IP func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool { switch scope { case firewallconfigs.FirewallScopeGlobal: ip = "*@" + ip + "@" + ipType case firewallconfigs.FirewallScopeService: ip = types.String(serverId) + "@" + ip + "@" + ipType default: ip = "*@" + ip + "@" + ipType } this.locker.RLock() _, ok := this.ipMap[ip] this.locker.RUnlock() return ok } // ContainsExpires 判断是否有某个IP,并返回过期时间 func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) (expiresAt int64, ok bool) { switch scope { case firewallconfigs.FirewallScopeGlobal: ip = "*@" + ip + "@" + ipType case firewallconfigs.FirewallScopeService: ip = types.String(serverId) + "@" + ip + "@" + ipType default: ip = "*@" + ip + "@" + ipType } this.locker.RLock() id, ok := this.ipMap[ip] if ok { expiresAt = this.expireList.ExpiresAt(id) } this.locker.RUnlock() return expiresAt, ok } // RemoveIP 删除IP func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) { this.locker.Lock() { var key = "*@" + ip + "@" + IPTypeAll id, ok := this.ipMap[key] if ok { delete(this.ipMap, key) delete(this.idMap, id) this.expireList.Remove(id) } } if serverId > 0 { var key = types.String(serverId) + "@" + ip + "@" + IPTypeAll id, ok := this.ipMap[key] if ok { delete(this.ipMap, key) delete(this.idMap, id) this.expireList.Remove(id) } } this.locker.Unlock() // 从本地防火墙中删除 if shouldExecute { _ = firewalls.Firewall().RemoveSourceIP(ip) } } func (this *IPList) remove(id uint64) { this.locker.Lock() ip, ok := this.idMap[id] if ok { ipId, ok := this.ipMap[ip] if ok && ipId == id { delete(this.ipMap, ip) } delete(this.idMap, id) } this.locker.Unlock() } func (this *IPList) nextId() uint64 { return atomic.AddUint64(&this.id, 1) }