自动清理本地IP名单过期条目/修复白名单可能不起作用的Bug

This commit is contained in:
刘祥超
2022-03-06 19:40:26 +08:00
parent 577a5618a1
commit 0d2d7591e5
6 changed files with 101 additions and 24 deletions

View File

@@ -5,20 +5,27 @@ package iplibrary
import (
"database/sql"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/Tea"
_ "github.com/mattn/go-sqlite3"
"os"
"path/filepath"
"time"
)
type IPListDB struct {
db *sql.DB
itemTableName string
deleteItemStmt *sql.Stmt
insertItemStmt *sql.Stmt
selectItemsStmt *sql.Stmt
itemTableName string
deleteExpiredItemsStmt *sql.Stmt
deleteItemStmt *sql.Stmt
insertItemStmt *sql.Stmt
selectItemsStmt *sql.Stmt
selectMaxVersionStmt *sql.Stmt
cleanTicker *time.Ticker
dir string
}
@@ -27,6 +34,7 @@ func NewIPListDB() (*IPListDB, error) {
var db = &IPListDB{
itemTableName: "ipItems",
dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour),
}
err := db.init()
return db, err
@@ -83,6 +91,11 @@ ON "` + this.itemTableName + `" (
}
// 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil {
return err
}
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
if err != nil {
return err
@@ -93,16 +106,37 @@ ON "` + this.itemTableName + `" (
return err
}
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
if err != nil {
return err
}
this.selectMaxVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
this.db = db
goman.New(func() {
events.On(events.EventQuit, func() {
this.cleanTicker.Stop()
})
for range this.cleanTicker.C {
err := this.DeleteExpiredItems()
if err != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+err.Error())
}
}
})
return nil
}
// DeleteExpiredItems 删除过期的条目
func (this *IPListDB) DeleteExpiredItems() error {
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
return err
}
func (this *IPListDB) AddItem(item *pb.IPItem) error {
_, err := this.deleteItemStmt.Exec(item.Id)
if err != nil {
@@ -133,11 +167,27 @@ func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, e
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *IPListDB) ReadMaxVersion() int64 {
row := this.selectMaxVersionStmt.QueryRow()
if row == nil {
return 0
}
var version int64
err = row.Scan(&version)
if err != nil {
return 0
}
return version
}
func (this *IPListDB) Close() error {
if this.db != nil {
_ = this.deleteExpiredItemsStmt.Close()
_ = this.deleteItemStmt.Close()
_ = this.insertItemStmt.Close()
_ = this.selectItemsStmt.Close()
_ = this.selectMaxVersionStmt.Close()
return this.db.Close()
}

View File

@@ -1,9 +1,10 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"testing"
@@ -11,7 +12,7 @@ import (
)
func TestIPListDB_AddItem(t *testing.T) {
db, err := NewIPListDB()
db, err := iplibrary.NewIPListDB()
if err != nil {
t.Fatal(err)
}
@@ -48,7 +49,7 @@ func TestIPListDB_AddItem(t *testing.T) {
}
func TestIPListDB_ReadItems(t *testing.T) {
db, err := NewIPListDB()
db, err := iplibrary.NewIPListDB()
if err != nil {
t.Fatal(err)
}
@@ -58,3 +59,11 @@ func TestIPListDB_ReadItems(t *testing.T) {
}
logs.PrintAsJSON(items, t)
}
func TestIPListDB_ReadMaxVersion(t *testing.T) {
db, err := iplibrary.NewIPListDB()
if err != nil {
t.Fatal(err)
}
t.Log(db.ReadMaxVersion())
}

View File

@@ -8,37 +8,37 @@ import (
// AllowIP 检查IP是否被允许访问
// 如果一个IP不在任何名单中则允许访问
func AllowIP(ip string, serverId int64) bool {
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) {
var ipLong = utils.IP2Long(ip)
if ipLong == 0 {
return false
return false, false
}
// check white lists
if GlobalWhiteIPList.Contains(ipLong) {
return true
return true, true
}
if serverId > 0 {
var list = SharedServerListManager.FindWhiteList(serverId, false)
if list != nil && list.Contains(ipLong) {
return true
return true, true
}
}
// check black lists
if GlobalBlackIPList.Contains(ipLong) {
return false
return false, false
}
if serverId > 0 {
var list = SharedServerListManager.FindBlackList(serverId, false)
if list != nil && list.Contains(ipLong) {
return false
return false, false
}
}
return true
return true, false
}
// IsInWhiteList 检查IP是否在白名单中
@@ -58,7 +58,7 @@ func AllowIPStrings(ipStrings []string, serverId int64) bool {
return true
}
for _, ip := range ipStrings {
isAllowed := AllowIP(ip, serverId)
isAllowed, _ := AllowIP(ip, serverId)
if !isAllowed {
return false
}

View File

@@ -61,7 +61,7 @@ func (this *IPListManager) Start() {
if Tea.IsTesting() {
this.ticker = time.NewTicker(10 * time.Second)
}
countErrors := 0
var countErrors = 0
for {
select {
case <-this.ticker.C:
@@ -100,6 +100,13 @@ func (this *IPListManager) init() {
} else {
this.db = db
// 删除本地数据库中过期的条目
_ = db.DeleteExpiredItems()
// 本地数据库中最大版本号
this.version = db.ReadMaxVersion()
// 从本地数据库中加载
var offset int64 = 0
var size int64 = 1000
for {
@@ -171,7 +178,7 @@ func (this *IPListManager) FindList(listId int64) *IPList {
return list
}
func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool) {
func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
this.locker.Lock()
var changedLists = map[*IPList]zero.Zero{}
for _, item := range items {
@@ -205,10 +212,10 @@ func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool)
list.Delete(item.Id)
// 从WAF名单中删除
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId, shouldExecute)
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId, fromRemote)
// 操作事件
if shouldExecute {
if fromRemote {
SharedActionManager.DeleteItem(item.ListType, item)
}
@@ -225,7 +232,7 @@ func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool)
})
// 事件操作
if shouldExecute {
if fromRemote {
SharedActionManager.DeleteItem(item.ListType, item)
SharedActionManager.AddItem(item.ListType, item)
}
@@ -236,5 +243,11 @@ func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool)
}
this.locker.Unlock()
this.version = items[len(items)-1].Version
if fromRemote {
var latestVersion = items[len(items)-1].Version
if latestVersion > this.version {
this.version = latestVersion
}
}
}

View File

@@ -49,7 +49,8 @@ func (this *ClientListener) Accept() (net.Conn, error) {
// 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil {
if !iplibrary.AllowIP(ip, 0) || (!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
canGoNext, _ := iplibrary.AllowIP(ip, 0)
if !canGoNext || (!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)) {
tcpConn, ok := conn.(*net.TCPConn)
if ok {

View File

@@ -35,11 +35,15 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
}
// 是否在全局名单中
if !iplibrary.AllowIP(remoteAddr, this.ReqServer.Id) {
canGoNext, isInAllowedList := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
if !canGoNext {
this.disableLog = true
this.Close()
return true
}
if isInAllowedList {
return false
}
// 检查是否在临时黑名单中
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {