diff --git a/internal/db/models/dns/dns_domain_dao.go b/internal/db/models/dns/dns_domain_dao.go index f35d3dba..6673703d 100644 --- a/internal/db/models/dns/dns_domain_dao.go +++ b/internal/db/models/dns/dns_domain_dao.go @@ -86,6 +86,7 @@ func (this *DNSDomainDAO) CreateDomain(tx *dbs.Tx, adminId int64, userId int64, op.Name = name op.State = DNSDomainStateEnabled op.IsOn = true + op.IsUp = true err := this.Save(tx, op) if err != nil { return 0, err @@ -247,3 +248,25 @@ func (this *DNSDomainDAO) ExistDomainRecord(tx *dbs.Tx, domainId int64, recordNa Param("query", query.AsJSON()). Exist() } + +// FindEnabledDomainWithName 根据名称查找某个域名 +func (this *DNSDomainDAO) FindEnabledDomainWithName(tx *dbs.Tx, providerId int64, domainName string) (*DNSDomain, error) { + one, err := this.Query(tx). + State(DNSDomainStateEnabled). + Attr("isOn", true). + Attr("providerId", providerId). + Attr("name", domainName). + Find() + if one != nil { + return one.(*DNSDomain), nil + } + return nil, err +} + +// UpdateDomainIsUp 设置是否在线 +func (this *DNSDomainDAO) UpdateDomainIsUp(tx *dbs.Tx, domainId int64, isUp bool) error { + return this.Query(tx). + Pk(domainId). + Set("isUp", isUp). + UpdateQuickly() +} diff --git a/internal/db/models/dns/dns_domain_model.go b/internal/db/models/dns/dns_domain_model.go index cfd5cbe1..f4519f4d 100644 --- a/internal/db/models/dns/dns_domain_model.go +++ b/internal/db/models/dns/dns_domain_model.go @@ -1,6 +1,6 @@ package dns -// 管理的域名 +// DNSDomain 管理的域名 type DNSDomain struct { Id uint32 `field:"id"` // ID AdminId uint32 `field:"adminId"` // 管理员ID @@ -14,6 +14,7 @@ type DNSDomain struct { Data string `field:"data"` // 原始数据信息 Records string `field:"records"` // 所有解析记录 Routes string `field:"routes"` // 线路数据 + IsUp uint8 `field:"isUp"` // 是否在线 State uint8 `field:"state"` // 状态 } @@ -30,6 +31,7 @@ type DNSDomainOperator struct { Data interface{} // 原始数据信息 Records interface{} // 所有解析记录 Routes interface{} // 线路数据 + IsUp interface{} // 是否在线 State interface{} // 状态 } diff --git a/internal/dnsclients/provider_alidns.go b/internal/dnsclients/provider_alidns.go index 77e67c80..b1b09cc5 100644 --- a/internal/dnsclients/provider_alidns.go +++ b/internal/dnsclients/provider_alidns.go @@ -31,6 +31,34 @@ func (this *AliDNSProvider) Auth(params maps.Map) error { return nil } +// GetDomains 获取所有域名列表 +func (this *AliDNSProvider) GetDomains() (domains []string, err error) { + pageNumber := 1 + size := 100 + + for { + req := alidns.CreateDescribeDomainsRequest() + req.PageNumber = requests.NewInteger(pageNumber) + req.PageSize = requests.NewInteger(size) + resp := alidns.CreateDescribeDomainsResponse() + err = this.doAPI(req, resp) + if err != nil { + return nil, err + } + + for _, domain := range resp.Domains.Domain { + domains = append(domains, domain.DomainName) + } + + pageNumber++ + if int64((pageNumber-1)*size) >= resp.TotalCount { + break + } + } + + return +} + // GetRecords 获取域名列表 func (this *AliDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { pageNumber := 1 diff --git a/internal/dnsclients/provider_alidns_test.go b/internal/dnsclients/provider_alidns_test.go index cd6bd8ba..938542c7 100644 --- a/internal/dnsclients/provider_alidns_test.go +++ b/internal/dnsclients/provider_alidns_test.go @@ -11,6 +11,14 @@ import ( "testing" ) +func TestAliDNSProvider_GetDomains(t *testing.T) { + provider, err := testAliDNSProvider() + if err != nil { + t.Fatal(err) + } + t.Log(provider.GetDomains()) +} + func TestAliDNSProvider_GetRecords(t *testing.T) { provider, err := testAliDNSProvider() if err != nil { diff --git a/internal/dnsclients/provider_cloud_flare.go b/internal/dnsclients/provider_cloud_flare.go index f10c4dbe..898f9738 100644 --- a/internal/dnsclients/provider_cloud_flare.go +++ b/internal/dnsclients/provider_cloud_flare.go @@ -59,6 +59,21 @@ func (this *CloudFlareProvider) Auth(params maps.Map) error { return nil } +// GetDomains 获取所有域名列表 +func (this *CloudFlareProvider) GetDomains() (domains []string, err error) { + resp := new(cloudflare.ZonesResponse) + err = this.doAPI(http.MethodGet, "zones", map[string]string{}, nil, resp) + if err != nil { + return nil, err + } + + for _, zone := range resp.Result { + domains = append(domains, zone.Name) + } + + return +} + // GetRecords 获取域名解析记录列表 func (this *CloudFlareProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { zoneId, err := this.findZoneIdWithDomain(domain) diff --git a/internal/dnsclients/provider_cloud_flare_test.go b/internal/dnsclients/provider_cloud_flare_test.go index 1e1f11e7..50ef5d2c 100644 --- a/internal/dnsclients/provider_cloud_flare_test.go +++ b/internal/dnsclients/provider_cloud_flare_test.go @@ -12,6 +12,14 @@ import ( "testing" ) +func TestCloudFlareProvider_GetDomains(t *testing.T) { + provider, err := testCloudFlareProvider() + if err != nil { + t.Fatal(err) + } + t.Log(provider.GetDomains()) +} + func TestCloudFlareProvider_GetRecords(t *testing.T) { provider, err := testCloudFlareProvider() if err != nil { diff --git a/internal/dnsclients/provider_custom_http.go b/internal/dnsclients/provider_custom_http.go index 9821cc42..ccff06ba 100644 --- a/internal/dnsclients/provider_custom_http.go +++ b/internal/dnsclients/provider_custom_http.go @@ -49,6 +49,16 @@ func (this *CustomHTTPProvider) Auth(params maps.Map) error { return nil } +// GetDomains 获取所有域名列表 +func (this *CustomHTTPProvider) GetDomains() (domains []string, err error) { + resp, err := this.post(maps.Map{}) + if err != nil { + return nil, err + } + err = json.Unmarshal(resp, &domains) + return +} + // GetRecords 获取域名解析记录列表 func (this *CustomHTTPProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { resp, err := this.post(maps.Map{ diff --git a/internal/dnsclients/provider_custom_http_test.go b/internal/dnsclients/provider_custom_http_test.go index 7a7e62af..2efd0a39 100644 --- a/internal/dnsclients/provider_custom_http_test.go +++ b/internal/dnsclients/provider_custom_http_test.go @@ -7,6 +7,22 @@ import ( "testing" ) +func TestCustomHTTPProvider_GetDomains(t *testing.T) { + provider := CustomHTTPProvider{} + err := provider.Auth(maps.Map{ + "url": "http://127.0.0.1:2345/dns", + "secret": "123456", + }) + if err != nil { + t.Fatal(err) + } + domains, err := provider.GetDomains() + if err != nil { + t.Fatal(err) + } + t.Log(domains) +} + func TestCustomHTTPProvider_AddRecord(t *testing.T) { provider := CustomHTTPProvider{} err := provider.Auth(maps.Map{ diff --git a/internal/dnsclients/provider_dnspod.go b/internal/dnsclients/provider_dnspod.go index abe231f9..06ee5bb0 100644 --- a/internal/dnsclients/provider_dnspod.go +++ b/internal/dnsclients/provider_dnspod.go @@ -35,6 +35,40 @@ func (this *DNSPodProvider) Auth(params maps.Map) error { return nil } +// GetDomains 获取所有域名列表 +func (this *DNSPodProvider) GetDomains() (domains []string, err error) { + offset := 0 + size := 100 + for { + domainsResp, err := this.post("/Domain.list", map[string]string{ + "offset": numberutils.FormatInt(offset), + "length": numberutils.FormatInt(size), + }) + if err != nil { + return nil, err + } + offset += size + + domainsSlice := domainsResp.GetSlice("domains") + if len(domainsSlice) == 0 { + break + } + + for _, domain := range domainsSlice { + domainMap := maps.NewMap(domain) + domains = append(domains, domainMap.GetString("name")) + } + + // 检查是否到头 + info := domainsResp.GetMap("info") + recordTotal := info.GetInt("record_total") + if offset >= recordTotal { + break + } + } + return +} + // GetRecords 获取域名列表 func (this *DNSPodProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { offset := 0 diff --git a/internal/dnsclients/provider_dnspod_test.go b/internal/dnsclients/provider_dnspod_test.go index 33f5a246..3e415a5e 100644 --- a/internal/dnsclients/provider_dnspod_test.go +++ b/internal/dnsclients/provider_dnspod_test.go @@ -9,6 +9,18 @@ import ( "testing" ) +func TestDNSPodProvider_GetDomains(t *testing.T) { + provider, err := testDNSPodProvider() + if err != nil { + t.Fatal(err) + } + domains, err := provider.GetDomains() + if err != nil { + t.Fatal(err) + } + t.Log(domains) +} + func TestDNSPodProvider_GetRoutes(t *testing.T) { provider, err := testDNSPodProvider() if err != nil { diff --git a/internal/dnsclients/provider_huawei_dns.go b/internal/dnsclients/provider_huawei_dns.go index ba9d1105..e914c19d 100644 --- a/internal/dnsclients/provider_huawei_dns.go +++ b/internal/dnsclients/provider_huawei_dns.go @@ -55,6 +55,21 @@ func (this *HuaweiDNSProvider) Auth(params maps.Map) error { return nil } +// GetDomains 获取所有域名列表 +func (this *HuaweiDNSProvider) GetDomains() (domains []string, err error) { + var resp = new(huaweidns.ZonesResponse) + err = this.doAPI(http.MethodGet, "/v2/zones", map[string]string{}, nil, resp) + if err != nil { + return nil, err + } + + for _, zone := range resp.Zones { + zone.Name = strings.TrimSuffix(zone.Name, ".") + domains = append(domains, zone.Name) + } + return +} + // GetRecords 获取域名解析记录列表 func (this *HuaweiDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { zoneId, err := this.findZoneIdWithDomain(domain) diff --git a/internal/dnsclients/provider_huawei_dns_test.go b/internal/dnsclients/provider_huawei_dns_test.go index 1e11707f..11ce71cf 100644 --- a/internal/dnsclients/provider_huawei_dns_test.go +++ b/internal/dnsclients/provider_huawei_dns_test.go @@ -11,6 +11,18 @@ import ( "testing" ) +func TestHuaweiDNSProvider_GetDomains(t *testing.T) { + provider, err := testHuaweiDNSProvider() + if err != nil { + t.Fatal(err) + } + domains, err := provider.GetDomains() + if err != nil { + t.Fatal(err) + } + t.Log("domains:", domains) +} + func TestHuaweiDNSProvider_GetRecords(t *testing.T) { provider, err := testHuaweiDNSProvider() if err != nil { diff --git a/internal/dnsclients/provider_interface.go b/internal/dnsclients/provider_interface.go index 8700660c..cd3dd438 100644 --- a/internal/dnsclients/provider_interface.go +++ b/internal/dnsclients/provider_interface.go @@ -10,6 +10,9 @@ type ProviderInterface interface { // Auth 认证 Auth(params maps.Map) error + // GetDomains 获取所有域名列表 + GetDomains() (domains []string, err error) + // GetRecords 获取域名解析记录列表 GetRecords(domain string) (records []*dnstypes.Record, err error) diff --git a/internal/dnsclients/provider_local_edge_dns.go b/internal/dnsclients/provider_local_edge_dns.go index b3829463..e471fc69 100644 --- a/internal/dnsclients/provider_local_edge_dns.go +++ b/internal/dnsclients/provider_local_edge_dns.go @@ -34,6 +34,19 @@ func (this *LocalEdgeDNSProvider) Auth(params maps.Map) error { return nil } +// GetDomains 获取所有域名列表 +func (this *LocalEdgeDNSProvider) GetDomains() (domains []string, err error) { + var tx *dbs.Tx + domainOnes, err := nameservers.SharedNSDomainDAO.ListEnabledDomains(tx, this.clusterId, 0, "", 0, 1000) + if err != nil { + return nil, err + } + for _, domain := range domainOnes { + domains = append(domains, domain.Name) + } + return +} + // GetRecords 获取域名解析记录列表 func (this *LocalEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { var tx *dbs.Tx diff --git a/internal/dnsclients/provider_local_edge_dns_test.go b/internal/dnsclients/provider_local_edge_dns_test.go index b8c1a292..ed6575b9 100644 --- a/internal/dnsclients/provider_local_edge_dns_test.go +++ b/internal/dnsclients/provider_local_edge_dns_test.go @@ -13,6 +13,24 @@ import ( const testClusterId = 7 +func TestLocalEdgeDNSProvider_GetDomains(t *testing.T) { + dbs.NotifyReady() + + provider := &dnsclients.LocalEdgeDNSProvider{} + err := provider.Auth(maps.Map{ + "clusterId": testClusterId, + }) + if err != nil { + t.Fatal(err) + } + + domains, err := provider.GetDomains() + if err != nil { + t.Fatal(err) + } + t.Log("domains:", domains) +} + func TestLocalEdgeDNSProvider_GetRecords(t *testing.T) { dbs.NotifyReady() diff --git a/internal/dnsclients/provider_user_edge_dns.go b/internal/dnsclients/provider_user_edge_dns.go index 364e477f..79a0ad82 100644 --- a/internal/dnsclients/provider_user_edge_dns.go +++ b/internal/dnsclients/provider_user_edge_dns.go @@ -16,6 +16,12 @@ func (this *UserEdgeDNSProvider) Auth(params maps.Map) error { return nil } +// GetDomains 获取所有域名列表 +func (this *UserEdgeDNSProvider) GetDomains() (domains []string, err error) { + // TODO + return +} + // GetRecords 获取域名解析记录列表 func (this *UserEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { // TODO diff --git a/internal/rpc/services/service_dns_domain.go b/internal/rpc/services/service_dns_domain.go index fbe3034a..b2b83dab 100644 --- a/internal/rpc/services/service_dns_domain.go +++ b/internal/rpc/services/service_dns_domain.go @@ -364,6 +364,7 @@ func (this *DNSDomainService) convertDomainToPB(domain *dns.DNSDomain) (*pb.DNSD ProviderId: int64(domain.ProviderId), Name: domain.Name, IsOn: domain.IsOn == 1, + IsUp: domain.IsUp == 1, DataUpdatedAt: int64(domain.DataUpdatedAt), CountNodeRecords: int64(countNodeRecords), NodesChanged: nodesChanged, @@ -693,7 +694,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( }, nil } -// 检查域名是否在记录中 +// ExistDNSDomainRecord 检查域名是否在记录中 func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb.ExistDNSDomainRecordRequest) (*pb.ExistDNSDomainRecordResponse, error) { _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { @@ -708,3 +709,84 @@ func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb. } return &pb.ExistDNSDomainRecordResponse{IsOk: isOk}, nil } + +// SyncDNSDomainsFromProvider 从服务商同步域名 +func (this *DNSDomainService) SyncDNSDomainsFromProvider(ctx context.Context, req *pb.SyncDNSDomainsFromProviderRequest) (*pb.SyncDNSDomainsFromProviderResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + tx := this.NullTx() + provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId) + if err != nil { + return nil, err + } + if provider == nil { + return nil, errors.New("can not find provider") + } + + // 下线不存在的域名 + oldDomains, err := dns.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(tx, req.DnsProviderId) + if err != nil { + return nil, err + } + + dnsProvider := dnsclients.FindProvider(provider.Type) + if dnsProvider == nil { + return nil, errors.New("provider type '" + provider.Type + "' is not supported yet") + } + + params, err := provider.DecodeAPIParams() + if err != nil { + return nil, errors.New("decode params failed: " + err.Error()) + } + err = dnsProvider.Auth(params) + if err != nil { + return nil, errors.New("auth failed: " + err.Error()) + } + + domainNames, err := dnsProvider.GetDomains() + if err != nil { + return nil, err + } + + var hasChanges = false + + // 创建或上线域名 + for _, domainName := range domainNames { + domain, err := dns.SharedDNSDomainDAO.FindEnabledDomainWithName(tx, req.DnsProviderId, domainName) + if err != nil { + return nil, err + } + if domain == nil { + _, err = dns.SharedDNSDomainDAO.CreateDomain(tx, 0, 0, req.DnsProviderId, domainName) + if err != nil { + return nil, err + } + hasChanges = true + } else if domain.IsUp == 0 { + err = dns.SharedDNSDomainDAO.UpdateDomainIsUp(tx, int64(domain.Id), true) + if err != nil { + return nil, err + } + hasChanges = true + } + } + + // 将老的域名置为下线 + for _, oldDomain := range oldDomains { + var domainName = oldDomain.Name + if oldDomain.IsUp == 1 && !lists.ContainsString(domainNames, domainName) { + err = dns.SharedDNSDomainDAO.UpdateDomainIsUp(tx, int64(oldDomain.Id), false) + if err != nil { + return nil, err + } + hasChanges = true + } + } + + return &pb.SyncDNSDomainsFromProviderResponse{ + HasChanges: hasChanges, + }, nil +} diff --git a/internal/rpc/services/service_dns_provider.go b/internal/rpc/services/service_dns_provider.go index e4a66b9e..67198b7f 100644 --- a/internal/rpc/services/service_dns_provider.go +++ b/internal/rpc/services/service_dns_provider.go @@ -26,7 +26,7 @@ func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.C if err != nil { return nil, err } - + return &pb.CreateDNSProviderResponse{DnsProviderId: providerId}, nil }