diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 46c07f24..48de096c 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -759,7 +759,8 @@ func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) { // CountAllEnabledServersMatch 计算所有可用服务数量 // 参数: -// groupId 分组ID,如果为-1,则搜索没有分组的服务 +// +// groupId 分组ID,如果为-1,则搜索没有分组的服务 func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string) (int64, error) { query := this.Query(tx). State(ServerStateEnabled) @@ -810,7 +811,8 @@ func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, ke // ListEnabledServersMatch 列出单页的服务 // 参数: -// groupId 分组ID,如果为-1,则搜索没有分组的服务 +// +// groupId 分组ID,如果为-1,则搜索没有分组的服务 func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size int64, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag int32, protocolFamilies []string, order string) (result []*Server, err error) { query := this.Query(tx). State(ServerStateEnabled). @@ -1629,6 +1631,33 @@ func (this *ServerDAO) GenerateServerDNSName(tx *dbs.Tx, serverId int64) (string return dnsName, nil } +// UpdateServerDNSName 设置CNAME +func (this *ServerDAO) UpdateServerDNSName(tx *dbs.Tx, serverId int64, dnsName string) error { + if serverId <= 0 || len(dnsName) == 0 { + return nil + } + dnsName = strings.ToLower(dnsName) + err := this.Query(tx). + Pk(serverId). + Set("dnsName", dnsName). + UpdateQuickly() + if err != nil { + return err + } + + return this.NotifyDNSUpdate(tx, serverId) +} + +// FindServerIdWithDNSName 根据CNAME查询服务ID +func (this *ServerDAO) FindServerIdWithDNSName(tx *dbs.Tx, clusterId int64, dnsName string) (int64, error) { + return this.Query(tx). + ResultPk(). + State(ServerStateEnabled). + Attr("clusterId", clusterId). + Attr("dnsName", dnsName). + FindInt64Col(0) +} + // FindServerClusterId 查询当前服务的集群ID func (this *ServerDAO) FindServerClusterId(tx *dbs.Tx, serverId int64) (int64, error) { return this.Query(tx). diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 4223497d..d09808a7 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -15,6 +15,8 @@ import ( "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" timeutil "github.com/iwind/TeaGo/utils/time" + "regexp" + "strings" ) type ServerService struct { @@ -643,8 +645,8 @@ func (this *ServerService) UpdateServerDNS(ctx context.Context, req *pb.UpdateSe return this.Success() } -// RegenerateServerCNAME 重新生成CNAME -func (this *ServerService) RegenerateServerCNAME(ctx context.Context, req *pb.RegenerateServerCNAMERequest) (*pb.RPCSuccess, error) { +// RegenerateServerDNSName 重新生成CNAME +func (this *ServerService) RegenerateServerDNSName(ctx context.Context, req *pb.RegenerateServerDNSNameRequest) (*pb.RPCSuccess, error) { _, err := this.ValidateAdmin(ctx) if err != nil { return nil, err @@ -658,6 +660,81 @@ func (this *ServerService) RegenerateServerCNAME(ctx context.Context, req *pb.Re return this.Success() } +// UpdateServerDNSName 修改服务的CNAME +func (this *ServerService) UpdateServerDNSName(ctx context.Context, req *pb.UpdateServerDNSNameRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + var dnsName = req.DnsName + + if req.ServerId <= 0 { + return nil, errors.New("invalid 'serverId'") + } + + if len(dnsName) == 0 { + return nil, errors.New("'dnsName' must not be empty") + } + + // 处理格式 + dnsName = strings.ToLower(dnsName) + const maxLen = 30 + if len(dnsName) > maxLen { + return nil, errors.New("'dnsName' too long than " + types.String(maxLen)) + } + if !regexp.MustCompile(`^[a-z0-9]{1,` + types.String(maxLen) + `}$`).MatchString(dnsName) { + return nil, errors.New("invalid 'dnsName': contains invalid character(s)") + } + + // 检查是否被使用 + clusterId, err := models.SharedServerDAO.FindServerClusterId(tx, req.ServerId) + if err != nil { + return nil, err + } + if clusterId <= 0 { + return nil, errors.New("the server is not belong to any cluster") + } + + serverId, err := models.SharedServerDAO.FindServerIdWithDNSName(tx, clusterId, dnsName) + if err != nil { + return nil, err + } + if serverId > 0 && serverId != req.ServerId { + return nil, errors.New("the 'dnsName': " + dnsName + " has already been used") + } + + err = models.SharedServerDAO.UpdateServerDNSName(tx, req.ServerId, dnsName) + if err != nil { + return nil, err + } + + return this.Success() +} + +// FindServerIdWithDNSName 使用CNAME查找服务 +func (this *ServerService) FindServerIdWithDNSName(ctx context.Context, req *pb.FindServerIdWithDNSNameRequest) (*pb.FindServerIdWithDNSNameResponse, error) { + _, err := this.ValidateAdmin(ctx) + if err != nil { + return nil, err + } + + if len(req.DnsName) == 0 { + return nil, errors.New("'dnsName' must not be empty") + } + + var tx = this.NullTx() + serverId, err := models.SharedServerDAO.FindServerIdWithDNSName(tx, req.NodeClusterId, req.DnsName) + if err != nil { + return nil, err + } + + return &pb.FindServerIdWithDNSNameResponse{ + ServerId: serverId, + }, nil +} + // CountAllEnabledServersMatch 计算服务数量 func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req *pb.CountAllEnabledServersMatchRequest) (*pb.RPCCountResponse, error) { // 校验请求