diff --git a/cmd/edge-api/main.go b/cmd/edge-api/main.go index 476b1070..26a143b3 100644 --- a/cmd/edge-api/main.go +++ b/cmd/edge-api/main.go @@ -1,9 +1,9 @@ package main import ( - "github.com/TeaOSLab/EdgeAPI/internal/apis" "github.com/TeaOSLab/EdgeAPI/internal/apps" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" + "github.com/TeaOSLab/EdgeAPI/internal/nodes" _ "github.com/TeaOSLab/EdgeAPI/internal/tasks" _ "github.com/iwind/TeaGo/bootstrap" ) @@ -14,6 +14,6 @@ func main() { app.Product(teaconst.ProductName) app.Usage(teaconst.ProcessName + " [start|stop|restart]") app.Run(func() { - apis.NewAPINode().Start() + nodes.NewAPINode().Start() }) } diff --git a/internal/apis/api_node.go b/internal/apis/api_node.go deleted file mode 100644 index acae94dd..00000000 --- a/internal/apis/api_node.go +++ /dev/null @@ -1,83 +0,0 @@ -package apis - -import ( - "errors" - "github.com/TeaOSLab/EdgeAPI/internal/configs" - "github.com/TeaOSLab/EdgeAPI/internal/rpc/services" - "github.com/TeaOSLab/EdgeAPI/internal/utils" - "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" - "github.com/iwind/TeaGo/logs" - "google.golang.org/grpc" - "net" - "os" - "strconv" -) - -var sharedAPIConfig *configs.APIConfig = nil - -type APINode struct { -} - -func NewAPINode() *APINode { - return &APINode{} -} - -func (this *APINode) Start() { - logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid())) - - config, err := configs.SharedAPIConfig() - if err != nil { - logs.Println("[API]start failed: " + err.Error()) - return - } - sharedAPIConfig = config - - // 设置rlimit - _ = utils.SetRLimit(1024 * 1024) - - // 监听RPC服务 - logs.Println("[API]start rpc: " + config.RPC.Listen) - err = this.listenRPC() - if err != nil { - logs.Println(err.Error()) - return - } -} - -// 启动RPC监听 -func (this *APINode) listenRPC() error { - listener, err := net.Listen("tcp", sharedAPIConfig.RPC.Listen) - if err != nil { - return errors.New("[API]listen rpc failed: " + err.Error()) - } - rpcServer := grpc.NewServer() - pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{}) - pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{}) - pb.RegisterServerServiceServer(rpcServer, &services.ServerService{}) - pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{}) - pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{}) - pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{}) - pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{}) - pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{}) - pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{}) - pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{}) - pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{}) - pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{}) - pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{}) - pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{}) - pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{}) - pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{}) - pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{}) - pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{}) - pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{}) - pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{}) - pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{}) - pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{}) - pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{}) - err = rpcServer.Serve(listener) - if err != nil { - return errors.New("[API]start rpc failed: " + err.Error()) - } - - return nil -} diff --git a/internal/configs/api_config.go b/internal/configs/api_config.go index d5fee40b..2387f75e 100644 --- a/internal/configs/api_config.go +++ b/internal/configs/api_config.go @@ -8,12 +8,15 @@ import ( var sharedAPIConfig *APIConfig = nil +// API节点配置 type APIConfig struct { - RPC struct { - Listen string `yaml:"listen"` - } `yaml:"rpc"` + NodeId string `yaml:"nodeId" json:"nodeId"` + Secret string `yaml:"secret" json:"secret"` + + numberId int64 // 数字ID } +// 获取共享配置 func SharedAPIConfig() (*APIConfig, error) { sharedLocker.Lock() defer sharedLocker.Unlock() @@ -36,3 +39,13 @@ func SharedAPIConfig() (*APIConfig, error) { sharedAPIConfig = config return config, nil } + +// 设置数字ID +func (this *APIConfig) SetNumberId(numberId int64) { + this.numberId = numberId +} + +// 获取数字ID +func (this *APIConfig) NumberId() int64 { + return this.numberId +} diff --git a/internal/db/models/api_node_dao.go b/internal/db/models/api_node_dao.go index b7fcad2d..49769d3e 100644 --- a/internal/db/models/api_node_dao.go +++ b/internal/db/models/api_node_dao.go @@ -59,6 +59,19 @@ func (this *APINodeDAO) FindEnabledAPINode(id int64) (*APINode, error) { return result.(*APINode), err } +// 根据ID和Secret查找节点 +func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(uniqueId string, secret string) (*APINode, error) { + one, err := this.Query(). + State(APINodeStateEnabled). + Attr("uniqueId", uniqueId). + Attr("secret", secret). + Find() + if err != nil || one == nil { + return nil, err + } + return one.(*APINode), nil +} + // 根据主键查找名称 func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) { return this.Query(). @@ -68,7 +81,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) { } // 创建API节点 -func (this *APINodeDAO) CreateAPINode(name string, description string, host string, port int) (nodeId int64, err error) { +func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { uniqueId, err := this.genUniqueId() if err != nil { return 0, err @@ -80,13 +93,22 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri } op := NewAPINodeOperator() - op.IsOn = true + op.IsOn = isOn op.UniqueId = uniqueId op.Secret = secret op.Name = name op.Description = description - op.Host = host - op.Port = port + + if len(httpJSON) > 0 { + op.Http = httpJSON + } + if len(httpsJSON) > 0 { + op.Https = httpsJSON + } + if len(accessAddrsJSON) > 0 { + op.AccessAddrs = accessAddrsJSON + } + op.State = NodeStateEnabled _, err = this.Save(op) if err != nil { @@ -97,7 +119,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri } // 修改API节点 -func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, host string, port int) error { +func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -106,8 +128,24 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str op.Id = nodeId op.Name = name op.Description = description - op.Host = host - op.Port = port + op.IsOn = isOn + + if len(httpJSON) > 0 { + op.Http = httpJSON + } else { + op.Http = "null" + } + if len(httpsJSON) > 0 { + op.Https = httpsJSON + } else { + op.Https = "null" + } + if len(accessAddrsJSON) > 0 { + op.AccessAddrs = accessAddrsJSON + } else { + op.AccessAddrs = "null" + } + _, err := this.Save(op) return err } diff --git a/internal/db/models/api_node_model.go b/internal/db/models/api_node_model.go index 03baccc0..b96b1309 100644 --- a/internal/db/models/api_node_model.go +++ b/internal/db/models/api_node_model.go @@ -9,8 +9,9 @@ type APINode struct { Secret string `field:"secret"` // 密钥 Name string `field:"name"` // 名称 Description string `field:"description"` // 描述 - Host string `field:"host"` // 主机 - Port uint32 `field:"port"` // 端口 + Http string `field:"http"` // 监听的HTTP配置 + Https string `field:"https"` // 监听的HTTPS配置 + AccessAddrs string `field:"accessAddrs"` // 外部访问地址 Order uint32 `field:"order"` // 排序 State uint8 `field:"state"` // 状态 CreatedAt uint64 `field:"createdAt"` // 创建时间 @@ -26,8 +27,9 @@ type APINodeOperator struct { Secret interface{} // 密钥 Name interface{} // 名称 Description interface{} // 描述 - Host interface{} // 主机 - Port interface{} // 端口 + Http interface{} // 监听的HTTP配置 + Https interface{} // 监听的HTTPS配置 + AccessAddrs interface{} // 外部访问地址 Order interface{} // 排序 State interface{} // 状态 CreatedAt interface{} // 创建时间 diff --git a/internal/db/models/api_node_model_ext.go b/internal/db/models/api_node_model_ext.go index cb50b153..a8b60566 100644 --- a/internal/db/models/api_node_model_ext.go +++ b/internal/db/models/api_node_model_ext.go @@ -1,8 +1,95 @@ package models -import "strconv" +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" +) -// 地址 -func (this *APINode) Address() string { - return this.Host + ":" + strconv.Itoa(int(this.Port)) +// 解析HTTP配置 +func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) { + if !IsNotNull(this.Http) { + return nil, nil + } + config := &serverconfigs.HTTPProtocolConfig{} + err := json.Unmarshal([]byte(this.Http), config) + if err != nil { + return nil, err + } + + err = config.Init() + if err != nil { + return nil, err + } + + return config, nil +} + +// 解析HTTPS配置 +func (this *APINode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) { + if !IsNotNull(this.Https) { + return nil, nil + } + config := &serverconfigs.HTTPSProtocolConfig{} + err := json.Unmarshal([]byte(this.Https), config) + if err != nil { + return nil, err + } + + err = config.Init() + if err != nil { + return nil, err + } + + if config.SSLPolicyRef != nil { + policyId := config.SSLPolicyRef.SSLPolicyId + if policyId > 0 { + sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId) + if err != nil { + return nil, err + } + if sslPolicy != nil { + config.SSLPolicy = sslPolicy + } + } + } + + err = config.Init() + if err != nil { + return nil, err + } + + return config, nil +} + +// 解析访问地址 +func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) { + if !IsNotNull(this.AccessAddrs) { + return nil, nil + } + + addrConfigs := []*serverconfigs.NetworkAddressConfig{} + err := json.Unmarshal([]byte(this.AccessAddrs), &addrConfigs) + if err != nil { + return nil, err + } + for _, addrConfig := range addrConfigs { + err = addrConfig.Init() + if err != nil { + return nil, err + } + } + return addrConfigs, nil +} + +// 解析访问地址,并返回字符串形式 +func (this *APINode) DecodeAccessAddrStrings() ([]string, error) { + addrs, err := this.DecodeAccessAddrs() + if err != nil { + return nil, err + } + result := []string{} + for _, addr := range addrs { + result = append(result, addr.FullAddresses()...) + } + return result, nil } diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index fb37e0de..9abc6146 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -2,7 +2,7 @@ package models import ( "encoding/json" - "errors" + "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" @@ -275,6 +275,17 @@ func (this *NodeDAO) FindAllNodeIdsMatch(clusterId int64) (result []int64, err e return } +// 获取一个集群的所有节点 +func (this *NodeDAO) FindAllEnabledNodesWithClusterId(clusterId int64) (result []*Node, err error) { + _, err = this.Query(). + State(NodeStateEnabled). + Attr("clusterId", clusterId). + DescPk(). + Slice(&result). + FindAll() + return +} + // 计算节点数量 func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState) (int64, error) { query := this.Query() @@ -422,6 +433,28 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e return config, nil } +// 修改当前连接的API节点 +func (this *NodeDAO) UpdateNodeConnectedAPINodes(nodeId int64, apiNodeIds []int64) error { + if nodeId <= 0 { + return errors.New("invalid nodeId") + } + + op := NewNodeOperator() + op.Id = nodeId + + if len(apiNodeIds) > 0 { + apiNodeIdsJSON, err := json.Marshal(apiNodeIds) + if err != nil { + return errors.Wrap(err) + } + op.ConnectedAPINodes = apiNodeIdsJSON + } else { + op.ConnectedAPINodes = "[]" + } + _, err := this.Save(op) + return err +} + // 生成唯一ID func (this *NodeDAO) genUniqueId() (string, error) { for { diff --git a/internal/db/models/node_model.go b/internal/db/models/node_model.go index eea6cec6..3dde6665 100644 --- a/internal/db/models/node_model.go +++ b/internal/db/models/node_model.go @@ -2,47 +2,49 @@ package models // 节点 type Node struct { - Id uint32 `field:"id"` // ID - AdminId uint32 `field:"adminId"` // 管理员ID - UserId uint32 `field:"userId"` // 用户ID - IsOn uint8 `field:"isOn"` // 是否启用 - UniqueId string `field:"uniqueId"` // 节点ID - Secret string `field:"secret"` // 密钥 - Name string `field:"name"` // 节点名 - Code string `field:"code"` // 代号 - ClusterId uint32 `field:"clusterId"` // 集群ID - RegionId uint32 `field:"regionId"` // 区域ID - GroupId uint32 `field:"groupId"` // 分组ID - CreatedAt uint64 `field:"createdAt"` // 创建时间 - Status string `field:"status"` // 最新的状态 - Version uint32 `field:"version"` // 当前版本号 - LatestVersion uint32 `field:"latestVersion"` // 最后版本号 - InstallDir string `field:"installDir"` // 安装目录 - IsInstalled uint8 `field:"isInstalled"` // 是否已安装 - InstallStatus string `field:"installStatus"` // 安装状态 - State uint8 `field:"state"` // 状态 + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + IsOn uint8 `field:"isOn"` // 是否启用 + UniqueId string `field:"uniqueId"` // 节点ID + Secret string `field:"secret"` // 密钥 + Name string `field:"name"` // 节点名 + Code string `field:"code"` // 代号 + ClusterId uint32 `field:"clusterId"` // 集群ID + RegionId uint32 `field:"regionId"` // 区域ID + GroupId uint32 `field:"groupId"` // 分组ID + CreatedAt uint64 `field:"createdAt"` // 创建时间 + Status string `field:"status"` // 最新的状态 + Version uint32 `field:"version"` // 当前版本号 + LatestVersion uint32 `field:"latestVersion"` // 最后版本号 + InstallDir string `field:"installDir"` // 安装目录 + IsInstalled uint8 `field:"isInstalled"` // 是否已安装 + InstallStatus string `field:"installStatus"` // 安装状态 + State uint8 `field:"state"` // 状态 + ConnectedAPINodes string `field:"connectedAPINodes"` // 当前连接的API节点 } type NodeOperator struct { - Id interface{} // ID - AdminId interface{} // 管理员ID - UserId interface{} // 用户ID - IsOn interface{} // 是否启用 - UniqueId interface{} // 节点ID - Secret interface{} // 密钥 - Name interface{} // 节点名 - Code interface{} // 代号 - ClusterId interface{} // 集群ID - RegionId interface{} // 区域ID - GroupId interface{} // 分组ID - CreatedAt interface{} // 创建时间 - Status interface{} // 最新的状态 - Version interface{} // 当前版本号 - LatestVersion interface{} // 最后版本号 - InstallDir interface{} // 安装目录 - IsInstalled interface{} // 是否已安装 - InstallStatus interface{} // 安装状态 - State interface{} // 状态 + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + IsOn interface{} // 是否启用 + UniqueId interface{} // 节点ID + Secret interface{} // 密钥 + Name interface{} // 节点名 + Code interface{} // 代号 + ClusterId interface{} // 集群ID + RegionId interface{} // 区域ID + GroupId interface{} // 分组ID + CreatedAt interface{} // 创建时间 + Status interface{} // 最新的状态 + Version interface{} // 当前版本号 + LatestVersion interface{} // 最后版本号 + InstallDir interface{} // 安装目录 + IsInstalled interface{} // 是否已安装 + InstallStatus interface{} // 安装状态 + State interface{} // 状态 + ConnectedAPINodes interface{} // 当前连接的API节点 } func NewNodeOperator() *NodeOperator { diff --git a/internal/installers/queue.go b/internal/installers/queue.go index a61b0ee7..d3e3af10 100644 --- a/internal/installers/queue.go +++ b/internal/installers/queue.go @@ -136,7 +136,13 @@ func (this *Queue) InstallNode(nodeId int64) error { apiEndpoints := []string{} for _, apiNode := range apiNodes { - apiEndpoints = append(apiEndpoints, apiNode.Host+":"+strconv.Itoa(int(apiNode.Port))) + addrConfigs, err := apiNode.DecodeAccessAddrs() + if err != nil { + return errors.New("decode api node access addresses failed: " + err.Error()) + } + for _, addrConfig := range addrConfigs { + apiEndpoints = append(apiEndpoints, addrConfig.FullAddresses()...) + } } params := &NodeParams{ diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go new file mode 100644 index 00000000..1858a732 --- /dev/null +++ b/internal/nodes/api_node.go @@ -0,0 +1,170 @@ +package nodes + +import ( + "crypto/tls" + "errors" + "github.com/TeaOSLab/EdgeAPI/internal/configs" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/rpc/services" + "github.com/TeaOSLab/EdgeAPI/internal/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/iwind/TeaGo/logs" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "net" + "os" + "strconv" +) + +var sharedAPIConfig *configs.APIConfig = nil + +type APINode struct { +} + +func NewAPINode() *APINode { + return &APINode{} +} + +func (this *APINode) Start() { + logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid())) + + // 读取配置 + config, err := configs.SharedAPIConfig() + if err != nil { + logs.Println("[API]start failed: " + err.Error()) + return + } + sharedAPIConfig = config + + // 校验 + apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret) + if err != nil { + logs.Println("[API]start failed: read api node from database failed: " + err.Error()) + return + } + if apiNode == nil { + logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'") + return + } + config.SetNumberId(int64(apiNode.Id)) + + // 设置rlimit + _ = utils.SetRLimit(1024 * 1024) + + // 监听RPC服务 + logs.Println("[API]starting rpc ...") + + // HTTP + httpConfig, err := apiNode.DecodeHTTP() + if err != nil { + logs.Println("[API]decode http config: " + err.Error()) + return + } + isListening := false + if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 { + for _, listen := range httpConfig.Listen { + for _, addr := range listen.Addresses() { + listener, err := net.Listen("tcp", addr) + if err != nil { + logs.Println("[API]listening '" + addr + "' failed: " + err.Error()) + continue + } + go func() { + err := this.listenRPC(listener, nil) + if err != nil { + logs.Println("[API]listening '" + addr + "' rpc: " + err.Error()) + return + } + }() + isListening = true + } + } + } + + // HTTPS + httpsConfig, err := apiNode.DecodeHTTPS() + if err != nil { + logs.Println("[API]decode https config: " + err.Error()) + return + } + if httpsConfig != nil && + httpsConfig.IsOn && + len(httpsConfig.Listen) > 0 && + httpsConfig.SSLPolicy != nil && + httpsConfig.SSLPolicy.IsOn && + len(httpsConfig.SSLPolicy.Certs) > 0 { + certs := []tls.Certificate{} + for _, cert := range httpsConfig.SSLPolicy.Certs { + certs = append(certs, *cert.CertObject()) + } + + for _, listen := range httpsConfig.Listen { + for _, addr := range listen.Addresses() { + listener, err := net.Listen("tcp", addr) + if err != nil { + logs.Println("[API]listening '" + addr + "' failed: " + err.Error()) + continue + } + go func() { + err := this.listenRPC(listener, &tls.Config{ + Certificates: certs, + }) + if err != nil { + logs.Println("[API]listening '" + addr + "' rpc: " + err.Error()) + return + } + }() + isListening = true + } + } + } + + if !isListening { + logs.Println("[API]the api node does have a listening address") + return + } + + // 保持进程 + select {} +} + +// 启动RPC监听 +func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error { + var rpcServer *grpc.Server + if tlsConfig == nil { + logs.Println("[API]listening http://" + listener.Addr().String() + " ...") + rpcServer = grpc.NewServer() + } else { + logs.Println("[API]listening https://" + listener.Addr().String() + " ...") + rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + } + pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{}) + pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{}) + pb.RegisterServerServiceServer(rpcServer, &services.ServerService{}) + pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{}) + pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{}) + pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{}) + pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{}) + pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{}) + pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{}) + pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{}) + pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{}) + pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{}) + pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{}) + pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{}) + pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{}) + pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{}) + pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{}) + pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{}) + pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{}) + pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{}) + pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{}) + pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{}) + pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{}) + err := rpcServer.Serve(listener) + if err != nil { + return errors.New("[API]start rpc failed: " + err.Error()) + } + + return nil +} diff --git a/internal/rpc/services/service_api_node.go b/internal/rpc/services/service_api_node.go index 911940cf..e412a535 100644 --- a/internal/rpc/services/service_api_node.go +++ b/internal/rpc/services/service_api_node.go @@ -17,7 +17,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI return nil, err } - nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.Host, int(req.Port)) + nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -32,7 +32,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI return nil, err } - err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.Host, int(req.Port)) + err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -69,17 +69,23 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb. result := []*pb.APINode{} for _, node := range nodes { + accessAddrs, err := node.DecodeAccessAddrStrings() + if err != nil { + return nil, err + } + result = append(result, &pb.APINode{ - Id: int64(node.Id), - IsOn: node.IsOn == 1, - ClusterId: int64(node.ClusterId), - UniqueId: node.UniqueId, - Secret: node.Secret, - Name: node.Name, - Description: node.Description, - Host: node.Host, - Port: int32(node.Port), - Address: node.Address(), + Id: int64(node.Id), + IsOn: node.IsOn == 1, + ClusterId: int64(node.ClusterId), + UniqueId: node.UniqueId, + Secret: node.Secret, + Name: node.Name, + Description: node.Description, + HttpJSON: []byte(node.Http), + HttpsJSON: []byte(node.Https), + AccessAddrsJSON: []byte(node.AccessAddrs), + AccessAddrs: accessAddrs, }) } @@ -115,17 +121,23 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis result := []*pb.APINode{} for _, node := range nodes { + accessAddrs, err := node.DecodeAccessAddrStrings() + if err != nil { + return nil, err + } + result = append(result, &pb.APINode{ - Id: int64(node.Id), - IsOn: node.IsOn == 1, - ClusterId: int64(node.ClusterId), - UniqueId: node.UniqueId, - Secret: node.Secret, - Name: node.Name, - Description: node.Description, - Host: node.Host, - Port: int32(node.Port), - Address: node.Address(), + Id: int64(node.Id), + IsOn: node.IsOn == 1, + ClusterId: int64(node.ClusterId), + UniqueId: node.UniqueId, + Secret: node.Secret, + Name: node.Name, + Description: node.Description, + HttpJSON: []byte(node.Http), + HttpsJSON: []byte(node.Https), + AccessAddrsJSON: []byte(node.AccessAddrs), + AccessAddrs: accessAddrs, }) } @@ -148,17 +160,23 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find return &pb.FindEnabledAPINodeResponse{Node: nil}, nil } + accessAddrs, err := node.DecodeAccessAddrStrings() + if err != nil { + return nil, err + } + result := &pb.APINode{ - Id: int64(node.Id), - IsOn: node.IsOn == 1, - ClusterId: int64(node.ClusterId), - UniqueId: node.UniqueId, - Secret: node.Secret, - Name: node.Name, - Description: node.Description, - Host: node.Host, - Port: int32(node.Port), - Address: node.Address(), + Id: int64(node.Id), + IsOn: node.IsOn == 1, + ClusterId: int64(node.ClusterId), + UniqueId: node.UniqueId, + Secret: node.Secret, + Name: node.Name, + Description: node.Description, + HttpJSON: []byte(node.Http), + HttpsJSON: []byte(node.Https), + AccessAddrsJSON: []byte(node.AccessAddrs), + AccessAddrs: accessAddrs, } return &pb.FindEnabledAPINodeResponse{Node: result}, nil } diff --git a/internal/rpc/services/service_node.go b/internal/rpc/services/service_node.go index cd18312b..9a0c9959 100644 --- a/internal/rpc/services/service_node.go +++ b/internal/rpc/services/service_node.go @@ -3,8 +3,8 @@ package services import ( "context" "encoding/json" - "errors" "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/installers" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" @@ -12,6 +12,7 @@ import ( "github.com/iwind/TeaGo/logs" ) +// 边缘节点相关服务 type NodeService struct { } @@ -121,6 +122,38 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List }, nil } +// 查找一个集群下的所有节点 +func (this *NodeService) FindAllEnabledNodesWithClusterId(ctx context.Context, req *pb.FindAllEnabledNodesWithClusterIdRequest) (*pb.FindAllEnabledNodesWithClusterIdResponse, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithClusterId(req.ClusterId) + if err != nil { + return nil, err + } + result := []*pb.Node{} + for _, node := range nodes { + apiNodeIds := []int64{} + if models.IsNotNull(node.ConnectedAPINodes) { + err = json.Unmarshal([]byte(node.ConnectedAPINodes), &apiNodeIds) + if err != nil { + return nil, err + } + } + + result = append(result, &pb.Node{ + Id: int64(node.Id), + Name: node.Name, + UniqueId: node.UniqueId, + Secret: node.Secret, + ConnectedAPINodeIds: apiNodeIds, + }) + } + return &pb.FindAllEnabledNodesWithClusterIdResponse{Nodes: result}, nil +} + // 禁用节点 func (this *NodeService) DisableNode(ctx context.Context, req *pb.DisableNodeRequest) (*pb.DisableNodeResponse, error) { _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -263,27 +296,6 @@ func (this *NodeService) ComposeNodeConfig(ctx context.Context, req *pb.ComposeN return &pb.ComposeNodeConfigResponse{NodeJSON: data}, nil } -// 节点stream -func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error { - // TODO 使用此stream快速通知边缘节点更新 - // 校验节点 - _, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode) - if err != nil { - return err - } - logs.Println("nodeId:", nodeId) - - _ = server.Send(&pb.NodeStreamResponse{}) - - for { - req, err := server.Recv() - if err != nil { - return err - } - logs.Println("received:", req) - } -} - // 更新节点状态 func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeStatusRequest) (*pb.RPCUpdateSuccess, error) { // 校验节点 @@ -354,3 +366,19 @@ func (this *NodeService) InstallNode(ctx context.Context, req *pb.InstallNodeReq return &pb.InstallNodeResponse{}, nil } + +// 更改节点连接的API节点信息 +func (this *NodeService) UpdateNodeConnectedAPINodes(ctx context.Context, req *pb.UpdateNodeConnectedAPINodesRequest) (*pb.RPCUpdateSuccess, error) { + // 校验节点 + _, nodeId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode) + if err != nil { + return nil, err + } + + err = models.SharedNodeDAO.UpdateNodeConnectedAPINodes(nodeId, req.ApiNodeIds) + if err != nil { + return nil, errors.Wrap(err) + } + + return rpcutils.RPCUpdateSuccess() +} diff --git a/internal/rpc/services/service_node_stream.go b/internal/rpc/services/service_node_stream.go new file mode 100644 index 00000000..8c3a9fdc --- /dev/null +++ b/internal/rpc/services/service_node_stream.go @@ -0,0 +1,251 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "github.com/TeaOSLab/EdgeAPI/internal/configs" + "github.com/TeaOSLab/EdgeAPI/internal/errors" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/iwind/TeaGo/logs" + "strconv" + "sync" + "sync/atomic" + "time" +) + +// 命令请求相关 +type CommandRequest struct { + Id int64 + Code string + CommandJSON []byte +} + +type CommandRequestWaiting struct { + Timestamp int64 + Chan chan *pb.NodeStreamMessage +} + +func (this *CommandRequestWaiting) Close() { + defer func() { + recover() + }() + + close(this.Chan) +} + +var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response +var commandRequestId = int64(0) + +var nodeLocker = &sync.Mutex{} +var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan + +func NextCommandRequestId() int64 { + return atomic.AddInt64(&commandRequestId, 1) +} + +func init() { + // 清理WaitingChannelMap + ticker := time.NewTicker(30 * time.Second) + go func() { + for range ticker.C { + nodeLocker.Lock() + for requestId, request := range responseChanMap { + if time.Now().Unix()-request.Timestamp > 3600 { + responseChanMap[requestId].Close() + delete(responseChanMap, requestId) + } + } + nodeLocker.Unlock() + } + }() +} + +// 节点stream +func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error { + // TODO 使用此stream快速通知边缘节点更新 + // 校验节点 + _, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode) + if err != nil { + return err + } + + // 返回连接成功 + { + apiConfig, err := configs.SharedAPIConfig() + if err != nil { + return err + } + connectedMessage := &messageconfigs.ConnectedAPINodeMessage{APINodeId: apiConfig.NumberId()} + connectedMessageJSON, err := json.Marshal(connectedMessage) + if err != nil { + return errors.Wrap(err) + } + err = server.Send(&pb.NodeStreamMessage{ + Code: messageconfigs.MessageCodeConnectedAPINode, + DataJSON: connectedMessageJSON, + }) + if err != nil { + return err + } + } + + logs.Println("[RPC]accepted node '" + strconv.FormatInt(nodeId, 10) + "' connection") + + nodeLocker.Lock() + requestChan, ok := requestChanMap[nodeId] + if !ok { + requestChan = make(chan *CommandRequest, 1024) + requestChanMap[nodeId] = requestChan + } + nodeLocker.Unlock() + + // 发送请求 + go func() { + for { + select { + case <-server.Context().Done(): + return + case commandRequest := <-requestChan: + logs.Println("[RPC]sending command '" + commandRequest.Code + "' to node '" + strconv.FormatInt(nodeId, 10) + "'") + retries := 3 // 错误重试次数 + for i := 0; i < retries; i++ { + err := server.Send(&pb.NodeStreamMessage{ + RequestId: commandRequest.Id, + Code: commandRequest.Code, + DataJSON: commandRequest.CommandJSON, + }) + if err != nil { + if i == retries-1 { + logs.Println("[RPC]send command '" + commandRequest.Code + "' failed: " + err.Error()) + } else { + time.Sleep(1 * time.Second) + } + } else { + break + } + } + } + } + }() + + // 接受请求 + for { + req, err := server.Recv() + if err != nil { + return err + } + + func(req *pb.NodeStreamMessage) { + // 因为 responseChan.Chan 有被关闭的风险,所以我们使用recover防止panic + defer func() { + recover() + }() + + nodeLocker.Lock() + responseChan, ok := responseChanMap[req.RequestId] + if ok { + select { + case responseChan.Chan <- req: + default: + + } + } + nodeLocker.Unlock() + }(req) + } +} + +// 向节点发送命令 +func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStreamMessage) (*pb.NodeStreamMessage, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + nodeId := req.NodeId + if nodeId <= 0 { + return nil, errors.New("node id should not be less than 0") + } + + nodeLocker.Lock() + requestChan, ok := requestChanMap[nodeId] + nodeLocker.Unlock() + + if !ok { + return &pb.NodeStreamMessage{ + RequestId: req.RequestId, + IsOk: false, + Message: "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet", + }, nil + } + + req.RequestId = NextCommandRequestId() + + select { + case requestChan <- &CommandRequest{ + Id: req.RequestId, + Code: req.Code, + CommandJSON: req.DataJSON, + }: + // 加入到等待队列中 + respChan := make(chan *pb.NodeStreamMessage, 1) + waiting := &CommandRequestWaiting{ + Timestamp: time.Now().Unix(), + Chan: respChan, + } + + nodeLocker.Lock() + responseChanMap[req.RequestId] = waiting + nodeLocker.Unlock() + + // 等待响应 + timeoutSeconds := req.TimeoutSeconds + if timeoutSeconds <= 0 { + timeoutSeconds = 10 + } + timeout := time.NewTimer(time.Duration(timeoutSeconds) * time.Second) + select { + case resp := <-respChan: + // 从队列中删除 + nodeLocker.Lock() + delete(responseChanMap, req.RequestId) + waiting.Close() + nodeLocker.Unlock() + + if resp == nil { + return &pb.NodeStreamMessage{ + RequestId: req.RequestId, + Code: req.Code, + Message: "response timeout", + IsOk: false, + }, nil + } + + return resp, nil + case <-timeout.C: + // 从队列中删除 + nodeLocker.Lock() + delete(responseChanMap, req.RequestId) + waiting.Close() + nodeLocker.Unlock() + + return &pb.NodeStreamMessage{ + RequestId: req.RequestId, + Code: req.Code, + Message: "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds", + IsOk: false, + }, nil + } + default: + return &pb.NodeStreamMessage{ + RequestId: req.RequestId, + Code: req.Code, + Message: "command queue is full over " + strconv.Itoa(len(requestChan)), + IsOk: false, + }, nil + } +}