mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 06:40:25 +08:00
WAF规则匹配后的IP也会上报/实现IP全局名单/将名单存储到本地数据库,提升读写速度
This commit is contained in:
5
internal/iplibrary/README.md
Normal file
5
internal/iplibrary/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# IPList
|
||||
List Check Order:
|
||||
~~~
|
||||
Global List --> Node List--> Server List --> WAF List --> Bind List
|
||||
~~~
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var GlobalBlackIPList = NewIPList()
|
||||
var GlobalWhiteIPList = NewIPList()
|
||||
|
||||
// IPList IP名单
|
||||
// TODO IP名单可以分片关闭,这样让每一片的数据量减少,查询更快
|
||||
type IPList struct {
|
||||
|
||||
145
internal/iplibrary/ip_list_db.go
Normal file
145
internal/iplibrary/ip_list_db.go
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type IPListDB struct {
|
||||
db *sql.DB
|
||||
|
||||
itemTableName string
|
||||
deleteItemStmt *sql.Stmt
|
||||
insertItemStmt *sql.Stmt
|
||||
selectItemsStmt *sql.Stmt
|
||||
|
||||
dir string
|
||||
}
|
||||
|
||||
func NewIPListDB() (*IPListDB, error) {
|
||||
var db = &IPListDB{
|
||||
itemTableName: "ipItems",
|
||||
dir: filepath.Clean(Tea.Root + "/data"),
|
||||
}
|
||||
err := db.init()
|
||||
return db, err
|
||||
}
|
||||
|
||||
func (this *IPListDB) init() error {
|
||||
// 检查目录是否存在
|
||||
_, err := os.Stat(this.dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(this.dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("CACHE", "create cache dir '"+this.dir+"'")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+this.dir+"/ip_list.db?cache=shared&mode=rwc&_journal_mode=WAL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
this.db = db
|
||||
|
||||
// 初始化数据库
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"listId" integer DEFAULT 0,
|
||||
"listType" varchar(32),
|
||||
"isGlobal" integer(1) DEFAULT 0,
|
||||
"type" varchar(16),
|
||||
"itemId" integer DEFAULT 0,
|
||||
"ipFrom" varchar(64) DEFAULT 0,
|
||||
"ipTo" varchar(64) DEFAULT 0,
|
||||
"expiredAt" integer DEFAULT 0,
|
||||
"eventLevel" varchar(32),
|
||||
"isDeleted" integer(1) DEFAULT 0,
|
||||
"version" integer DEFAULT 0,
|
||||
"nodeId" integer DEFAULT 0,
|
||||
"serverId" integer DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"itemId" ASC
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"expiredAt" ASC
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化SQL语句
|
||||
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.db = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *IPListDB) AddItem(item *pb.IPItem) error {
|
||||
_, err := this.deleteItemStmt.Exec(item.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
|
||||
rows, err := this.selectItemsStmt.Query(offset, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
|
||||
var pbItem = &pb.IPItem{}
|
||||
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, pbItem)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *IPListDB) Close() error {
|
||||
if this.db != nil {
|
||||
_ = this.deleteItemStmt.Close()
|
||||
_ = this.insertItemStmt.Close()
|
||||
_ = this.selectItemsStmt.Close()
|
||||
|
||||
return this.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
60
internal/iplibrary/ip_list_db_test.go
Normal file
60
internal/iplibrary/ip_list_db_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPListDB_AddItem(t *testing.T) {
|
||||
db, err := NewIPListDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.AddItem(&pb.IPItem{
|
||||
Id: 1,
|
||||
IpFrom: "192.168.1.101",
|
||||
IpTo: "",
|
||||
Version: 1024,
|
||||
ExpiredAt: time.Now().Unix(),
|
||||
Reason: "",
|
||||
ListId: 2,
|
||||
IsDeleted: true,
|
||||
Type: "ipv4",
|
||||
EventLevel: "error",
|
||||
ListType: "black",
|
||||
IsGlobal: true,
|
||||
CreatedAt: 0,
|
||||
NodeId: 11,
|
||||
ServerId: 22,
|
||||
SourceNodeId: 0,
|
||||
SourceServerId: 0,
|
||||
SourceHTTPFirewallPolicyId: 0,
|
||||
SourceHTTPFirewallRuleGroupId: 0,
|
||||
SourceHTTPFirewallRuleSetId: 0,
|
||||
SourceServer: nil,
|
||||
SourceHTTPFirewallPolicy: nil,
|
||||
SourceHTTPFirewallRuleGroup: nil,
|
||||
SourceHTTPFirewallRuleSet: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestIPListDB_ReadItems(t *testing.T) {
|
||||
db, err := NewIPListDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
items, err := db.ReadItems(0, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logs.PrintAsJSON(items, t)
|
||||
}
|
||||
55
internal/iplibrary/list_utils.go
Normal file
55
internal/iplibrary/list_utils.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
)
|
||||
|
||||
// AllowIP 检查IP是否被允许访问
|
||||
func AllowIP(ip string, serverId int64) bool {
|
||||
var ipLong = utils.IP2Long(ip)
|
||||
if ipLong == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// check white lists
|
||||
if GlobalWhiteIPList.Contains(ipLong) {
|
||||
return true
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindWhiteList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// check black lists
|
||||
if GlobalBlackIPList.Contains(ipLong) {
|
||||
return false
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindBlackList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AllowIPStrings 检查一组IP是否被允许访问
|
||||
func AllowIPStrings(ipStrings []string, serverId int64) bool {
|
||||
if len(ipStrings) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, ip := range ipStrings {
|
||||
isAllowed := AllowIP(ip, serverId)
|
||||
if !isAllowed {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
20
internal/iplibrary/list_utils_test.go
Normal file
20
internal/iplibrary/list_utils_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPIsAllowed(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.init()
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
t.Log(AllowIP("127.0.0.1", 0))
|
||||
t.Log(AllowIP("127.0.0.1", 23))
|
||||
}
|
||||
@@ -8,9 +8,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -24,13 +21,9 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
var versionCacheFile = "ip_list_version.cache"
|
||||
|
||||
// IPListManager IP名单管理
|
||||
type IPListManager struct {
|
||||
// 缓存文件
|
||||
// 每行一个数据:id|from|to|expiredAt
|
||||
cacheFile string
|
||||
db *IPListDB
|
||||
|
||||
version int64
|
||||
pageSize int64
|
||||
@@ -41,17 +34,13 @@ type IPListManager struct {
|
||||
|
||||
func NewIPListManager() *IPListManager {
|
||||
return &IPListManager{
|
||||
cacheFile: Tea.Root + "/configs/ip_list.cache",
|
||||
pageSize: 500,
|
||||
listMap: map[int64]*IPList{},
|
||||
pageSize: 500,
|
||||
listMap: map[int64]*IPList{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPListManager) Start() {
|
||||
// TODO 从缓存当中读取数据
|
||||
|
||||
// 从缓存中读取位置
|
||||
this.version = this.readLocalVersion()
|
||||
this.init()
|
||||
|
||||
// 第一次读取
|
||||
err := this.loop()
|
||||
@@ -60,6 +49,9 @@ func (this *IPListManager) Start() {
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
if Tea.IsTesting() {
|
||||
ticker = time.NewTicker(10 * time.Second)
|
||||
}
|
||||
events.On(events.EventQuit, func() {
|
||||
ticker.Stop()
|
||||
})
|
||||
@@ -88,6 +80,31 @@ func (this *IPListManager) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPListManager) init() {
|
||||
// 从数据库中当中读取数据
|
||||
db, err := NewIPListDB()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error())
|
||||
} else {
|
||||
this.db = db
|
||||
|
||||
var offset int64 = 0
|
||||
var size int64 = 1000
|
||||
for {
|
||||
items, err := db.ReadItems(offset, size)
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+err.Error())
|
||||
} else {
|
||||
if len(items) == 0 {
|
||||
break
|
||||
}
|
||||
this.processItems(items, false)
|
||||
}
|
||||
offset += int64(len(items))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPListManager) loop() error {
|
||||
for {
|
||||
hasNext, err := this.fetch()
|
||||
@@ -119,11 +136,53 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
if len(items) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 保存到本地数据库
|
||||
if this.db != nil {
|
||||
for _, item := range items {
|
||||
err = this.db.AddItem(item)
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "insert item to local database failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.processItems(items, true)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *IPListManager) FindList(listId int64) *IPList {
|
||||
this.locker.Lock()
|
||||
list, _ := this.listMap[listId]
|
||||
this.locker.Unlock()
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool) {
|
||||
this.locker.Lock()
|
||||
var changedLists = map[*IPList]bool{}
|
||||
for _, item := range items {
|
||||
list, ok := this.listMap[item.ListId]
|
||||
if !ok {
|
||||
var list *IPList
|
||||
// TODO 实现节点专有List
|
||||
if item.ServerId > 0 { // 服务专有List
|
||||
switch item.ListType {
|
||||
case "black":
|
||||
list = SharedServerListManager.FindBlackList(item.ServerId, true)
|
||||
case "white":
|
||||
list = SharedServerListManager.FindWhiteList(item.ServerId, true)
|
||||
}
|
||||
} else if item.IsGlobal { // 全局List
|
||||
switch item.ListType {
|
||||
case "black":
|
||||
list = GlobalBlackIPList
|
||||
case "white":
|
||||
list = GlobalWhiteIPList
|
||||
}
|
||||
} else { // 其他List
|
||||
list = this.listMap[item.ListId]
|
||||
}
|
||||
if list == nil {
|
||||
list = NewIPList()
|
||||
this.listMap[item.ListId] = list
|
||||
}
|
||||
@@ -133,18 +192,13 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
if item.IsDeleted {
|
||||
list.Delete(item.Id)
|
||||
|
||||
// 从临时名单中删除
|
||||
if len(item.IpFrom) > 0 && len(item.IpTo) == 0 {
|
||||
switch item.ListType {
|
||||
case "black":
|
||||
waf.SharedIPBlackList.RemoveIP(item.IpFrom)
|
||||
case "white":
|
||||
waf.SharedIPWhiteList.RemoveIP(item.IpFrom)
|
||||
}
|
||||
}
|
||||
// 从WAF名单中删除
|
||||
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId)
|
||||
|
||||
// 操作事件
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
if shouldExecute {
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -159,8 +213,10 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
})
|
||||
|
||||
// 事件操作
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
SharedActionManager.AddItem(item.ListType, item)
|
||||
if shouldExecute {
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
SharedActionManager.AddItem(item.ListType, item)
|
||||
}
|
||||
}
|
||||
|
||||
for changedList := range changedLists {
|
||||
@@ -169,38 +225,4 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
|
||||
this.locker.Unlock()
|
||||
this.version = items[len(items)-1].Version
|
||||
|
||||
// 写入版本号到缓存当中
|
||||
this.updateLocalVersion(this.version)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *IPListManager) FindList(listId int64) *IPList {
|
||||
this.locker.Lock()
|
||||
list, _ := this.listMap[listId]
|
||||
this.locker.Unlock()
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *IPListManager) readLocalVersion() int64 {
|
||||
data, err := ioutil.ReadFile(Tea.ConfigFile(versionCacheFile))
|
||||
if err != nil || len(data) == 0 {
|
||||
return 0
|
||||
}
|
||||
return types.Int64(string(data))
|
||||
}
|
||||
|
||||
func (this *IPListManager) updateLocalVersion(version int64) {
|
||||
fp, err := os.OpenFile(Tea.ConfigFile(versionCacheFile), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
remotelogs.Warn("IP_LIST", "write local version cache failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
_, err = fp.WriteString(types.String(version))
|
||||
if err != nil {
|
||||
remotelogs.Warn("IP_LIST", "write local version cache failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
_ = fp.Close()
|
||||
}
|
||||
|
||||
@@ -1,10 +1,36 @@
|
||||
package iplibrary
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPListManager_init(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.init()
|
||||
t.Log(manager.listMap)
|
||||
t.Log(SharedServerListManager.blackMap)
|
||||
logs.PrintAsJSON(GlobalBlackIPList.sortedItems, t)
|
||||
}
|
||||
|
||||
func TestIPListManager_check(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.init()
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
t.Log(SharedServerListManager.FindBlackList(23, true).Contains(utils.IP2Long("127.0.0.2")))
|
||||
t.Log(GlobalBlackIPList.Contains(utils.IP2Long("127.0.0.6")))
|
||||
}
|
||||
|
||||
func TestIPListManager_loop(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.pageSize = 2
|
||||
manager.Start()
|
||||
manager.pageSize = 10
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
61
internal/iplibrary/server_list_manager.go
Normal file
61
internal/iplibrary/server_list_manager.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import "sync"
|
||||
|
||||
var SharedServerListManager = NewServerListManager()
|
||||
|
||||
// ServerListManager 服务相关名单
|
||||
type ServerListManager struct {
|
||||
whiteMap map[int64]*IPList // serverId => *List
|
||||
blackMap map[int64]*IPList // serverId => *List
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewServerListManager() *ServerListManager {
|
||||
return &ServerListManager{
|
||||
whiteMap: map[int64]*IPList{},
|
||||
blackMap: map[int64]*IPList{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ServerListManager) FindWhiteList(serverId int64, autoCreate bool) *IPList {
|
||||
this.locker.RLock()
|
||||
list, ok := this.whiteMap[serverId]
|
||||
this.locker.RUnlock()
|
||||
if ok {
|
||||
return list
|
||||
}
|
||||
|
||||
if autoCreate {
|
||||
list = NewIPList()
|
||||
this.locker.Lock()
|
||||
this.whiteMap[serverId] = list
|
||||
this.locker.Unlock()
|
||||
|
||||
return list
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *ServerListManager) FindBlackList(serverId int64, autoCreate bool) *IPList {
|
||||
this.locker.RLock()
|
||||
list, ok := this.blackMap[serverId]
|
||||
this.locker.RUnlock()
|
||||
if ok {
|
||||
return list
|
||||
}
|
||||
|
||||
if autoCreate {
|
||||
list = NewIPList()
|
||||
this.locker.Lock()
|
||||
this.blackMap[serverId] = list
|
||||
this.locker.Unlock()
|
||||
|
||||
return list
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
)
|
||||
@@ -29,9 +30,8 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
// 是否在WAF名单中
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err == nil {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
|
||||
if !iplibrary.AllowIP(ip, 0) || (!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)) {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
|
||||
@@ -26,8 +26,17 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// 是否在全局名单中
|
||||
var remoteAddr = this.requestRemoteAddr(true)
|
||||
if !iplibrary.AllowIP(remoteAddr, this.Server.Id) {
|
||||
this.disableLog = true
|
||||
if conn != nil {
|
||||
_ = conn.(net.Conn).Close()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否在临时黑名单中
|
||||
var remoteAddr = this.WAFRemoteIP()
|
||||
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.Server.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {
|
||||
this.disableLog = true
|
||||
if conn != nil {
|
||||
|
||||
@@ -63,7 +63,9 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
|
||||
if timeout <= 0 {
|
||||
timeout = 60 // 默认封锁60秒
|
||||
}
|
||||
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout))
|
||||
|
||||
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, group.Id, set.Id)
|
||||
|
||||
if writer != nil {
|
||||
// close the connection
|
||||
|
||||
@@ -6,15 +6,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var captchaSalt = stringutil.Rand(32)
|
||||
|
||||
const (
|
||||
CaptchaSeconds = 600 // 10 minutes
|
||||
CaptchaPath = "/WAF/VERIFY/CAPTCHA"
|
||||
@@ -66,6 +63,8 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
"action": this,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"url": refURL,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(captchaConfig)
|
||||
|
||||
@@ -57,6 +57,8 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"scope": this.Scope,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
|
||||
@@ -56,7 +56,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
life = 600 // 默认10分钟
|
||||
}
|
||||
var setId = m.GetString("setId")
|
||||
SharedIPWhiteList.Add("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
|
||||
SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), m.GetInt64("groupId"), m.GetInt64("setId"))
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -65,6 +65,8 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"scope": this.Scope,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
"remoteIP": request.WAFRemoteIP(),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
@@ -18,6 +19,7 @@ type recordIPTask struct {
|
||||
listId int64
|
||||
expiredAt int64
|
||||
level string
|
||||
serverId int64
|
||||
|
||||
sourceServerId int64
|
||||
sourceHTTPFirewallPolicyId int64
|
||||
@@ -49,6 +51,7 @@ func init() {
|
||||
Reason: "触发WAF规则自动加入",
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
ServerId: task.serverId,
|
||||
SourceNodeId: teaconst.NodeId,
|
||||
SourceServerId: task.sourceServerId,
|
||||
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
|
||||
@@ -115,12 +118,18 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
|
||||
// 上报
|
||||
if this.IPListId > 0 {
|
||||
var serverId int64
|
||||
if this.Scope == firewallconfigs.FirewallScopeService {
|
||||
serverId = request.WAFServerId()
|
||||
}
|
||||
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: request.WAFRemoteIP(),
|
||||
listId: this.IPListId,
|
||||
expiredAt: expiredAt,
|
||||
level: this.Level,
|
||||
serverId: serverId,
|
||||
sourceServerId: request.WAFServerId(),
|
||||
sourceHTTPFirewallPolicyId: waf.Id,
|
||||
sourceHTTPFirewallRuleGroupId: group.Id,
|
||||
|
||||
@@ -54,7 +54,7 @@ func (this *CaptchaValidator) Run(request requests.Request, writer http.Response
|
||||
var originURL = m.GetString("url")
|
||||
|
||||
if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(actionConfig, setId, originURL, request, writer)
|
||||
this.validate(actionConfig, m.GetInt64("policyId"), m.GetInt64("groupId"), setId, originURL, request, writer)
|
||||
} else {
|
||||
this.show(actionConfig, request, writer)
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests
|
||||
</html>`))
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
|
||||
if len(captchaId) > 0 {
|
||||
captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
|
||||
@@ -143,7 +143,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64,
|
||||
}
|
||||
|
||||
// 加入到白名单
|
||||
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life))
|
||||
SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, groupId, setId)
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW
|
||||
life = 600 // 默认10分钟
|
||||
}
|
||||
setId := m.GetString("setId")
|
||||
SharedIPWhiteList.Add("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
|
||||
SharedIPWhiteList.RecordIP("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), m.GetInt64("groupId"), m.GetInt64("setId"))
|
||||
|
||||
// 返回原始URL
|
||||
var url = m.GetString("url")
|
||||
|
||||
@@ -63,6 +63,26 @@ func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serv
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// RecordIP 记录IP
|
||||
func (this *IPList) RecordIP(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64, policyId int64, groupId int64, setId int64) {
|
||||
this.Add(ipType, scope, serverId, ip, expiresAt)
|
||||
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: ip,
|
||||
listId: firewallconfigs.GlobalListId,
|
||||
expiredAt: expiresAt,
|
||||
level: firewallconfigs.DefaultEventLevel,
|
||||
sourceServerId: serverId,
|
||||
sourceHTTPFirewallPolicyId: policyId,
|
||||
sourceHTTPFirewallRuleGroupId: groupId,
|
||||
sourceHTTPFirewallRuleSetId: setId,
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Contains 判断是否有某个IP
|
||||
func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
|
||||
switch scope {
|
||||
@@ -81,10 +101,12 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope,
|
||||
}
|
||||
|
||||
// RemoveIP 删除IP
|
||||
// 暂时没办法清除某个服务相关的IP
|
||||
func (this *IPList) RemoveIP(ip string) {
|
||||
func (this *IPList) RemoveIP(ip string, serverId int64) {
|
||||
this.locker.Lock()
|
||||
delete(this.ipMap, "*@"+ip+"@"+IPTypeAll)
|
||||
if serverId > 0 {
|
||||
delete(this.ipMap, types.String(serverId)+"@"+ip+"@"+IPTypeAll)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user