优化域名查询程序

This commit is contained in:
刘祥超
2023-03-21 11:38:20 +08:00
parent ba1fd07555
commit e45a6cbcb5

View File

@@ -1,27 +1,47 @@
package utils package utils
import ( import (
"errors"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/iwind/TeaGo/logs"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
var sharedDNSClient *dns.Client
var sharedDNSConfig *dns.ClientConfig
func init() {
if !teaconst.IsMain {
return
}
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
logs.Println("ERROR: configure dns client failed: " + err.Error())
return
}
sharedDNSConfig = config
sharedDNSClient = &dns.Client{}
}
// LookupCNAME 查询CNAME记录 // LookupCNAME 查询CNAME记录
// TODO 可以设置使用的DNS主机地址 // TODO 可以设置使用的DNS主机地址
func LookupCNAME(host string) (string, error) { func LookupCNAME(host string) (string, error) {
config, err := dns.ClientConfigFromFile("/etc/resolv.conf") if sharedDNSClient == nil {
if err != nil { return "", errors.New("could not find dns client")
return "", err
} }
var c = new(dns.Client)
var m = new(dns.Msg) var m = new(dns.Msg)
m.SetQuestion(host+".", dns.TypeCNAME) m.SetQuestion(host+".", dns.TypeCNAME)
m.RecursionDesired = true m.RecursionDesired = true
var lastErr error var lastErr error
for _, serverAddr := range config.Servers { for _, serverAddr := range sharedDNSConfig.Servers {
r, _, err := c.Exchange(m, configutils.QuoteIP(serverAddr)+":"+config.Port) r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
if err != nil { if err != nil {
lastErr = err lastErr = err
continue continue
@@ -38,12 +58,6 @@ func LookupCNAME(host string) (string, error) {
// LookupNS 查询NS记录 // LookupNS 查询NS记录
// TODO 可以设置使用的DNS主机地址 // TODO 可以设置使用的DNS主机地址
func LookupNS(host string) ([]string, error) { func LookupNS(host string) ([]string, error) {
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, err
}
var c = new(dns.Client)
var m = new(dns.Msg) var m = new(dns.Msg)
m.SetQuestion(host+".", dns.TypeNS) m.SetQuestion(host+".", dns.TypeNS)
@@ -53,8 +67,8 @@ func LookupNS(host string) ([]string, error) {
var lastErr error var lastErr error
var hasValidServer = false var hasValidServer = false
for _, serverAddr := range config.Servers { for _, serverAddr := range sharedDNSConfig.Servers {
r, _, err := c.Exchange(m, configutils.QuoteIP(serverAddr)+":"+config.Port) r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
if err != nil { if err != nil {
lastErr = err lastErr = err
continue continue
@@ -82,22 +96,16 @@ func LookupNS(host string) ([]string, error) {
// LookupTXT 获取CNAME // LookupTXT 获取CNAME
// TODO 可以设置使用的DNS主机地址 // TODO 可以设置使用的DNS主机地址
func LookupTXT(host string) ([]string, error) { func LookupTXT(host string) ([]string, error) {
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, err
}
var c = new(dns.Client)
var m = new(dns.Msg) var m = new(dns.Msg)
m.SetQuestion(host + ".", dns.TypeTXT) m.SetQuestion(host+".", dns.TypeTXT)
m.RecursionDesired = true m.RecursionDesired = true
var lastErr error var lastErr error
var result = []string{} var result = []string{}
var hasValidServer = false var hasValidServer = false
for _, serverAddr := range config.Servers { for _, serverAddr := range sharedDNSConfig.Servers {
r, _, err := c.Exchange(m, configutils.QuoteIP(serverAddr)+":"+config.Port) r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
if err != nil { if err != nil {
lastErr = err lastErr = err
continue continue