mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	实现缓存策略的部分功能
This commit is contained in:
		@@ -1,9 +1,9 @@
 | 
				
			|||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeAPI/internal/apis"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeAPI/internal/apps"
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/apps"
 | 
				
			||||||
	teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
 | 
						teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/nodes"
 | 
				
			||||||
	_ "github.com/TeaOSLab/EdgeAPI/internal/tasks"
 | 
						_ "github.com/TeaOSLab/EdgeAPI/internal/tasks"
 | 
				
			||||||
	_ "github.com/iwind/TeaGo/bootstrap"
 | 
						_ "github.com/iwind/TeaGo/bootstrap"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -14,6 +14,6 @@ func main() {
 | 
				
			|||||||
	app.Product(teaconst.ProductName)
 | 
						app.Product(teaconst.ProductName)
 | 
				
			||||||
	app.Usage(teaconst.ProcessName + " [start|stop|restart]")
 | 
						app.Usage(teaconst.ProcessName + " [start|stop|restart]")
 | 
				
			||||||
	app.Run(func() {
 | 
						app.Run(func() {
 | 
				
			||||||
		apis.NewAPINode().Start()
 | 
							nodes.NewAPINode().Start()
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,83 +0,0 @@
 | 
				
			|||||||
package apis
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeAPI/internal/configs"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeAPI/internal/utils"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
					 | 
				
			||||||
	"github.com/iwind/TeaGo/logs"
 | 
					 | 
				
			||||||
	"google.golang.org/grpc"
 | 
					 | 
				
			||||||
	"net"
 | 
					 | 
				
			||||||
	"os"
 | 
					 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var sharedAPIConfig *configs.APIConfig = nil
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type APINode struct {
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func NewAPINode() *APINode {
 | 
					 | 
				
			||||||
	return &APINode{}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (this *APINode) Start() {
 | 
					 | 
				
			||||||
	logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	config, err := configs.SharedAPIConfig()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logs.Println("[API]start failed: " + err.Error())
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	sharedAPIConfig = config
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// 设置rlimit
 | 
					 | 
				
			||||||
	_ = utils.SetRLimit(1024 * 1024)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// 监听RPC服务
 | 
					 | 
				
			||||||
	logs.Println("[API]start rpc: " + config.RPC.Listen)
 | 
					 | 
				
			||||||
	err = this.listenRPC()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logs.Println(err.Error())
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// 启动RPC监听
 | 
					 | 
				
			||||||
func (this *APINode) listenRPC() error {
 | 
					 | 
				
			||||||
	listener, err := net.Listen("tcp", sharedAPIConfig.RPC.Listen)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errors.New("[API]listen rpc failed: " + err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	rpcServer := grpc.NewServer()
 | 
					 | 
				
			||||||
	pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
 | 
					 | 
				
			||||||
	pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
 | 
					 | 
				
			||||||
	pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
 | 
					 | 
				
			||||||
	pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
 | 
					 | 
				
			||||||
	pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
 | 
					 | 
				
			||||||
	pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
 | 
					 | 
				
			||||||
	pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
 | 
					 | 
				
			||||||
	pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
 | 
					 | 
				
			||||||
	pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
 | 
					 | 
				
			||||||
	pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
 | 
					 | 
				
			||||||
	pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
 | 
					 | 
				
			||||||
	pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
 | 
					 | 
				
			||||||
	pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
 | 
					 | 
				
			||||||
	err = rpcServer.Serve(listener)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errors.New("[API]start rpc failed: " + err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -8,12 +8,15 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var sharedAPIConfig *APIConfig = nil
 | 
					var sharedAPIConfig *APIConfig = nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// API节点配置
 | 
				
			||||||
type APIConfig struct {
 | 
					type APIConfig struct {
 | 
				
			||||||
	RPC struct {
 | 
						NodeId string `yaml:"nodeId" json:"nodeId"`
 | 
				
			||||||
		Listen string `yaml:"listen"`
 | 
						Secret string `yaml:"secret" json:"secret"`
 | 
				
			||||||
	} `yaml:"rpc"`
 | 
					
 | 
				
			||||||
 | 
						numberId int64 // 数字ID
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 获取共享配置
 | 
				
			||||||
func SharedAPIConfig() (*APIConfig, error) {
 | 
					func SharedAPIConfig() (*APIConfig, error) {
 | 
				
			||||||
	sharedLocker.Lock()
 | 
						sharedLocker.Lock()
 | 
				
			||||||
	defer sharedLocker.Unlock()
 | 
						defer sharedLocker.Unlock()
 | 
				
			||||||
@@ -36,3 +39,13 @@ func SharedAPIConfig() (*APIConfig, error) {
 | 
				
			|||||||
	sharedAPIConfig = config
 | 
						sharedAPIConfig = config
 | 
				
			||||||
	return config, nil
 | 
						return config, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 设置数字ID
 | 
				
			||||||
 | 
					func (this *APIConfig) SetNumberId(numberId int64) {
 | 
				
			||||||
 | 
						this.numberId = numberId
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 获取数字ID
 | 
				
			||||||
 | 
					func (this *APIConfig) NumberId() int64 {
 | 
				
			||||||
 | 
						return this.numberId
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -59,6 +59,19 @@ func (this *APINodeDAO) FindEnabledAPINode(id int64) (*APINode, error) {
 | 
				
			|||||||
	return result.(*APINode), err
 | 
						return result.(*APINode), err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 根据ID和Secret查找节点
 | 
				
			||||||
 | 
					func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(uniqueId string, secret string) (*APINode, error) {
 | 
				
			||||||
 | 
						one, err := this.Query().
 | 
				
			||||||
 | 
							State(APINodeStateEnabled).
 | 
				
			||||||
 | 
							Attr("uniqueId", uniqueId).
 | 
				
			||||||
 | 
							Attr("secret", secret).
 | 
				
			||||||
 | 
							Find()
 | 
				
			||||||
 | 
						if err != nil || one == nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return one.(*APINode), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 根据主键查找名称
 | 
					// 根据主键查找名称
 | 
				
			||||||
func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
 | 
					func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
 | 
				
			||||||
	return this.Query().
 | 
						return this.Query().
 | 
				
			||||||
@@ -68,7 +81,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 创建API节点
 | 
					// 创建API节点
 | 
				
			||||||
func (this *APINodeDAO) CreateAPINode(name string, description string, host string, port int) (nodeId int64, err error) {
 | 
					func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) {
 | 
				
			||||||
	uniqueId, err := this.genUniqueId()
 | 
						uniqueId, err := this.genUniqueId()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return 0, err
 | 
							return 0, err
 | 
				
			||||||
@@ -80,13 +93,22 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	op := NewAPINodeOperator()
 | 
						op := NewAPINodeOperator()
 | 
				
			||||||
	op.IsOn = true
 | 
						op.IsOn = isOn
 | 
				
			||||||
	op.UniqueId = uniqueId
 | 
						op.UniqueId = uniqueId
 | 
				
			||||||
	op.Secret = secret
 | 
						op.Secret = secret
 | 
				
			||||||
	op.Name = name
 | 
						op.Name = name
 | 
				
			||||||
	op.Description = description
 | 
						op.Description = description
 | 
				
			||||||
	op.Host = host
 | 
					
 | 
				
			||||||
	op.Port = port
 | 
						if len(httpJSON) > 0 {
 | 
				
			||||||
 | 
							op.Http = httpJSON
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(httpsJSON) > 0 {
 | 
				
			||||||
 | 
							op.Https = httpsJSON
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(accessAddrsJSON) > 0 {
 | 
				
			||||||
 | 
							op.AccessAddrs = accessAddrsJSON
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	op.State = NodeStateEnabled
 | 
						op.State = NodeStateEnabled
 | 
				
			||||||
	_, err = this.Save(op)
 | 
						_, err = this.Save(op)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -97,7 +119,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 修改API节点
 | 
					// 修改API节点
 | 
				
			||||||
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, host string, port int) error {
 | 
					func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error {
 | 
				
			||||||
	if nodeId <= 0 {
 | 
						if nodeId <= 0 {
 | 
				
			||||||
		return errors.New("invalid nodeId")
 | 
							return errors.New("invalid nodeId")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -106,8 +128,24 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str
 | 
				
			|||||||
	op.Id = nodeId
 | 
						op.Id = nodeId
 | 
				
			||||||
	op.Name = name
 | 
						op.Name = name
 | 
				
			||||||
	op.Description = description
 | 
						op.Description = description
 | 
				
			||||||
	op.Host = host
 | 
						op.IsOn = isOn
 | 
				
			||||||
	op.Port = port
 | 
					
 | 
				
			||||||
 | 
						if len(httpJSON) > 0 {
 | 
				
			||||||
 | 
							op.Http = httpJSON
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							op.Http = "null"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(httpsJSON) > 0 {
 | 
				
			||||||
 | 
							op.Https = httpsJSON
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							op.Https = "null"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(accessAddrsJSON) > 0 {
 | 
				
			||||||
 | 
							op.AccessAddrs = accessAddrsJSON
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							op.AccessAddrs = "null"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err := this.Save(op)
 | 
						_, err := this.Save(op)
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,8 +9,9 @@ type APINode struct {
 | 
				
			|||||||
	Secret      string `field:"secret"`      // 密钥
 | 
						Secret      string `field:"secret"`      // 密钥
 | 
				
			||||||
	Name        string `field:"name"`        // 名称
 | 
						Name        string `field:"name"`        // 名称
 | 
				
			||||||
	Description string `field:"description"` // 描述
 | 
						Description string `field:"description"` // 描述
 | 
				
			||||||
	Host        string `field:"host"`        // 主机
 | 
						Http        string `field:"http"`        // 监听的HTTP配置
 | 
				
			||||||
	Port        uint32 `field:"port"`        // 端口
 | 
						Https       string `field:"https"`       // 监听的HTTPS配置
 | 
				
			||||||
 | 
						AccessAddrs string `field:"accessAddrs"` // 外部访问地址
 | 
				
			||||||
	Order       uint32 `field:"order"`       // 排序
 | 
						Order       uint32 `field:"order"`       // 排序
 | 
				
			||||||
	State       uint8  `field:"state"`       // 状态
 | 
						State       uint8  `field:"state"`       // 状态
 | 
				
			||||||
	CreatedAt   uint64 `field:"createdAt"`   // 创建时间
 | 
						CreatedAt   uint64 `field:"createdAt"`   // 创建时间
 | 
				
			||||||
@@ -26,8 +27,9 @@ type APINodeOperator struct {
 | 
				
			|||||||
	Secret      interface{} // 密钥
 | 
						Secret      interface{} // 密钥
 | 
				
			||||||
	Name        interface{} // 名称
 | 
						Name        interface{} // 名称
 | 
				
			||||||
	Description interface{} // 描述
 | 
						Description interface{} // 描述
 | 
				
			||||||
	Host        interface{} // 主机
 | 
						Http        interface{} // 监听的HTTP配置
 | 
				
			||||||
	Port        interface{} // 端口
 | 
						Https       interface{} // 监听的HTTPS配置
 | 
				
			||||||
 | 
						AccessAddrs interface{} // 外部访问地址
 | 
				
			||||||
	Order       interface{} // 排序
 | 
						Order       interface{} // 排序
 | 
				
			||||||
	State       interface{} // 状态
 | 
						State       interface{} // 状态
 | 
				
			||||||
	CreatedAt   interface{} // 创建时间
 | 
						CreatedAt   interface{} // 创建时间
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,8 +1,95 @@
 | 
				
			|||||||
package models
 | 
					package models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "strconv"
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 地址
 | 
					// 解析HTTP配置
 | 
				
			||||||
func (this *APINode) Address() string {
 | 
					func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
 | 
				
			||||||
	return this.Host + ":" + strconv.Itoa(int(this.Port))
 | 
						if !IsNotNull(this.Http) {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						config := &serverconfigs.HTTPProtocolConfig{}
 | 
				
			||||||
 | 
						err := json.Unmarshal([]byte(this.Http), config)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = config.Init()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return config, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 解析HTTPS配置
 | 
				
			||||||
 | 
					func (this *APINode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) {
 | 
				
			||||||
 | 
						if !IsNotNull(this.Https) {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						config := &serverconfigs.HTTPSProtocolConfig{}
 | 
				
			||||||
 | 
						err := json.Unmarshal([]byte(this.Https), config)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = config.Init()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if config.SSLPolicyRef != nil {
 | 
				
			||||||
 | 
							policyId := config.SSLPolicyRef.SSLPolicyId
 | 
				
			||||||
 | 
							if policyId > 0 {
 | 
				
			||||||
 | 
								sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if sslPolicy != nil {
 | 
				
			||||||
 | 
									config.SSLPolicy = sslPolicy
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = config.Init()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return config, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 解析访问地址
 | 
				
			||||||
 | 
					func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) {
 | 
				
			||||||
 | 
						if !IsNotNull(this.AccessAddrs) {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						addrConfigs := []*serverconfigs.NetworkAddressConfig{}
 | 
				
			||||||
 | 
						err := json.Unmarshal([]byte(this.AccessAddrs), &addrConfigs)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, addrConfig := range addrConfigs {
 | 
				
			||||||
 | 
							err = addrConfig.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return addrConfigs, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 解析访问地址,并返回字符串形式
 | 
				
			||||||
 | 
					func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
 | 
				
			||||||
 | 
						addrs, err := this.DecodeAccessAddrs()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						result := []string{}
 | 
				
			||||||
 | 
						for _, addr := range addrs {
 | 
				
			||||||
 | 
							result = append(result, addr.FullAddresses()...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,7 +2,7 @@ package models
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
				
			||||||
@@ -275,6 +275,17 @@ func (this *NodeDAO) FindAllNodeIdsMatch(clusterId int64) (result []int64, err e
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 获取一个集群的所有节点
 | 
				
			||||||
 | 
					func (this *NodeDAO) FindAllEnabledNodesWithClusterId(clusterId int64) (result []*Node, err error) {
 | 
				
			||||||
 | 
						_, err = this.Query().
 | 
				
			||||||
 | 
							State(NodeStateEnabled).
 | 
				
			||||||
 | 
							Attr("clusterId", clusterId).
 | 
				
			||||||
 | 
							DescPk().
 | 
				
			||||||
 | 
							Slice(&result).
 | 
				
			||||||
 | 
							FindAll()
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 计算节点数量
 | 
					// 计算节点数量
 | 
				
			||||||
func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState) (int64, error) {
 | 
					func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState) (int64, error) {
 | 
				
			||||||
	query := this.Query()
 | 
						query := this.Query()
 | 
				
			||||||
@@ -422,6 +433,28 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e
 | 
				
			|||||||
	return config, nil
 | 
						return config, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 修改当前连接的API节点
 | 
				
			||||||
 | 
					func (this *NodeDAO) UpdateNodeConnectedAPINodes(nodeId int64, apiNodeIds []int64) error {
 | 
				
			||||||
 | 
						if nodeId <= 0 {
 | 
				
			||||||
 | 
							return errors.New("invalid nodeId")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						op := NewNodeOperator()
 | 
				
			||||||
 | 
						op.Id = nodeId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(apiNodeIds) > 0 {
 | 
				
			||||||
 | 
							apiNodeIdsJSON, err := json.Marshal(apiNodeIds)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errors.Wrap(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							op.ConnectedAPINodes = apiNodeIdsJSON
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							op.ConnectedAPINodes = "[]"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err := this.Save(op)
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 生成唯一ID
 | 
					// 生成唯一ID
 | 
				
			||||||
func (this *NodeDAO) genUniqueId() (string, error) {
 | 
					func (this *NodeDAO) genUniqueId() (string, error) {
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,6 +21,7 @@ type Node struct {
 | 
				
			|||||||
	IsInstalled       uint8  `field:"isInstalled"`       // 是否已安装
 | 
						IsInstalled       uint8  `field:"isInstalled"`       // 是否已安装
 | 
				
			||||||
	InstallStatus     string `field:"installStatus"`     // 安装状态
 | 
						InstallStatus     string `field:"installStatus"`     // 安装状态
 | 
				
			||||||
	State             uint8  `field:"state"`             // 状态
 | 
						State             uint8  `field:"state"`             // 状态
 | 
				
			||||||
 | 
						ConnectedAPINodes string `field:"connectedAPINodes"` // 当前连接的API节点
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type NodeOperator struct {
 | 
					type NodeOperator struct {
 | 
				
			||||||
@@ -43,6 +44,7 @@ type NodeOperator struct {
 | 
				
			|||||||
	IsInstalled       interface{} // 是否已安装
 | 
						IsInstalled       interface{} // 是否已安装
 | 
				
			||||||
	InstallStatus     interface{} // 安装状态
 | 
						InstallStatus     interface{} // 安装状态
 | 
				
			||||||
	State             interface{} // 状态
 | 
						State             interface{} // 状态
 | 
				
			||||||
 | 
						ConnectedAPINodes interface{} // 当前连接的API节点
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewNodeOperator() *NodeOperator {
 | 
					func NewNodeOperator() *NodeOperator {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -136,7 +136,13 @@ func (this *Queue) InstallNode(nodeId int64) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	apiEndpoints := []string{}
 | 
						apiEndpoints := []string{}
 | 
				
			||||||
	for _, apiNode := range apiNodes {
 | 
						for _, apiNode := range apiNodes {
 | 
				
			||||||
		apiEndpoints = append(apiEndpoints, apiNode.Host+":"+strconv.Itoa(int(apiNode.Port)))
 | 
							addrConfigs, err := apiNode.DecodeAccessAddrs()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errors.New("decode api node access addresses failed: " + err.Error())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							for _, addrConfig := range addrConfigs {
 | 
				
			||||||
 | 
								apiEndpoints = append(apiEndpoints, addrConfig.FullAddresses()...)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	params := &NodeParams{
 | 
						params := &NodeParams{
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										170
									
								
								internal/nodes/api_node.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								internal/nodes/api_node.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,170 @@
 | 
				
			|||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/tls"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/configs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/db/models"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"google.golang.org/grpc"
 | 
				
			||||||
 | 
						"google.golang.org/grpc/credentials"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var sharedAPIConfig *configs.APIConfig = nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type APINode struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewAPINode() *APINode {
 | 
				
			||||||
 | 
						return &APINode{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *APINode) Start() {
 | 
				
			||||||
 | 
						logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 读取配置
 | 
				
			||||||
 | 
						config, err := configs.SharedAPIConfig()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[API]start failed: " + err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						sharedAPIConfig = config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 校验
 | 
				
			||||||
 | 
						apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[API]start failed: read api node from database failed: " + err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if apiNode == nil {
 | 
				
			||||||
 | 
							logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						config.SetNumberId(int64(apiNode.Id))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 设置rlimit
 | 
				
			||||||
 | 
						_ = utils.SetRLimit(1024 * 1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 监听RPC服务
 | 
				
			||||||
 | 
						logs.Println("[API]starting rpc ...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// HTTP
 | 
				
			||||||
 | 
						httpConfig, err := apiNode.DecodeHTTP()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[API]decode http config: " + err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						isListening := false
 | 
				
			||||||
 | 
						if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
 | 
				
			||||||
 | 
							for _, listen := range httpConfig.Listen {
 | 
				
			||||||
 | 
								for _, addr := range listen.Addresses() {
 | 
				
			||||||
 | 
									listener, err := net.Listen("tcp", addr)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									go func() {
 | 
				
			||||||
 | 
										err := this.listenRPC(listener, nil)
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
 | 
				
			||||||
 | 
											return
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}()
 | 
				
			||||||
 | 
									isListening = true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// HTTPS
 | 
				
			||||||
 | 
						httpsConfig, err := apiNode.DecodeHTTPS()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[API]decode https config: " + err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if httpsConfig != nil &&
 | 
				
			||||||
 | 
							httpsConfig.IsOn &&
 | 
				
			||||||
 | 
							len(httpsConfig.Listen) > 0 &&
 | 
				
			||||||
 | 
							httpsConfig.SSLPolicy != nil &&
 | 
				
			||||||
 | 
							httpsConfig.SSLPolicy.IsOn &&
 | 
				
			||||||
 | 
							len(httpsConfig.SSLPolicy.Certs) > 0 {
 | 
				
			||||||
 | 
							certs := []tls.Certificate{}
 | 
				
			||||||
 | 
							for _, cert := range httpsConfig.SSLPolicy.Certs {
 | 
				
			||||||
 | 
								certs = append(certs, *cert.CertObject())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, listen := range httpsConfig.Listen {
 | 
				
			||||||
 | 
								for _, addr := range listen.Addresses() {
 | 
				
			||||||
 | 
									listener, err := net.Listen("tcp", addr)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									go func() {
 | 
				
			||||||
 | 
										err := this.listenRPC(listener, &tls.Config{
 | 
				
			||||||
 | 
											Certificates: certs,
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
 | 
				
			||||||
 | 
											return
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}()
 | 
				
			||||||
 | 
									isListening = true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !isListening {
 | 
				
			||||||
 | 
							logs.Println("[API]the api node does have a listening address")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 保持进程
 | 
				
			||||||
 | 
						select {}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 启动RPC监听
 | 
				
			||||||
 | 
					func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
 | 
				
			||||||
 | 
						var rpcServer *grpc.Server
 | 
				
			||||||
 | 
						if tlsConfig == nil {
 | 
				
			||||||
 | 
							logs.Println("[API]listening http://" + listener.Addr().String() + " ...")
 | 
				
			||||||
 | 
							rpcServer = grpc.NewServer()
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							logs.Println("[API]listening https://" + listener.Addr().String() + " ...")
 | 
				
			||||||
 | 
							rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
 | 
				
			||||||
 | 
						pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
 | 
				
			||||||
 | 
						pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
 | 
				
			||||||
 | 
						pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
 | 
				
			||||||
 | 
						pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
 | 
				
			||||||
 | 
						pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
 | 
				
			||||||
 | 
						pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
 | 
				
			||||||
 | 
						pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
 | 
				
			||||||
 | 
						pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
 | 
				
			||||||
 | 
						pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
 | 
				
			||||||
 | 
						pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
 | 
				
			||||||
 | 
						pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
 | 
				
			||||||
 | 
						pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
 | 
				
			||||||
 | 
						err := rpcServer.Serve(listener)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errors.New("[API]start rpc failed: " + err.Error())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -17,7 +17,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.Host, int(req.Port))
 | 
						nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -32,7 +32,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.Host, int(req.Port))
 | 
						err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -69,6 +69,11 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	result := []*pb.APINode{}
 | 
						result := []*pb.APINode{}
 | 
				
			||||||
	for _, node := range nodes {
 | 
						for _, node := range nodes {
 | 
				
			||||||
 | 
							accessAddrs, err := node.DecodeAccessAddrStrings()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		result = append(result, &pb.APINode{
 | 
							result = append(result, &pb.APINode{
 | 
				
			||||||
			Id:              int64(node.Id),
 | 
								Id:              int64(node.Id),
 | 
				
			||||||
			IsOn:            node.IsOn == 1,
 | 
								IsOn:            node.IsOn == 1,
 | 
				
			||||||
@@ -77,9 +82,10 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb.
 | 
				
			|||||||
			Secret:          node.Secret,
 | 
								Secret:          node.Secret,
 | 
				
			||||||
			Name:            node.Name,
 | 
								Name:            node.Name,
 | 
				
			||||||
			Description:     node.Description,
 | 
								Description:     node.Description,
 | 
				
			||||||
			Host:        node.Host,
 | 
								HttpJSON:        []byte(node.Http),
 | 
				
			||||||
			Port:        int32(node.Port),
 | 
								HttpsJSON:       []byte(node.Https),
 | 
				
			||||||
			Address:     node.Address(),
 | 
								AccessAddrsJSON: []byte(node.AccessAddrs),
 | 
				
			||||||
 | 
								AccessAddrs:     accessAddrs,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -115,6 +121,11 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	result := []*pb.APINode{}
 | 
						result := []*pb.APINode{}
 | 
				
			||||||
	for _, node := range nodes {
 | 
						for _, node := range nodes {
 | 
				
			||||||
 | 
							accessAddrs, err := node.DecodeAccessAddrStrings()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		result = append(result, &pb.APINode{
 | 
							result = append(result, &pb.APINode{
 | 
				
			||||||
			Id:              int64(node.Id),
 | 
								Id:              int64(node.Id),
 | 
				
			||||||
			IsOn:            node.IsOn == 1,
 | 
								IsOn:            node.IsOn == 1,
 | 
				
			||||||
@@ -123,9 +134,10 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
 | 
				
			|||||||
			Secret:          node.Secret,
 | 
								Secret:          node.Secret,
 | 
				
			||||||
			Name:            node.Name,
 | 
								Name:            node.Name,
 | 
				
			||||||
			Description:     node.Description,
 | 
								Description:     node.Description,
 | 
				
			||||||
			Host:        node.Host,
 | 
								HttpJSON:        []byte(node.Http),
 | 
				
			||||||
			Port:        int32(node.Port),
 | 
								HttpsJSON:       []byte(node.Https),
 | 
				
			||||||
			Address:     node.Address(),
 | 
								AccessAddrsJSON: []byte(node.AccessAddrs),
 | 
				
			||||||
 | 
								AccessAddrs:     accessAddrs,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -148,6 +160,11 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
 | 
				
			|||||||
		return &pb.FindEnabledAPINodeResponse{Node: nil}, nil
 | 
							return &pb.FindEnabledAPINodeResponse{Node: nil}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						accessAddrs, err := node.DecodeAccessAddrStrings()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	result := &pb.APINode{
 | 
						result := &pb.APINode{
 | 
				
			||||||
		Id:              int64(node.Id),
 | 
							Id:              int64(node.Id),
 | 
				
			||||||
		IsOn:            node.IsOn == 1,
 | 
							IsOn:            node.IsOn == 1,
 | 
				
			||||||
@@ -156,9 +173,10 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
 | 
				
			|||||||
		Secret:          node.Secret,
 | 
							Secret:          node.Secret,
 | 
				
			||||||
		Name:            node.Name,
 | 
							Name:            node.Name,
 | 
				
			||||||
		Description:     node.Description,
 | 
							Description:     node.Description,
 | 
				
			||||||
		Host:        node.Host,
 | 
							HttpJSON:        []byte(node.Http),
 | 
				
			||||||
		Port:        int32(node.Port),
 | 
							HttpsJSON:       []byte(node.Https),
 | 
				
			||||||
		Address:     node.Address(),
 | 
							AccessAddrsJSON: []byte(node.AccessAddrs),
 | 
				
			||||||
 | 
							AccessAddrs:     accessAddrs,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &pb.FindEnabledAPINodeResponse{Node: result}, nil
 | 
						return &pb.FindEnabledAPINodeResponse{Node: result}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,8 +3,8 @@ package services
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeAPI/internal/db/models"
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/db/models"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeAPI/internal/installers"
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/installers"
 | 
				
			||||||
	rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
 | 
						rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
 | 
				
			||||||
@@ -12,6 +12,7 @@ import (
 | 
				
			|||||||
	"github.com/iwind/TeaGo/logs"
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 边缘节点相关服务
 | 
				
			||||||
type NodeService struct {
 | 
					type NodeService struct {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -121,6 +122,38 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List
 | 
				
			|||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 查找一个集群下的所有节点
 | 
				
			||||||
 | 
					func (this *NodeService) FindAllEnabledNodesWithClusterId(ctx context.Context, req *pb.FindAllEnabledNodesWithClusterIdRequest) (*pb.FindAllEnabledNodesWithClusterIdResponse, error) {
 | 
				
			||||||
 | 
						_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithClusterId(req.ClusterId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						result := []*pb.Node{}
 | 
				
			||||||
 | 
						for _, node := range nodes {
 | 
				
			||||||
 | 
							apiNodeIds := []int64{}
 | 
				
			||||||
 | 
							if models.IsNotNull(node.ConnectedAPINodes) {
 | 
				
			||||||
 | 
								err = json.Unmarshal([]byte(node.ConnectedAPINodes), &apiNodeIds)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							result = append(result, &pb.Node{
 | 
				
			||||||
 | 
								Id:                  int64(node.Id),
 | 
				
			||||||
 | 
								Name:                node.Name,
 | 
				
			||||||
 | 
								UniqueId:            node.UniqueId,
 | 
				
			||||||
 | 
								Secret:              node.Secret,
 | 
				
			||||||
 | 
								ConnectedAPINodeIds: apiNodeIds,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &pb.FindAllEnabledNodesWithClusterIdResponse{Nodes: result}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 禁用节点
 | 
					// 禁用节点
 | 
				
			||||||
func (this *NodeService) DisableNode(ctx context.Context, req *pb.DisableNodeRequest) (*pb.DisableNodeResponse, error) {
 | 
					func (this *NodeService) DisableNode(ctx context.Context, req *pb.DisableNodeRequest) (*pb.DisableNodeResponse, error) {
 | 
				
			||||||
	_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
 | 
						_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
 | 
				
			||||||
@@ -263,27 +296,6 @@ func (this *NodeService) ComposeNodeConfig(ctx context.Context, req *pb.ComposeN
 | 
				
			|||||||
	return &pb.ComposeNodeConfigResponse{NodeJSON: data}, nil
 | 
						return &pb.ComposeNodeConfigResponse{NodeJSON: data}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 节点stream
 | 
					 | 
				
			||||||
func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
 | 
					 | 
				
			||||||
	// TODO 使用此stream快速通知边缘节点更新
 | 
					 | 
				
			||||||
	// 校验节点
 | 
					 | 
				
			||||||
	_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	logs.Println("nodeId:", nodeId)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_ = server.Send(&pb.NodeStreamResponse{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for {
 | 
					 | 
				
			||||||
		req, err := server.Recv()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		logs.Println("received:", req)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// 更新节点状态
 | 
					// 更新节点状态
 | 
				
			||||||
func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeStatusRequest) (*pb.RPCUpdateSuccess, error) {
 | 
					func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeStatusRequest) (*pb.RPCUpdateSuccess, error) {
 | 
				
			||||||
	// 校验节点
 | 
						// 校验节点
 | 
				
			||||||
@@ -354,3 +366,19 @@ func (this *NodeService) InstallNode(ctx context.Context, req *pb.InstallNodeReq
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return &pb.InstallNodeResponse{}, nil
 | 
						return &pb.InstallNodeResponse{}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 更改节点连接的API节点信息
 | 
				
			||||||
 | 
					func (this *NodeService) UpdateNodeConnectedAPINodes(ctx context.Context, req *pb.UpdateNodeConnectedAPINodesRequest) (*pb.RPCUpdateSuccess, error) {
 | 
				
			||||||
 | 
						// 校验节点
 | 
				
			||||||
 | 
						_, nodeId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = models.SharedNodeDAO.UpdateNodeConnectedAPINodes(nodeId, req.ApiNodeIds)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return rpcutils.RPCUpdateSuccess()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										251
									
								
								internal/rpc/services/service_node_stream.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								internal/rpc/services/service_node_stream.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,251 @@
 | 
				
			|||||||
 | 
					package services
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/configs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
				
			||||||
 | 
						rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 命令请求相关
 | 
				
			||||||
 | 
					type CommandRequest struct {
 | 
				
			||||||
 | 
						Id          int64
 | 
				
			||||||
 | 
						Code        string
 | 
				
			||||||
 | 
						CommandJSON []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CommandRequestWaiting struct {
 | 
				
			||||||
 | 
						Timestamp int64
 | 
				
			||||||
 | 
						Chan      chan *pb.NodeStreamMessage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CommandRequestWaiting) Close() {
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							recover()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						close(this.Chan)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response
 | 
				
			||||||
 | 
					var commandRequestId = int64(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var nodeLocker = &sync.Mutex{}
 | 
				
			||||||
 | 
					var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NextCommandRequestId() int64 {
 | 
				
			||||||
 | 
						return atomic.AddInt64(&commandRequestId, 1)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						// 清理WaitingChannelMap
 | 
				
			||||||
 | 
						ticker := time.NewTicker(30 * time.Second)
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for range ticker.C {
 | 
				
			||||||
 | 
								nodeLocker.Lock()
 | 
				
			||||||
 | 
								for requestId, request := range responseChanMap {
 | 
				
			||||||
 | 
									if time.Now().Unix()-request.Timestamp > 3600 {
 | 
				
			||||||
 | 
										responseChanMap[requestId].Close()
 | 
				
			||||||
 | 
										delete(responseChanMap, requestId)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								nodeLocker.Unlock()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 节点stream
 | 
				
			||||||
 | 
					func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
 | 
				
			||||||
 | 
						// TODO 使用此stream快速通知边缘节点更新
 | 
				
			||||||
 | 
						// 校验节点
 | 
				
			||||||
 | 
						_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 返回连接成功
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							apiConfig, err := configs.SharedAPIConfig()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							connectedMessage := &messageconfigs.ConnectedAPINodeMessage{APINodeId: apiConfig.NumberId()}
 | 
				
			||||||
 | 
							connectedMessageJSON, err := json.Marshal(connectedMessage)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errors.Wrap(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							err = server.Send(&pb.NodeStreamMessage{
 | 
				
			||||||
 | 
								Code:     messageconfigs.MessageCodeConnectedAPINode,
 | 
				
			||||||
 | 
								DataJSON: connectedMessageJSON,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						logs.Println("[RPC]accepted node '" + strconv.FormatInt(nodeId, 10) + "' connection")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						nodeLocker.Lock()
 | 
				
			||||||
 | 
						requestChan, ok := requestChanMap[nodeId]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							requestChan = make(chan *CommandRequest, 1024)
 | 
				
			||||||
 | 
							requestChanMap[nodeId] = requestChan
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						nodeLocker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 发送请求
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								select {
 | 
				
			||||||
 | 
								case <-server.Context().Done():
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								case commandRequest := <-requestChan:
 | 
				
			||||||
 | 
									logs.Println("[RPC]sending command '" + commandRequest.Code + "' to node '" + strconv.FormatInt(nodeId, 10) + "'")
 | 
				
			||||||
 | 
									retries := 3 // 错误重试次数
 | 
				
			||||||
 | 
									for i := 0; i < retries; i++ {
 | 
				
			||||||
 | 
										err := server.Send(&pb.NodeStreamMessage{
 | 
				
			||||||
 | 
											RequestId: commandRequest.Id,
 | 
				
			||||||
 | 
											Code:      commandRequest.Code,
 | 
				
			||||||
 | 
											DataJSON:  commandRequest.CommandJSON,
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											if i == retries-1 {
 | 
				
			||||||
 | 
												logs.Println("[RPC]send command '" + commandRequest.Code + "' failed: " + err.Error())
 | 
				
			||||||
 | 
											} else {
 | 
				
			||||||
 | 
												time.Sleep(1 * time.Second)
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											break
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 接受请求
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							req, err := server.Recv()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							func(req *pb.NodeStreamMessage) {
 | 
				
			||||||
 | 
								// 因为 responseChan.Chan 有被关闭的风险,所以我们使用recover防止panic
 | 
				
			||||||
 | 
								defer func() {
 | 
				
			||||||
 | 
									recover()
 | 
				
			||||||
 | 
								}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								nodeLocker.Lock()
 | 
				
			||||||
 | 
								responseChan, ok := responseChanMap[req.RequestId]
 | 
				
			||||||
 | 
								if ok {
 | 
				
			||||||
 | 
									select {
 | 
				
			||||||
 | 
									case responseChan.Chan <- req:
 | 
				
			||||||
 | 
									default:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								nodeLocker.Unlock()
 | 
				
			||||||
 | 
							}(req)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 向节点发送命令
 | 
				
			||||||
 | 
					func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStreamMessage) (*pb.NodeStreamMessage, error) {
 | 
				
			||||||
 | 
						// 校验请求
 | 
				
			||||||
 | 
						_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						nodeId := req.NodeId
 | 
				
			||||||
 | 
						if nodeId <= 0 {
 | 
				
			||||||
 | 
							return nil, errors.New("node id should not be less than 0")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						nodeLocker.Lock()
 | 
				
			||||||
 | 
						requestChan, ok := requestChanMap[nodeId]
 | 
				
			||||||
 | 
						nodeLocker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return &pb.NodeStreamMessage{
 | 
				
			||||||
 | 
								RequestId: req.RequestId,
 | 
				
			||||||
 | 
								IsOk:      false,
 | 
				
			||||||
 | 
								Message:   "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet",
 | 
				
			||||||
 | 
							}, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.RequestId = NextCommandRequestId()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case requestChan <- &CommandRequest{
 | 
				
			||||||
 | 
							Id:          req.RequestId,
 | 
				
			||||||
 | 
							Code:        req.Code,
 | 
				
			||||||
 | 
							CommandJSON: req.DataJSON,
 | 
				
			||||||
 | 
						}:
 | 
				
			||||||
 | 
							// 加入到等待队列中
 | 
				
			||||||
 | 
							respChan := make(chan *pb.NodeStreamMessage, 1)
 | 
				
			||||||
 | 
							waiting := &CommandRequestWaiting{
 | 
				
			||||||
 | 
								Timestamp: time.Now().Unix(),
 | 
				
			||||||
 | 
								Chan:      respChan,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							nodeLocker.Lock()
 | 
				
			||||||
 | 
							responseChanMap[req.RequestId] = waiting
 | 
				
			||||||
 | 
							nodeLocker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 等待响应
 | 
				
			||||||
 | 
							timeoutSeconds := req.TimeoutSeconds
 | 
				
			||||||
 | 
							if timeoutSeconds <= 0 {
 | 
				
			||||||
 | 
								timeoutSeconds = 10
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							timeout := time.NewTimer(time.Duration(timeoutSeconds) * time.Second)
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case resp := <-respChan:
 | 
				
			||||||
 | 
								// 从队列中删除
 | 
				
			||||||
 | 
								nodeLocker.Lock()
 | 
				
			||||||
 | 
								delete(responseChanMap, req.RequestId)
 | 
				
			||||||
 | 
								waiting.Close()
 | 
				
			||||||
 | 
								nodeLocker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if resp == nil {
 | 
				
			||||||
 | 
									return &pb.NodeStreamMessage{
 | 
				
			||||||
 | 
										RequestId: req.RequestId,
 | 
				
			||||||
 | 
										Code:      req.Code,
 | 
				
			||||||
 | 
										Message:   "response timeout",
 | 
				
			||||||
 | 
										IsOk:      false,
 | 
				
			||||||
 | 
									}, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return resp, nil
 | 
				
			||||||
 | 
							case <-timeout.C:
 | 
				
			||||||
 | 
								// 从队列中删除
 | 
				
			||||||
 | 
								nodeLocker.Lock()
 | 
				
			||||||
 | 
								delete(responseChanMap, req.RequestId)
 | 
				
			||||||
 | 
								waiting.Close()
 | 
				
			||||||
 | 
								nodeLocker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return &pb.NodeStreamMessage{
 | 
				
			||||||
 | 
									RequestId: req.RequestId,
 | 
				
			||||||
 | 
									Code:      req.Code,
 | 
				
			||||||
 | 
									Message:   "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds",
 | 
				
			||||||
 | 
									IsOk:      false,
 | 
				
			||||||
 | 
								}, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return &pb.NodeStreamMessage{
 | 
				
			||||||
 | 
								RequestId: req.RequestId,
 | 
				
			||||||
 | 
								Code:      req.Code,
 | 
				
			||||||
 | 
								Message:   "command queue is full over " + strconv.Itoa(len(requestChan)),
 | 
				
			||||||
 | 
								IsOk:      false,
 | 
				
			||||||
 | 
							}, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user