自动将API节点的IP加入到白名单,防止误封

This commit is contained in:
GoEdgeLab
2021-12-06 10:09:37 +08:00
parent 03a4463fff
commit 4f129fa49f
3 changed files with 66 additions and 2 deletions

View File

@@ -8,9 +8,11 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"net"
"strconv" "strconv"
"strings" "strings"
) )
@@ -321,3 +323,51 @@ func (this *APINodeDAO) CountAllEnabledAPINodesWithSSLPolicyIds(tx *dbs.Tx, sslP
Param("policyIds", strings.Join(policyStringIds, ",")). Param("policyIds", strings.Join(policyStringIds, ",")).
Count() Count()
} }
// FindAllEnabledAPIAccessIPs 获取所有的API可访问IP地址
func (this *APINodeDAO) FindAllEnabledAPIAccessIPs(tx *dbs.Tx, cacheMap *utils.CacheMap) ([]string, error) {
var cacheKey = this.Table + ":FindAllEnabledAPIAccessIPs"
if cacheMap != nil {
cache, ok := cacheMap.Get(cacheKey)
if ok {
return cache.([]string), nil
}
}
ones, _, err := this.Query(tx).
State(APINodeStateEnabled).
Result("JSON_EXTRACT(accessAddrs, '$[*].host') AS host").
FindOnes()
if err != nil {
return nil, err
}
var result = []string{}
for _, one := range ones {
var host = one.GetString("host")
if len(host) == 0 {
continue
}
var ips = []string{}
err = json.Unmarshal([]byte(host), &ips)
if err != nil {
continue
}
for _, ip := range ips {
if !lists.ContainsString(result, ip) {
if net.ParseIP(ip) == nil {
continue
}
result = append(result, ip)
}
}
}
if cacheMap != nil {
cacheMap.Put(cacheKey, result)
}
return result, nil
}

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/utils"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"runtime" "runtime"
@@ -27,6 +28,12 @@ func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) {
} }
} }
func TestAPINodeDAO_FindAllEnabledAPIAccessIPs(t *testing.T) {
var cacheMap = utils.NewCacheMap()
t.Log(NewAPINodeDAO().FindAllEnabledAPIAccessIPs(nil, cacheMap))
t.Log(NewAPINodeDAO().FindAllEnabledAPIAccessIPs(nil, cacheMap))
}
func BenchmarkAPINodeDAO_New(b *testing.B) { func BenchmarkAPINodeDAO_New(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@@ -727,6 +727,13 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, cacheMap *utils
RegionId: int64(node.RegionId), RegionId: int64(node.RegionId),
} }
// API节点IP
apiNodeIPs, err := SharedAPINodeDAO.FindAllEnabledAPIAccessIPs(tx, cacheMap)
if err != nil {
return nil, err
}
config.AllowedIPs = append(config.AllowedIPs, apiNodeIPs...)
// 获取所有的服务 // 获取所有的服务
servers, err := SharedServerDAO.FindAllEnabledServersWithNode(tx, int64(node.Id)) servers, err := SharedServerDAO.FindAllEnabledServersWithNode(tx, int64(node.Id))
if err != nil { if err != nil {