Files
EdgeNode/internal/waf/ip_list.go
2022-08-04 11:01:16 +08:00

169 lines
3.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/iwind/TeaGo/types"
"sync"
"sync/atomic"
"time"
)
var SharedIPWhiteList = NewIPList(IPListTypeAllow)
var SharedIPBlackList = NewIPList(IPListTypeDeny)
type IPListType = string
const (
IPListTypeAllow IPListType = "allow"
IPListTypeDeny IPListType = "deny"
)
const IPTypeAll = "*"
// IPList IP列表管理
type IPList struct {
expireList *expires.List
ipMap map[string]uint64 // ip => id
idMap map[uint64]string // id => ip
listType IPListType
id uint64
locker sync.RWMutex
}
// NewIPList 获取新对象
func NewIPList(listType IPListType) *IPList {
var list = &IPList{
ipMap: map[string]uint64{},
idMap: map[uint64]string{},
listType: listType,
}
e := expires.NewList()
list.expireList = e
e.OnGC(func(itemId uint64) {
list.remove(itemId)
})
return list
}
// Add 添加IP
func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = "*@" + ip + "@" + ipType
}
var id = this.nextId()
this.expireList.Add(id, expiresAt)
this.locker.Lock()
this.ipMap[ip] = id
this.idMap[id] = ip
this.locker.Unlock()
}
// RecordIP 记录IP
func (this *IPList) RecordIP(ipType string,
scope firewallconfigs.FirewallScope,
serverId int64,
ip string,
expiresAt int64,
policyId int64,
useLocalFirewall bool,
groupId int64,
setId int64,
reason string) {
this.Add(ipType, scope, serverId, ip, expiresAt)
if this.listType == IPListTypeDeny {
// 加入队列等待上传
select {
case recordIPTaskChan <- &recordIPTask{
ip: ip,
listId: firewallconfigs.GlobalListId,
expiredAt: expiresAt,
level: firewallconfigs.DefaultEventLevel,
serverId: serverId,
sourceServerId: serverId,
sourceHTTPFirewallPolicyId: policyId,
sourceHTTPFirewallRuleGroupId: groupId,
sourceHTTPFirewallRuleSetId: setId,
reason: reason,
}:
default:
}
// 使用本地防火墙
if useLocalFirewall {
var seconds = expiresAt - time.Now().Unix()
if seconds > 0 {
// 最大3600防止误封时间过长
if seconds > 3600 {
seconds = 3600
}
_ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true)
}
}
}
}
// Contains 判断是否有某个IP
func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = "*@" + ip + "@" + ipType
}
this.locker.RLock()
_, ok := this.ipMap[ip]
this.locker.RUnlock()
return ok
}
// RemoveIP 删除IP
func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
this.locker.Lock()
delete(this.ipMap, "*@"+ip+"@"+IPTypeAll)
if serverId > 0 {
delete(this.ipMap, types.String(serverId)+"@"+ip+"@"+IPTypeAll)
}
this.locker.Unlock()
// 从本地防火墙中删除
if shouldExecute {
_ = firewalls.Firewall().RemoveSourceIP(ip)
}
}
func (this *IPList) remove(id uint64) {
this.locker.Lock()
ip, ok := this.idMap[id]
if ok {
ipId, ok := this.ipMap[ip]
if ok && ipId == id {
delete(this.ipMap, ip)
}
delete(this.idMap, id)
}
this.locker.Unlock()
}
func (this *IPList) nextId() uint64 {
return atomic.AddUint64(&this.id, 1)
}