mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-01 21:30:27 +08:00
299 lines
7.7 KiB
Go
299 lines
7.7 KiB
Go
//go:build plus
|
|
|
|
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
|
|
|
package dnsclients
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
|
|
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
|
|
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
|
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
|
"github.com/iwind/TeaGo/dbs"
|
|
"github.com/iwind/TeaGo/maps"
|
|
"github.com/iwind/TeaGo/types"
|
|
"strings"
|
|
)
|
|
|
|
type LocalEdgeDNSProvider struct {
|
|
clusterId int64 // 集群ID
|
|
ttl int32 // TTL
|
|
|
|
BaseProvider
|
|
}
|
|
|
|
// Auth 认证
|
|
func (this *LocalEdgeDNSProvider) Auth(params maps.Map) error {
|
|
this.clusterId = params.GetInt64("clusterId")
|
|
if this.clusterId <= 0 {
|
|
return errors.New("'clusterId' should be greater than 0")
|
|
}
|
|
|
|
this.ttl = params.GetInt32("ttl")
|
|
if this.ttl <= 0 {
|
|
this.ttl = 3600
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetDomains 获取所有域名列表
|
|
func (this *LocalEdgeDNSProvider) GetDomains() (domains []string, err error) {
|
|
var tx *dbs.Tx
|
|
domainOnes, err := nameservers.SharedNSDomainDAO.ListEnabledDomains(tx, this.clusterId, 0, 0, "", 0, 10000)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, domain := range domainOnes {
|
|
domains = append(domains, domain.Name)
|
|
}
|
|
return
|
|
}
|
|
|
|
// GetRecords 获取域名解析记录列表
|
|
func (this *LocalEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
|
|
var tx *dbs.Tx
|
|
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, this.clusterId, domain)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if domainId == 0 {
|
|
return nil, errors.New("can not find domain '" + domain + "'")
|
|
}
|
|
|
|
offset := int64(0)
|
|
size := int64(1000)
|
|
for {
|
|
result, err := nameservers.SharedNSRecordDAO.ListEnabledRecords(tx, domainId, "", "", "", offset, size)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(result) == 0 {
|
|
break
|
|
}
|
|
for _, record := range result {
|
|
if record.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(record.Value, ".") {
|
|
record.Value += "."
|
|
}
|
|
|
|
routeIds := record.DecodeRouteIds()
|
|
if len(routeIds) == 0 {
|
|
routeIds = []string{dnsconfigs.DefaultRouteCode}
|
|
}
|
|
records = append(records, &dnstypes.Record{
|
|
Id: fmt.Sprintf("%d", record.Id),
|
|
Name: record.Name,
|
|
Type: record.Type,
|
|
Value: record.Value,
|
|
Route: routeIds[0],
|
|
TTL: types.Int32(record.Ttl),
|
|
})
|
|
}
|
|
|
|
offset += size
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// GetRoutes 读取域名支持的线路数据
|
|
func (this *LocalEdgeDNSProvider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) {
|
|
var tx *dbs.Tx
|
|
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, this.clusterId, domain)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if domainId == 0 {
|
|
return nil, errors.New("can not find domain '" + domain + "'")
|
|
}
|
|
|
|
// 默认线路
|
|
for _, route := range dnsconfigs.AllDefaultRoutes {
|
|
routes = append(routes, &dnstypes.Route{
|
|
Name: route.Name,
|
|
Code: route.Code,
|
|
})
|
|
}
|
|
|
|
// 自定义线路
|
|
result, err := nameservers.SharedNSRouteDAO.FindAllEnabledRoutes(tx, 0, 0, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, route := range result {
|
|
routes = append(routes, &dnstypes.Route{
|
|
Name: route.Name,
|
|
Code: "id:" + types.String(route.Id),
|
|
})
|
|
}
|
|
|
|
// 默认ISP
|
|
for _, route := range dnsconfigs.AllDefaultISPRoutes {
|
|
routes = append(routes, &dnstypes.Route{
|
|
Name: route.Name,
|
|
Code: route.Code,
|
|
})
|
|
}
|
|
|
|
// 默认中国省份
|
|
for _, route := range dnsconfigs.AllDefaultChinaProvinceRoutes {
|
|
routes = append(routes, &dnstypes.Route{
|
|
Name: route.Name,
|
|
Code: route.Code,
|
|
})
|
|
}
|
|
|
|
// 默认全球国家/地区
|
|
for _, route := range dnsconfigs.AllDefaultWorldRegionRoutes {
|
|
routes = append(routes, &dnstypes.Route{
|
|
Name: route.Name,
|
|
Code: route.Code,
|
|
})
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// QueryRecord 查询单个记录
|
|
func (this *LocalEdgeDNSProvider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) {
|
|
var tx *dbs.Tx
|
|
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, this.clusterId, domain)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if domainId == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
record, err := nameservers.SharedNSRecordDAO.FindEnabledRecordWithName(tx, domainId, name, recordType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if record == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
routeIds := record.DecodeRouteIds()
|
|
var routeIdString = ""
|
|
if len(routeIds) > 0 {
|
|
routeIdString = routeIds[0]
|
|
} else {
|
|
routeIdString = dnsconfigs.DefaultRouteCode
|
|
}
|
|
|
|
return &dnstypes.Record{
|
|
Id: fmt.Sprintf("%d", record.Id),
|
|
Name: record.Name,
|
|
Type: record.Type,
|
|
Value: record.Value,
|
|
Route: routeIdString,
|
|
TTL: types.Int32(record.Ttl),
|
|
}, nil
|
|
}
|
|
|
|
// AddRecord 设置记录
|
|
func (this *LocalEdgeDNSProvider) AddRecord(domain string, newRecord *dnstypes.Record) error {
|
|
var tx *dbs.Tx
|
|
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, this.clusterId, domain)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, newRecord)
|
|
}
|
|
if domainId == 0 {
|
|
return this.WrapError(errors.New("can not find domain '"+domain+"'"), domain, newRecord)
|
|
}
|
|
|
|
var routeIds = []string{}
|
|
if len(newRecord.Route) > 0 {
|
|
routeIds = append(routeIds, newRecord.Route)
|
|
}
|
|
|
|
if newRecord.TTL <= 0 {
|
|
newRecord.TTL = this.ttl
|
|
}
|
|
_, err = nameservers.SharedNSRecordDAO.CreateRecord(tx, domainId, "", newRecord.Name, newRecord.Type, newRecord.Value, newRecord.TTL, routeIds)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, newRecord)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateRecord 修改记录
|
|
func (this *LocalEdgeDNSProvider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error {
|
|
var tx *dbs.Tx
|
|
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, this.clusterId, domain)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, newRecord)
|
|
}
|
|
if domainId == 0 {
|
|
return errors.New("can not find domain '" + domain + "'")
|
|
}
|
|
|
|
var routeIds []string
|
|
if len(newRecord.Route) > 0 {
|
|
routeIds = append(routeIds, newRecord.Route)
|
|
}
|
|
|
|
if newRecord.TTL <= 0 {
|
|
newRecord.TTL = this.ttl
|
|
}
|
|
|
|
if len(record.Id) > 0 {
|
|
err = nameservers.SharedNSRecordDAO.UpdateRecord(tx, types.Int64(record.Id), "", newRecord.Name, newRecord.Type, newRecord.Value, newRecord.TTL, routeIds, true)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, newRecord)
|
|
}
|
|
} else {
|
|
realRecord, err := nameservers.SharedNSRecordDAO.FindEnabledRecordWithName(tx, domainId, record.Name, record.Type)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, newRecord)
|
|
}
|
|
if realRecord != nil {
|
|
err = nameservers.SharedNSRecordDAO.UpdateRecord(tx, types.Int64(realRecord.Id), "", newRecord.Name, newRecord.Type, newRecord.Value, newRecord.TTL, routeIds, true)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, newRecord)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteRecord 删除记录
|
|
func (this *LocalEdgeDNSProvider) DeleteRecord(domain string, record *dnstypes.Record) error {
|
|
var tx *dbs.Tx
|
|
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, this.clusterId, domain)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, record)
|
|
}
|
|
if domainId == 0 {
|
|
return errors.New("can not find domain '" + domain + "'")
|
|
}
|
|
|
|
if len(record.Id) > 0 {
|
|
err = nameservers.SharedNSRecordDAO.DisableNSRecord(tx, types.Int64(record.Id))
|
|
if err != nil {
|
|
return this.WrapError(err, domain, record)
|
|
}
|
|
} else {
|
|
realRecord, err := nameservers.SharedNSRecordDAO.FindEnabledRecordWithName(tx, domainId, record.Name, record.Type)
|
|
if err != nil {
|
|
return this.WrapError(err, domain, record)
|
|
}
|
|
if realRecord != nil {
|
|
err = nameservers.SharedNSRecordDAO.DisableNSRecord(tx, types.Int64(realRecord.Id))
|
|
if err != nil {
|
|
return this.WrapError(err, domain, record)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DefaultRoute 默认线路
|
|
func (this *LocalEdgeDNSProvider) DefaultRoute() string {
|
|
return "default"
|
|
}
|