支持自定义HTTP DNS

This commit is contained in:
GoEdgeLab
2021-01-28 11:29:57 +08:00
parent a73a04521d
commit 8eb6538c44
3 changed files with 225 additions and 3 deletions

View File

@@ -0,0 +1,170 @@
package dnsclients
import (
"bytes"
"crypto/sha1"
"crypto/tls"
"encoding/json"
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/iwind/TeaGo/maps"
"io/ioutil"
"net/http"
"strconv"
"time"
)
var customHTTPClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
// HTTP自定义DNS
type CustomHTTPProvider struct {
url string
secret string
}
// 认证
// 参数:
// - url
// - secret
func (this *CustomHTTPProvider) Auth(params maps.Map) error {
this.url = params.GetString("url")
if len(this.url) == 0 {
return errors.New("'url' should not be empty")
}
this.secret = params.GetString("secret")
if len(this.secret) == 0 {
return errors.New("'secret' should not be empty")
}
return nil
}
// 获取域名解析记录列表
func (this *CustomHTTPProvider) GetRecords(domain string) (records []*Record, err error) {
resp, err := this.post(maps.Map{
"action": "GetRecords",
"domain": domain,
})
if err != nil {
return nil, err
}
err = json.Unmarshal(resp, &records)
return
}
// 读取域名支持的线路数据
func (this *CustomHTTPProvider) GetRoutes(domain string) (routes []*Route, err error) {
resp, err := this.post(maps.Map{
"action": "GetRoutes",
"domain": domain,
})
if err != nil {
return nil, err
}
err = json.Unmarshal(resp, &routes)
return
}
// 查询单个记录
func (this *CustomHTTPProvider) QueryRecord(domain string, name string, recordType RecordType) (*Record, error) {
resp, err := this.post(maps.Map{
"action": "QueryRecord",
"domain": domain,
"name": name,
"recordType": recordType,
})
if err != nil {
return nil, err
}
if len(resp) == 0 {
return nil, nil
}
record := &Record{}
err = json.Unmarshal(resp, record)
if err != nil {
return nil, err
}
if len(record.Value) == 0 {
return nil, nil
}
return record, nil
}
// 设置记录
func (this *CustomHTTPProvider) AddRecord(domain string, newRecord *Record) error {
_, err := this.post(maps.Map{
"action": "AddRecord",
"domain": domain,
"newRecord": newRecord,
})
return err
}
// 修改记录
func (this *CustomHTTPProvider) UpdateRecord(domain string, record *Record, newRecord *Record) error {
_, err := this.post(maps.Map{
"action": "UpdateRecord",
"domain": domain,
"record": record,
"newRecord": newRecord,
})
return err
}
// 删除记录
func (this *CustomHTTPProvider) DeleteRecord(domain string, record *Record) error {
_, err := this.post(maps.Map{
"action": "DeleteRecord",
"domain": domain,
"record": record,
})
return err
}
// 默认线路
func (this *CustomHTTPProvider) DefaultRoute() string {
resp, err := this.post(maps.Map{
"action": "DefaultRoute",
})
if err != nil {
return ""
}
return string(resp)
}
// 执行操作
func (this *CustomHTTPProvider) post(params maps.Map) (respData []byte, err error) {
data, err := json.Marshal(params)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, this.url, bytes.NewReader(data))
if err != nil {
return nil, err
}
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
req.Header.Set("Timestamp", timestamp)
req.Header.Set("Token", fmt.Sprintf("%x", sha1.Sum([]byte(this.secret+"@"+timestamp))))
req.Header.Set("User-Agent", "GoEdge/"+teaconst.Version)
resp, err := customHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != 200 {
return nil, errors.New("status should be 200, but got '" + strconv.Itoa(resp.StatusCode) + "'")
}
return ioutil.ReadAll(resp.Body)
}

View File

@@ -0,0 +1,45 @@
package dnsclients
import (
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"testing"
)
func TestCustomHTTPProvider_AddRecord(t *testing.T) {
provider := CustomHTTPProvider{}
err := provider.Auth(maps.Map{
"url": "http://127.0.0.1:1234/dns",
"secret": "123456",
})
if err != nil {
t.Fatal(err)
}
err = provider.AddRecord("hello.com", &Record{
Id: "",
Name: "world",
Type: RecordTypeA,
Value: "127.0.0.1",
Route: "default",
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestCustomHTTPProvider_GetRecords(t *testing.T) {
provider := CustomHTTPProvider{}
err := provider.Auth(maps.Map{
"url": "http://127.0.0.1:1234/dns",
"secret": "123456",
})
if err != nil {
t.Fatal(err)
}
records, err := provider.GetRecords("hello.com")
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(records, t)
}

View File

@@ -6,9 +6,10 @@ type ProviderType = string
// 服务商代号 // 服务商代号
const ( const (
ProviderTypeDNSPod ProviderType = "dnspod" ProviderTypeDNSPod ProviderType = "dnspod"
ProviderTypeAliDNS ProviderType = "alidns" ProviderTypeAliDNS ProviderType = "alidns"
ProviderTypeDNSCom ProviderType = "dnscom" ProviderTypeDNSCom ProviderType = "dnscom"
ProviderTypeCustomHTTP ProviderType = "customHTTP"
) )
// 所有的服务商类型 // 所有的服务商类型
@@ -25,6 +26,10 @@ var AllProviderTypes = []maps.Map{
"name": "帝恩思DNS.COM", "name": "帝恩思DNS.COM",
"code": ProviderTypeDNSCom, "code": ProviderTypeDNSCom,
},**/ },**/
{
"name": "自定义HTTP DNS",
"code": ProviderTypeCustomHTTP,
},
} }
// 查找服务商实例 // 查找服务商实例
@@ -34,6 +39,8 @@ func FindProvider(providerType ProviderType) ProviderInterface {
return &DNSPodProvider{} return &DNSPodProvider{}
case ProviderTypeAliDNS: case ProviderTypeAliDNS:
return &AliDNSProvider{} return &AliDNSProvider{}
case ProviderTypeCustomHTTP:
return &CustomHTTPProvider{}
} }
return nil return nil
} }