diff --git a/internal/db/models/dns_domain_dao.go b/internal/db/models/dns_domain_dao.go index ef415d9b..5b124c7d 100644 --- a/internal/db/models/dns_domain_dao.go +++ b/internal/db/models/dns_domain_dao.go @@ -7,7 +7,9 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" + "strings" "time" ) @@ -213,3 +215,34 @@ func (this *DNSDomainDAO) ExistAvailableDomains() (bool, error) { Where("providerId IN (" + subQuery + ")"). Exist() } + +// 检查域名解析记录是否存在 +func (this *DNSDomainDAO) ExistDomainRecord(domainId int64, recordName string, recordType string, recordRoute string, recordValue string) (bool, error) { + query := maps.Map{ + "name": recordName, + "type": recordType, + } + if len(recordRoute) > 0 { + query["route"] = recordRoute + } + if len(recordValue) > 0 { + query["value"] = recordValue + + // CNAME兼容点(.)符号 + if recordType == "CNAME" && !strings.HasSuffix(recordValue, ".") { + b, err := this.ExistDomainRecord(domainId, recordName, recordType, recordRoute, recordValue+".") + if err != nil { + return false, err + } + if b { + return true, nil + } + } + } + recordType = strings.ToUpper(recordType) + return this.Query(). + Pk(domainId). + Where("JSON_CONTAINS(records, :query)"). + Param("query", query.AsJSON()). + Exist() +} diff --git a/internal/db/models/dns_domain_dao_test.go b/internal/db/models/dns_domain_dao_test.go index 97c24b56..36ae0232 100644 --- a/internal/db/models/dns_domain_dao_test.go +++ b/internal/db/models/dns_domain_dao_test.go @@ -2,4 +2,36 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "testing" ) + +func TestDNSDomainDAO_ExistDomainRecord(t *testing.T) { + { + b, err := NewDNSDomainDAO().ExistDomainRecord(1, "mycluster", "A", "", "") + if err != nil { + t.Fatal(err) + } + t.Log(b) + } + { + b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "A", "", "") + if err != nil { + t.Fatal(err) + } + t.Log(b) + } + { + b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "MX", "", "") + if err != nil { + t.Fatal(err) + } + t.Log(b) + } + { + b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster123", "A", "", "") + if err != nil { + t.Fatal(err) + } + t.Log(b) + } +} diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index 4b57d8ef..d405da53 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -685,6 +685,19 @@ func (this *NodeDAO) FindAllEnabledNodesDNSWithClusterId(clusterId int64) (resul return } +// 计算一个集群的节点DNS数量 +func (this *NodeDAO) CountAllEnabledNodesDNSWithClusterId(clusterId int64) (result int64, err error) { + return this.Query(). + State(NodeStateEnabled). + Attr("clusterId", clusterId). + Attr("isOn", true). + Attr("isUp", true). + Result("id", "name", "dnsRoutes", "isOn"). + DescPk(). + Slice(&result). + Count() +} + // 获取单个节点的DNS信息 func (this *NodeDAO) FindEnabledNodeDNS(nodeId int64) (*Node, error) { one, err := this.Query(). diff --git a/internal/rpc/services/service_dns_domain.go b/internal/rpc/services/service_dns_domain.go index 3422222d..0ebaf63e 100644 --- a/internal/rpc/services/service_dns_domain.go +++ b/internal/rpc/services/service_dns_domain.go @@ -300,13 +300,18 @@ func (this *DNSDomainService) convertDomainToPB(domain *models.DNSDomain) (*pb.D if err != nil { return nil, err } + countClusters := len(clusters) + countAllNodes1 := int64(0) + countAllServers1 := int64(0) for _, cluster := range clusters { - _, nodeRecords, serverRecords, nodesChanged2, serversChanged2, err := this.findClusterDNSChanges(cluster, records, domain.Name) + _, nodeRecords, serverRecords, countAllNodes, countAllServers, nodesChanged2, serversChanged2, err := this.findClusterDNSChanges(cluster, records, domain.Name) if err != nil { return nil, err } countNodeRecords += len(nodeRecords) countServerRecords += len(serverRecords) + countAllNodes1 += countAllNodes + countAllServers1 += countAllServers if nodesChanged2 { nodesChanged = true } @@ -339,6 +344,9 @@ func (this *DNSDomainService) convertDomainToPB(domain *models.DNSDomain) (*pb.D CountServerRecords: int64(countServerRecords), ServersChanged: serversChanged, Routes: pbRoutes, + CountNodeClusters: int64(countClusters), + CountAllNodes: countAllNodes1, + CountAllServers: countAllServers1, }, nil } @@ -354,7 +362,7 @@ func (this *DNSDomainService) convertRecordToPB(record *dnsclients.Record) *pb.D } // 检查集群节点变化 -func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnsclients.Record, domainName string) (result []maps.Map, doneNodeRecords []*dnsclients.Record, doneServerRecords []*dnsclients.Record, nodesChanged bool, serversChanged bool, err error) { +func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnsclients.Record, domainName string) (result []maps.Map, doneNodeRecords []*dnsclients.Record, doneServerRecords []*dnsclients.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) { clusterId := int64(cluster.Id) clusterDnsName := cluster.DnsName clusterDomain := clusterDnsName + "." + domainName @@ -362,8 +370,9 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, // 节点域名 nodes, err := models.SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(clusterId) if err != nil { - return nil, nil, nil, false, false, err + return nil, nil, nil, 0, 0, false, false, err } + countAllNodes = int64(len(nodes)) nodeRecords := []*dnsclients.Record{} // 之所以用数组再存一遍,是因为dnsName可能会重复 nodeRecordMapping := map[string]*dnsclients.Record{} // value_route => *Record for _, record := range records { @@ -378,14 +387,14 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, for _, node := range nodes { ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddress(int64(node.Id)) if err != nil { - return nil, nil, nil, false, false, err + return nil, nil, nil, 0, 0, false, false, err } if len(ipAddr) == 0 { continue } routeCodes, err := node.DNSRouteCodesForDomainId(int64(cluster.DnsDomainId)) if err != nil { - return nil, nil, nil, false, false, err + return nil, nil, nil, 0, 0, false, false, err } if len(routeCodes) == 0 { continue @@ -427,8 +436,9 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, // 服务域名 servers, err := models.SharedServerDAO.FindAllServersDNSWithClusterId(clusterId) if err != nil { - return nil, nil, nil, false, false, err + return nil, nil, nil, 0, 0, false, false, err } + countAllServers = int64(len(servers)) serverRecords := []*dnsclients.Record{} // 之所以用数组再存一遍,是因为dnsName可能会重复 serverRecordsMap := map[string]*dnsclients.Record{} // dnsName => *Record for _, record := range records { @@ -443,7 +453,7 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, for _, server := range servers { dnsName := server.DnsName if len(dnsName) == 0 { - return nil, nil, nil, false, false, errors.New("server '" + numberutils.FormatInt64(int64(server.Id)) + "' 'dnsName' should not empty") + return nil, nil, nil, 0, 0, false, false, errors.New("server '" + numberutils.FormatInt64(int64(server.Id)) + "' 'dnsName' should not empty") } serverDNSNames = append(serverDNSNames, dnsName) record, ok := serverRecordsMap[dnsName] @@ -589,7 +599,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( // 对比变化 allChanges := []maps.Map{} for _, cluster := range clusters { - changes, _, _, _, _, err := this.findClusterDNSChanges(cluster, records, domainName) + changes, _, _, _, _, _, _, err := this.findClusterDNSChanges(cluster, records, domainName) if err != nil { return nil, err } @@ -639,3 +649,17 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( IsOk: true, }, nil } + +// 检查域名是否在记录中 +func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb.ExistDNSDomainRecordRequest) (*pb.ExistDNSDomainRecordResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + isOk, err := models.SharedDNSDomainDAO.ExistDomainRecord(req.DnsDomainId, req.Name, req.Type, req.Route, req.Value) + if err != nil { + return nil, err + } + return &pb.ExistDNSDomainRecordResponse{IsOk: isOk}, nil +} diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index 3592127d..79199c5b 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -88,6 +88,8 @@ func (this *NodeClusterService) FindEnabledNodeCluster(ctx context.Context, req Secret: cluster.Secret, HttpCachePolicyId: int64(cluster.CachePolicyId), HttpFirewallPolicyId: int64(cluster.HttpFirewallPolicyId), + DnsName: cluster.DnsName, + DnsDomainId: int64(cluster.DnsDomainId), }}, nil } @@ -463,6 +465,31 @@ func (this *NodeClusterService) CountAllEnabledNodeClustersWithDNSDomainId(ctx c return this.SuccessCount(count) } +// 查找使用某个域名的所有集群 +func (this *NodeClusterService) FindAllEnabledNodeClustersWithDNSDomainId(ctx context.Context, req *pb.FindAllEnabledNodeClustersWithDNSDomainIdRequest) (*pb.FindAllEnabledNodeClustersWithDNSDomainIdResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(req.DnsDomainId) + if err != nil { + return nil, err + } + + result := []*pb.NodeCluster{} + for _, cluster := range clusters { + result = append(result, &pb.NodeCluster{ + Id: int64(cluster.Id), + Name: cluster.Name, + DnsName: cluster.DnsName, + DnsDomainId: int64(cluster.DnsDomainId), + }) + } + return &pb.FindAllEnabledNodeClustersWithDNSDomainIdResponse{NodeClusters: result}, nil +} + // 检查集群域名是否已经被使用 func (this *NodeClusterService) CheckNodeClusterDNSName(ctx context.Context, req *pb.CheckNodeClusterDNSNameRequest) (*pb.CheckNodeClusterDNSNameResponse, error) { // 校验请求 @@ -524,7 +551,7 @@ func (this *NodeClusterService) CheckNodeClusterDNSChanges(ctx context.Context, } service := &DNSDomainService{} - changes, _, _, _, _, err := service.findClusterDNSChanges(cluster, records, domain.Name) + changes, _, _, _, _, _, _, err := service.findClusterDNSChanges(cluster, records, domain.Name) if err != nil { return nil, err }