diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 3a1db33a..635b1db4 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -2697,6 +2697,72 @@ func (this *ServerDAO) UpdateServerBandwidth(tx *dbs.Tx, serverId int64, fullTim } } +// UpdateServerUserId 修改服务所属用户 +func (this *ServerDAO) UpdateServerUserId(tx *dbs.Tx, serverId int64, userId int64) error { + if serverId <= 0 { + return nil + } + + serverOne, err := this.Query(tx). + Result("https", "tls"). + Pk(serverId). + State(ServerStateEnabled). + Find() + if err != nil || serverOne == nil { + return err + } + var server = serverOne.(*Server) + + // 修改服务 + err = this.Query(tx). + Pk(serverId). + Set("userId", userId). + UpdateQuickly() + if err != nil { + return err + } + + // 修改证书相关数据 + var sslPolicyIds = []int64{} + var httpsConfig = server.DecodeHTTPS() + if httpsConfig != nil && httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { + sslPolicyIds = append(sslPolicyIds, httpsConfig.SSLPolicyRef.SSLPolicyId) + } + + var tlsConfig = server.DecodeTLS() + if tlsConfig != nil && tlsConfig.SSLPolicyRef != nil && tlsConfig.SSLPolicyRef.SSLPolicyId > 0 { + sslPolicyIds = append(sslPolicyIds, tlsConfig.SSLPolicyRef.SSLPolicyId) + } + if len(sslPolicyIds) > 0 { + for _, sslPolicyId := range sslPolicyIds { + policy, err := SharedSSLPolicyDAO.FindEnabledSSLPolicy(tx, sslPolicyId) + if err != nil { + return err + } + if policy != nil { + // 修改策略 + err = SharedSSLPolicyDAO.UpdatePolicyUser(tx, sslPolicyId, userId) + if err != nil { + return err + } + + var certRefs = policy.DecodeCerts() + for _, certRef := range certRefs { + if certRef.CertId > 0 { + // 修改证书 + err = SharedSSLCertDAO.UpdateCertUser(tx, certRef.CertId, userId) + if err != nil { + return err + } + } + } + } + } + } + + return this.NotifyUpdate(tx, serverId) +} + // NotifyUpdate 同步服务所在的集群 func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error { // 创建任务 diff --git a/internal/db/models/server_model_ext.go b/internal/db/models/server_model_ext.go index 6f482584..4b3aec19 100644 --- a/internal/db/models/server_model_ext.go +++ b/internal/db/models/server_model_ext.go @@ -42,10 +42,38 @@ func (this *Server) DecodeHTTPPorts() (ports []int) { return } +// DecodeHTTPS 解析HTTPS设置 +func (this *Server) DecodeHTTPS() *serverconfigs.HTTPSProtocolConfig { + if len(this.Https) == 0 { + return nil + } + + var config = &serverconfigs.HTTPSProtocolConfig{} + err := json.Unmarshal(this.Https, config) + if err != nil { + remotelogs.Error("Server_DecodeHTTPS", err.Error()) + } + return config +} + +// DecodeTLS 解析TLS设置 +func (this *Server) DecodeTLS() *serverconfigs.TLSProtocolConfig { + if len(this.Tls) == 0 { + return nil + } + + var config = &serverconfigs.TLSProtocolConfig{} + err := json.Unmarshal(this.Tls, config) + if err != nil { + remotelogs.Error("Server_DecodeTLS", err.Error()) + } + return config +} + // DecodeHTTPSPorts 获取HTTPS所有端口 func (this *Server) DecodeHTTPSPorts() (ports []int) { if len(this.Https) > 0 { - config := &serverconfigs.HTTPSProtocolConfig{} + var config = &serverconfigs.HTTPSProtocolConfig{} err := json.Unmarshal(this.Https, config) if err != nil { return nil diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index 7deb42fc..d2a88265 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -399,6 +399,17 @@ func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) er return nil } +// UpdateCertUser 修改证书所属用户 +func (this *SSLCertDAO) UpdateCertUser(tx *dbs.Tx, certId int64, userId int64) error { + if certId <= 0 || userId <= 0 { + return nil + } + return this.Query(tx). + Pk(certId). + Set("userId", userId). + UpdateQuickly() +} + // ListCertsToUpdateOCSP 查找需要更新OCSP的证书 func (this *SSLCertDAO) ListCertsToUpdateOCSP(tx *dbs.Tx, maxTries int, size int64) (result []*SSLCert, err error) { var nowTime = time.Now().Unix() diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index b4132c4d..7ddb272b 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -306,6 +306,18 @@ func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, userId int64, policyId int return nil } +// UpdatePolicyUser 修改策略所属用户 +func (this *SSLPolicyDAO) UpdatePolicyUser(tx *dbs.Tx, policyId int64, userId int64) error { + if policyId <= 0 || userId <= 0 { + return nil + } + + return this.Query(tx). + Pk(policyId). + Set("userId", userId). + UpdateQuickly() +} + // NotifyUpdate 通知更新 func (this *SSLPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithSSLPolicyIds(tx, []int64{policyId}) diff --git a/internal/db/models/ssl_policy_model_ext.go b/internal/db/models/ssl_policy_model_ext.go index 2640e7f9..18592875 100644 --- a/internal/db/models/ssl_policy_model_ext.go +++ b/internal/db/models/ssl_policy_model_ext.go @@ -1 +1,20 @@ package models + +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/remotelogs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" +) + +func (this *SSLPolicy) DecodeCerts() []*sslconfigs.SSLCertRef { + if len(this.Certs) == 0 { + return nil + } + + var refs = []*sslconfigs.SSLCertRef{} + err := json.Unmarshal(this.Certs, &refs) + if err != nil { + remotelogs.Error("SSLPolicy_DecodeCerts", err.Error()) + } + return refs +} diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index ce134f7e..40ab7078 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -12,6 +12,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" @@ -770,7 +771,9 @@ func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb. var tx = this.NullTx() + var fromUser = false if userId > 0 { + fromUser = true req.UserId = userId } @@ -817,15 +820,17 @@ func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb. } // 用户 - user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(server.UserId)) - if err != nil { - return nil, err - } var pbUser *pb.User = nil - if user != nil { - pbUser = &pb.User{ - Id: int64(user.Id), - Fullname: user.Fullname, + if !fromUser { + user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(server.UserId)) + if err != nil { + return nil, err + } + if user != nil { + pbUser = &pb.User{ + Id: int64(user.Id), + Fullname: user.Fullname, + } } } @@ -2239,3 +2244,28 @@ func (this *ServerService) ComposeServerConfig(ctx context.Context, req *pb.Comp } return &pb.ComposeServerConfigResponse{ServerConfigJSON: configJSON}, nil } + +// UpdateServerUser 修改服务所属用户 +func (this *ServerService) UpdateServerUser(ctx context.Context, req *pb.UpdateServerUserRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx) + if err != nil { + return nil, err + } + + if req.ServerId <= 0 { + return nil, errors.New("invalid serverId") + } + + if req.UserId <= 0 { + return nil, errors.New("invalid userId") + } + + err = this.RunTx(func(tx *dbs.Tx) error { + return models.SharedServerDAO.UpdateServerUserId(tx, req.ServerId, req.UserId) + }) + if err != nil { + return nil, err + } + + return this.Success() +}