diff --git a/go.mod b/go.mod index a86d2910..97579918 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,11 @@ go 1.15 replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon -replace github.com/TeaOSLab/EdgePlus => ../EdgePlus + require ( github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 - github.com/TeaOSLab/EdgePlus v0.0.0-00010101000000-000000000000 github.com/aliyun/alibaba-cloud-sdk-go v1.61.641 github.com/cespare/xxhash/v2 v2.1.1 github.com/go-acme/lego/v4 v4.1.2 diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index 6c22c9a8..e95cd91e 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -646,11 +646,13 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.N continue } - serverConfig := &serverconfigs.ServerConfig{} - err = json.Unmarshal([]byte(server.Config), serverConfig) + serverConfig, err := SharedServerDAO.ComposeServerConfig(tx, server) if err != nil { return nil, err } + if serverConfig == nil { + continue + } config.Servers = append(config.Servers, serverConfig) } @@ -669,34 +671,37 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.N config.GlobalConfig = globalConfig } - // WAF - clusterId := int64(node.ClusterId) - httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId) - if err != nil { - return nil, err - } - if httpFirewallPolicyId > 0 { - firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, httpFirewallPolicyId) + var primaryClusterId = int64(node.ClusterId) + var clusterIds = []int64{primaryClusterId} + clusterIds = append(clusterIds, node.DecodeSecondaryClusterIds()...) + for _, clusterId := range clusterIds { + httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId) if err != nil { return nil, err } - if firewallPolicy != nil { - config.HTTPFirewallPolicy = firewallPolicy + if httpFirewallPolicyId > 0 { + firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, httpFirewallPolicyId) + if err != nil { + return nil, err + } + if firewallPolicy != nil { + config.HTTPFirewallPolicies = append(config.HTTPFirewallPolicies, firewallPolicy) + } } - } - // 缓存策略 - httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId) - if err != nil { - return nil, err - } - if httpCachePolicyId > 0 { - cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, httpCachePolicyId) + // 缓存策略 + httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId) if err != nil { return nil, err } - if cachePolicy != nil { - config.HTTPCachePolicy = cachePolicy + if httpCachePolicyId > 0 { + cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, httpCachePolicyId) + if err != nil { + return nil, err + } + if cachePolicy != nil { + config.HTTPCachePolicies = append(config.HTTPCachePolicies, cachePolicy) + } } } @@ -724,14 +729,14 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.N } // TOA - toaConfig, err := SharedNodeClusterDAO.FindClusterTOAConfig(tx, clusterId) + toaConfig, err := SharedNodeClusterDAO.FindClusterTOAConfig(tx, primaryClusterId) if err != nil { return nil, err } config.TOA = toaConfig // 系统服务 - services, err := SharedNodeClusterDAO.FindNodeClusterSystemServices(tx, clusterId) + services, err := SharedNodeClusterDAO.FindNodeClusterSystemServices(tx, primaryClusterId) if err != nil { return nil, err } @@ -740,7 +745,7 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.N } // 防火墙动作 - actions, err := SharedNodeClusterFirewallActionDAO.FindAllEnabledFirewallActions(tx, clusterId) + actions, err := SharedNodeClusterFirewallActionDAO.FindAllEnabledFirewallActions(tx, primaryClusterId) if err != nil { return nil, err } diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index baf09855..6eeaa829 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -775,8 +775,8 @@ func (this *ServerDAO) FindServerNodeFilters(tx *dbs.Tx, serverId int64) (isOk b return true, int64(server.ClusterId), nil } -// ComposeServerConfig 构造服务的Config -func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverconfigs.ServerConfig, error) { +// ComposeServerConfigWithServerId 构造服务的Config +func (this *ServerDAO) ComposeServerConfigWithServerId(tx *dbs.Tx, serverId int64) (*serverconfigs.ServerConfig, error) { server, err := this.FindEnabledServer(tx, serverId) if err != nil { return nil, err @@ -784,9 +784,17 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc if server == nil { return nil, ErrNotFound } + return this.ComposeServerConfig(tx, server) +} + +// ComposeServerConfig 构造服务的Config +func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverconfigs.ServerConfig, error) { + if server == nil { + return nil, ErrNotFound + } config := &serverconfigs.ServerConfig{} - config.Id = serverId + config.Id = int64(server.Id) config.ClusterId = int64(server.ClusterId) config.Type = server.Type config.IsOn = server.IsOn == 1 @@ -796,7 +804,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // ServerNames if len(server.ServerNames) > 0 && server.ServerNames != "null" { serverNames := []*serverconfigs.ServerNameConfig{} - err = json.Unmarshal([]byte(server.ServerNames), &serverNames) + err := json.Unmarshal([]byte(server.ServerNames), &serverNames) if err != nil { return nil, err } @@ -824,7 +832,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // HTTP if len(server.Http) > 0 && server.Http != "null" { httpConfig := &serverconfigs.HTTPProtocolConfig{} - err = json.Unmarshal([]byte(server.Http), httpConfig) + err := json.Unmarshal([]byte(server.Http), httpConfig) if err != nil { return nil, err } @@ -834,7 +842,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // HTTPS if len(server.Https) > 0 && server.Https != "null" { httpsConfig := &serverconfigs.HTTPSProtocolConfig{} - err = json.Unmarshal([]byte(server.Https), httpsConfig) + err := json.Unmarshal([]byte(server.Https), httpsConfig) if err != nil { return nil, err } @@ -856,7 +864,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // TCP if len(server.Tcp) > 0 && server.Tcp != "null" { tcpConfig := &serverconfigs.TCPProtocolConfig{} - err = json.Unmarshal([]byte(server.Tcp), tcpConfig) + err := json.Unmarshal([]byte(server.Tcp), tcpConfig) if err != nil { return nil, err } @@ -866,7 +874,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // TLS if len(server.Tls) > 0 && server.Tls != "null" { tlsConfig := &serverconfigs.TLSProtocolConfig{} - err = json.Unmarshal([]byte(server.Tls), tlsConfig) + err := json.Unmarshal([]byte(server.Tls), tlsConfig) if err != nil { return nil, err } @@ -888,7 +896,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // Unix if len(server.Unix) > 0 && server.Unix != "null" { unixConfig := &serverconfigs.UnixProtocolConfig{} - err = json.Unmarshal([]byte(server.Unix), unixConfig) + err := json.Unmarshal([]byte(server.Unix), unixConfig) if err != nil { return nil, err } @@ -898,7 +906,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // UDP if len(server.Udp) > 0 && server.Udp != "null" { udpConfig := &serverconfigs.UDPProtocolConfig{} - err = json.Unmarshal([]byte(server.Udp), udpConfig) + err := json.Unmarshal([]byte(server.Udp), udpConfig) if err != nil { return nil, err } @@ -919,7 +927,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc // ReverseProxy if IsNotNull(server.ReverseProxy) { reverseProxyRef := &serverconfigs.ReverseProxyRef{} - err = json.Unmarshal([]byte(server.ReverseProxy), reverseProxyRef) + err := json.Unmarshal([]byte(server.ReverseProxy), reverseProxyRef) if err != nil { return nil, err } @@ -934,12 +942,31 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc } } + // WAF策略 + clusterId := int64(server.ClusterId) + httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId) + if err != nil { + return nil, err + } + if httpFirewallPolicyId > 0 { + config.HTTPFirewallPolicyId = httpFirewallPolicyId + } + + // 缓存策略 + httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId) + if err != nil { + return nil, err + } + if httpCachePolicyId > 0 { + config.HTTPCachePolicyId = httpCachePolicyId + } + return config, nil } // RenewServerConfig 更新服务的Config配置 func (this *ServerDAO) RenewServerConfig(tx *dbs.Tx, serverId int64, updateMd5 bool) (isChanged bool, err error) { - serverConfig, err := this.ComposeServerConfig(tx, serverId) + serverConfig, err := this.ComposeServerConfigWithServerId(tx, serverId) if err != nil { return false, err } diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 0b734ebb..10bd03d6 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -733,7 +733,7 @@ func (this *ServerService) FindEnabledServerConfig(ctx context.Context, req *pb. } } - config, err := models.SharedServerDAO.ComposeServerConfig(tx, req.ServerId) + config, err := models.SharedServerDAO.ComposeServerConfigWithServerId(tx, req.ServerId) if err != nil { return nil, err }