mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-03 15:00:27 +08:00
实现缓存策略的部分功能
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/apis"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/apps"
|
||||
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/nodes"
|
||||
_ "github.com/TeaOSLab/EdgeAPI/internal/tasks"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
)
|
||||
@@ -14,6 +14,6 @@ func main() {
|
||||
app.Product(teaconst.ProductName)
|
||||
app.Usage(teaconst.ProcessName + " [start|stop|restart]")
|
||||
app.Run(func() {
|
||||
apis.NewAPINode().Start()
|
||||
nodes.NewAPINode().Start()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"google.golang.org/grpc"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var sharedAPIConfig *configs.APIConfig = nil
|
||||
|
||||
type APINode struct {
|
||||
}
|
||||
|
||||
func NewAPINode() *APINode {
|
||||
return &APINode{}
|
||||
}
|
||||
|
||||
func (this *APINode) Start() {
|
||||
logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
|
||||
|
||||
config, err := configs.SharedAPIConfig()
|
||||
if err != nil {
|
||||
logs.Println("[API]start failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
sharedAPIConfig = config
|
||||
|
||||
// 设置rlimit
|
||||
_ = utils.SetRLimit(1024 * 1024)
|
||||
|
||||
// 监听RPC服务
|
||||
logs.Println("[API]start rpc: " + config.RPC.Listen)
|
||||
err = this.listenRPC()
|
||||
if err != nil {
|
||||
logs.Println(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 启动RPC监听
|
||||
func (this *APINode) listenRPC() error {
|
||||
listener, err := net.Listen("tcp", sharedAPIConfig.RPC.Listen)
|
||||
if err != nil {
|
||||
return errors.New("[API]listen rpc failed: " + err.Error())
|
||||
}
|
||||
rpcServer := grpc.NewServer()
|
||||
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
|
||||
pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
|
||||
pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
|
||||
pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
|
||||
pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
|
||||
pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
|
||||
pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
|
||||
pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
|
||||
pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
|
||||
pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
|
||||
pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
|
||||
pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
|
||||
pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
|
||||
pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
|
||||
pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
|
||||
pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
|
||||
pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
|
||||
pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
|
||||
pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
|
||||
pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
|
||||
pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
|
||||
pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
|
||||
pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
|
||||
err = rpcServer.Serve(listener)
|
||||
if err != nil {
|
||||
return errors.New("[API]start rpc failed: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -8,12 +8,15 @@ import (
|
||||
|
||||
var sharedAPIConfig *APIConfig = nil
|
||||
|
||||
// API节点配置
|
||||
type APIConfig struct {
|
||||
RPC struct {
|
||||
Listen string `yaml:"listen"`
|
||||
} `yaml:"rpc"`
|
||||
NodeId string `yaml:"nodeId" json:"nodeId"`
|
||||
Secret string `yaml:"secret" json:"secret"`
|
||||
|
||||
numberId int64 // 数字ID
|
||||
}
|
||||
|
||||
// 获取共享配置
|
||||
func SharedAPIConfig() (*APIConfig, error) {
|
||||
sharedLocker.Lock()
|
||||
defer sharedLocker.Unlock()
|
||||
@@ -36,3 +39,13 @@ func SharedAPIConfig() (*APIConfig, error) {
|
||||
sharedAPIConfig = config
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 设置数字ID
|
||||
func (this *APIConfig) SetNumberId(numberId int64) {
|
||||
this.numberId = numberId
|
||||
}
|
||||
|
||||
// 获取数字ID
|
||||
func (this *APIConfig) NumberId() int64 {
|
||||
return this.numberId
|
||||
}
|
||||
|
||||
@@ -59,6 +59,19 @@ func (this *APINodeDAO) FindEnabledAPINode(id int64) (*APINode, error) {
|
||||
return result.(*APINode), err
|
||||
}
|
||||
|
||||
// 根据ID和Secret查找节点
|
||||
func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(uniqueId string, secret string) (*APINode, error) {
|
||||
one, err := this.Query().
|
||||
State(APINodeStateEnabled).
|
||||
Attr("uniqueId", uniqueId).
|
||||
Attr("secret", secret).
|
||||
Find()
|
||||
if err != nil || one == nil {
|
||||
return nil, err
|
||||
}
|
||||
return one.(*APINode), nil
|
||||
}
|
||||
|
||||
// 根据主键查找名称
|
||||
func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
|
||||
return this.Query().
|
||||
@@ -68,7 +81,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
|
||||
}
|
||||
|
||||
// 创建API节点
|
||||
func (this *APINodeDAO) CreateAPINode(name string, description string, host string, port int) (nodeId int64, err error) {
|
||||
func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) {
|
||||
uniqueId, err := this.genUniqueId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -80,13 +93,22 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
|
||||
}
|
||||
|
||||
op := NewAPINodeOperator()
|
||||
op.IsOn = true
|
||||
op.IsOn = isOn
|
||||
op.UniqueId = uniqueId
|
||||
op.Secret = secret
|
||||
op.Name = name
|
||||
op.Description = description
|
||||
op.Host = host
|
||||
op.Port = port
|
||||
|
||||
if len(httpJSON) > 0 {
|
||||
op.Http = httpJSON
|
||||
}
|
||||
if len(httpsJSON) > 0 {
|
||||
op.Https = httpsJSON
|
||||
}
|
||||
if len(accessAddrsJSON) > 0 {
|
||||
op.AccessAddrs = accessAddrsJSON
|
||||
}
|
||||
|
||||
op.State = NodeStateEnabled
|
||||
_, err = this.Save(op)
|
||||
if err != nil {
|
||||
@@ -97,7 +119,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
|
||||
}
|
||||
|
||||
// 修改API节点
|
||||
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, host string, port int) error {
|
||||
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error {
|
||||
if nodeId <= 0 {
|
||||
return errors.New("invalid nodeId")
|
||||
}
|
||||
@@ -106,8 +128,24 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str
|
||||
op.Id = nodeId
|
||||
op.Name = name
|
||||
op.Description = description
|
||||
op.Host = host
|
||||
op.Port = port
|
||||
op.IsOn = isOn
|
||||
|
||||
if len(httpJSON) > 0 {
|
||||
op.Http = httpJSON
|
||||
} else {
|
||||
op.Http = "null"
|
||||
}
|
||||
if len(httpsJSON) > 0 {
|
||||
op.Https = httpsJSON
|
||||
} else {
|
||||
op.Https = "null"
|
||||
}
|
||||
if len(accessAddrsJSON) > 0 {
|
||||
op.AccessAddrs = accessAddrsJSON
|
||||
} else {
|
||||
op.AccessAddrs = "null"
|
||||
}
|
||||
|
||||
_, err := this.Save(op)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,8 +9,9 @@ type APINode struct {
|
||||
Secret string `field:"secret"` // 密钥
|
||||
Name string `field:"name"` // 名称
|
||||
Description string `field:"description"` // 描述
|
||||
Host string `field:"host"` // 主机
|
||||
Port uint32 `field:"port"` // 端口
|
||||
Http string `field:"http"` // 监听的HTTP配置
|
||||
Https string `field:"https"` // 监听的HTTPS配置
|
||||
AccessAddrs string `field:"accessAddrs"` // 外部访问地址
|
||||
Order uint32 `field:"order"` // 排序
|
||||
State uint8 `field:"state"` // 状态
|
||||
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
||||
@@ -26,8 +27,9 @@ type APINodeOperator struct {
|
||||
Secret interface{} // 密钥
|
||||
Name interface{} // 名称
|
||||
Description interface{} // 描述
|
||||
Host interface{} // 主机
|
||||
Port interface{} // 端口
|
||||
Http interface{} // 监听的HTTP配置
|
||||
Https interface{} // 监听的HTTPS配置
|
||||
AccessAddrs interface{} // 外部访问地址
|
||||
Order interface{} // 排序
|
||||
State interface{} // 状态
|
||||
CreatedAt interface{} // 创建时间
|
||||
|
||||
@@ -1,8 +1,95 @@
|
||||
package models
|
||||
|
||||
import "strconv"
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
)
|
||||
|
||||
// 地址
|
||||
func (this *APINode) Address() string {
|
||||
return this.Host + ":" + strconv.Itoa(int(this.Port))
|
||||
// 解析HTTP配置
|
||||
func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
|
||||
if !IsNotNull(this.Http) {
|
||||
return nil, nil
|
||||
}
|
||||
config := &serverconfigs.HTTPProtocolConfig{}
|
||||
err := json.Unmarshal([]byte(this.Http), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = config.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 解析HTTPS配置
|
||||
func (this *APINode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) {
|
||||
if !IsNotNull(this.Https) {
|
||||
return nil, nil
|
||||
}
|
||||
config := &serverconfigs.HTTPSProtocolConfig{}
|
||||
err := json.Unmarshal([]byte(this.Https), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = config.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.SSLPolicyRef != nil {
|
||||
policyId := config.SSLPolicyRef.SSLPolicyId
|
||||
if policyId > 0 {
|
||||
sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sslPolicy != nil {
|
||||
config.SSLPolicy = sslPolicy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = config.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 解析访问地址
|
||||
func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) {
|
||||
if !IsNotNull(this.AccessAddrs) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
addrConfigs := []*serverconfigs.NetworkAddressConfig{}
|
||||
err := json.Unmarshal([]byte(this.AccessAddrs), &addrConfigs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, addrConfig := range addrConfigs {
|
||||
err = addrConfig.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return addrConfigs, nil
|
||||
}
|
||||
|
||||
// 解析访问地址,并返回字符串形式
|
||||
func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
|
||||
addrs, err := this.DecodeAccessAddrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := []string{}
|
||||
for _, addr := range addrs {
|
||||
result = append(result, addr.FullAddresses()...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
@@ -275,6 +275,17 @@ func (this *NodeDAO) FindAllNodeIdsMatch(clusterId int64) (result []int64, err e
|
||||
return
|
||||
}
|
||||
|
||||
// 获取一个集群的所有节点
|
||||
func (this *NodeDAO) FindAllEnabledNodesWithClusterId(clusterId int64) (result []*Node, err error) {
|
||||
_, err = this.Query().
|
||||
State(NodeStateEnabled).
|
||||
Attr("clusterId", clusterId).
|
||||
DescPk().
|
||||
Slice(&result).
|
||||
FindAll()
|
||||
return
|
||||
}
|
||||
|
||||
// 计算节点数量
|
||||
func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState) (int64, error) {
|
||||
query := this.Query()
|
||||
@@ -422,6 +433,28 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 修改当前连接的API节点
|
||||
func (this *NodeDAO) UpdateNodeConnectedAPINodes(nodeId int64, apiNodeIds []int64) error {
|
||||
if nodeId <= 0 {
|
||||
return errors.New("invalid nodeId")
|
||||
}
|
||||
|
||||
op := NewNodeOperator()
|
||||
op.Id = nodeId
|
||||
|
||||
if len(apiNodeIds) > 0 {
|
||||
apiNodeIdsJSON, err := json.Marshal(apiNodeIds)
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
op.ConnectedAPINodes = apiNodeIdsJSON
|
||||
} else {
|
||||
op.ConnectedAPINodes = "[]"
|
||||
}
|
||||
_, err := this.Save(op)
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成唯一ID
|
||||
func (this *NodeDAO) genUniqueId() (string, error) {
|
||||
for {
|
||||
|
||||
@@ -2,47 +2,49 @@ package models
|
||||
|
||||
// 节点
|
||||
type Node struct {
|
||||
Id uint32 `field:"id"` // ID
|
||||
AdminId uint32 `field:"adminId"` // 管理员ID
|
||||
UserId uint32 `field:"userId"` // 用户ID
|
||||
IsOn uint8 `field:"isOn"` // 是否启用
|
||||
UniqueId string `field:"uniqueId"` // 节点ID
|
||||
Secret string `field:"secret"` // 密钥
|
||||
Name string `field:"name"` // 节点名
|
||||
Code string `field:"code"` // 代号
|
||||
ClusterId uint32 `field:"clusterId"` // 集群ID
|
||||
RegionId uint32 `field:"regionId"` // 区域ID
|
||||
GroupId uint32 `field:"groupId"` // 分组ID
|
||||
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
||||
Status string `field:"status"` // 最新的状态
|
||||
Version uint32 `field:"version"` // 当前版本号
|
||||
LatestVersion uint32 `field:"latestVersion"` // 最后版本号
|
||||
InstallDir string `field:"installDir"` // 安装目录
|
||||
IsInstalled uint8 `field:"isInstalled"` // 是否已安装
|
||||
InstallStatus string `field:"installStatus"` // 安装状态
|
||||
State uint8 `field:"state"` // 状态
|
||||
Id uint32 `field:"id"` // ID
|
||||
AdminId uint32 `field:"adminId"` // 管理员ID
|
||||
UserId uint32 `field:"userId"` // 用户ID
|
||||
IsOn uint8 `field:"isOn"` // 是否启用
|
||||
UniqueId string `field:"uniqueId"` // 节点ID
|
||||
Secret string `field:"secret"` // 密钥
|
||||
Name string `field:"name"` // 节点名
|
||||
Code string `field:"code"` // 代号
|
||||
ClusterId uint32 `field:"clusterId"` // 集群ID
|
||||
RegionId uint32 `field:"regionId"` // 区域ID
|
||||
GroupId uint32 `field:"groupId"` // 分组ID
|
||||
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
||||
Status string `field:"status"` // 最新的状态
|
||||
Version uint32 `field:"version"` // 当前版本号
|
||||
LatestVersion uint32 `field:"latestVersion"` // 最后版本号
|
||||
InstallDir string `field:"installDir"` // 安装目录
|
||||
IsInstalled uint8 `field:"isInstalled"` // 是否已安装
|
||||
InstallStatus string `field:"installStatus"` // 安装状态
|
||||
State uint8 `field:"state"` // 状态
|
||||
ConnectedAPINodes string `field:"connectedAPINodes"` // 当前连接的API节点
|
||||
}
|
||||
|
||||
type NodeOperator struct {
|
||||
Id interface{} // ID
|
||||
AdminId interface{} // 管理员ID
|
||||
UserId interface{} // 用户ID
|
||||
IsOn interface{} // 是否启用
|
||||
UniqueId interface{} // 节点ID
|
||||
Secret interface{} // 密钥
|
||||
Name interface{} // 节点名
|
||||
Code interface{} // 代号
|
||||
ClusterId interface{} // 集群ID
|
||||
RegionId interface{} // 区域ID
|
||||
GroupId interface{} // 分组ID
|
||||
CreatedAt interface{} // 创建时间
|
||||
Status interface{} // 最新的状态
|
||||
Version interface{} // 当前版本号
|
||||
LatestVersion interface{} // 最后版本号
|
||||
InstallDir interface{} // 安装目录
|
||||
IsInstalled interface{} // 是否已安装
|
||||
InstallStatus interface{} // 安装状态
|
||||
State interface{} // 状态
|
||||
Id interface{} // ID
|
||||
AdminId interface{} // 管理员ID
|
||||
UserId interface{} // 用户ID
|
||||
IsOn interface{} // 是否启用
|
||||
UniqueId interface{} // 节点ID
|
||||
Secret interface{} // 密钥
|
||||
Name interface{} // 节点名
|
||||
Code interface{} // 代号
|
||||
ClusterId interface{} // 集群ID
|
||||
RegionId interface{} // 区域ID
|
||||
GroupId interface{} // 分组ID
|
||||
CreatedAt interface{} // 创建时间
|
||||
Status interface{} // 最新的状态
|
||||
Version interface{} // 当前版本号
|
||||
LatestVersion interface{} // 最后版本号
|
||||
InstallDir interface{} // 安装目录
|
||||
IsInstalled interface{} // 是否已安装
|
||||
InstallStatus interface{} // 安装状态
|
||||
State interface{} // 状态
|
||||
ConnectedAPINodes interface{} // 当前连接的API节点
|
||||
}
|
||||
|
||||
func NewNodeOperator() *NodeOperator {
|
||||
|
||||
@@ -136,7 +136,13 @@ func (this *Queue) InstallNode(nodeId int64) error {
|
||||
|
||||
apiEndpoints := []string{}
|
||||
for _, apiNode := range apiNodes {
|
||||
apiEndpoints = append(apiEndpoints, apiNode.Host+":"+strconv.Itoa(int(apiNode.Port)))
|
||||
addrConfigs, err := apiNode.DecodeAccessAddrs()
|
||||
if err != nil {
|
||||
return errors.New("decode api node access addresses failed: " + err.Error())
|
||||
}
|
||||
for _, addrConfig := range addrConfigs {
|
||||
apiEndpoints = append(apiEndpoints, addrConfig.FullAddresses()...)
|
||||
}
|
||||
}
|
||||
|
||||
params := &NodeParams{
|
||||
|
||||
170
internal/nodes/api_node.go
Normal file
170
internal/nodes/api_node.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var sharedAPIConfig *configs.APIConfig = nil
|
||||
|
||||
type APINode struct {
|
||||
}
|
||||
|
||||
func NewAPINode() *APINode {
|
||||
return &APINode{}
|
||||
}
|
||||
|
||||
func (this *APINode) Start() {
|
||||
logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
|
||||
|
||||
// 读取配置
|
||||
config, err := configs.SharedAPIConfig()
|
||||
if err != nil {
|
||||
logs.Println("[API]start failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
sharedAPIConfig = config
|
||||
|
||||
// 校验
|
||||
apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret)
|
||||
if err != nil {
|
||||
logs.Println("[API]start failed: read api node from database failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
if apiNode == nil {
|
||||
logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'")
|
||||
return
|
||||
}
|
||||
config.SetNumberId(int64(apiNode.Id))
|
||||
|
||||
// 设置rlimit
|
||||
_ = utils.SetRLimit(1024 * 1024)
|
||||
|
||||
// 监听RPC服务
|
||||
logs.Println("[API]starting rpc ...")
|
||||
|
||||
// HTTP
|
||||
httpConfig, err := apiNode.DecodeHTTP()
|
||||
if err != nil {
|
||||
logs.Println("[API]decode http config: " + err.Error())
|
||||
return
|
||||
}
|
||||
isListening := false
|
||||
if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
|
||||
for _, listen := range httpConfig.Listen {
|
||||
for _, addr := range listen.Addresses() {
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
err := this.listenRPC(listener, nil)
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
|
||||
return
|
||||
}
|
||||
}()
|
||||
isListening = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPS
|
||||
httpsConfig, err := apiNode.DecodeHTTPS()
|
||||
if err != nil {
|
||||
logs.Println("[API]decode https config: " + err.Error())
|
||||
return
|
||||
}
|
||||
if httpsConfig != nil &&
|
||||
httpsConfig.IsOn &&
|
||||
len(httpsConfig.Listen) > 0 &&
|
||||
httpsConfig.SSLPolicy != nil &&
|
||||
httpsConfig.SSLPolicy.IsOn &&
|
||||
len(httpsConfig.SSLPolicy.Certs) > 0 {
|
||||
certs := []tls.Certificate{}
|
||||
for _, cert := range httpsConfig.SSLPolicy.Certs {
|
||||
certs = append(certs, *cert.CertObject())
|
||||
}
|
||||
|
||||
for _, listen := range httpsConfig.Listen {
|
||||
for _, addr := range listen.Addresses() {
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
err := this.listenRPC(listener, &tls.Config{
|
||||
Certificates: certs,
|
||||
})
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
|
||||
return
|
||||
}
|
||||
}()
|
||||
isListening = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isListening {
|
||||
logs.Println("[API]the api node does have a listening address")
|
||||
return
|
||||
}
|
||||
|
||||
// 保持进程
|
||||
select {}
|
||||
}
|
||||
|
||||
// 启动RPC监听
|
||||
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
|
||||
var rpcServer *grpc.Server
|
||||
if tlsConfig == nil {
|
||||
logs.Println("[API]listening http://" + listener.Addr().String() + " ...")
|
||||
rpcServer = grpc.NewServer()
|
||||
} else {
|
||||
logs.Println("[API]listening https://" + listener.Addr().String() + " ...")
|
||||
rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
}
|
||||
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
|
||||
pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
|
||||
pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
|
||||
pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
|
||||
pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
|
||||
pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
|
||||
pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
|
||||
pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
|
||||
pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
|
||||
pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
|
||||
pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
|
||||
pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
|
||||
pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
|
||||
pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
|
||||
pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
|
||||
pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
|
||||
pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
|
||||
pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
|
||||
pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
|
||||
pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
|
||||
pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
|
||||
pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
|
||||
pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
|
||||
err := rpcServer.Serve(listener)
|
||||
if err != nil {
|
||||
return errors.New("[API]start rpc failed: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -17,7 +17,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.Host, int(req.Port))
|
||||
nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.Host, int(req.Port))
|
||||
err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -69,17 +69,23 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb.
|
||||
|
||||
result := []*pb.APINode{}
|
||||
for _, node := range nodes {
|
||||
accessAddrs, err := node.DecodeAccessAddrStrings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, &pb.APINode{
|
||||
Id: int64(node.Id),
|
||||
IsOn: node.IsOn == 1,
|
||||
ClusterId: int64(node.ClusterId),
|
||||
UniqueId: node.UniqueId,
|
||||
Secret: node.Secret,
|
||||
Name: node.Name,
|
||||
Description: node.Description,
|
||||
Host: node.Host,
|
||||
Port: int32(node.Port),
|
||||
Address: node.Address(),
|
||||
Id: int64(node.Id),
|
||||
IsOn: node.IsOn == 1,
|
||||
ClusterId: int64(node.ClusterId),
|
||||
UniqueId: node.UniqueId,
|
||||
Secret: node.Secret,
|
||||
Name: node.Name,
|
||||
Description: node.Description,
|
||||
HttpJSON: []byte(node.Http),
|
||||
HttpsJSON: []byte(node.Https),
|
||||
AccessAddrsJSON: []byte(node.AccessAddrs),
|
||||
AccessAddrs: accessAddrs,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -115,17 +121,23 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
|
||||
|
||||
result := []*pb.APINode{}
|
||||
for _, node := range nodes {
|
||||
accessAddrs, err := node.DecodeAccessAddrStrings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, &pb.APINode{
|
||||
Id: int64(node.Id),
|
||||
IsOn: node.IsOn == 1,
|
||||
ClusterId: int64(node.ClusterId),
|
||||
UniqueId: node.UniqueId,
|
||||
Secret: node.Secret,
|
||||
Name: node.Name,
|
||||
Description: node.Description,
|
||||
Host: node.Host,
|
||||
Port: int32(node.Port),
|
||||
Address: node.Address(),
|
||||
Id: int64(node.Id),
|
||||
IsOn: node.IsOn == 1,
|
||||
ClusterId: int64(node.ClusterId),
|
||||
UniqueId: node.UniqueId,
|
||||
Secret: node.Secret,
|
||||
Name: node.Name,
|
||||
Description: node.Description,
|
||||
HttpJSON: []byte(node.Http),
|
||||
HttpsJSON: []byte(node.Https),
|
||||
AccessAddrsJSON: []byte(node.AccessAddrs),
|
||||
AccessAddrs: accessAddrs,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -148,17 +160,23 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
|
||||
return &pb.FindEnabledAPINodeResponse{Node: nil}, nil
|
||||
}
|
||||
|
||||
accessAddrs, err := node.DecodeAccessAddrStrings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &pb.APINode{
|
||||
Id: int64(node.Id),
|
||||
IsOn: node.IsOn == 1,
|
||||
ClusterId: int64(node.ClusterId),
|
||||
UniqueId: node.UniqueId,
|
||||
Secret: node.Secret,
|
||||
Name: node.Name,
|
||||
Description: node.Description,
|
||||
Host: node.Host,
|
||||
Port: int32(node.Port),
|
||||
Address: node.Address(),
|
||||
Id: int64(node.Id),
|
||||
IsOn: node.IsOn == 1,
|
||||
ClusterId: int64(node.ClusterId),
|
||||
UniqueId: node.UniqueId,
|
||||
Secret: node.Secret,
|
||||
Name: node.Name,
|
||||
Description: node.Description,
|
||||
HttpJSON: []byte(node.Http),
|
||||
HttpsJSON: []byte(node.Https),
|
||||
AccessAddrsJSON: []byte(node.AccessAddrs),
|
||||
AccessAddrs: accessAddrs,
|
||||
}
|
||||
return &pb.FindEnabledAPINodeResponse{Node: result}, nil
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package services
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/installers"
|
||||
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
)
|
||||
|
||||
// 边缘节点相关服务
|
||||
type NodeService struct {
|
||||
}
|
||||
|
||||
@@ -121,6 +122,38 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 查找一个集群下的所有节点
|
||||
func (this *NodeService) FindAllEnabledNodesWithClusterId(ctx context.Context, req *pb.FindAllEnabledNodesWithClusterIdRequest) (*pb.FindAllEnabledNodesWithClusterIdResponse, error) {
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithClusterId(req.ClusterId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := []*pb.Node{}
|
||||
for _, node := range nodes {
|
||||
apiNodeIds := []int64{}
|
||||
if models.IsNotNull(node.ConnectedAPINodes) {
|
||||
err = json.Unmarshal([]byte(node.ConnectedAPINodes), &apiNodeIds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, &pb.Node{
|
||||
Id: int64(node.Id),
|
||||
Name: node.Name,
|
||||
UniqueId: node.UniqueId,
|
||||
Secret: node.Secret,
|
||||
ConnectedAPINodeIds: apiNodeIds,
|
||||
})
|
||||
}
|
||||
return &pb.FindAllEnabledNodesWithClusterIdResponse{Nodes: result}, nil
|
||||
}
|
||||
|
||||
// 禁用节点
|
||||
func (this *NodeService) DisableNode(ctx context.Context, req *pb.DisableNodeRequest) (*pb.DisableNodeResponse, error) {
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
||||
@@ -263,27 +296,6 @@ func (this *NodeService) ComposeNodeConfig(ctx context.Context, req *pb.ComposeN
|
||||
return &pb.ComposeNodeConfigResponse{NodeJSON: data}, nil
|
||||
}
|
||||
|
||||
// 节点stream
|
||||
func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
|
||||
// TODO 使用此stream快速通知边缘节点更新
|
||||
// 校验节点
|
||||
_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logs.Println("nodeId:", nodeId)
|
||||
|
||||
_ = server.Send(&pb.NodeStreamResponse{})
|
||||
|
||||
for {
|
||||
req, err := server.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logs.Println("received:", req)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新节点状态
|
||||
func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeStatusRequest) (*pb.RPCUpdateSuccess, error) {
|
||||
// 校验节点
|
||||
@@ -354,3 +366,19 @@ func (this *NodeService) InstallNode(ctx context.Context, req *pb.InstallNodeReq
|
||||
|
||||
return &pb.InstallNodeResponse{}, nil
|
||||
}
|
||||
|
||||
// 更改节点连接的API节点信息
|
||||
func (this *NodeService) UpdateNodeConnectedAPINodes(ctx context.Context, req *pb.UpdateNodeConnectedAPINodesRequest) (*pb.RPCUpdateSuccess, error) {
|
||||
// 校验节点
|
||||
_, nodeId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = models.SharedNodeDAO.UpdateNodeConnectedAPINodes(nodeId, req.ApiNodeIds)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err)
|
||||
}
|
||||
|
||||
return rpcutils.RPCUpdateSuccess()
|
||||
}
|
||||
|
||||
251
internal/rpc/services/service_node_stream.go
Normal file
251
internal/rpc/services/service_node_stream.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 命令请求相关
|
||||
type CommandRequest struct {
|
||||
Id int64
|
||||
Code string
|
||||
CommandJSON []byte
|
||||
}
|
||||
|
||||
type CommandRequestWaiting struct {
|
||||
Timestamp int64
|
||||
Chan chan *pb.NodeStreamMessage
|
||||
}
|
||||
|
||||
func (this *CommandRequestWaiting) Close() {
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
|
||||
close(this.Chan)
|
||||
}
|
||||
|
||||
var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response
|
||||
var commandRequestId = int64(0)
|
||||
|
||||
var nodeLocker = &sync.Mutex{}
|
||||
var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan
|
||||
|
||||
func NextCommandRequestId() int64 {
|
||||
return atomic.AddInt64(&commandRequestId, 1)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 清理WaitingChannelMap
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
nodeLocker.Lock()
|
||||
for requestId, request := range responseChanMap {
|
||||
if time.Now().Unix()-request.Timestamp > 3600 {
|
||||
responseChanMap[requestId].Close()
|
||||
delete(responseChanMap, requestId)
|
||||
}
|
||||
}
|
||||
nodeLocker.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 节点stream
|
||||
func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
|
||||
// TODO 使用此stream快速通知边缘节点更新
|
||||
// 校验节点
|
||||
_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 返回连接成功
|
||||
{
|
||||
apiConfig, err := configs.SharedAPIConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
connectedMessage := &messageconfigs.ConnectedAPINodeMessage{APINodeId: apiConfig.NumberId()}
|
||||
connectedMessageJSON, err := json.Marshal(connectedMessage)
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
err = server.Send(&pb.NodeStreamMessage{
|
||||
Code: messageconfigs.MessageCodeConnectedAPINode,
|
||||
DataJSON: connectedMessageJSON,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
logs.Println("[RPC]accepted node '" + strconv.FormatInt(nodeId, 10) + "' connection")
|
||||
|
||||
nodeLocker.Lock()
|
||||
requestChan, ok := requestChanMap[nodeId]
|
||||
if !ok {
|
||||
requestChan = make(chan *CommandRequest, 1024)
|
||||
requestChanMap[nodeId] = requestChan
|
||||
}
|
||||
nodeLocker.Unlock()
|
||||
|
||||
// 发送请求
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-server.Context().Done():
|
||||
return
|
||||
case commandRequest := <-requestChan:
|
||||
logs.Println("[RPC]sending command '" + commandRequest.Code + "' to node '" + strconv.FormatInt(nodeId, 10) + "'")
|
||||
retries := 3 // 错误重试次数
|
||||
for i := 0; i < retries; i++ {
|
||||
err := server.Send(&pb.NodeStreamMessage{
|
||||
RequestId: commandRequest.Id,
|
||||
Code: commandRequest.Code,
|
||||
DataJSON: commandRequest.CommandJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if i == retries-1 {
|
||||
logs.Println("[RPC]send command '" + commandRequest.Code + "' failed: " + err.Error())
|
||||
} else {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 接受请求
|
||||
for {
|
||||
req, err := server.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func(req *pb.NodeStreamMessage) {
|
||||
// 因为 responseChan.Chan 有被关闭的风险,所以我们使用recover防止panic
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
|
||||
nodeLocker.Lock()
|
||||
responseChan, ok := responseChanMap[req.RequestId]
|
||||
if ok {
|
||||
select {
|
||||
case responseChan.Chan <- req:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
nodeLocker.Unlock()
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
// 向节点发送命令
|
||||
func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStreamMessage) (*pb.NodeStreamMessage, error) {
|
||||
// 校验请求
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeId := req.NodeId
|
||||
if nodeId <= 0 {
|
||||
return nil, errors.New("node id should not be less than 0")
|
||||
}
|
||||
|
||||
nodeLocker.Lock()
|
||||
requestChan, ok := requestChanMap[nodeId]
|
||||
nodeLocker.Unlock()
|
||||
|
||||
if !ok {
|
||||
return &pb.NodeStreamMessage{
|
||||
RequestId: req.RequestId,
|
||||
IsOk: false,
|
||||
Message: "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet",
|
||||
}, nil
|
||||
}
|
||||
|
||||
req.RequestId = NextCommandRequestId()
|
||||
|
||||
select {
|
||||
case requestChan <- &CommandRequest{
|
||||
Id: req.RequestId,
|
||||
Code: req.Code,
|
||||
CommandJSON: req.DataJSON,
|
||||
}:
|
||||
// 加入到等待队列中
|
||||
respChan := make(chan *pb.NodeStreamMessage, 1)
|
||||
waiting := &CommandRequestWaiting{
|
||||
Timestamp: time.Now().Unix(),
|
||||
Chan: respChan,
|
||||
}
|
||||
|
||||
nodeLocker.Lock()
|
||||
responseChanMap[req.RequestId] = waiting
|
||||
nodeLocker.Unlock()
|
||||
|
||||
// 等待响应
|
||||
timeoutSeconds := req.TimeoutSeconds
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 10
|
||||
}
|
||||
timeout := time.NewTimer(time.Duration(timeoutSeconds) * time.Second)
|
||||
select {
|
||||
case resp := <-respChan:
|
||||
// 从队列中删除
|
||||
nodeLocker.Lock()
|
||||
delete(responseChanMap, req.RequestId)
|
||||
waiting.Close()
|
||||
nodeLocker.Unlock()
|
||||
|
||||
if resp == nil {
|
||||
return &pb.NodeStreamMessage{
|
||||
RequestId: req.RequestId,
|
||||
Code: req.Code,
|
||||
Message: "response timeout",
|
||||
IsOk: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
case <-timeout.C:
|
||||
// 从队列中删除
|
||||
nodeLocker.Lock()
|
||||
delete(responseChanMap, req.RequestId)
|
||||
waiting.Close()
|
||||
nodeLocker.Unlock()
|
||||
|
||||
return &pb.NodeStreamMessage{
|
||||
RequestId: req.RequestId,
|
||||
Code: req.Code,
|
||||
Message: "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds",
|
||||
IsOk: false,
|
||||
}, nil
|
||||
}
|
||||
default:
|
||||
return &pb.NodeStreamMessage{
|
||||
RequestId: req.RequestId,
|
||||
Code: req.Code,
|
||||
Message: "command queue is full over " + strconv.Itoa(len(requestChan)),
|
||||
IsOk: false,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user