diff --git a/internal/dnsclients/provider_custom_http.go b/internal/dnsclients/provider_custom_http.go new file mode 100644 index 00000000..ef40e28d --- /dev/null +++ b/internal/dnsclients/provider_custom_http.go @@ -0,0 +1,170 @@ +package dnsclients + +import ( + "bytes" + "crypto/sha1" + "crypto/tls" + "encoding/json" + "fmt" + teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" + "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/iwind/TeaGo/maps" + "io/ioutil" + "net/http" + "strconv" + "time" +) + +var customHTTPClient = &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, +} + +// HTTP自定义DNS +type CustomHTTPProvider struct { + url string + secret string +} + +// 认证 +// 参数: +// - url +// - secret +func (this *CustomHTTPProvider) Auth(params maps.Map) error { + this.url = params.GetString("url") + if len(this.url) == 0 { + return errors.New("'url' should not be empty") + } + + this.secret = params.GetString("secret") + if len(this.secret) == 0 { + return errors.New("'secret' should not be empty") + } + + return nil +} + +// 获取域名解析记录列表 +func (this *CustomHTTPProvider) GetRecords(domain string) (records []*Record, err error) { + resp, err := this.post(maps.Map{ + "action": "GetRecords", + "domain": domain, + }) + if err != nil { + return nil, err + } + err = json.Unmarshal(resp, &records) + return +} + +// 读取域名支持的线路数据 +func (this *CustomHTTPProvider) GetRoutes(domain string) (routes []*Route, err error) { + resp, err := this.post(maps.Map{ + "action": "GetRoutes", + "domain": domain, + }) + if err != nil { + return nil, err + } + err = json.Unmarshal(resp, &routes) + return +} + +// 查询单个记录 +func (this *CustomHTTPProvider) QueryRecord(domain string, name string, recordType RecordType) (*Record, error) { + resp, err := this.post(maps.Map{ + "action": "QueryRecord", + "domain": domain, + "name": name, + "recordType": recordType, + }) + if err != nil { + return nil, err + } + if len(resp) == 0 { + return nil, nil + } + record := &Record{} + err = json.Unmarshal(resp, record) + if err != nil { + return nil, err + } + if len(record.Value) == 0 { + return nil, nil + } + return record, nil +} + +// 设置记录 +func (this *CustomHTTPProvider) AddRecord(domain string, newRecord *Record) error { + _, err := this.post(maps.Map{ + "action": "AddRecord", + "domain": domain, + "newRecord": newRecord, + }) + return err +} + +// 修改记录 +func (this *CustomHTTPProvider) UpdateRecord(domain string, record *Record, newRecord *Record) error { + _, err := this.post(maps.Map{ + "action": "UpdateRecord", + "domain": domain, + "record": record, + "newRecord": newRecord, + }) + return err +} + +// 删除记录 +func (this *CustomHTTPProvider) DeleteRecord(domain string, record *Record) error { + _, err := this.post(maps.Map{ + "action": "DeleteRecord", + "domain": domain, + "record": record, + }) + return err +} + +// 默认线路 +func (this *CustomHTTPProvider) DefaultRoute() string { + resp, err := this.post(maps.Map{ + "action": "DefaultRoute", + }) + if err != nil { + return "" + } + return string(resp) +} + +// 执行操作 +func (this *CustomHTTPProvider) post(params maps.Map) (respData []byte, err error) { + data, err := json.Marshal(params) + if err != nil { + return nil, err + } + req, err := http.NewRequest(http.MethodPost, this.url, bytes.NewReader(data)) + if err != nil { + return nil, err + } + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + req.Header.Set("Timestamp", timestamp) + req.Header.Set("Token", fmt.Sprintf("%x", sha1.Sum([]byte(this.secret+"@"+timestamp)))) + req.Header.Set("User-Agent", "GoEdge/"+teaconst.Version) + + resp, err := customHTTPClient.Do(req) + if err != nil { + return nil, err + } + defer func() { + _ = resp.Body.Close() + }() + if resp.StatusCode != 200 { + return nil, errors.New("status should be 200, but got '" + strconv.Itoa(resp.StatusCode) + "'") + } + return ioutil.ReadAll(resp.Body) +} diff --git a/internal/dnsclients/provider_custom_http_test.go b/internal/dnsclients/provider_custom_http_test.go new file mode 100644 index 00000000..7d02bb17 --- /dev/null +++ b/internal/dnsclients/provider_custom_http_test.go @@ -0,0 +1,45 @@ +package dnsclients + +import ( + "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/maps" + "testing" +) + +func TestCustomHTTPProvider_AddRecord(t *testing.T) { + provider := CustomHTTPProvider{} + err := provider.Auth(maps.Map{ + "url": "http://127.0.0.1:1234/dns", + "secret": "123456", + }) + if err != nil { + t.Fatal(err) + } + err = provider.AddRecord("hello.com", &Record{ + Id: "", + Name: "world", + Type: RecordTypeA, + Value: "127.0.0.1", + Route: "default", + }) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func TestCustomHTTPProvider_GetRecords(t *testing.T) { + provider := CustomHTTPProvider{} + err := provider.Auth(maps.Map{ + "url": "http://127.0.0.1:1234/dns", + "secret": "123456", + }) + if err != nil { + t.Fatal(err) + } + records, err := provider.GetRecords("hello.com") + if err != nil { + t.Fatal(err) + } + logs.PrintAsJSON(records, t) +} diff --git a/internal/dnsclients/types.go b/internal/dnsclients/types.go index 1a342ab1..da445a3f 100644 --- a/internal/dnsclients/types.go +++ b/internal/dnsclients/types.go @@ -6,9 +6,10 @@ type ProviderType = string // 服务商代号 const ( - ProviderTypeDNSPod ProviderType = "dnspod" - ProviderTypeAliDNS ProviderType = "alidns" - ProviderTypeDNSCom ProviderType = "dnscom" + ProviderTypeDNSPod ProviderType = "dnspod" + ProviderTypeAliDNS ProviderType = "alidns" + ProviderTypeDNSCom ProviderType = "dnscom" + ProviderTypeCustomHTTP ProviderType = "customHTTP" ) // 所有的服务商类型 @@ -25,6 +26,10 @@ var AllProviderTypes = []maps.Map{ "name": "帝恩思DNS.COM", "code": ProviderTypeDNSCom, },**/ + { + "name": "自定义HTTP DNS", + "code": ProviderTypeCustomHTTP, + }, } // 查找服务商实例 @@ -34,6 +39,8 @@ func FindProvider(providerType ProviderType) ProviderInterface { return &DNSPodProvider{} case ProviderTypeAliDNS: return &AliDNSProvider{} + case ProviderTypeCustomHTTP: + return &CustomHTTPProvider{} } return nil }