diff --git a/internal/db/models/api_node_dao.go b/internal/db/models/api_node_dao.go index 526c3bf7..98e6b0f6 100644 --- a/internal/db/models/api_node_dao.go +++ b/internal/db/models/api_node_dao.go @@ -8,9 +8,11 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" + "net" "strconv" "strings" ) @@ -321,3 +323,51 @@ func (this *APINodeDAO) CountAllEnabledAPINodesWithSSLPolicyIds(tx *dbs.Tx, sslP Param("policyIds", strings.Join(policyStringIds, ",")). Count() } + +// FindAllEnabledAPIAccessIPs 获取所有的API可访问IP地址 +func (this *APINodeDAO) FindAllEnabledAPIAccessIPs(tx *dbs.Tx, cacheMap *utils.CacheMap) ([]string, error) { + var cacheKey = this.Table + ":FindAllEnabledAPIAccessIPs" + if cacheMap != nil { + cache, ok := cacheMap.Get(cacheKey) + if ok { + return cache.([]string), nil + } + } + + ones, _, err := this.Query(tx). + State(APINodeStateEnabled). + Result("JSON_EXTRACT(accessAddrs, '$[*].host') AS host"). + FindOnes() + if err != nil { + return nil, err + } + var result = []string{} + for _, one := range ones { + var host = one.GetString("host") + if len(host) == 0 { + continue + } + + var ips = []string{} + err = json.Unmarshal([]byte(host), &ips) + if err != nil { + continue + } + + for _, ip := range ips { + if !lists.ContainsString(result, ip) { + if net.ParseIP(ip) == nil { + continue + } + + result = append(result, ip) + } + } + } + + if cacheMap != nil { + cacheMap.Put(cacheKey, result) + } + + return result, nil +} diff --git a/internal/db/models/api_node_dao_test.go b/internal/db/models/api_node_dao_test.go index 30018a9e..1c2a97a6 100644 --- a/internal/db/models/api_node_dao_test.go +++ b/internal/db/models/api_node_dao_test.go @@ -1,6 +1,7 @@ package models import ( + "github.com/TeaOSLab/EdgeAPI/internal/utils" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/dbs" "runtime" @@ -27,6 +28,12 @@ func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) { } } +func TestAPINodeDAO_FindAllEnabledAPIAccessIPs(t *testing.T) { + var cacheMap = utils.NewCacheMap() + t.Log(NewAPINodeDAO().FindAllEnabledAPIAccessIPs(nil, cacheMap)) + t.Log(NewAPINodeDAO().FindAllEnabledAPIAccessIPs(nil, cacheMap)) +} + func BenchmarkAPINodeDAO_New(b *testing.B) { runtime.GOMAXPROCS(1) for i := 0; i < b.N; i++ { diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index 6b20a582..7b4da25d 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -519,9 +519,9 @@ func (this *NodeDAO) FindAllInactiveNodesWithClusterId(tx *dbs.Tx, clusterId int _, err = this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). - Attr("isOn", true). // 只监控启用的节点 + Attr("isOn", true). // 只监控启用的节点 Attr("isInstalled", true). // 只监控已经安装的节点 - Attr("isActive", true). // 当前已经在线的 + Attr("isActive", true). // 当前已经在线的 Where("(status IS NULL OR (JSON_EXTRACT(status, '$.isActive')=false AND UNIX_TIMESTAMP()-JSON_EXTRACT(status, '$.updatedAt')>10) OR UNIX_TIMESTAMP()-JSON_EXTRACT(status, '$.updatedAt')>120)"). Result("id", "name"). Slice(&result). @@ -727,6 +727,13 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, cacheMap *utils RegionId: int64(node.RegionId), } + // API节点IP + apiNodeIPs, err := SharedAPINodeDAO.FindAllEnabledAPIAccessIPs(tx, cacheMap) + if err != nil { + return nil, err + } + config.AllowedIPs = append(config.AllowedIPs, apiNodeIPs...) + // 获取所有的服务 servers, err := SharedServerDAO.FindAllEnabledServersWithNode(tx, int64(node.Id)) if err != nil {