mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	实现缓存策略的部分功能
This commit is contained in:
		@@ -1,9 +1,9 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/apis"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/apps"
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/nodes"
 | 
			
		||||
	_ "github.com/TeaOSLab/EdgeAPI/internal/tasks"
 | 
			
		||||
	_ "github.com/iwind/TeaGo/bootstrap"
 | 
			
		||||
)
 | 
			
		||||
@@ -14,6 +14,6 @@ func main() {
 | 
			
		||||
	app.Product(teaconst.ProductName)
 | 
			
		||||
	app.Usage(teaconst.ProcessName + " [start|stop|restart]")
 | 
			
		||||
	app.Run(func() {
 | 
			
		||||
		apis.NewAPINode().Start()
 | 
			
		||||
		nodes.NewAPINode().Start()
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,83 +0,0 @@
 | 
			
		||||
package apis
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/configs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"google.golang.org/grpc"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var sharedAPIConfig *configs.APIConfig = nil
 | 
			
		||||
 | 
			
		||||
type APINode struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAPINode() *APINode {
 | 
			
		||||
	return &APINode{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *APINode) Start() {
 | 
			
		||||
	logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
 | 
			
		||||
 | 
			
		||||
	config, err := configs.SharedAPIConfig()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]start failed: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	sharedAPIConfig = config
 | 
			
		||||
 | 
			
		||||
	// 设置rlimit
 | 
			
		||||
	_ = utils.SetRLimit(1024 * 1024)
 | 
			
		||||
 | 
			
		||||
	// 监听RPC服务
 | 
			
		||||
	logs.Println("[API]start rpc: " + config.RPC.Listen)
 | 
			
		||||
	err = this.listenRPC()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println(err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 启动RPC监听
 | 
			
		||||
func (this *APINode) listenRPC() error {
 | 
			
		||||
	listener, err := net.Listen("tcp", sharedAPIConfig.RPC.Listen)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("[API]listen rpc failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	rpcServer := grpc.NewServer()
 | 
			
		||||
	pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
 | 
			
		||||
	pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
 | 
			
		||||
	pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
 | 
			
		||||
	pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
 | 
			
		||||
	pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
 | 
			
		||||
	pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
 | 
			
		||||
	pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
 | 
			
		||||
	pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
 | 
			
		||||
	pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
 | 
			
		||||
	pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
 | 
			
		||||
	pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
 | 
			
		||||
	pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
 | 
			
		||||
	pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
 | 
			
		||||
	pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
 | 
			
		||||
	pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
 | 
			
		||||
	pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
 | 
			
		||||
	pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
 | 
			
		||||
	pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
 | 
			
		||||
	pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
 | 
			
		||||
	pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
 | 
			
		||||
	pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
 | 
			
		||||
	pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
 | 
			
		||||
	pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
 | 
			
		||||
	err = rpcServer.Serve(listener)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("[API]start rpc failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -8,12 +8,15 @@ import (
 | 
			
		||||
 | 
			
		||||
var sharedAPIConfig *APIConfig = nil
 | 
			
		||||
 | 
			
		||||
// API节点配置
 | 
			
		||||
type APIConfig struct {
 | 
			
		||||
	RPC struct {
 | 
			
		||||
		Listen string `yaml:"listen"`
 | 
			
		||||
	} `yaml:"rpc"`
 | 
			
		||||
	NodeId string `yaml:"nodeId" json:"nodeId"`
 | 
			
		||||
	Secret string `yaml:"secret" json:"secret"`
 | 
			
		||||
 | 
			
		||||
	numberId int64 // 数字ID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取共享配置
 | 
			
		||||
func SharedAPIConfig() (*APIConfig, error) {
 | 
			
		||||
	sharedLocker.Lock()
 | 
			
		||||
	defer sharedLocker.Unlock()
 | 
			
		||||
@@ -36,3 +39,13 @@ func SharedAPIConfig() (*APIConfig, error) {
 | 
			
		||||
	sharedAPIConfig = config
 | 
			
		||||
	return config, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 设置数字ID
 | 
			
		||||
func (this *APIConfig) SetNumberId(numberId int64) {
 | 
			
		||||
	this.numberId = numberId
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取数字ID
 | 
			
		||||
func (this *APIConfig) NumberId() int64 {
 | 
			
		||||
	return this.numberId
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -59,6 +59,19 @@ func (this *APINodeDAO) FindEnabledAPINode(id int64) (*APINode, error) {
 | 
			
		||||
	return result.(*APINode), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据ID和Secret查找节点
 | 
			
		||||
func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(uniqueId string, secret string) (*APINode, error) {
 | 
			
		||||
	one, err := this.Query().
 | 
			
		||||
		State(APINodeStateEnabled).
 | 
			
		||||
		Attr("uniqueId", uniqueId).
 | 
			
		||||
		Attr("secret", secret).
 | 
			
		||||
		Find()
 | 
			
		||||
	if err != nil || one == nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return one.(*APINode), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据主键查找名称
 | 
			
		||||
func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
 | 
			
		||||
	return this.Query().
 | 
			
		||||
@@ -68,7 +81,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 创建API节点
 | 
			
		||||
func (this *APINodeDAO) CreateAPINode(name string, description string, host string, port int) (nodeId int64, err error) {
 | 
			
		||||
func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) {
 | 
			
		||||
	uniqueId, err := this.genUniqueId()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
@@ -80,13 +93,22 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	op := NewAPINodeOperator()
 | 
			
		||||
	op.IsOn = true
 | 
			
		||||
	op.IsOn = isOn
 | 
			
		||||
	op.UniqueId = uniqueId
 | 
			
		||||
	op.Secret = secret
 | 
			
		||||
	op.Name = name
 | 
			
		||||
	op.Description = description
 | 
			
		||||
	op.Host = host
 | 
			
		||||
	op.Port = port
 | 
			
		||||
 | 
			
		||||
	if len(httpJSON) > 0 {
 | 
			
		||||
		op.Http = httpJSON
 | 
			
		||||
	}
 | 
			
		||||
	if len(httpsJSON) > 0 {
 | 
			
		||||
		op.Https = httpsJSON
 | 
			
		||||
	}
 | 
			
		||||
	if len(accessAddrsJSON) > 0 {
 | 
			
		||||
		op.AccessAddrs = accessAddrsJSON
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	op.State = NodeStateEnabled
 | 
			
		||||
	_, err = this.Save(op)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -97,7 +119,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 修改API节点
 | 
			
		||||
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, host string, port int) error {
 | 
			
		||||
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error {
 | 
			
		||||
	if nodeId <= 0 {
 | 
			
		||||
		return errors.New("invalid nodeId")
 | 
			
		||||
	}
 | 
			
		||||
@@ -106,8 +128,24 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str
 | 
			
		||||
	op.Id = nodeId
 | 
			
		||||
	op.Name = name
 | 
			
		||||
	op.Description = description
 | 
			
		||||
	op.Host = host
 | 
			
		||||
	op.Port = port
 | 
			
		||||
	op.IsOn = isOn
 | 
			
		||||
 | 
			
		||||
	if len(httpJSON) > 0 {
 | 
			
		||||
		op.Http = httpJSON
 | 
			
		||||
	} else {
 | 
			
		||||
		op.Http = "null"
 | 
			
		||||
	}
 | 
			
		||||
	if len(httpsJSON) > 0 {
 | 
			
		||||
		op.Https = httpsJSON
 | 
			
		||||
	} else {
 | 
			
		||||
		op.Https = "null"
 | 
			
		||||
	}
 | 
			
		||||
	if len(accessAddrsJSON) > 0 {
 | 
			
		||||
		op.AccessAddrs = accessAddrsJSON
 | 
			
		||||
	} else {
 | 
			
		||||
		op.AccessAddrs = "null"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := this.Save(op)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,8 +9,9 @@ type APINode struct {
 | 
			
		||||
	Secret      string `field:"secret"`      // 密钥
 | 
			
		||||
	Name        string `field:"name"`        // 名称
 | 
			
		||||
	Description string `field:"description"` // 描述
 | 
			
		||||
	Host        string `field:"host"`        // 主机
 | 
			
		||||
	Port        uint32 `field:"port"`        // 端口
 | 
			
		||||
	Http        string `field:"http"`        // 监听的HTTP配置
 | 
			
		||||
	Https       string `field:"https"`       // 监听的HTTPS配置
 | 
			
		||||
	AccessAddrs string `field:"accessAddrs"` // 外部访问地址
 | 
			
		||||
	Order       uint32 `field:"order"`       // 排序
 | 
			
		||||
	State       uint8  `field:"state"`       // 状态
 | 
			
		||||
	CreatedAt   uint64 `field:"createdAt"`   // 创建时间
 | 
			
		||||
@@ -26,8 +27,9 @@ type APINodeOperator struct {
 | 
			
		||||
	Secret      interface{} // 密钥
 | 
			
		||||
	Name        interface{} // 名称
 | 
			
		||||
	Description interface{} // 描述
 | 
			
		||||
	Host        interface{} // 主机
 | 
			
		||||
	Port        interface{} // 端口
 | 
			
		||||
	Http        interface{} // 监听的HTTP配置
 | 
			
		||||
	Https       interface{} // 监听的HTTPS配置
 | 
			
		||||
	AccessAddrs interface{} // 外部访问地址
 | 
			
		||||
	Order       interface{} // 排序
 | 
			
		||||
	State       interface{} // 状态
 | 
			
		||||
	CreatedAt   interface{} // 创建时间
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,95 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import "strconv"
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 地址
 | 
			
		||||
func (this *APINode) Address() string {
 | 
			
		||||
	return this.Host + ":" + strconv.Itoa(int(this.Port))
 | 
			
		||||
// 解析HTTP配置
 | 
			
		||||
func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
 | 
			
		||||
	if !IsNotNull(this.Http) {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	config := &serverconfigs.HTTPProtocolConfig{}
 | 
			
		||||
	err := json.Unmarshal([]byte(this.Http), config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = config.Init()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return config, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 解析HTTPS配置
 | 
			
		||||
func (this *APINode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) {
 | 
			
		||||
	if !IsNotNull(this.Https) {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	config := &serverconfigs.HTTPSProtocolConfig{}
 | 
			
		||||
	err := json.Unmarshal([]byte(this.Https), config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = config.Init()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.SSLPolicyRef != nil {
 | 
			
		||||
		policyId := config.SSLPolicyRef.SSLPolicyId
 | 
			
		||||
		if policyId > 0 {
 | 
			
		||||
			sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			if sslPolicy != nil {
 | 
			
		||||
				config.SSLPolicy = sslPolicy
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = config.Init()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return config, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 解析访问地址
 | 
			
		||||
func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) {
 | 
			
		||||
	if !IsNotNull(this.AccessAddrs) {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addrConfigs := []*serverconfigs.NetworkAddressConfig{}
 | 
			
		||||
	err := json.Unmarshal([]byte(this.AccessAddrs), &addrConfigs)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	for _, addrConfig := range addrConfigs {
 | 
			
		||||
		err = addrConfig.Init()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return addrConfigs, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 解析访问地址,并返回字符串形式
 | 
			
		||||
func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
 | 
			
		||||
	addrs, err := this.DecodeAccessAddrs()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	result := []string{}
 | 
			
		||||
	for _, addr := range addrs {
 | 
			
		||||
		result = append(result, addr.FullAddresses()...)
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
			
		||||
@@ -275,6 +275,17 @@ func (this *NodeDAO) FindAllNodeIdsMatch(clusterId int64) (result []int64, err e
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取一个集群的所有节点
 | 
			
		||||
func (this *NodeDAO) FindAllEnabledNodesWithClusterId(clusterId int64) (result []*Node, err error) {
 | 
			
		||||
	_, err = this.Query().
 | 
			
		||||
		State(NodeStateEnabled).
 | 
			
		||||
		Attr("clusterId", clusterId).
 | 
			
		||||
		DescPk().
 | 
			
		||||
		Slice(&result).
 | 
			
		||||
		FindAll()
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 计算节点数量
 | 
			
		||||
func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState) (int64, error) {
 | 
			
		||||
	query := this.Query()
 | 
			
		||||
@@ -422,6 +433,28 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e
 | 
			
		||||
	return config, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 修改当前连接的API节点
 | 
			
		||||
func (this *NodeDAO) UpdateNodeConnectedAPINodes(nodeId int64, apiNodeIds []int64) error {
 | 
			
		||||
	if nodeId <= 0 {
 | 
			
		||||
		return errors.New("invalid nodeId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	op := NewNodeOperator()
 | 
			
		||||
	op.Id = nodeId
 | 
			
		||||
 | 
			
		||||
	if len(apiNodeIds) > 0 {
 | 
			
		||||
		apiNodeIdsJSON, err := json.Marshal(apiNodeIds)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.Wrap(err)
 | 
			
		||||
		}
 | 
			
		||||
		op.ConnectedAPINodes = apiNodeIdsJSON
 | 
			
		||||
	} else {
 | 
			
		||||
		op.ConnectedAPINodes = "[]"
 | 
			
		||||
	}
 | 
			
		||||
	_, err := this.Save(op)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 生成唯一ID
 | 
			
		||||
func (this *NodeDAO) genUniqueId() (string, error) {
 | 
			
		||||
	for {
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,7 @@ type Node struct {
 | 
			
		||||
	IsInstalled       uint8  `field:"isInstalled"`       // 是否已安装
 | 
			
		||||
	InstallStatus     string `field:"installStatus"`     // 安装状态
 | 
			
		||||
	State             uint8  `field:"state"`             // 状态
 | 
			
		||||
	ConnectedAPINodes string `field:"connectedAPINodes"` // 当前连接的API节点
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NodeOperator struct {
 | 
			
		||||
@@ -43,6 +44,7 @@ type NodeOperator struct {
 | 
			
		||||
	IsInstalled       interface{} // 是否已安装
 | 
			
		||||
	InstallStatus     interface{} // 安装状态
 | 
			
		||||
	State             interface{} // 状态
 | 
			
		||||
	ConnectedAPINodes interface{} // 当前连接的API节点
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewNodeOperator() *NodeOperator {
 | 
			
		||||
 
 | 
			
		||||
@@ -136,7 +136,13 @@ func (this *Queue) InstallNode(nodeId int64) error {
 | 
			
		||||
 | 
			
		||||
	apiEndpoints := []string{}
 | 
			
		||||
	for _, apiNode := range apiNodes {
 | 
			
		||||
		apiEndpoints = append(apiEndpoints, apiNode.Host+":"+strconv.Itoa(int(apiNode.Port)))
 | 
			
		||||
		addrConfigs, err := apiNode.DecodeAccessAddrs()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("decode api node access addresses failed: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		for _, addrConfig := range addrConfigs {
 | 
			
		||||
			apiEndpoints = append(apiEndpoints, addrConfig.FullAddresses()...)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	params := &NodeParams{
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										170
									
								
								internal/nodes/api_node.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								internal/nodes/api_node.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,170 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/configs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/db/models"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"google.golang.org/grpc"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var sharedAPIConfig *configs.APIConfig = nil
 | 
			
		||||
 | 
			
		||||
type APINode struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAPINode() *APINode {
 | 
			
		||||
	return &APINode{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *APINode) Start() {
 | 
			
		||||
	logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
 | 
			
		||||
 | 
			
		||||
	// 读取配置
 | 
			
		||||
	config, err := configs.SharedAPIConfig()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]start failed: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	sharedAPIConfig = config
 | 
			
		||||
 | 
			
		||||
	// 校验
 | 
			
		||||
	apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]start failed: read api node from database failed: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if apiNode == nil {
 | 
			
		||||
		logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	config.SetNumberId(int64(apiNode.Id))
 | 
			
		||||
 | 
			
		||||
	// 设置rlimit
 | 
			
		||||
	_ = utils.SetRLimit(1024 * 1024)
 | 
			
		||||
 | 
			
		||||
	// 监听RPC服务
 | 
			
		||||
	logs.Println("[API]starting rpc ...")
 | 
			
		||||
 | 
			
		||||
	// HTTP
 | 
			
		||||
	httpConfig, err := apiNode.DecodeHTTP()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]decode http config: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	isListening := false
 | 
			
		||||
	if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
 | 
			
		||||
		for _, listen := range httpConfig.Listen {
 | 
			
		||||
			for _, addr := range listen.Addresses() {
 | 
			
		||||
				listener, err := net.Listen("tcp", addr)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				go func() {
 | 
			
		||||
					err := this.listenRPC(listener, nil)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
				isListening = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// HTTPS
 | 
			
		||||
	httpsConfig, err := apiNode.DecodeHTTPS()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]decode https config: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if httpsConfig != nil &&
 | 
			
		||||
		httpsConfig.IsOn &&
 | 
			
		||||
		len(httpsConfig.Listen) > 0 &&
 | 
			
		||||
		httpsConfig.SSLPolicy != nil &&
 | 
			
		||||
		httpsConfig.SSLPolicy.IsOn &&
 | 
			
		||||
		len(httpsConfig.SSLPolicy.Certs) > 0 {
 | 
			
		||||
		certs := []tls.Certificate{}
 | 
			
		||||
		for _, cert := range httpsConfig.SSLPolicy.Certs {
 | 
			
		||||
			certs = append(certs, *cert.CertObject())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, listen := range httpsConfig.Listen {
 | 
			
		||||
			for _, addr := range listen.Addresses() {
 | 
			
		||||
				listener, err := net.Listen("tcp", addr)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				go func() {
 | 
			
		||||
					err := this.listenRPC(listener, &tls.Config{
 | 
			
		||||
						Certificates: certs,
 | 
			
		||||
					})
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
				isListening = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !isListening {
 | 
			
		||||
		logs.Println("[API]the api node does have a listening address")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 保持进程
 | 
			
		||||
	select {}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 启动RPC监听
 | 
			
		||||
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
 | 
			
		||||
	var rpcServer *grpc.Server
 | 
			
		||||
	if tlsConfig == nil {
 | 
			
		||||
		logs.Println("[API]listening http://" + listener.Addr().String() + " ...")
 | 
			
		||||
		rpcServer = grpc.NewServer()
 | 
			
		||||
	} else {
 | 
			
		||||
		logs.Println("[API]listening https://" + listener.Addr().String() + " ...")
 | 
			
		||||
		rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
 | 
			
		||||
	}
 | 
			
		||||
	pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
 | 
			
		||||
	pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
 | 
			
		||||
	pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
 | 
			
		||||
	pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
 | 
			
		||||
	pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
 | 
			
		||||
	pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
 | 
			
		||||
	pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
 | 
			
		||||
	pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
 | 
			
		||||
	pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
 | 
			
		||||
	pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
 | 
			
		||||
	pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
 | 
			
		||||
	pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
 | 
			
		||||
	pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
 | 
			
		||||
	pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
 | 
			
		||||
	pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
 | 
			
		||||
	pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
 | 
			
		||||
	pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
 | 
			
		||||
	pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
 | 
			
		||||
	pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
 | 
			
		||||
	pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
 | 
			
		||||
	pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
 | 
			
		||||
	pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
 | 
			
		||||
	pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
 | 
			
		||||
	err := rpcServer.Serve(listener)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("[API]start rpc failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -17,7 +17,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.Host, int(req.Port))
 | 
			
		||||
	nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -32,7 +32,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.Host, int(req.Port))
 | 
			
		||||
	err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -69,6 +69,11 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb.
 | 
			
		||||
 | 
			
		||||
	result := []*pb.APINode{}
 | 
			
		||||
	for _, node := range nodes {
 | 
			
		||||
		accessAddrs, err := node.DecodeAccessAddrStrings()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		result = append(result, &pb.APINode{
 | 
			
		||||
			Id:              int64(node.Id),
 | 
			
		||||
			IsOn:            node.IsOn == 1,
 | 
			
		||||
@@ -77,9 +82,10 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb.
 | 
			
		||||
			Secret:          node.Secret,
 | 
			
		||||
			Name:            node.Name,
 | 
			
		||||
			Description:     node.Description,
 | 
			
		||||
			Host:        node.Host,
 | 
			
		||||
			Port:        int32(node.Port),
 | 
			
		||||
			Address:     node.Address(),
 | 
			
		||||
			HttpJSON:        []byte(node.Http),
 | 
			
		||||
			HttpsJSON:       []byte(node.Https),
 | 
			
		||||
			AccessAddrsJSON: []byte(node.AccessAddrs),
 | 
			
		||||
			AccessAddrs:     accessAddrs,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -115,6 +121,11 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
 | 
			
		||||
 | 
			
		||||
	result := []*pb.APINode{}
 | 
			
		||||
	for _, node := range nodes {
 | 
			
		||||
		accessAddrs, err := node.DecodeAccessAddrStrings()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		result = append(result, &pb.APINode{
 | 
			
		||||
			Id:              int64(node.Id),
 | 
			
		||||
			IsOn:            node.IsOn == 1,
 | 
			
		||||
@@ -123,9 +134,10 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
 | 
			
		||||
			Secret:          node.Secret,
 | 
			
		||||
			Name:            node.Name,
 | 
			
		||||
			Description:     node.Description,
 | 
			
		||||
			Host:        node.Host,
 | 
			
		||||
			Port:        int32(node.Port),
 | 
			
		||||
			Address:     node.Address(),
 | 
			
		||||
			HttpJSON:        []byte(node.Http),
 | 
			
		||||
			HttpsJSON:       []byte(node.Https),
 | 
			
		||||
			AccessAddrsJSON: []byte(node.AccessAddrs),
 | 
			
		||||
			AccessAddrs:     accessAddrs,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -148,6 +160,11 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
 | 
			
		||||
		return &pb.FindEnabledAPINodeResponse{Node: nil}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	accessAddrs, err := node.DecodeAccessAddrStrings()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := &pb.APINode{
 | 
			
		||||
		Id:              int64(node.Id),
 | 
			
		||||
		IsOn:            node.IsOn == 1,
 | 
			
		||||
@@ -156,9 +173,10 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
 | 
			
		||||
		Secret:          node.Secret,
 | 
			
		||||
		Name:            node.Name,
 | 
			
		||||
		Description:     node.Description,
 | 
			
		||||
		Host:        node.Host,
 | 
			
		||||
		Port:        int32(node.Port),
 | 
			
		||||
		Address:     node.Address(),
 | 
			
		||||
		HttpJSON:        []byte(node.Http),
 | 
			
		||||
		HttpsJSON:       []byte(node.Https),
 | 
			
		||||
		AccessAddrsJSON: []byte(node.AccessAddrs),
 | 
			
		||||
		AccessAddrs:     accessAddrs,
 | 
			
		||||
	}
 | 
			
		||||
	return &pb.FindEnabledAPINodeResponse{Node: result}, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,8 +3,8 @@ package services
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/db/models"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/installers"
 | 
			
		||||
	rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
 | 
			
		||||
@@ -12,6 +12,7 @@ import (
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 边缘节点相关服务
 | 
			
		||||
type NodeService struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -121,6 +122,38 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找一个集群下的所有节点
 | 
			
		||||
func (this *NodeService) FindAllEnabledNodesWithClusterId(ctx context.Context, req *pb.FindAllEnabledNodesWithClusterIdRequest) (*pb.FindAllEnabledNodesWithClusterIdResponse, error) {
 | 
			
		||||
	_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithClusterId(req.ClusterId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	result := []*pb.Node{}
 | 
			
		||||
	for _, node := range nodes {
 | 
			
		||||
		apiNodeIds := []int64{}
 | 
			
		||||
		if models.IsNotNull(node.ConnectedAPINodes) {
 | 
			
		||||
			err = json.Unmarshal([]byte(node.ConnectedAPINodes), &apiNodeIds)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		result = append(result, &pb.Node{
 | 
			
		||||
			Id:                  int64(node.Id),
 | 
			
		||||
			Name:                node.Name,
 | 
			
		||||
			UniqueId:            node.UniqueId,
 | 
			
		||||
			Secret:              node.Secret,
 | 
			
		||||
			ConnectedAPINodeIds: apiNodeIds,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &pb.FindAllEnabledNodesWithClusterIdResponse{Nodes: result}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 禁用节点
 | 
			
		||||
func (this *NodeService) DisableNode(ctx context.Context, req *pb.DisableNodeRequest) (*pb.DisableNodeResponse, error) {
 | 
			
		||||
	_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
 | 
			
		||||
@@ -263,27 +296,6 @@ func (this *NodeService) ComposeNodeConfig(ctx context.Context, req *pb.ComposeN
 | 
			
		||||
	return &pb.ComposeNodeConfigResponse{NodeJSON: data}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 节点stream
 | 
			
		||||
func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
 | 
			
		||||
	// TODO 使用此stream快速通知边缘节点更新
 | 
			
		||||
	// 校验节点
 | 
			
		||||
	_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	logs.Println("nodeId:", nodeId)
 | 
			
		||||
 | 
			
		||||
	_ = server.Send(&pb.NodeStreamResponse{})
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		req, err := server.Recv()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		logs.Println("received:", req)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 更新节点状态
 | 
			
		||||
func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeStatusRequest) (*pb.RPCUpdateSuccess, error) {
 | 
			
		||||
	// 校验节点
 | 
			
		||||
@@ -354,3 +366,19 @@ func (this *NodeService) InstallNode(ctx context.Context, req *pb.InstallNodeReq
 | 
			
		||||
 | 
			
		||||
	return &pb.InstallNodeResponse{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 更改节点连接的API节点信息
 | 
			
		||||
func (this *NodeService) UpdateNodeConnectedAPINodes(ctx context.Context, req *pb.UpdateNodeConnectedAPINodesRequest) (*pb.RPCUpdateSuccess, error) {
 | 
			
		||||
	// 校验节点
 | 
			
		||||
	_, nodeId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = models.SharedNodeDAO.UpdateNodeConnectedAPINodes(nodeId, req.ApiNodeIds)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rpcutils.RPCUpdateSuccess()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										251
									
								
								internal/rpc/services/service_node_stream.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								internal/rpc/services/service_node_stream.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,251 @@
 | 
			
		||||
package services
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/configs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
			
		||||
	rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 命令请求相关
 | 
			
		||||
type CommandRequest struct {
 | 
			
		||||
	Id          int64
 | 
			
		||||
	Code        string
 | 
			
		||||
	CommandJSON []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CommandRequestWaiting struct {
 | 
			
		||||
	Timestamp int64
 | 
			
		||||
	Chan      chan *pb.NodeStreamMessage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *CommandRequestWaiting) Close() {
 | 
			
		||||
	defer func() {
 | 
			
		||||
		recover()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	close(this.Chan)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response
 | 
			
		||||
var commandRequestId = int64(0)
 | 
			
		||||
 | 
			
		||||
var nodeLocker = &sync.Mutex{}
 | 
			
		||||
var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan
 | 
			
		||||
 | 
			
		||||
func NextCommandRequestId() int64 {
 | 
			
		||||
	return atomic.AddInt64(&commandRequestId, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	// 清理WaitingChannelMap
 | 
			
		||||
	ticker := time.NewTicker(30 * time.Second)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for range ticker.C {
 | 
			
		||||
			nodeLocker.Lock()
 | 
			
		||||
			for requestId, request := range responseChanMap {
 | 
			
		||||
				if time.Now().Unix()-request.Timestamp > 3600 {
 | 
			
		||||
					responseChanMap[requestId].Close()
 | 
			
		||||
					delete(responseChanMap, requestId)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			nodeLocker.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 节点stream
 | 
			
		||||
func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
 | 
			
		||||
	// TODO 使用此stream快速通知边缘节点更新
 | 
			
		||||
	// 校验节点
 | 
			
		||||
	_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 返回连接成功
 | 
			
		||||
	{
 | 
			
		||||
		apiConfig, err := configs.SharedAPIConfig()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		connectedMessage := &messageconfigs.ConnectedAPINodeMessage{APINodeId: apiConfig.NumberId()}
 | 
			
		||||
		connectedMessageJSON, err := json.Marshal(connectedMessage)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.Wrap(err)
 | 
			
		||||
		}
 | 
			
		||||
		err = server.Send(&pb.NodeStreamMessage{
 | 
			
		||||
			Code:     messageconfigs.MessageCodeConnectedAPINode,
 | 
			
		||||
			DataJSON: connectedMessageJSON,
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logs.Println("[RPC]accepted node '" + strconv.FormatInt(nodeId, 10) + "' connection")
 | 
			
		||||
 | 
			
		||||
	nodeLocker.Lock()
 | 
			
		||||
	requestChan, ok := requestChanMap[nodeId]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		requestChan = make(chan *CommandRequest, 1024)
 | 
			
		||||
		requestChanMap[nodeId] = requestChan
 | 
			
		||||
	}
 | 
			
		||||
	nodeLocker.Unlock()
 | 
			
		||||
 | 
			
		||||
	// 发送请求
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-server.Context().Done():
 | 
			
		||||
				return
 | 
			
		||||
			case commandRequest := <-requestChan:
 | 
			
		||||
				logs.Println("[RPC]sending command '" + commandRequest.Code + "' to node '" + strconv.FormatInt(nodeId, 10) + "'")
 | 
			
		||||
				retries := 3 // 错误重试次数
 | 
			
		||||
				for i := 0; i < retries; i++ {
 | 
			
		||||
					err := server.Send(&pb.NodeStreamMessage{
 | 
			
		||||
						RequestId: commandRequest.Id,
 | 
			
		||||
						Code:      commandRequest.Code,
 | 
			
		||||
						DataJSON:  commandRequest.CommandJSON,
 | 
			
		||||
					})
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						if i == retries-1 {
 | 
			
		||||
							logs.Println("[RPC]send command '" + commandRequest.Code + "' failed: " + err.Error())
 | 
			
		||||
						} else {
 | 
			
		||||
							time.Sleep(1 * time.Second)
 | 
			
		||||
						}
 | 
			
		||||
					} else {
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// 接受请求
 | 
			
		||||
	for {
 | 
			
		||||
		req, err := server.Recv()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		func(req *pb.NodeStreamMessage) {
 | 
			
		||||
			// 因为 responseChan.Chan 有被关闭的风险,所以我们使用recover防止panic
 | 
			
		||||
			defer func() {
 | 
			
		||||
				recover()
 | 
			
		||||
			}()
 | 
			
		||||
 | 
			
		||||
			nodeLocker.Lock()
 | 
			
		||||
			responseChan, ok := responseChanMap[req.RequestId]
 | 
			
		||||
			if ok {
 | 
			
		||||
				select {
 | 
			
		||||
				case responseChan.Chan <- req:
 | 
			
		||||
				default:
 | 
			
		||||
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			nodeLocker.Unlock()
 | 
			
		||||
		}(req)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 向节点发送命令
 | 
			
		||||
func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStreamMessage) (*pb.NodeStreamMessage, error) {
 | 
			
		||||
	// 校验请求
 | 
			
		||||
	_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nodeId := req.NodeId
 | 
			
		||||
	if nodeId <= 0 {
 | 
			
		||||
		return nil, errors.New("node id should not be less than 0")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nodeLocker.Lock()
 | 
			
		||||
	requestChan, ok := requestChanMap[nodeId]
 | 
			
		||||
	nodeLocker.Unlock()
 | 
			
		||||
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return &pb.NodeStreamMessage{
 | 
			
		||||
			RequestId: req.RequestId,
 | 
			
		||||
			IsOk:      false,
 | 
			
		||||
			Message:   "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet",
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req.RequestId = NextCommandRequestId()
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case requestChan <- &CommandRequest{
 | 
			
		||||
		Id:          req.RequestId,
 | 
			
		||||
		Code:        req.Code,
 | 
			
		||||
		CommandJSON: req.DataJSON,
 | 
			
		||||
	}:
 | 
			
		||||
		// 加入到等待队列中
 | 
			
		||||
		respChan := make(chan *pb.NodeStreamMessage, 1)
 | 
			
		||||
		waiting := &CommandRequestWaiting{
 | 
			
		||||
			Timestamp: time.Now().Unix(),
 | 
			
		||||
			Chan:      respChan,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		nodeLocker.Lock()
 | 
			
		||||
		responseChanMap[req.RequestId] = waiting
 | 
			
		||||
		nodeLocker.Unlock()
 | 
			
		||||
 | 
			
		||||
		// 等待响应
 | 
			
		||||
		timeoutSeconds := req.TimeoutSeconds
 | 
			
		||||
		if timeoutSeconds <= 0 {
 | 
			
		||||
			timeoutSeconds = 10
 | 
			
		||||
		}
 | 
			
		||||
		timeout := time.NewTimer(time.Duration(timeoutSeconds) * time.Second)
 | 
			
		||||
		select {
 | 
			
		||||
		case resp := <-respChan:
 | 
			
		||||
			// 从队列中删除
 | 
			
		||||
			nodeLocker.Lock()
 | 
			
		||||
			delete(responseChanMap, req.RequestId)
 | 
			
		||||
			waiting.Close()
 | 
			
		||||
			nodeLocker.Unlock()
 | 
			
		||||
 | 
			
		||||
			if resp == nil {
 | 
			
		||||
				return &pb.NodeStreamMessage{
 | 
			
		||||
					RequestId: req.RequestId,
 | 
			
		||||
					Code:      req.Code,
 | 
			
		||||
					Message:   "response timeout",
 | 
			
		||||
					IsOk:      false,
 | 
			
		||||
				}, nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return resp, nil
 | 
			
		||||
		case <-timeout.C:
 | 
			
		||||
			// 从队列中删除
 | 
			
		||||
			nodeLocker.Lock()
 | 
			
		||||
			delete(responseChanMap, req.RequestId)
 | 
			
		||||
			waiting.Close()
 | 
			
		||||
			nodeLocker.Unlock()
 | 
			
		||||
 | 
			
		||||
			return &pb.NodeStreamMessage{
 | 
			
		||||
				RequestId: req.RequestId,
 | 
			
		||||
				Code:      req.Code,
 | 
			
		||||
				Message:   "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds",
 | 
			
		||||
				IsOk:      false,
 | 
			
		||||
			}, nil
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return &pb.NodeStreamMessage{
 | 
			
		||||
			RequestId: req.RequestId,
 | 
			
		||||
			Code:      req.Code,
 | 
			
		||||
			Message:   "command queue is full over " + strconv.Itoa(len(requestChan)),
 | 
			
		||||
			IsOk:      false,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user