自动清理本地IP名单过期条目/修复白名单可能不起作用的Bug

This commit is contained in:
GoEdgeLab
2022-03-06 19:40:26 +08:00
parent ca243fa739
commit e94ea8386d
6 changed files with 101 additions and 24 deletions

View File

@@ -5,20 +5,27 @@ package iplibrary
import ( import (
"database/sql" "database/sql"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
"path/filepath" "path/filepath"
"time"
) )
type IPListDB struct { type IPListDB struct {
db *sql.DB db *sql.DB
itemTableName string itemTableName string
deleteExpiredItemsStmt *sql.Stmt
deleteItemStmt *sql.Stmt deleteItemStmt *sql.Stmt
insertItemStmt *sql.Stmt insertItemStmt *sql.Stmt
selectItemsStmt *sql.Stmt selectItemsStmt *sql.Stmt
selectMaxVersionStmt *sql.Stmt
cleanTicker *time.Ticker
dir string dir string
} }
@@ -27,6 +34,7 @@ func NewIPListDB() (*IPListDB, error) {
var db = &IPListDB{ var db = &IPListDB{
itemTableName: "ipItems", itemTableName: "ipItems",
dir: filepath.Clean(Tea.Root + "/data"), dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour),
} }
err := db.init() err := db.init()
return db, err return db, err
@@ -83,6 +91,11 @@ ON "` + this.itemTableName + `" (
} }
// 初始化SQL语句 // 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil {
return err
}
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`) this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
if err != nil { if err != nil {
return err return err
@@ -93,16 +106,37 @@ ON "` + this.itemTableName + `" (
return err return err
} }
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`) this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
if err != nil { if err != nil {
return err return err
} }
this.selectMaxVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
this.db = db this.db = db
goman.New(func() {
events.On(events.EventQuit, func() {
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
err := this.DeleteExpiredItems()
if err != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
}
}
})
return nil return nil
} }
// DeleteExpiredItems 删除过期的条目
func (this *IPListDB) DeleteExpiredItems() error {
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
return err
}
func (this *IPListDB) AddItem(item *pb.IPItem) error { func (this *IPListDB) AddItem(item *pb.IPItem) error {
_, err := this.deleteItemStmt.Exec(item.Id) _, err := this.deleteItemStmt.Exec(item.Id)
if err != nil { if err != nil {
@@ -133,11 +167,27 @@ func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, e
return return
} }
// ReadMaxVersion 读取当前最大版本号
func (this *IPListDB) ReadMaxVersion() int64 {
row := this.selectMaxVersionStmt.QueryRow()
if row == nil {
return 0
}
var version int64
err = row.Scan(&version)
if err != nil {
return 0
}
return version
}
func (this *IPListDB) Close() error { func (this *IPListDB) Close() error {
if this.db != nil { if this.db != nil {
_ = this.deleteExpiredItemsStmt.Close()
_ = this.deleteItemStmt.Close() _ = this.deleteItemStmt.Close()
_ = this.insertItemStmt.Close() _ = this.insertItemStmt.Close()
_ = this.selectItemsStmt.Close() _ = this.selectItemsStmt.Close()
_ = this.selectMaxVersionStmt.Close()
return this.db.Close() return this.db.Close()
} }

View File

@@ -1,9 +1,10 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary package iplibrary_test
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
_ "github.com/iwind/TeaGo/bootstrap" _ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
"testing" "testing"
@@ -11,7 +12,7 @@ import (
) )
func TestIPListDB_AddItem(t *testing.T) { func TestIPListDB_AddItem(t *testing.T) {
db, err := NewIPListDB() db, err := iplibrary.NewIPListDB()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -48,7 +49,7 @@ func TestIPListDB_AddItem(t *testing.T) {
} }
func TestIPListDB_ReadItems(t *testing.T) { func TestIPListDB_ReadItems(t *testing.T) {
db, err := NewIPListDB() db, err := iplibrary.NewIPListDB()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -58,3 +59,11 @@ func TestIPListDB_ReadItems(t *testing.T) {
} }
logs.PrintAsJSON(items, t) logs.PrintAsJSON(items, t)
} }
func TestIPListDB_ReadMaxVersion(t *testing.T) {
db, err := iplibrary.NewIPListDB()
if err != nil {
t.Fatal(err)
}
t.Log(db.ReadMaxVersion())
}

View File

@@ -8,37 +8,37 @@ import (
// AllowIP 检查IP是否被允许访问 // AllowIP 检查IP是否被允许访问
// 如果一个IP不在任何名单中则允许访问 // 如果一个IP不在任何名单中则允许访问
func AllowIP(ip string, serverId int64) bool { func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) {
var ipLong = utils.IP2Long(ip) var ipLong = utils.IP2Long(ip)
if ipLong == 0 { if ipLong == 0 {
return false return false, false
} }
// check white lists // check white lists
if GlobalWhiteIPList.Contains(ipLong) { if GlobalWhiteIPList.Contains(ipLong) {
return true return true, true
} }
if serverId > 0 { if serverId > 0 {
var list = SharedServerListManager.FindWhiteList(serverId, false) var list = SharedServerListManager.FindWhiteList(serverId, false)
if list != nil && list.Contains(ipLong) { if list != nil && list.Contains(ipLong) {
return true return true, true
} }
} }
// check black lists // check black lists
if GlobalBlackIPList.Contains(ipLong) { if GlobalBlackIPList.Contains(ipLong) {
return false return false, false
} }
if serverId > 0 { if serverId > 0 {
var list = SharedServerListManager.FindBlackList(serverId, false) var list = SharedServerListManager.FindBlackList(serverId, false)
if list != nil && list.Contains(ipLong) { if list != nil && list.Contains(ipLong) {
return false return false, false
} }
} }
return true return true, false
} }
// IsInWhiteList 检查IP是否在白名单中 // IsInWhiteList 检查IP是否在白名单中
@@ -58,7 +58,7 @@ func AllowIPStrings(ipStrings []string, serverId int64) bool {
return true return true
} }
for _, ip := range ipStrings { for _, ip := range ipStrings {
isAllowed := AllowIP(ip, serverId) isAllowed, _ := AllowIP(ip, serverId)
if !isAllowed { if !isAllowed {
return false return false
} }

View File

@@ -61,7 +61,7 @@ func (this *IPListManager) Start() {
if Tea.IsTesting() { if Tea.IsTesting() {
this.ticker = time.NewTicker(10 * time.Second) this.ticker = time.NewTicker(10 * time.Second)
} }
countErrors := 0 var countErrors = 0
for { for {
select { select {
case <-this.ticker.C: case <-this.ticker.C:
@@ -100,6 +100,13 @@ func (this *IPListManager) init() {
} else { } else {
this.db = db this.db = db
// 删除本地数据库中过期的条目
_ = db.DeleteExpiredItems()
// 本地数据库中最大版本号
this.version = db.ReadMaxVersion()
// 从本地数据库中加载
var offset int64 = 0 var offset int64 = 0
var size int64 = 1000 var size int64 = 1000
for { for {
@@ -171,7 +178,7 @@ func (this *IPListManager) FindList(listId int64) *IPList {
return list return list
} }
func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool) { func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
this.locker.Lock() this.locker.Lock()
var changedLists = map[*IPList]zero.Zero{} var changedLists = map[*IPList]zero.Zero{}
for _, item := range items { for _, item := range items {
@@ -205,10 +212,10 @@ func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool)
list.Delete(item.Id) list.Delete(item.Id)
// 从WAF名单中删除 // 从WAF名单中删除
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId, shouldExecute) waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId, fromRemote)
// 操作事件 // 操作事件
if shouldExecute { if fromRemote {
SharedActionManager.DeleteItem(item.ListType, item) SharedActionManager.DeleteItem(item.ListType, item)
} }
@@ -225,7 +232,7 @@ func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool)
}) })
// 事件操作 // 事件操作
if shouldExecute { if fromRemote {
SharedActionManager.DeleteItem(item.ListType, item) SharedActionManager.DeleteItem(item.ListType, item)
SharedActionManager.AddItem(item.ListType, item) SharedActionManager.AddItem(item.ListType, item)
} }
@@ -236,5 +243,11 @@ func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool)
} }
this.locker.Unlock() this.locker.Unlock()
this.version = items[len(items)-1].Version
if fromRemote {
var latestVersion = items[len(items)-1].Version
if latestVersion > this.version {
this.version = latestVersion
}
}
} }

View File

@@ -49,7 +49,8 @@ func (this *ClientListener) Accept() (net.Conn, error) {
// 是否在WAF名单中 // 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil { if err == nil {
if !iplibrary.AllowIP(ip, 0) || (!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) && canGoNext, _ := iplibrary.AllowIP(ip, 0)
if !canGoNext || (!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)) { waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)) {
tcpConn, ok := conn.(*net.TCPConn) tcpConn, ok := conn.(*net.TCPConn)
if ok { if ok {

View File

@@ -35,11 +35,15 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
} }
// 是否在全局名单中 // 是否在全局名单中
if !iplibrary.AllowIP(remoteAddr, this.ReqServer.Id) { canGoNext, isInAllowedList := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
if !canGoNext {
this.disableLog = true this.disableLog = true
this.Close() this.Close()
return true return true
} }
if isInAllowedList {
return false
}
// 检查是否在临时黑名单中 // 检查是否在临时黑名单中
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) { if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {