From 61b5316a1fcdfe72551db848ea1ca1bd4e80def1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Wed, 23 Nov 2022 20:13:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/node.go | 24 +++++++++--------------- internal/rpc/rpc_client.go | 7 ++++++- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 3faaef5..900d4f0 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -2,7 +2,6 @@ package nodes import ( "bytes" - "context" "encoding/json" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" @@ -317,8 +316,7 @@ func (this *Node) loop() error { return errors.New("create rpc client failed: " + err.Error()) } - var nodeCtx = rpcClient.Context() - tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(nodeCtx, &pb.FindNodeTasksRequest{ + tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{ Version: this.lastTaskVersion, }) if err != nil { @@ -328,7 +326,7 @@ func (this *Node) loop() error { return errors.New("read node tasks failed: " + err.Error()) } for _, task := range tasksResp.NodeTasks { - err := this.execTask(rpcClient, nodeCtx, task) + err := this.execTask(rpcClient, task) if !this.finishTask(task.Id, task.Version, err) { // 防止失败的任务无法重试 break @@ -339,7 +337,7 @@ func (this *Node) loop() error { } // 执行任务 -func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, task *pb.NodeTask) error { +func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error { switch task.Type { case "ipItemChanged": // 防止阻塞 @@ -369,7 +367,7 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta return errors.New("reload common scripts failed: " + err.Error()) } case "nodeLevelChanged": - levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(nodeCtx, &pb.FindNodeLevelInfoRequest{}) + levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(rpcClient.Context(), &pb.FindNodeLevelInfoRequest{}) if err != nil { return err } @@ -390,7 +388,7 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta sharedNodeConfig.ParentNodes = parentNodes } case "ddosProtectionChanged": - resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{}) + resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(rpcClient.Context(), &pb.FindNodeDDoSProtectionRequest{}) if err != nil { return err } @@ -418,7 +416,7 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta return nil } case "globalServerConfigChanged": - resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(nodeCtx, &pb.FindNodeGlobalServerConfigRequest{}) + resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(rpcClient.Context(), &pb.FindNodeGlobalServerConfigRequest{}) if err != nil { return err } @@ -441,7 +439,7 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta } case "userServersStateChanged": if task.UserId > 0 { - resp, err := rpcClient.UserRPC.CheckUserServersState(nodeCtx, &pb.CheckUserServersStateRequest{UserId: task.UserId}) + resp, err := rpcClient.UserRPC.CheckUserServersState(rpcClient.Context(), &pb.CheckUserServersStateRequest{UserId: task.UserId}) if err != nil { return err } @@ -474,8 +472,6 @@ func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (su return false } - var nodeCtx = rpcClient.Context() - var isOk = taskErr == nil if isOk && taskVersion > this.lastTaskVersion { this.lastTaskVersion = taskVersion @@ -486,7 +482,7 @@ func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (su errMsg = taskErr.Error() } - _, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{ + _, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{ NodeTaskId: taskId, IsOk: isOk, Error: errMsg, @@ -533,10 +529,8 @@ func (this *Node) syncConfig(taskVersion int64) error { } // 获取同步任务 - var nodeCtx = rpcClient.Context() - // TODO 这里考虑只同步版本号有变更的 - configResp, err := rpcClient.NodeRPC.FindCurrentNodeConfig(nodeCtx, &pb.FindCurrentNodeConfigRequest{ + configResp, err := rpcClient.NodeRPC.FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{ Version: -1, // 更新所有版本 Compress: true, NodeTaskVersion: taskVersion, diff --git a/internal/rpc/rpc_client.go b/internal/rpc/rpc_client.go index 0975b72..aab1678 100644 --- a/internal/rpc/rpc_client.go +++ b/internal/rpc/rpc_client.go @@ -262,7 +262,12 @@ func (this *RPCClient) pickConn() *grpc.ClientConn { defer this.locker.Unlock() // 检查连接状态 - if len(this.conns) > 0 { + var countConns = len(this.conns) + if countConns > 0 { + if countConns == 1 { + return this.conns[0] + } + for _, stateArray := range [][2]connectivity.State{ {connectivity.Ready, connectivity.Idle}, // 优先Ready和Idle {connectivity.Connecting, connectivity.Connecting},