diff --git a/internal/utils/lookup.go b/internal/utils/lookup.go index f35f996e..ea7c8456 100644 --- a/internal/utils/lookup.go +++ b/internal/utils/lookup.go @@ -3,8 +3,10 @@ package utils import ( teaconst "github.com/TeaOSLab/EdgeAdmin/internal/const" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/logs" "github.com/miekg/dns" + "sync" ) var sharedDNSClient *dns.Client @@ -33,17 +35,47 @@ func LookupCNAME(host string) (string, error) { m.RecursionDesired = true var lastErr error - for _, serverAddr := range sharedDNSConfig.Servers { - r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port) - if err != nil { - lastErr = err - continue - } - if len(r.Answer) == 0 { - continue - } + var success = false + var result = "" - return r.Answer[0].(*dns.CNAME).Target, nil + var serverAddrs = sharedDNSConfig.Servers + + { + var publicDNSHosts = []string{"8.8.8.8" /** Google **/, "8.8.4.4" /** Google **/} + for _, publicDNSHost := range publicDNSHosts { + if !lists.ContainsString(serverAddrs, publicDNSHost) { + serverAddrs = append(serverAddrs, publicDNSHost) + } + } } + + var wg = &sync.WaitGroup{} + + for _, serverAddr := range serverAddrs { + wg.Add(1) + + go func(serverAddr string) { + defer wg.Done() + r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port) + if err != nil { + lastErr = err + return + } + + success = true + + if len(r.Answer) == 0 { + return + } + + result = r.Answer[0].(*dns.CNAME).Target + }(serverAddr) + } + wg.Wait() + + if success { + return result, nil + } + return "", lastErr } diff --git a/internal/utils/lookup_test.go b/internal/utils/lookup_test.go index 54d1fd00..21af54dc 100644 --- a/internal/utils/lookup_test.go +++ b/internal/utils/lookup_test.go @@ -8,5 +8,8 @@ import ( ) func TestLookupCNAME(t *testing.T) { - t.Log(utils.LookupCNAME("www.yun4s.cn")) + for _, domain := range []string{"www.yun4s.cn", "example.com"} { + result, err := utils.LookupCNAME(domain) + t.Log(domain, "=>", result, err) + } }