diff --git a/internal/db/models/node_cluster_dao.go b/internal/db/models/node_cluster_dao.go index ae9e0d10..84c47e61 100644 --- a/internal/db/models/node_cluster_dao.go +++ b/internal/db/models/node_cluster_dao.go @@ -198,7 +198,7 @@ func (this *NodeClusterDAO) CreateCluster(tx *dbs.Tx, adminId int64, name string } // UpdateCluster 修改集群 -func (this *NodeClusterDAO) UpdateCluster(tx *dbs.Tx, clusterId int64, name string, grantId int64, installDir string, timezone string, nodeMaxThreads int32, autoOpenPorts bool, clockConfig *nodeconfigs.ClockConfig, autoRemoteStart bool, autoInstallTables bool) error { +func (this *NodeClusterDAO) UpdateCluster(tx *dbs.Tx, clusterId int64, name string, grantId int64, installDir string, timezone string, nodeMaxThreads int32, autoOpenPorts bool, clockConfig *nodeconfigs.ClockConfig, autoRemoteStart bool, autoInstallTables bool, sshParams *nodeconfigs.SSHParams) error { if clusterId <= 0 { return errors.New("invalid clusterId") } @@ -226,6 +226,14 @@ func (this *NodeClusterDAO) UpdateCluster(tx *dbs.Tx, clusterId int64, name stri op.AutoRemoteStart = autoRemoteStart op.AutoInstallNftables = autoInstallTables + if sshParams != nil { + sshParamsJSON, err := json.Marshal(sshParams) + if err != nil { + return err + } + op.SshParams = sshParamsJSON + } + err := this.Save(tx, op) if err != nil { return err @@ -454,6 +462,27 @@ func (this *NodeClusterDAO) FindClusterGrantId(tx *dbs.Tx, clusterId int64) (int FindInt64Col(0) } +// FindClusterSSHParams 查找集群的SSH默认参数 +func (this *NodeClusterDAO) FindClusterSSHParams(tx *dbs.Tx, clusterId int64) (*nodeconfigs.SSHParams, error) { + sshParamsJSON, err := this.Query(tx). + Pk(clusterId). + Result("sshParams"). + FindJSONCol() + if err != nil { + return nil, err + } + + var params = nodeconfigs.DefaultSSHParams() + if len(sshParamsJSON) == 0 { + return params, nil + } + err = json.Unmarshal(sshParamsJSON, params) + if err != nil { + return nil, err + } + return params, nil +} + // FindClusterDNSInfo 查找DNS信息 func (this *NodeClusterDAO) FindClusterDNSInfo(tx *dbs.Tx, clusterId int64, cacheMap *utils.CacheMap) (*NodeCluster, error) { var cacheKey = this.Table + ":FindClusterDNSInfo:" + types.String(clusterId) diff --git a/internal/db/models/node_cluster_model.go b/internal/db/models/node_cluster_model.go index c334b247..4d519383 100644 --- a/internal/db/models/node_cluster_model.go +++ b/internal/db/models/node_cluster_model.go @@ -15,6 +15,7 @@ type NodeCluster struct { Order uint32 `field:"order"` // 排序 CreatedAt uint64 `field:"createdAt"` // 创建时间 GrantId uint32 `field:"grantId"` // 默认认证方式 + SshParams dbs.JSON `field:"sshParams"` // SSH默认参数 State uint8 `field:"state"` // 状态 AutoRegister uint8 `field:"autoRegister"` // 是否开启自动注册 UniqueId string `field:"uniqueId"` // 唯一ID @@ -53,6 +54,7 @@ type NodeClusterOperator struct { Order any // 排序 CreatedAt any // 创建时间 GrantId any // 默认认证方式 + SshParams any // SSH默认参数 State any // 状态 AutoRegister any // 是否开启自动注册 UniqueId any // 唯一ID diff --git a/internal/installers/queue_node.go b/internal/installers/queue_node.go index cac85b87..5a669846 100644 --- a/internal/installers/queue_node.go +++ b/internal/installers/queue_node.go @@ -27,7 +27,7 @@ func SharedNodeQueue() *NodeQueue { // InstallNodeProcess 安装边缘节点流程控制 func (this *NodeQueue) InstallNodeProcess(nodeId int64, isUpgrading bool) error { - installStatus := models.NewNodeInstallStatus() + var installStatus = models.NewNodeInstallStatus() installStatus.IsRunning = true installStatus.UpdatedAt = time.Now().Unix() @@ -37,7 +37,7 @@ func (this *NodeQueue) InstallNodeProcess(nodeId int64, isUpgrading bool) error } // 更新时间 - ticker := utils.NewTicker(3 * time.Second) + var ticker = utils.NewTicker(3 * time.Second) goman.New(func() { for ticker.Wait() { installStatus.UpdatedAt = time.Now().Unix() @@ -104,13 +104,31 @@ func (this *NodeQueue) InstallNode(nodeId int64, installStatus *models.NodeInsta } if len(loginParams.Host) == 0 { - installStatus.ErrorCode = "EMPTY_SSH_HOST" - return errors.New("ssh host should not be empty") + // 查询节点IP + ip, _, err := models.SharedNodeIPAddressDAO.FindFirstNodeAccessIPAddress(nil, nodeId, false, nodeconfigs.NodeRoleNode) + if err != nil { + return err + } + if len(ip) > 0 { + loginParams.Host = ip + } else { + installStatus.ErrorCode = "EMPTY_SSH_HOST" + return errors.New("ssh host should not be empty") + } } if loginParams.Port <= 0 { - installStatus.ErrorCode = "EMPTY_SSH_PORT" - return errors.New("ssh port is invalid") + // 从集群中读取 + sshParams, err := models.SharedNodeClusterDAO.FindClusterSSHParams(nil, int64(node.ClusterId)) + if err != nil { + return err + } + if sshParams != nil && sshParams.Port > 0 { + loginParams.Port = sshParams.Port + } else { + installStatus.ErrorCode = "EMPTY_SSH_PORT" + return errors.New("ssh port is invalid") + } } if loginParams.GrantId == 0 { @@ -161,7 +179,7 @@ func (this *NodeQueue) InstallNode(nodeId int64, installStatus *models.NodeInsta IsUpgrading: isUpgrading, } - installer := &NodeInstaller{} + var installer = &NodeInstaller{} err = installer.Login(&Credentials{ Host: loginParams.Host, Port: loginParams.Port, @@ -226,11 +244,29 @@ func (this *NodeQueue) StartNode(nodeId int64) error { } if len(loginParams.Host) == 0 { - return newGrantError("ssh host should not be empty") + // 查询节点IP + ip, _, err := models.SharedNodeIPAddressDAO.FindFirstNodeAccessIPAddress(nil, nodeId, false, nodeconfigs.NodeRoleNode) + if err != nil { + return err + } + if len(ip) > 0 { + loginParams.Host = ip + } else { + return newGrantError("ssh host should not be empty") + } } if loginParams.Port <= 0 { - return newGrantError("ssh port is invalid") + // 从集群中读取 + sshParams, err := models.SharedNodeClusterDAO.FindClusterSSHParams(nil, int64(node.ClusterId)) + if err != nil { + return err + } + if sshParams != nil && sshParams.Port > 0 { + loginParams.Port = sshParams.Port + } else { + return newGrantError("ssh port is invalid") + } } if loginParams.GrantId == 0 { @@ -315,11 +351,29 @@ func (this *NodeQueue) StopNode(nodeId int64) error { } if len(loginParams.Host) == 0 { - return errors.New("ssh host should not be empty") + // 查询节点IP + ip, _, err := models.SharedNodeIPAddressDAO.FindFirstNodeAccessIPAddress(nil, nodeId, false, nodeconfigs.NodeRoleNode) + if err != nil { + return err + } + if len(ip) > 0 { + loginParams.Host = ip + } else { + return errors.New("ssh host should not be empty") + } } if loginParams.Port <= 0 { - return errors.New("ssh port is invalid") + // 从集群中读取 + sshParams, err := models.SharedNodeClusterDAO.FindClusterSSHParams(nil, int64(node.ClusterId)) + if err != nil { + return err + } + if sshParams != nil && sshParams.Port > 0 { + loginParams.Port = sshParams.Port + } else { + return errors.New("ssh port is invalid") + } } if loginParams.GrantId == 0 { @@ -341,7 +395,7 @@ func (this *NodeQueue) StopNode(nodeId int64) error { return errors.New("can not find user grant with id '" + numberutils.FormatInt64(loginParams.GrantId) + "'") } - installer := &NodeInstaller{} + var installer = &NodeInstaller{} err = installer.Login(&Credentials{ Host: loginParams.Host, Port: loginParams.Port, @@ -386,7 +440,7 @@ func (this *NodeQueue) lookupNodeExe(node *models.Node, client *SSHClient) (stri if len(node.InstallDir) > 0 { nodeDirs = append(nodeDirs, node.InstallDir) } - clusterId := node.ClusterId + var clusterId = node.ClusterId cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(nil, int64(clusterId)) if err != nil { return "", err diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index ada5c297..88f66e5a 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -116,7 +116,17 @@ func (this *NodeClusterService) UpdateNodeCluster(ctx context.Context, req *pb.U } } - err = models.SharedNodeClusterDAO.UpdateCluster(tx, req.NodeClusterId, req.Name, req.NodeGrantId, req.InstallDir, req.TimeZone, req.NodeMaxThreads, req.AutoOpenPorts, clockConfig, req.AutoRemoteStart, req.AutoInstallNftables) + // ssh params + var sshParams *nodeconfigs.SSHParams + if len(req.SshParamsJSON) > 0 { + sshParams = nodeconfigs.DefaultSSHParams() + err = json.Unmarshal(req.SshParamsJSON, sshParams) + if err != nil { + return nil, err + } + } + + err = models.SharedNodeClusterDAO.UpdateCluster(tx, req.NodeClusterId, req.Name, req.NodeGrantId, req.InstallDir, req.TimeZone, req.NodeMaxThreads, req.AutoOpenPorts, clockConfig, req.AutoRemoteStart, req.AutoInstallNftables, sshParams) if err != nil { return nil, err } @@ -208,6 +218,7 @@ func (this *NodeClusterService) FindEnabledNodeCluster(ctx context.Context, req CreatedAt: int64(cluster.CreatedAt), InstallDir: cluster.InstallDir, NodeGrantId: int64(cluster.GrantId), + SshParamsJSON: cluster.SshParams, UniqueId: cluster.UniqueId, Secret: cluster.Secret, HttpCachePolicyId: int64(cluster.CachePolicyId),