mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2026-01-01 02:56:34 +08:00
改进DNS域名解析相关函数
This commit is contained in:
@@ -3,20 +3,27 @@ package utils
|
||||
import (
|
||||
"errors"
|
||||
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils/taskutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/miekg/dns"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var sharedDNSClient *dns.Client
|
||||
var sharedDNSConfig *dns.ClientConfig
|
||||
var sharedDNSLocker = &sync.RWMutex{}
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
|
||||
var resolvConfFile = "/etc/resolv.conf"
|
||||
config, err := dns.ClientConfigFromFile(resolvConfFile)
|
||||
if err != nil {
|
||||
logs.Println("ERROR: configure dns client failed: " + err.Error())
|
||||
return
|
||||
@@ -25,6 +32,21 @@ func init() {
|
||||
sharedDNSConfig = config
|
||||
sharedDNSClient = &dns.Client{}
|
||||
|
||||
// 监视文件变化,以便及时更新配置
|
||||
go func() {
|
||||
watcher, watcherErr := fsnotify.NewWatcher()
|
||||
if watcherErr == nil {
|
||||
err = watcher.Add(resolvConfFile)
|
||||
for range watcher.Events {
|
||||
newConfig, err := dns.ClientConfigFromFile(resolvConfFile)
|
||||
if err == nil && newConfig != nil {
|
||||
sharedDNSLocker.Lock()
|
||||
sharedDNSConfig = newConfig
|
||||
sharedDNSLocker.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// LookupCNAME 查询CNAME记录
|
||||
@@ -40,8 +62,10 @@ func LookupCNAME(host string) (string, error) {
|
||||
m.RecursionDesired = true
|
||||
|
||||
var lastErr error
|
||||
for _, serverAddr := range sharedDNSConfig.Servers {
|
||||
r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
|
||||
var serverAddrs = composeDNSResolverAddrs(nil)
|
||||
|
||||
for _, serverAddr := range serverAddrs {
|
||||
r, _, err := sharedDNSClient.Exchange(m, serverAddr)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
@@ -56,8 +80,7 @@ func LookupCNAME(host string) (string, error) {
|
||||
}
|
||||
|
||||
// LookupNS 查询NS记录
|
||||
// TODO 可以设置使用的DNS主机地址
|
||||
func LookupNS(host string) ([]string, error) {
|
||||
func LookupNS(host string, extraResolvers []*dnsconfigs.DNSResolver) ([]string, error) {
|
||||
var m = new(dns.Msg)
|
||||
|
||||
m.SetQuestion(host+".", dns.TypeNS)
|
||||
@@ -67,23 +90,36 @@ func LookupNS(host string) ([]string, error) {
|
||||
|
||||
var lastErr error
|
||||
var hasValidServer = false
|
||||
for _, serverAddr := range sharedDNSConfig.Servers {
|
||||
r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
|
||||
var serverAddrs = composeDNSResolverAddrs(extraResolvers)
|
||||
if len(serverAddrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
taskErr := taskutils.RunConcurrent(serverAddrs, taskutils.DefaultConcurrent, func(task any, locker *sync.RWMutex) {
|
||||
var serverAddr = task.(string)
|
||||
r, _, err := sharedDNSClient.Exchange(m, serverAddr)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
hasValidServer = true
|
||||
|
||||
if len(r.Answer) == 0 {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
for _, answer := range r.Answer {
|
||||
result = append(result, answer.(*dns.NS).Ns)
|
||||
var value = answer.(*dns.NS).Ns
|
||||
locker.Lock()
|
||||
if len(value) > 0 && !lists.ContainsString(result, value) {
|
||||
result = append(result, value)
|
||||
}
|
||||
locker.Unlock()
|
||||
}
|
||||
break
|
||||
})
|
||||
if taskErr != nil {
|
||||
return result, taskErr
|
||||
}
|
||||
|
||||
if hasValidServer {
|
||||
@@ -94,8 +130,7 @@ func LookupNS(host string) ([]string, error) {
|
||||
}
|
||||
|
||||
// LookupTXT 获取CNAME
|
||||
// TODO 可以设置使用的DNS主机地址
|
||||
func LookupTXT(host string) ([]string, error) {
|
||||
func LookupTXT(host string, extraResolvers []*dnsconfigs.DNSResolver) ([]string, error) {
|
||||
var m = new(dns.Msg)
|
||||
|
||||
m.SetQuestion(host+".", dns.TypeTXT)
|
||||
@@ -104,23 +139,36 @@ func LookupTXT(host string) ([]string, error) {
|
||||
var lastErr error
|
||||
var result = []string{}
|
||||
var hasValidServer = false
|
||||
for _, serverAddr := range sharedDNSConfig.Servers {
|
||||
r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
|
||||
var serverAddrs = composeDNSResolverAddrs(extraResolvers)
|
||||
if len(serverAddrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
taskErr := taskutils.RunConcurrent(serverAddrs, taskutils.DefaultConcurrent, func(task any, locker *sync.RWMutex) {
|
||||
var serverAddr = task.(string)
|
||||
r, _, err := sharedDNSClient.Exchange(m, serverAddr)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
return
|
||||
}
|
||||
hasValidServer = true
|
||||
|
||||
if len(r.Answer) == 0 {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
for _, answer := range r.Answer {
|
||||
result = append(result, answer.(*dns.TXT).Txt...)
|
||||
for _, txt := range answer.(*dns.TXT).Txt {
|
||||
locker.Lock()
|
||||
if len(txt) > 0 && !lists.ContainsString(result, txt) {
|
||||
result = append(result, txt)
|
||||
}
|
||||
locker.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
break
|
||||
})
|
||||
if taskErr != nil {
|
||||
return result, taskErr
|
||||
}
|
||||
|
||||
if hasValidServer {
|
||||
@@ -129,3 +177,22 @@ func LookupTXT(host string) ([]string, error) {
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// 组合DNS解析服务器地址
|
||||
func composeDNSResolverAddrs(extraResolvers []*dnsconfigs.DNSResolver) []string {
|
||||
sharedDNSLocker.RLock()
|
||||
defer sharedDNSLocker.RUnlock()
|
||||
|
||||
// 这里不处理重复,方便我们可以多次重试
|
||||
var servers = sharedDNSConfig.Servers
|
||||
var port = sharedDNSConfig.Port
|
||||
|
||||
var serverAddrs = []string{}
|
||||
for _, serverAddr := range servers {
|
||||
serverAddrs = append(serverAddrs, configutils.QuoteIP(serverAddr)+":"+port)
|
||||
}
|
||||
for _, resolver := range extraResolvers {
|
||||
serverAddrs = append(serverAddrs, resolver.Addr())
|
||||
}
|
||||
return serverAddrs
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package utils_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -12,9 +13,25 @@ func TestLookupCNAME(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLookupNS(t *testing.T) {
|
||||
t.Log(utils.LookupNS("goedge.cn"))
|
||||
t.Log(utils.LookupNS("goedge.cn", nil))
|
||||
}
|
||||
|
||||
func TestLookupNSExtra(t *testing.T) {
|
||||
t.Log(utils.LookupNS("goedge.cn", []*dnsconfigs.DNSResolver{
|
||||
{
|
||||
Host: "192.168.2.2",
|
||||
},
|
||||
{
|
||||
Host: "192.168.2.2",
|
||||
Port: 58,
|
||||
},
|
||||
{
|
||||
Host: "8.8.8.8",
|
||||
Port: 53,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestLookupTXT(t *testing.T) {
|
||||
t.Log(utils.LookupTXT("yanzheng.goedge.cn"))
|
||||
t.Log(utils.LookupTXT("yanzheng.goedge.cn", nil))
|
||||
}
|
||||
|
||||
61
internal/utils/taskutils/concurrent.go
Normal file
61
internal/utils/taskutils/concurrent.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package taskutils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const DefaultConcurrent = 16
|
||||
|
||||
func RunConcurrent(tasks any, concurrent int, f func(task any, locker *sync.RWMutex)) error {
|
||||
if tasks == nil {
|
||||
return nil
|
||||
}
|
||||
var tasksValue = reflect.ValueOf(tasks)
|
||||
if tasksValue.Type().Kind() != reflect.Slice {
|
||||
return errors.New("ony works for slice")
|
||||
}
|
||||
|
||||
var countTasks = tasksValue.Len()
|
||||
if countTasks == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if concurrent <= 0 {
|
||||
concurrent = 8
|
||||
}
|
||||
if concurrent > countTasks {
|
||||
concurrent = countTasks
|
||||
}
|
||||
|
||||
var taskChan = make(chan any, countTasks)
|
||||
for i := 0; i < countTasks; i++ {
|
||||
taskChan <- tasksValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
var wg = &sync.WaitGroup{}
|
||||
wg.Add(concurrent)
|
||||
|
||||
var locker = &sync.RWMutex{}
|
||||
|
||||
for i := 0; i < concurrent; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case task := <-taskChan:
|
||||
f(task, locker)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
18
internal/utils/taskutils/concurrent_test.go
Normal file
18
internal/utils/taskutils/concurrent_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package taskutils_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils/taskutils"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRunConcurrent(t *testing.T) {
|
||||
err := taskutils.RunConcurrent([]string{"a", "b", "c", "d", "e"}, 3, func(task any, locker *sync.RWMutex) {
|
||||
t.Log("run", task)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user