实现缓存策略的部分功能

This commit is contained in:
GoEdgeLab
2020-10-04 14:27:14 +08:00
parent 9f4119b892
commit 9f8c705d12
13 changed files with 762 additions and 197 deletions

View File

@@ -1,9 +1,9 @@
package main package main
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/apis"
"github.com/TeaOSLab/EdgeAPI/internal/apps" "github.com/TeaOSLab/EdgeAPI/internal/apps"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/nodes"
_ "github.com/TeaOSLab/EdgeAPI/internal/tasks" _ "github.com/TeaOSLab/EdgeAPI/internal/tasks"
_ "github.com/iwind/TeaGo/bootstrap" _ "github.com/iwind/TeaGo/bootstrap"
) )
@@ -14,6 +14,6 @@ func main() {
app.Product(teaconst.ProductName) app.Product(teaconst.ProductName)
app.Usage(teaconst.ProcessName + " [start|stop|restart]") app.Usage(teaconst.ProcessName + " [start|stop|restart]")
app.Run(func() { app.Run(func() {
apis.NewAPINode().Start() nodes.NewAPINode().Start()
}) })
} }

View File

@@ -1,83 +0,0 @@
package apis
import (
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/configs"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/logs"
"google.golang.org/grpc"
"net"
"os"
"strconv"
)
var sharedAPIConfig *configs.APIConfig = nil
type APINode struct {
}
func NewAPINode() *APINode {
return &APINode{}
}
func (this *APINode) Start() {
logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
config, err := configs.SharedAPIConfig()
if err != nil {
logs.Println("[API]start failed: " + err.Error())
return
}
sharedAPIConfig = config
// 设置rlimit
_ = utils.SetRLimit(1024 * 1024)
// 监听RPC服务
logs.Println("[API]start rpc: " + config.RPC.Listen)
err = this.listenRPC()
if err != nil {
logs.Println(err.Error())
return
}
}
// 启动RPC监听
func (this *APINode) listenRPC() error {
listener, err := net.Listen("tcp", sharedAPIConfig.RPC.Listen)
if err != nil {
return errors.New("[API]listen rpc failed: " + err.Error())
}
rpcServer := grpc.NewServer()
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
err = rpcServer.Serve(listener)
if err != nil {
return errors.New("[API]start rpc failed: " + err.Error())
}
return nil
}

View File

@@ -8,12 +8,15 @@ import (
var sharedAPIConfig *APIConfig = nil var sharedAPIConfig *APIConfig = nil
// API节点配置
type APIConfig struct { type APIConfig struct {
RPC struct { NodeId string `yaml:"nodeId" json:"nodeId"`
Listen string `yaml:"listen"` Secret string `yaml:"secret" json:"secret"`
} `yaml:"rpc"`
numberId int64 // 数字ID
} }
// 获取共享配置
func SharedAPIConfig() (*APIConfig, error) { func SharedAPIConfig() (*APIConfig, error) {
sharedLocker.Lock() sharedLocker.Lock()
defer sharedLocker.Unlock() defer sharedLocker.Unlock()
@@ -36,3 +39,13 @@ func SharedAPIConfig() (*APIConfig, error) {
sharedAPIConfig = config sharedAPIConfig = config
return config, nil return config, nil
} }
// 设置数字ID
func (this *APIConfig) SetNumberId(numberId int64) {
this.numberId = numberId
}
// 获取数字ID
func (this *APIConfig) NumberId() int64 {
return this.numberId
}

View File

@@ -59,6 +59,19 @@ func (this *APINodeDAO) FindEnabledAPINode(id int64) (*APINode, error) {
return result.(*APINode), err return result.(*APINode), err
} }
// 根据ID和Secret查找节点
func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(uniqueId string, secret string) (*APINode, error) {
one, err := this.Query().
State(APINodeStateEnabled).
Attr("uniqueId", uniqueId).
Attr("secret", secret).
Find()
if err != nil || one == nil {
return nil, err
}
return one.(*APINode), nil
}
// 根据主键查找名称 // 根据主键查找名称
func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) { func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
return this.Query(). return this.Query().
@@ -68,7 +81,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
} }
// 创建API节点 // 创建API节点
func (this *APINodeDAO) CreateAPINode(name string, description string, host string, port int) (nodeId int64, err error) { func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) {
uniqueId, err := this.genUniqueId() uniqueId, err := this.genUniqueId()
if err != nil { if err != nil {
return 0, err return 0, err
@@ -80,13 +93,22 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
} }
op := NewAPINodeOperator() op := NewAPINodeOperator()
op.IsOn = true op.IsOn = isOn
op.UniqueId = uniqueId op.UniqueId = uniqueId
op.Secret = secret op.Secret = secret
op.Name = name op.Name = name
op.Description = description op.Description = description
op.Host = host
op.Port = port if len(httpJSON) > 0 {
op.Http = httpJSON
}
if len(httpsJSON) > 0 {
op.Https = httpsJSON
}
if len(accessAddrsJSON) > 0 {
op.AccessAddrs = accessAddrsJSON
}
op.State = NodeStateEnabled op.State = NodeStateEnabled
_, err = this.Save(op) _, err = this.Save(op)
if err != nil { if err != nil {
@@ -97,7 +119,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, host stri
} }
// 修改API节点 // 修改API节点
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, host string, port int) error { func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error {
if nodeId <= 0 { if nodeId <= 0 {
return errors.New("invalid nodeId") return errors.New("invalid nodeId")
} }
@@ -106,8 +128,24 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str
op.Id = nodeId op.Id = nodeId
op.Name = name op.Name = name
op.Description = description op.Description = description
op.Host = host op.IsOn = isOn
op.Port = port
if len(httpJSON) > 0 {
op.Http = httpJSON
} else {
op.Http = "null"
}
if len(httpsJSON) > 0 {
op.Https = httpsJSON
} else {
op.Https = "null"
}
if len(accessAddrsJSON) > 0 {
op.AccessAddrs = accessAddrsJSON
} else {
op.AccessAddrs = "null"
}
_, err := this.Save(op) _, err := this.Save(op)
return err return err
} }

View File

@@ -9,8 +9,9 @@ type APINode struct {
Secret string `field:"secret"` // 密钥 Secret string `field:"secret"` // 密钥
Name string `field:"name"` // 名称 Name string `field:"name"` // 名称
Description string `field:"description"` // 描述 Description string `field:"description"` // 描述
Host string `field:"host"` // 主机 Http string `field:"http"` // 监听的HTTP配置
Port uint32 `field:"port"` // 端口 Https string `field:"https"` // 监听的HTTPS配置
AccessAddrs string `field:"accessAddrs"` // 外部访问地址
Order uint32 `field:"order"` // 排序 Order uint32 `field:"order"` // 排序
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间 CreatedAt uint64 `field:"createdAt"` // 创建时间
@@ -26,8 +27,9 @@ type APINodeOperator struct {
Secret interface{} // 密钥 Secret interface{} // 密钥
Name interface{} // 名称 Name interface{} // 名称
Description interface{} // 描述 Description interface{} // 描述
Host interface{} // 主机 Http interface{} // 监听的HTTP配置
Port interface{} // 端口 Https interface{} // 监听的HTTPS配置
AccessAddrs interface{} // 外部访问地址
Order interface{} // 排序 Order interface{} // 排序
State interface{} // 状态 State interface{} // 状态
CreatedAt interface{} // 创建时间 CreatedAt interface{} // 创建时间

View File

@@ -1,8 +1,95 @@
package models package models
import "strconv" import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
)
// 地址 // 解析HTTP配置
func (this *APINode) Address() string { func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
return this.Host + ":" + strconv.Itoa(int(this.Port)) if !IsNotNull(this.Http) {
return nil, nil
}
config := &serverconfigs.HTTPProtocolConfig{}
err := json.Unmarshal([]byte(this.Http), config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, err
}
return config, nil
}
// 解析HTTPS配置
func (this *APINode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) {
if !IsNotNull(this.Https) {
return nil, nil
}
config := &serverconfigs.HTTPSProtocolConfig{}
err := json.Unmarshal([]byte(this.Https), config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, err
}
if config.SSLPolicyRef != nil {
policyId := config.SSLPolicyRef.SSLPolicyId
if policyId > 0 {
sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId)
if err != nil {
return nil, err
}
if sslPolicy != nil {
config.SSLPolicy = sslPolicy
}
}
}
err = config.Init()
if err != nil {
return nil, err
}
return config, nil
}
// 解析访问地址
func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) {
if !IsNotNull(this.AccessAddrs) {
return nil, nil
}
addrConfigs := []*serverconfigs.NetworkAddressConfig{}
err := json.Unmarshal([]byte(this.AccessAddrs), &addrConfigs)
if err != nil {
return nil, err
}
for _, addrConfig := range addrConfigs {
err = addrConfig.Init()
if err != nil {
return nil, err
}
}
return addrConfigs, nil
}
// 解析访问地址,并返回字符串形式
func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
addrs, err := this.DecodeAccessAddrs()
if err != nil {
return nil, err
}
result := []string{}
for _, addr := range addrs {
result = append(result, addr.FullAddresses()...)
}
return result, nil
} }

View File

@@ -2,7 +2,7 @@ package models
import ( import (
"encoding/json" "encoding/json"
"errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
@@ -275,6 +275,17 @@ func (this *NodeDAO) FindAllNodeIdsMatch(clusterId int64) (result []int64, err e
return return
} }
// 获取一个集群的所有节点
func (this *NodeDAO) FindAllEnabledNodesWithClusterId(clusterId int64) (result []*Node, err error) {
_, err = this.Query().
State(NodeStateEnabled).
Attr("clusterId", clusterId).
DescPk().
Slice(&result).
FindAll()
return
}
// 计算节点数量 // 计算节点数量
func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState) (int64, error) { func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState) (int64, error) {
query := this.Query() query := this.Query()
@@ -422,6 +433,28 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e
return config, nil return config, nil
} }
// 修改当前连接的API节点
func (this *NodeDAO) UpdateNodeConnectedAPINodes(nodeId int64, apiNodeIds []int64) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
op := NewNodeOperator()
op.Id = nodeId
if len(apiNodeIds) > 0 {
apiNodeIdsJSON, err := json.Marshal(apiNodeIds)
if err != nil {
return errors.Wrap(err)
}
op.ConnectedAPINodes = apiNodeIdsJSON
} else {
op.ConnectedAPINodes = "[]"
}
_, err := this.Save(op)
return err
}
// 生成唯一ID // 生成唯一ID
func (this *NodeDAO) genUniqueId() (string, error) { func (this *NodeDAO) genUniqueId() (string, error) {
for { for {

View File

@@ -2,47 +2,49 @@ package models
// 节点 // 节点
type Node struct { type Node struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID UserId uint32 `field:"userId"` // 用户ID
IsOn uint8 `field:"isOn"` // 是否启用 IsOn uint8 `field:"isOn"` // 是否启用
UniqueId string `field:"uniqueId"` // 节点ID UniqueId string `field:"uniqueId"` // 节点ID
Secret string `field:"secret"` // 密钥 Secret string `field:"secret"` // 密钥
Name string `field:"name"` // 节点名 Name string `field:"name"` // 节点名
Code string `field:"code"` // 代号 Code string `field:"code"` // 代号
ClusterId uint32 `field:"clusterId"` // 集群ID ClusterId uint32 `field:"clusterId"` // 集群ID
RegionId uint32 `field:"regionId"` // 区域ID RegionId uint32 `field:"regionId"` // 区域ID
GroupId uint32 `field:"groupId"` // 分组ID GroupId uint32 `field:"groupId"` // 分组ID
CreatedAt uint64 `field:"createdAt"` // 创建时间 CreatedAt uint64 `field:"createdAt"` // 创建时间
Status string `field:"status"` // 最新的状态 Status string `field:"status"` // 最新的状态
Version uint32 `field:"version"` // 当前版本号 Version uint32 `field:"version"` // 当前版本号
LatestVersion uint32 `field:"latestVersion"` // 最后版本号 LatestVersion uint32 `field:"latestVersion"` // 最后版本号
InstallDir string `field:"installDir"` // 安装目录 InstallDir string `field:"installDir"` // 安装目录
IsInstalled uint8 `field:"isInstalled"` // 是否已安装 IsInstalled uint8 `field:"isInstalled"` // 是否已安装
InstallStatus string `field:"installStatus"` // 安装状态 InstallStatus string `field:"installStatus"` // 安装状态
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
ConnectedAPINodes string `field:"connectedAPINodes"` // 当前连接的API节点
} }
type NodeOperator struct { type NodeOperator struct {
Id interface{} // ID Id interface{} // ID
AdminId interface{} // 管理员ID AdminId interface{} // 管理员ID
UserId interface{} // 用户ID UserId interface{} // 用户ID
IsOn interface{} // 是否启用 IsOn interface{} // 是否启用
UniqueId interface{} // 节点ID UniqueId interface{} // 节点ID
Secret interface{} // 密钥 Secret interface{} // 密钥
Name interface{} // 节点名 Name interface{} // 节点名
Code interface{} // 代号 Code interface{} // 代号
ClusterId interface{} // 集群ID ClusterId interface{} // 集群ID
RegionId interface{} // 区域ID RegionId interface{} // 区域ID
GroupId interface{} // 分组ID GroupId interface{} // 分组ID
CreatedAt interface{} // 创建时间 CreatedAt interface{} // 创建时间
Status interface{} // 最新的状态 Status interface{} // 最新的状态
Version interface{} // 当前版本号 Version interface{} // 当前版本号
LatestVersion interface{} // 最后版本号 LatestVersion interface{} // 最后版本号
InstallDir interface{} // 安装目录 InstallDir interface{} // 安装目录
IsInstalled interface{} // 是否已安装 IsInstalled interface{} // 是否已安装
InstallStatus interface{} // 安装状态 InstallStatus interface{} // 安装状态
State interface{} // 状态 State interface{} // 状态
ConnectedAPINodes interface{} // 当前连接的API节点
} }
func NewNodeOperator() *NodeOperator { func NewNodeOperator() *NodeOperator {

View File

@@ -136,7 +136,13 @@ func (this *Queue) InstallNode(nodeId int64) error {
apiEndpoints := []string{} apiEndpoints := []string{}
for _, apiNode := range apiNodes { for _, apiNode := range apiNodes {
apiEndpoints = append(apiEndpoints, apiNode.Host+":"+strconv.Itoa(int(apiNode.Port))) addrConfigs, err := apiNode.DecodeAccessAddrs()
if err != nil {
return errors.New("decode api node access addresses failed: " + err.Error())
}
for _, addrConfig := range addrConfigs {
apiEndpoints = append(apiEndpoints, addrConfig.FullAddresses()...)
}
} }
params := &NodeParams{ params := &NodeParams{

170
internal/nodes/api_node.go Normal file
View File

@@ -0,0 +1,170 @@
package nodes
import (
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/configs"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/logs"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"net"
"os"
"strconv"
)
var sharedAPIConfig *configs.APIConfig = nil
type APINode struct {
}
func NewAPINode() *APINode {
return &APINode{}
}
func (this *APINode) Start() {
logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
// 读取配置
config, err := configs.SharedAPIConfig()
if err != nil {
logs.Println("[API]start failed: " + err.Error())
return
}
sharedAPIConfig = config
// 校验
apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret)
if err != nil {
logs.Println("[API]start failed: read api node from database failed: " + err.Error())
return
}
if apiNode == nil {
logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'")
return
}
config.SetNumberId(int64(apiNode.Id))
// 设置rlimit
_ = utils.SetRLimit(1024 * 1024)
// 监听RPC服务
logs.Println("[API]starting rpc ...")
// HTTP
httpConfig, err := apiNode.DecodeHTTP()
if err != nil {
logs.Println("[API]decode http config: " + err.Error())
return
}
isListening := false
if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
for _, listen := range httpConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
continue
}
go func() {
err := this.listenRPC(listener, nil)
if err != nil {
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
return
}
}()
isListening = true
}
}
}
// HTTPS
httpsConfig, err := apiNode.DecodeHTTPS()
if err != nil {
logs.Println("[API]decode https config: " + err.Error())
return
}
if httpsConfig != nil &&
httpsConfig.IsOn &&
len(httpsConfig.Listen) > 0 &&
httpsConfig.SSLPolicy != nil &&
httpsConfig.SSLPolicy.IsOn &&
len(httpsConfig.SSLPolicy.Certs) > 0 {
certs := []tls.Certificate{}
for _, cert := range httpsConfig.SSLPolicy.Certs {
certs = append(certs, *cert.CertObject())
}
for _, listen := range httpsConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
continue
}
go func() {
err := this.listenRPC(listener, &tls.Config{
Certificates: certs,
})
if err != nil {
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
return
}
}()
isListening = true
}
}
}
if !isListening {
logs.Println("[API]the api node does have a listening address")
return
}
// 保持进程
select {}
}
// 启动RPC监听
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
var rpcServer *grpc.Server
if tlsConfig == nil {
logs.Println("[API]listening http://" + listener.Addr().String() + " ...")
rpcServer = grpc.NewServer()
} else {
logs.Println("[API]listening https://" + listener.Addr().String() + " ...")
rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
}
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
pb.RegisterNodeGrantServiceServer(rpcServer, &services.NodeGrantService{})
pb.RegisterServerServiceServer(rpcServer, &services.ServerService{})
pb.RegisterNodeServiceServer(rpcServer, &services.NodeService{})
pb.RegisterNodeClusterServiceServer(rpcServer, &services.NodeClusterService{})
pb.RegisterNodeIPAddressServiceServer(rpcServer, &services.NodeIPAddressService{})
pb.RegisterAPINodeServiceServer(rpcServer, &services.APINodeService{})
pb.RegisterOriginServiceServer(rpcServer, &services.OriginService{})
pb.RegisterHTTPWebServiceServer(rpcServer, &services.HTTPWebService{})
pb.RegisterReverseProxyServiceServer(rpcServer, &services.ReverseProxyService{})
pb.RegisterHTTPGzipServiceServer(rpcServer, &services.HTTPGzipService{})
pb.RegisterHTTPHeaderPolicyServiceServer(rpcServer, &services.HTTPHeaderPolicyService{})
pb.RegisterHTTPHeaderServiceServer(rpcServer, &services.HTTPHeaderService{})
pb.RegisterHTTPPageServiceServer(rpcServer, &services.HTTPPageService{})
pb.RegisterHTTPAccessLogPolicyServiceServer(rpcServer, &services.HTTPAccessLogPolicyService{})
pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{})
pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{})
pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{})
pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{})
err := rpcServer.Serve(listener)
if err != nil {
return errors.New("[API]start rpc failed: " + err.Error())
}
return nil
}

View File

@@ -17,7 +17,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI
return nil, err return nil, err
} }
nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.Host, int(req.Port)) nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -32,7 +32,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI
return nil, err return nil, err
} }
err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.Host, int(req.Port)) err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -69,17 +69,23 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb.
result := []*pb.APINode{} result := []*pb.APINode{}
for _, node := range nodes { for _, node := range nodes {
accessAddrs, err := node.DecodeAccessAddrStrings()
if err != nil {
return nil, err
}
result = append(result, &pb.APINode{ result = append(result, &pb.APINode{
Id: int64(node.Id), Id: int64(node.Id),
IsOn: node.IsOn == 1, IsOn: node.IsOn == 1,
ClusterId: int64(node.ClusterId), ClusterId: int64(node.ClusterId),
UniqueId: node.UniqueId, UniqueId: node.UniqueId,
Secret: node.Secret, Secret: node.Secret,
Name: node.Name, Name: node.Name,
Description: node.Description, Description: node.Description,
Host: node.Host, HttpJSON: []byte(node.Http),
Port: int32(node.Port), HttpsJSON: []byte(node.Https),
Address: node.Address(), AccessAddrsJSON: []byte(node.AccessAddrs),
AccessAddrs: accessAddrs,
}) })
} }
@@ -115,17 +121,23 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
result := []*pb.APINode{} result := []*pb.APINode{}
for _, node := range nodes { for _, node := range nodes {
accessAddrs, err := node.DecodeAccessAddrStrings()
if err != nil {
return nil, err
}
result = append(result, &pb.APINode{ result = append(result, &pb.APINode{
Id: int64(node.Id), Id: int64(node.Id),
IsOn: node.IsOn == 1, IsOn: node.IsOn == 1,
ClusterId: int64(node.ClusterId), ClusterId: int64(node.ClusterId),
UniqueId: node.UniqueId, UniqueId: node.UniqueId,
Secret: node.Secret, Secret: node.Secret,
Name: node.Name, Name: node.Name,
Description: node.Description, Description: node.Description,
Host: node.Host, HttpJSON: []byte(node.Http),
Port: int32(node.Port), HttpsJSON: []byte(node.Https),
Address: node.Address(), AccessAddrsJSON: []byte(node.AccessAddrs),
AccessAddrs: accessAddrs,
}) })
} }
@@ -148,17 +160,23 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
return &pb.FindEnabledAPINodeResponse{Node: nil}, nil return &pb.FindEnabledAPINodeResponse{Node: nil}, nil
} }
accessAddrs, err := node.DecodeAccessAddrStrings()
if err != nil {
return nil, err
}
result := &pb.APINode{ result := &pb.APINode{
Id: int64(node.Id), Id: int64(node.Id),
IsOn: node.IsOn == 1, IsOn: node.IsOn == 1,
ClusterId: int64(node.ClusterId), ClusterId: int64(node.ClusterId),
UniqueId: node.UniqueId, UniqueId: node.UniqueId,
Secret: node.Secret, Secret: node.Secret,
Name: node.Name, Name: node.Name,
Description: node.Description, Description: node.Description,
Host: node.Host, HttpJSON: []byte(node.Http),
Port: int32(node.Port), HttpsJSON: []byte(node.Https),
Address: node.Address(), AccessAddrsJSON: []byte(node.AccessAddrs),
AccessAddrs: accessAddrs,
} }
return &pb.FindEnabledAPINodeResponse{Node: result}, nil return &pb.FindEnabledAPINodeResponse{Node: result}, nil
} }

View File

@@ -3,8 +3,8 @@ package services
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/installers" "github.com/TeaOSLab/EdgeAPI/internal/installers"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils"
@@ -12,6 +12,7 @@ import (
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
) )
// 边缘节点相关服务
type NodeService struct { type NodeService struct {
} }
@@ -121,6 +122,38 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List
}, nil }, nil
} }
// 查找一个集群下的所有节点
func (this *NodeService) FindAllEnabledNodesWithClusterId(ctx context.Context, req *pb.FindAllEnabledNodesWithClusterIdRequest) (*pb.FindAllEnabledNodesWithClusterIdResponse, error) {
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithClusterId(req.ClusterId)
if err != nil {
return nil, err
}
result := []*pb.Node{}
for _, node := range nodes {
apiNodeIds := []int64{}
if models.IsNotNull(node.ConnectedAPINodes) {
err = json.Unmarshal([]byte(node.ConnectedAPINodes), &apiNodeIds)
if err != nil {
return nil, err
}
}
result = append(result, &pb.Node{
Id: int64(node.Id),
Name: node.Name,
UniqueId: node.UniqueId,
Secret: node.Secret,
ConnectedAPINodeIds: apiNodeIds,
})
}
return &pb.FindAllEnabledNodesWithClusterIdResponse{Nodes: result}, nil
}
// 禁用节点 // 禁用节点
func (this *NodeService) DisableNode(ctx context.Context, req *pb.DisableNodeRequest) (*pb.DisableNodeResponse, error) { func (this *NodeService) DisableNode(ctx context.Context, req *pb.DisableNodeRequest) (*pb.DisableNodeResponse, error) {
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
@@ -263,27 +296,6 @@ func (this *NodeService) ComposeNodeConfig(ctx context.Context, req *pb.ComposeN
return &pb.ComposeNodeConfigResponse{NodeJSON: data}, nil return &pb.ComposeNodeConfigResponse{NodeJSON: data}, nil
} }
// 节点stream
func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
// TODO 使用此stream快速通知边缘节点更新
// 校验节点
_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
if err != nil {
return err
}
logs.Println("nodeId:", nodeId)
_ = server.Send(&pb.NodeStreamResponse{})
for {
req, err := server.Recv()
if err != nil {
return err
}
logs.Println("received:", req)
}
}
// 更新节点状态 // 更新节点状态
func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeStatusRequest) (*pb.RPCUpdateSuccess, error) { func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeStatusRequest) (*pb.RPCUpdateSuccess, error) {
// 校验节点 // 校验节点
@@ -354,3 +366,19 @@ func (this *NodeService) InstallNode(ctx context.Context, req *pb.InstallNodeReq
return &pb.InstallNodeResponse{}, nil return &pb.InstallNodeResponse{}, nil
} }
// 更改节点连接的API节点信息
func (this *NodeService) UpdateNodeConnectedAPINodes(ctx context.Context, req *pb.UpdateNodeConnectedAPINodesRequest) (*pb.RPCUpdateSuccess, error) {
// 校验节点
_, nodeId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
err = models.SharedNodeDAO.UpdateNodeConnectedAPINodes(nodeId, req.ApiNodeIds)
if err != nil {
return nil, errors.Wrap(err)
}
return rpcutils.RPCUpdateSuccess()
}

View File

@@ -0,0 +1,251 @@
package services
import (
"context"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/configs"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/logs"
"strconv"
"sync"
"sync/atomic"
"time"
)
// 命令请求相关
type CommandRequest struct {
Id int64
Code string
CommandJSON []byte
}
type CommandRequestWaiting struct {
Timestamp int64
Chan chan *pb.NodeStreamMessage
}
func (this *CommandRequestWaiting) Close() {
defer func() {
recover()
}()
close(this.Chan)
}
var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response
var commandRequestId = int64(0)
var nodeLocker = &sync.Mutex{}
var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan
func NextCommandRequestId() int64 {
return atomic.AddInt64(&commandRequestId, 1)
}
func init() {
// 清理WaitingChannelMap
ticker := time.NewTicker(30 * time.Second)
go func() {
for range ticker.C {
nodeLocker.Lock()
for requestId, request := range responseChanMap {
if time.Now().Unix()-request.Timestamp > 3600 {
responseChanMap[requestId].Close()
delete(responseChanMap, requestId)
}
}
nodeLocker.Unlock()
}
}()
}
// 节点stream
func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) error {
// TODO 使用此stream快速通知边缘节点更新
// 校验节点
_, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeNode)
if err != nil {
return err
}
// 返回连接成功
{
apiConfig, err := configs.SharedAPIConfig()
if err != nil {
return err
}
connectedMessage := &messageconfigs.ConnectedAPINodeMessage{APINodeId: apiConfig.NumberId()}
connectedMessageJSON, err := json.Marshal(connectedMessage)
if err != nil {
return errors.Wrap(err)
}
err = server.Send(&pb.NodeStreamMessage{
Code: messageconfigs.MessageCodeConnectedAPINode,
DataJSON: connectedMessageJSON,
})
if err != nil {
return err
}
}
logs.Println("[RPC]accepted node '" + strconv.FormatInt(nodeId, 10) + "' connection")
nodeLocker.Lock()
requestChan, ok := requestChanMap[nodeId]
if !ok {
requestChan = make(chan *CommandRequest, 1024)
requestChanMap[nodeId] = requestChan
}
nodeLocker.Unlock()
// 发送请求
go func() {
for {
select {
case <-server.Context().Done():
return
case commandRequest := <-requestChan:
logs.Println("[RPC]sending command '" + commandRequest.Code + "' to node '" + strconv.FormatInt(nodeId, 10) + "'")
retries := 3 // 错误重试次数
for i := 0; i < retries; i++ {
err := server.Send(&pb.NodeStreamMessage{
RequestId: commandRequest.Id,
Code: commandRequest.Code,
DataJSON: commandRequest.CommandJSON,
})
if err != nil {
if i == retries-1 {
logs.Println("[RPC]send command '" + commandRequest.Code + "' failed: " + err.Error())
} else {
time.Sleep(1 * time.Second)
}
} else {
break
}
}
}
}
}()
// 接受请求
for {
req, err := server.Recv()
if err != nil {
return err
}
func(req *pb.NodeStreamMessage) {
// 因为 responseChan.Chan 有被关闭的风险所以我们使用recover防止panic
defer func() {
recover()
}()
nodeLocker.Lock()
responseChan, ok := responseChanMap[req.RequestId]
if ok {
select {
case responseChan.Chan <- req:
default:
}
}
nodeLocker.Unlock()
}(req)
}
}
// 向节点发送命令
func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStreamMessage) (*pb.NodeStreamMessage, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
nodeId := req.NodeId
if nodeId <= 0 {
return nil, errors.New("node id should not be less than 0")
}
nodeLocker.Lock()
requestChan, ok := requestChanMap[nodeId]
nodeLocker.Unlock()
if !ok {
return &pb.NodeStreamMessage{
RequestId: req.RequestId,
IsOk: false,
Message: "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet",
}, nil
}
req.RequestId = NextCommandRequestId()
select {
case requestChan <- &CommandRequest{
Id: req.RequestId,
Code: req.Code,
CommandJSON: req.DataJSON,
}:
// 加入到等待队列中
respChan := make(chan *pb.NodeStreamMessage, 1)
waiting := &CommandRequestWaiting{
Timestamp: time.Now().Unix(),
Chan: respChan,
}
nodeLocker.Lock()
responseChanMap[req.RequestId] = waiting
nodeLocker.Unlock()
// 等待响应
timeoutSeconds := req.TimeoutSeconds
if timeoutSeconds <= 0 {
timeoutSeconds = 10
}
timeout := time.NewTimer(time.Duration(timeoutSeconds) * time.Second)
select {
case resp := <-respChan:
// 从队列中删除
nodeLocker.Lock()
delete(responseChanMap, req.RequestId)
waiting.Close()
nodeLocker.Unlock()
if resp == nil {
return &pb.NodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "response timeout",
IsOk: false,
}, nil
}
return resp, nil
case <-timeout.C:
// 从队列中删除
nodeLocker.Lock()
delete(responseChanMap, req.RequestId)
waiting.Close()
nodeLocker.Unlock()
return &pb.NodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds",
IsOk: false,
}, nil
}
default:
return &pb.NodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "command queue is full over " + strconv.Itoa(len(requestChan)),
IsOk: false,
}, nil
}
}