diff --git a/internal/rpc/services/service_node_grant.go b/internal/rpc/services/service_node_grant.go index d94df653..97b82a91 100644 --- a/internal/rpc/services/service_node_grant.go +++ b/internal/rpc/services/service_node_grant.go @@ -3,15 +3,20 @@ package services import ( "context" "errors" + "fmt" "github.com/TeaOSLab/EdgeAPI/internal/db/models" - rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "golang.org/x/crypto/ssh" + "net" + "time" ) type NodeGrantService struct { BaseService } +// CreateNodeGrant 创建认证 func (this *NodeGrantService) CreateNodeGrant(ctx context.Context, req *pb.CreateNodeGrantRequest) (*pb.CreateNodeGrantResponse, error) { adminId, err := this.ValidateAdmin(ctx, 0) if err != nil { @@ -25,40 +30,43 @@ func (this *NodeGrantService) CreateNodeGrant(ctx context.Context, req *pb.Creat return nil, err } return &pb.CreateNodeGrantResponse{ - GrantId: grantId, + NodeGrantId: grantId, }, err } +// UpdateNodeGrant 修改认证 func (this *NodeGrantService) UpdateNodeGrant(ctx context.Context, req *pb.UpdateNodeGrantRequest) (*pb.RPCSuccess, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } - if req.GrantId <= 0 { + if req.NodeGrantId <= 0 { return nil, errors.New("wrong grantId") } tx := this.NullTx() - err = models.SharedNodeGrantDAO.UpdateGrant(tx, req.GrantId, req.Name, req.Method, req.Username, req.Password, req.PrivateKey, req.Description, req.NodeId) + err = models.SharedNodeGrantDAO.UpdateGrant(tx, req.NodeGrantId, req.Name, req.Method, req.Username, req.Password, req.PrivateKey, req.Description, req.NodeId) return this.Success() } +// DisableNodeGrant 禁用认证 func (this *NodeGrantService) DisableNodeGrant(ctx context.Context, req *pb.DisableNodeGrantRequest) (*pb.DisableNodeGrantResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } tx := this.NullTx() - err = models.SharedNodeGrantDAO.DisableNodeGrant(tx, req.GrantId) + err = models.SharedNodeGrantDAO.DisableNodeGrant(tx, req.NodeGrantId) return &pb.DisableNodeGrantResponse{}, err } +// CountAllEnabledNodeGrants 计算认证的数量 func (this *NodeGrantService) CountAllEnabledNodeGrants(ctx context.Context, req *pb.CountAllEnabledNodeGrantsRequest) (*pb.RPCCountResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } @@ -72,8 +80,9 @@ func (this *NodeGrantService) CountAllEnabledNodeGrants(ctx context.Context, req return this.SuccessCount(count) } +// ListEnabledNodeGrants 列出单页认证 func (this *NodeGrantService) ListEnabledNodeGrants(ctx context.Context, req *pb.ListEnabledNodeGrantsRequest) (*pb.ListEnabledNodeGrantsResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } @@ -98,12 +107,12 @@ func (this *NodeGrantService) ListEnabledNodeGrants(ctx context.Context, req *pb }) } - return &pb.ListEnabledNodeGrantsResponse{Grants: result}, nil + return &pb.ListEnabledNodeGrantsResponse{NodeGrants: result}, nil } -// 列出所有认证信息 +// FindAllEnabledNodeGrants 列出所有认证信息 func (this *NodeGrantService) FindAllEnabledNodeGrants(ctx context.Context, req *pb.FindAllEnabledNodeGrantsRequest) (*pb.FindAllEnabledNodeGrantsResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } @@ -125,23 +134,24 @@ func (this *NodeGrantService) FindAllEnabledNodeGrants(ctx context.Context, req }) } - return &pb.FindAllEnabledNodeGrantsResponse{Grants: result}, nil + return &pb.FindAllEnabledNodeGrantsResponse{NodeGrants: result}, nil } -func (this *NodeGrantService) FindEnabledGrant(ctx context.Context, req *pb.FindEnabledGrantRequest) (*pb.FindEnabledGrantResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) +// FindEnabledNodeGrant 获取单个认证信息 +func (this *NodeGrantService) FindEnabledNodeGrant(ctx context.Context, req *pb.FindEnabledNodeGrantRequest) (*pb.FindEnabledNodeGrantResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } - grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(this.NullTx(), req.GrantId) + grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(this.NullTx(), req.NodeGrantId) if err != nil { return nil, err } if grant == nil { - return &pb.FindEnabledGrantResponse{}, nil + return &pb.FindEnabledNodeGrantResponse{}, nil } - return &pb.FindEnabledGrantResponse{Grant: &pb.NodeGrant{ + return &pb.FindEnabledNodeGrantResponse{NodeGrant: &pb.NodeGrant{ Id: int64(grant.Id), Name: grant.Name, Method: grant.Method, @@ -153,3 +163,97 @@ func (this *NodeGrantService) FindEnabledGrant(ctx context.Context, req *pb.Find NodeId: int64(grant.NodeId), }}, nil } + +// TestNodeGrant 测试连接 +func (this *NodeGrantService) TestNodeGrant(ctx context.Context, req *pb.TestNodeGrantRequest) (*pb.TestNodeGrantResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + var hostKeyCallback ssh.HostKeyCallback = nil + + resp := &pb.TestNodeGrantResponse{ + IsOk: false, + Error: "", + } + + var tx = this.NullTx() + grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(tx, req.NodeGrantId) + if err != nil { + return nil, err + } + if grant == nil { + resp.Error = "can not find grant with id '" + numberutils.FormatInt64(req.NodeGrantId) + "'" + return resp, nil + } + + // 检查参数 + if len(req.Host) == 0 { + resp.Error = "'host' should not be empty" + return resp, nil + } + if req.Port <= 0 { + resp.Error = "'port' should be greater than 0" + return resp, nil + } + + if len(grant.Password) == 0 && len(grant.PrivateKey) == 0 { + resp.Error = "require user 'password' or 'privateKey'" + return resp, nil + } + + // 不使用known_hosts + if hostKeyCallback == nil { + hostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + } + } + + // 认证 + methods := []ssh.AuthMethod{} + if len(grant.Password) > 0 { + { + authMethod := ssh.Password(grant.Password) + methods = append(methods, authMethod) + } + + { + authMethod := ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) == 0 { + return []string{}, nil + } + return []string{grant.Password}, nil + }) + methods = append(methods, authMethod) + } + } else { + signer, err := ssh.ParsePrivateKey([]byte(grant.PrivateKey)) + if err != nil { + resp.Error = "parse private key: " + err.Error() + return resp, nil + } + authMethod := ssh.PublicKeys(signer) + methods = append(methods, authMethod) + } + + // SSH客户端 + config := &ssh.ClientConfig{ + User: grant.Username, + Auth: methods, + HostKeyCallback: hostKeyCallback, + Timeout: 5 * time.Second, // TODO 后期可以设置这个超时时间 + } + + sshClient, err := ssh.Dial("tcp", req.Host+":"+fmt.Sprintf("%d", req.Port), config) + if err != nil { + resp.Error = "connect failed: " + err.Error() + return resp, nil + } + defer func() { + _ = sshClient.Close() + }() + + resp.IsOk = true + return resp, nil +}