Files
EdgeNode/internal/waf/ip_list.go

129 lines
3.0 KiB
Go

// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/iwind/TeaGo/types"
"sync"
"sync/atomic"
)
var SharedIPWhiteList = NewIPList()
var SharedIPBlackList = NewIPList()
const IPTypeAll = "*"
// IPList IP列表管理
type IPList struct {
expireList *expires.List
ipMap map[string]int64 // ip => id
idMap map[int64]string // id => ip
id int64
locker sync.RWMutex
}
// NewIPList 获取新对象
func NewIPList() *IPList {
var list = &IPList{
ipMap: map[string]int64{},
idMap: map[int64]string{},
}
e := expires.NewList()
list.expireList = e
go func() {
e.StartGC(func(itemId int64) {
list.remove(itemId)
})
}()
return list
}
// Add 添加IP
func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = "*@" + ip + "@" + ipType
}
var id = this.nextId()
this.expireList.Add(id, expiresAt)
this.locker.Lock()
this.ipMap[ip] = id
this.idMap[id] = ip
this.locker.Unlock()
}
// RecordIP 记录IP
func (this *IPList) RecordIP(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64, policyId int64, groupId int64, setId int64) {
this.Add(ipType, scope, serverId, ip, expiresAt)
select {
case recordIPTaskChan <- &recordIPTask{
ip: ip,
listId: firewallconfigs.GlobalListId,
expiredAt: expiresAt,
level: firewallconfigs.DefaultEventLevel,
sourceServerId: serverId,
sourceHTTPFirewallPolicyId: policyId,
sourceHTTPFirewallRuleGroupId: groupId,
sourceHTTPFirewallRuleSetId: setId,
}:
default:
}
}
// Contains 判断是否有某个IP
func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = "*@" + ip + "@" + ipType
}
this.locker.RLock()
defer this.locker.RUnlock()
_, ok := this.ipMap[ip]
return ok
}
// RemoveIP 删除IP
func (this *IPList) RemoveIP(ip string, serverId int64) {
this.locker.Lock()
delete(this.ipMap, "*@"+ip+"@"+IPTypeAll)
if serverId > 0 {
delete(this.ipMap, types.String(serverId)+"@"+ip+"@"+IPTypeAll)
}
this.locker.Unlock()
}
func (this *IPList) remove(id int64) {
this.locker.Lock()
ip, ok := this.idMap[id]
if ok {
ipId, ok := this.ipMap[ip]
if ok && ipId == id {
delete(this.ipMap, ip)
}
delete(this.idMap, id)
}
this.locker.Unlock()
}
func (this *IPList) nextId() int64 {
return atomic.AddInt64(&this.id, 1)
}