diff --git a/internal/rpc/rpc_client.go b/internal/rpc/rpc_client.go index 69687d00..a58a4f2a 100644 --- a/internal/rpc/rpc_client.go +++ b/internal/rpc/rpc_client.go @@ -11,6 +11,7 @@ import ( "github.com/TeaOSLab/EdgeAdmin/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/dao" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/rands" "google.golang.org/grpc" @@ -19,7 +20,9 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/metadata" + "net" "net/url" + "strings" "sync" "time" ) @@ -38,7 +41,7 @@ func NewRPCClient(apiConfig *configs.APIConfig, isPrimary bool) (*RPCClient, err return nil, errors.New("api config should not be nil") } - client := &RPCClient{ + var client = &RPCClient{ apiConfig: apiConfig, } @@ -486,8 +489,8 @@ func (this *RPCClient) TrafficDailyStatRPC() pb.TrafficDailyStatServiceClient { // Context 构造Admin上下文 func (this *RPCClient) Context(adminId int64) context.Context { - ctx := context.Background() - m := maps.Map{ + var ctx = context.Background() + var m = maps.Map{ "timestamp": time.Now().Unix(), "type": "admin", "userId": adminId, @@ -502,15 +505,15 @@ func (this *RPCClient) Context(adminId int64) context.Context { utils.PrintError(err) return context.Background() } - token := base64.StdEncoding.EncodeToString(data) + var token = base64.StdEncoding.EncodeToString(data) ctx = metadata.AppendToOutgoingContext(ctx, "nodeId", this.apiConfig.NodeId, "token", token) return ctx } // APIContext 构造API上下文 func (this *RPCClient) APIContext(apiNodeId int64) context.Context { - ctx := context.Background() - m := maps.Map{ + var ctx = context.Background() + var m = maps.Map{ "timestamp": time.Now().Unix(), "type": "api", "userId": apiNodeId, @@ -525,7 +528,7 @@ func (this *RPCClient) APIContext(apiNodeId int64) context.Context { utils.PrintError(err) return context.Background() } - token := base64.StdEncoding.EncodeToString(data) + var token = base64.StdEncoding.EncodeToString(data) ctx = metadata.AppendToOutgoingContext(ctx, "nodeId", this.apiConfig.NodeId, "token", token) return ctx } @@ -542,20 +545,39 @@ func (this *RPCClient) UpdateConfig(config *configs.APIConfig) error { // 初始化 func (this *RPCClient) init() error { + // 当前的IP地址 + var localIPAddrs = this.localIPAddrs() + // 重新连接 - conns := []*grpc.ClientConn{} + var conns = []*grpc.ClientConn{} for _, endpoint := range this.apiConfig.RPC.Endpoints { u, err := url.Parse(endpoint) if err != nil { return errors.New("parse endpoint failed: " + err.Error()) } + + var apiHost = u.Host + + // 如果本机,则将地址修改为回路地址 + if lists.ContainsString(localIPAddrs, u.Hostname()) { + if strings.Contains(apiHost, "[") { // IPv6 [host]:port + apiHost = "[::1]" + } else { + apiHost = "127.0.0.1" + } + var port = u.Port() + if len(port) > 0 { + apiHost += ":" + port + } + } + var conn *grpc.ClientConn var callOptions = grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(128*1024*1024), grpc.UseCompressor(gzip.Name)) if u.Scheme == "http" { - conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions) + conn, err = grpc.Dial(apiHost, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions) } else if u.Scheme == "https" { - conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + conn, err = grpc.Dial(apiHost, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ InsecureSkipVerify: true, })), callOptions) } else { @@ -582,7 +604,7 @@ func (this *RPCClient) pickConn() *grpc.ClientConn { // 检查连接状态 if len(this.conns) > 0 { - availableConns := []*grpc.ClientConn{} + var availableConns = []*grpc.ClientConn{} for _, state := range []connectivity.State{connectivity.Ready, connectivity.Idle, connectivity.Connecting} { for _, conn := range this.conns { if conn.GetState() == state { @@ -634,3 +656,18 @@ func (this *RPCClient) Close() error { return lastErr } + +func (this *RPCClient) localIPAddrs() []string { + localInterfaceAddrs, err := net.InterfaceAddrs() + var localIPAddrs = []string{} + if err == nil { + for _, addr := range localInterfaceAddrs { + var addrString = addr.String() + var index = strings.Index(addrString, "/") + if index > 0 { + localIPAddrs = append(localIPAddrs, addrString[:index]) + } + } + } + return localIPAddrs +}