Files
EdgeNode/internal/waf/ip_list.go

332 lines
7.3 KiB
Go
Raw Normal View History

2024-05-17 18:30:33 +08:00
// Copyright 2021 GoEdge goedge.cdn@gmail.com. All rights reserved.
2021-07-18 15:51:49 +08:00
package waf
import (
"encoding/json"
2024-07-27 15:42:50 +08:00
"os"
"sync"
"sync/atomic"
2021-10-18 20:08:43 +08:00
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
2021-07-18 15:51:49 +08:00
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
2021-10-18 20:08:43 +08:00
"github.com/iwind/TeaGo/types"
2021-07-18 15:51:49 +08:00
)
var SharedIPWhiteList = NewIPList(IPListTypeAllow)
var SharedIPBlackList = NewIPList(IPListTypeDeny)
type IPListType = string
const (
IPListTypeAllow IPListType = "allow"
IPListTypeDeny IPListType = "deny"
)
2021-07-18 15:51:49 +08:00
const IPTypeAll = "*"
func init() {
if !teaconst.IsMain {
return
}
var cacheFile = Tea.Root + "/data/waf_white_list.cache"
// save
events.On(events.EventTerminated, func() {
_ = SharedIPWhiteList.Save(cacheFile)
})
// load
go func() {
if !Tea.IsTesting() {
_ = SharedIPWhiteList.Load(cacheFile)
_ = os.Remove(cacheFile)
}
}()
}
2021-07-18 15:51:49 +08:00
// IPList IP列表管理
type IPList struct {
expireList *expires.List
ipMap map[string]uint64 // ip info => id
idMap map[uint64]string // id => ip info
listType IPListType
2021-07-18 15:51:49 +08:00
2022-04-09 18:28:22 +08:00
id uint64
2021-07-18 15:51:49 +08:00
locker sync.RWMutex
2023-04-01 20:51:49 +08:00
lastIPInfo string // 加入到 recordIPTaskChan 之前尽可能去重
lastTime int64
2021-07-18 15:51:49 +08:00
}
// NewIPList 获取新对象
func NewIPList(listType IPListType) *IPList {
2021-07-18 15:51:49 +08:00
var list = &IPList{
2022-04-09 18:28:22 +08:00
ipMap: map[string]uint64{},
idMap: map[uint64]string{},
listType: listType,
2021-07-18 15:51:49 +08:00
}
var e = expires.NewList()
2021-07-18 15:51:49 +08:00
list.expireList = e
2022-04-09 18:28:22 +08:00
e.OnGC(func(itemId uint64) {
list.remove(itemId) // TODO 使用异步防止阻塞GC
})
2021-07-18 15:51:49 +08:00
return list
}
// Add 添加IP
2021-10-18 20:08:43 +08:00
func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeServer:
2021-10-18 20:08:43 +08:00
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
2021-10-21 09:31:31 +08:00
ip = "*@" + ip + "@" + ipType
2021-10-18 20:08:43 +08:00
}
2021-07-18 15:51:49 +08:00
var id = this.nextId()
this.expireList.Add(id, expiresAt)
this.locker.Lock()
2022-09-03 09:54:25 +08:00
// 删除以前
oldId, ok := this.ipMap[ip]
if ok {
delete(this.idMap, oldId)
this.expireList.Remove(oldId)
}
2021-07-18 15:51:49 +08:00
this.ipMap[ip] = id
this.idMap[id] = ip
this.locker.Unlock()
}
// RecordIP 记录IP
func (this *IPList) RecordIP(ipType string,
scope firewallconfigs.FirewallScope,
serverId int64,
ip string,
expiresAt int64,
policyId int64,
useLocalFirewall bool,
groupId int64,
2022-01-10 19:54:10 +08:00
setId int64,
reason string) {
this.Add(ipType, scope, serverId, ip, expiresAt)
if this.listType == IPListTypeDeny {
2023-03-31 21:37:15 +08:00
// 作用域
var scopeServerId int64
if scope == firewallconfigs.FirewallScopeServer {
2023-03-31 21:37:15 +08:00
scopeServerId = serverId
}
// 加入队列等待上传
if this.lastIPInfo != ip+"@"+ipType || fasttime.Now().Unix()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ {
2023-04-01 20:51:49 +08:00
select {
case recordIPTaskChan <- &recordIPTask{
ip: ip,
listId: firewallconfigs.GlobalBlackListId,
2023-04-01 20:51:49 +08:00
expiresAt: expiresAt,
level: firewallconfigs.DefaultEventLevel,
serverId: scopeServerId,
sourceServerId: serverId,
sourceHTTPFirewallPolicyId: policyId,
sourceHTTPFirewallRuleGroupId: groupId,
sourceHTTPFirewallRuleSetId: setId,
reason: reason,
}:
this.lastIPInfo = ip + "@" + ipType
this.lastTime = fasttime.Now().Unix()
2023-04-01 20:51:49 +08:00
default:
}
// 使用本地防火墙
2023-04-05 09:25:33 +08:00
if useLocalFirewall {
2023-04-01 20:51:49 +08:00
firewalls.DropTemporaryTo(ip, expiresAt)
}
}
2022-09-03 09:54:25 +08:00
// 关闭此IP相关连接
conns.SharedMap.CloseIPConns(ip)
}
}
2021-07-18 15:51:49 +08:00
// Contains 判断是否有某个IP
2021-10-18 20:08:43 +08:00
func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeServer:
2021-10-18 20:08:43 +08:00
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
2021-10-21 09:31:31 +08:00
ip = "*@" + ip + "@" + ipType
2021-10-18 20:08:43 +08:00
}
2021-07-18 15:51:49 +08:00
this.locker.RLock()
_, ok := this.ipMap[ip]
this.locker.RUnlock()
2021-07-18 15:51:49 +08:00
return ok
}
2022-09-03 09:54:25 +08:00
// ContainsExpires 判断是否有某个IP并返回过期时间
func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) (expiresAt int64, ok bool) {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeServer:
2022-09-03 09:54:25 +08:00
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = "*@" + ip + "@" + ipType
}
this.locker.RLock()
id, ok := this.ipMap[ip]
if ok {
expiresAt = this.expireList.ExpiresAt(id)
}
this.locker.RUnlock()
return expiresAt, ok
}
// RemoveIP 删除IP
func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
this.locker.Lock()
2022-09-03 09:54:25 +08:00
{
var key = "*@" + ip + "@" + IPTypeAll
id, ok := this.ipMap[key]
if ok {
delete(this.ipMap, key)
delete(this.idMap, id)
this.expireList.Remove(id)
}
}
if serverId > 0 {
2022-09-03 09:54:25 +08:00
var key = types.String(serverId) + "@" + ip + "@" + IPTypeAll
id, ok := this.ipMap[key]
if ok {
delete(this.ipMap, key)
delete(this.idMap, id)
this.expireList.Remove(id)
}
}
2022-09-03 09:54:25 +08:00
this.locker.Unlock()
// 从本地防火墙中删除
if shouldExecute {
_ = firewalls.Firewall().RemoveSourceIP(ip)
}
}
// Save to local file
func (this *IPList) Save(path string) error {
var itemMaps = []maps.Map{} // [ {ip info, expiresAt }, ... ]
this.locker.Lock()
defer this.locker.Unlock()
// prevent too many items
if len(this.ipMap) > 100_000 {
return nil
}
for ipInfo, id := range this.ipMap {
var expiresAt = this.expireList.ExpiresAt(id)
if expiresAt <= 0 {
continue
}
itemMaps = append(itemMaps, maps.Map{
"ip": ipInfo,
"expiresAt": expiresAt,
})
}
itemMapsJSON, err := json.Marshal(itemMaps)
if err != nil {
return err
}
return os.WriteFile(path, itemMapsJSON, 0666)
}
// Load from local file
func (this *IPList) Load(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if len(data) == 0 {
return nil
}
var itemMaps = []maps.Map{}
err = json.Unmarshal(data, &itemMaps)
if err != nil {
return err
}
this.locker.Lock()
defer this.locker.Unlock()
for _, itemMap := range itemMaps {
var ip = itemMap.GetString("ip")
var expiresAt = itemMap.GetInt64("expiresAt")
if len(ip) == 0 || expiresAt < fasttime.Now().Unix()+10 /** seconds **/ {
continue
}
var id = this.nextId()
this.expireList.Add(id, expiresAt)
this.ipMap[ip] = id
this.idMap[id] = ip
}
return nil
}
// IPMap get ipMap
func (this *IPList) IPMap() map[string]uint64 {
this.locker.RLock()
defer this.locker.RUnlock()
return this.ipMap
}
// IdMap get idMap
func (this *IPList) IdMap() map[uint64]string {
this.locker.RLock()
defer this.locker.RUnlock()
return this.idMap
}
2022-04-09 18:28:22 +08:00
func (this *IPList) remove(id uint64) {
2021-07-18 15:51:49 +08:00
this.locker.Lock()
ip, ok := this.idMap[id]
if ok {
ipId, ok := this.ipMap[ip]
if ok && ipId == id {
delete(this.ipMap, ip)
}
delete(this.idMap, id)
}
this.locker.Unlock()
}
2022-04-09 18:28:22 +08:00
func (this *IPList) nextId() uint64 {
return atomic.AddUint64(&this.id, 1)
2021-07-18 15:51:49 +08:00
}