diff --git a/pkg/configutils/domain.go b/pkg/configutils/domain.go index 44bf654..900d263 100644 --- a/pkg/configutils/domain.go +++ b/pkg/configutils/domain.go @@ -27,6 +27,10 @@ func MatchDomain(pattern string, domain string) (isMatched bool) { return } + if pattern == domain { + return true + } + if pattern == "*" { return true } @@ -61,3 +65,19 @@ func MatchDomain(pattern string, domain string) (isMatched bool) { } return isMatched } + +// IsFuzzyDomain 判断是否为特殊域名 +func IsFuzzyDomain(domain string) bool { + if len(domain) == 0 { + return true + } + if domain[0] == '.' || domain[0] == '~' { + return true + } + for _, c := range domain { + if c == '*' { + return true + } + } + return false +} diff --git a/pkg/configutils/domain_test.go b/pkg/configutils/domain_test.go index fd9aa77..b137c3f 100644 --- a/pkg/configutils/domain_test.go +++ b/pkg/configutils/domain_test.go @@ -81,3 +81,14 @@ func TestMatchDomain(t *testing.T) { a.IsTrue(ok) } } + +func TestIsSpecialDomain(t *testing.T) { + var a = assert.NewAssertion(t) + + a.IsTrue(IsFuzzyDomain("")) + a.IsTrue(IsFuzzyDomain(".hello.com")) + a.IsTrue(IsFuzzyDomain("*.hello.com")) + a.IsTrue(IsFuzzyDomain("hello.*.com")) + a.IsTrue(IsFuzzyDomain("~^hello\\.com")) + a.IsFalse(IsFuzzyDomain("hello.com")) +} diff --git a/pkg/serverconfigs/server_address_group.go b/pkg/serverconfigs/server_address_group.go index a90b1da..0d69a20 100644 --- a/pkg/serverconfigs/server_address_group.go +++ b/pkg/serverconfigs/server_address_group.go @@ -1,19 +1,71 @@ package serverconfigs -import "strings" +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "strings" + "sync" +) type ServerAddressGroup struct { fullAddr string - Servers []*ServerConfig + servers []*ServerConfig + + // 域名和服务映射 + strictDomainMap map[string]map[string]*ServerConfig // domain[:2] => {domain => *ServerConfig} + fuzzyDomainMap map[string]*ServerConfig // special domain => *ServerConfig + + cacheLocker sync.RWMutex + cacheDomainMap map[string]map[string]*ServerConfig // domain[:2] => {domain => *ServerConfig} + countCacheDomains int + + // 支持CNAME的服务 + cnameDomainMap map[string]map[string]*ServerConfig // domain[:2] => {domain => *ServerConfig} + + // 第一个TLS Server + firstTLSServer *ServerConfig } func NewServerAddressGroup(fullAddr string) *ServerAddressGroup { - return &ServerAddressGroup{fullAddr: fullAddr} + return &ServerAddressGroup{ + fullAddr: fullAddr, + strictDomainMap: map[string]map[string]*ServerConfig{}, + fuzzyDomainMap: map[string]*ServerConfig{}, + cacheDomainMap: map[string]map[string]*ServerConfig{}, + cnameDomainMap: map[string]map[string]*ServerConfig{}, + } } // Add 添加服务 func (this *ServerAddressGroup) Add(server *ServerConfig) { - this.Servers = append(this.Servers, server) + for _, serverName := range server.AllStrictNames() { + var prefix = this.domainPrefix(serverName) + domainsMap, ok := this.strictDomainMap[prefix] + if ok { + domainsMap[serverName] = server + } else { + this.strictDomainMap[prefix] = map[string]*ServerConfig{serverName: server} + } + + // CNAME + if server.SupportCNAME { + cnameDomainsMap, ok := this.cnameDomainMap[prefix] + if ok { + cnameDomainsMap[serverName] = server + } else { + this.cnameDomainMap[prefix] = map[string]*ServerConfig{serverName: server} + } + } + } + for _, serverName := range server.AllFuzzyNames() { + this.fuzzyDomainMap[serverName] = server + } + + this.servers = append(this.servers, server) + + // 第一个TLS Server + if this.firstTLSServer == nil && server.SSLPolicy() != nil && server.SSLPolicy().IsOn { + this.firstTLSServer = server + } } // FullAddr 获取完整的地址 @@ -40,6 +92,11 @@ func (this *ServerAddressGroup) Addr() string { return strings.TrimPrefix(this.fullAddr, protocol.String()+"://") } +// Servers 读取所有服务 +func (this *ServerAddressGroup) Servers() []*ServerConfig { + return this.servers +} + // IsHTTP 判断当前分组是否为HTTP func (this *ServerAddressGroup) IsHTTP() bool { p := this.Protocol() @@ -78,8 +135,82 @@ func (this *ServerAddressGroup) IsUDP() bool { // FirstServer 获取第一个Server func (this *ServerAddressGroup) FirstServer() *ServerConfig { - if len(this.Servers) > 0 { - return this.Servers[0] + if len(this.servers) > 0 { + return this.servers[0] } return nil } + +// FirstTLSServer 获取第一个TLS Server +func (this *ServerAddressGroup) FirstTLSServer() *ServerConfig { + return this.firstTLSServer +} + +// MatchServerName 使用域名查找服务 +func (this *ServerAddressGroup) MatchServerName(serverName string) *ServerConfig { + var prefix = this.domainPrefix(serverName) + + // 试图从缓存中读取 + this.cacheLocker.RLock() + if len(this.cacheDomainMap) > 0 { + domainMap, ok := this.cacheDomainMap[prefix] + if ok { + server, ok := domainMap[serverName] + if ok { + return server + } + } + } + this.cacheLocker.RUnlock() + + domainMap, ok := this.strictDomainMap[prefix] + if ok { + server, ok := domainMap[serverName] + if ok { + return server + } + } + for pattern, server := range this.fuzzyDomainMap { + if configutils.MatchDomain(pattern, serverName) { + // 加入到缓存 + this.cacheLocker.Lock() + + // 限制缓存的最大尺寸,防止内存耗尽 + if this.countCacheDomains < 1_000_000 { + domainMap, ok := this.cacheDomainMap[prefix] + if ok { + domainMap[serverName] = server + } else { + this.cacheDomainMap[prefix] = map[string]*ServerConfig{serverName: server} + } + this.countCacheDomains++ + } + this.cacheLocker.Unlock() + + return server + } + } + return nil +} + +// MatchServerCNAME 使用CNAME查找服务 +func (this *ServerAddressGroup) MatchServerCNAME(serverName string) *ServerConfig { + var prefix = this.domainPrefix(serverName) + + domainMap, ok := this.cnameDomainMap[prefix] + if ok { + server, ok := domainMap[serverName] + if ok { + return server + } + } + + return nil +} + +func (this *ServerAddressGroup) domainPrefix(domain string) string { + if len(domain) < 2 { + return domain + } + return domain[:2] +} diff --git a/pkg/serverconfigs/server_address_group_test.go b/pkg/serverconfigs/server_address_group_test.go index 5e23322..7c2f655 100644 --- a/pkg/serverconfigs/server_address_group_test.go +++ b/pkg/serverconfigs/server_address_group_test.go @@ -2,7 +2,9 @@ package serverconfigs import ( "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/types" "testing" + "time" ) func TestServerAddressGroup_Protocol(t *testing.T) { @@ -32,3 +34,79 @@ func TestServerAddressGroup_Protocol(t *testing.T) { a.IsTrue(group.Addr() == "/tmp/my.sock") } } + +func TestServerAddressGroup_MatchServerName(t *testing.T) { + var group = NewServerAddressGroup("") + for i := 0; i < 1_000_000; i++ { + group.Add(&ServerConfig{ + ServerNames: []*ServerNameConfig{ + { + Name: "hello" + types.String(i) + ".com", + SubNames: []string{}, + }, + }, + }) + } + group.Add(&ServerConfig{ + ServerNames: []*ServerNameConfig{ + { + Name: "hello.com", + SubNames: []string{}, + }, + }, + }) + group.Add(&ServerConfig{ + ServerNames: []*ServerNameConfig{ + { + Name: "*.hello.com", + SubNames: []string{}, + }, + }, + }) + + var before = time.Now() + defer func() { + t.Log(time.Since(before).Seconds()*1000, "ms") + }() + + t.Log(group.MatchServerName("hello99999.com").AllStrictNames()) + t.Log(group.MatchServerName("hello.com").AllStrictNames()) + t.Log(group.MatchServerName("world.hello.com").AllFuzzyNames()) + for i := 0; i < 100_000; i++ { + _ = group.MatchServerName("world.hello.com") + } +} + +func TestServerAddressGroup_MatchServerCNAME(t *testing.T) { + var group = NewServerAddressGroup("") + group.Add(&ServerConfig{ + ServerNames: []*ServerNameConfig{ + { + Name: "hello.com", + SubNames: []string{}, + }, + }, + SupportCNAME: true, + }) + group.Add(&ServerConfig{ + ServerNames: []*ServerNameConfig{ + { + Name: "*.hello.com", + SubNames: []string{}, + }, + }, + }) + + var before = time.Now() + defer func() { + t.Log(time.Since(before).Seconds()*1000, "ms") + }() + + server := group.MatchServerCNAME("hello.com") + if server != nil { + t.Log(server.AllStrictNames()) + } else { + t.Log(server) + } + t.Log(group.MatchServerCNAME("world.hello.com")) +} diff --git a/pkg/serverconfigs/server_config.go b/pkg/serverconfigs/server_config.go index 3bea6a5..925517c 100644 --- a/pkg/serverconfigs/server_config.go +++ b/pkg/serverconfigs/server_config.go @@ -315,35 +315,60 @@ func (this *ServerConfig) IsUDPFamily() bool { return this.UDP != nil } -// MatchName 判断是否和域名匹配 -func (this *ServerConfig) MatchName(name string) bool { - if len(name) == 0 { - return false - } - if len(this.AliasServerNames) > 0 && configutils.MatchDomains(this.AliasServerNames, name) { - return true - } - for _, serverName := range this.ServerNames { - if serverName.Match(name) { - return true +// AllStrictNames 所有严格域名 +func (this *ServerConfig) AllStrictNames() []string { + var result = []string{} + for _, name := range this.AliasServerNames { + if len(name) > 0 { + if !configutils.IsFuzzyDomain(name) { + result = append(result, name) + } } } - return false + for _, serverName := range this.ServerNames { + var name = serverName.Name + if len(name) > 0 { + if !configutils.IsFuzzyDomain(name) { + result = append(result, name) + } + } + for _, name := range serverName.SubNames { + if len(name) > 0 { + if !configutils.IsFuzzyDomain(name) { + result = append(result, name) + } + } + } + } + return result } -// MatchNameStrictly 判断是否严格匹配 -func (this *ServerConfig) MatchNameStrictly(name string) bool { - for _, serverName := range this.AliasServerNames { - if serverName == name { - return true +// AllFuzzyNames 所有模糊域名 +func (this *ServerConfig) AllFuzzyNames() []string { + var result = []string{} + for _, name := range this.AliasServerNames { + if len(name) > 0 { + if configutils.IsFuzzyDomain(name) { + result = append(result, name) + } } } for _, serverName := range this.ServerNames { - if serverName.Name == name { - return true + var name = serverName.Name + if len(name) > 0 { + if configutils.IsFuzzyDomain(name) { + result = append(result, name) + } + } + for _, name := range serverName.SubNames { + if len(name) > 0 { + if configutils.IsFuzzyDomain(name) { + result = append(result, name) + } + } } } - return false + return result } // SSLPolicy SSL信息 diff --git a/pkg/serverconfigs/server_config_test.go b/pkg/serverconfigs/server_config_test.go index 33da9e1..77b9b68 100644 --- a/pkg/serverconfigs/server_config_test.go +++ b/pkg/serverconfigs/server_config_test.go @@ -72,3 +72,22 @@ func TestServerConfig_Protocols(t *testing.T) { t.Log(server.FullAddresses()) } } + +func TestServerConfig_AllStrictNames(t *testing.T) { + var config = &ServerConfig{ + AliasServerNames: []string{"hello.com", ".hello.com"}, + ServerNames: []*ServerNameConfig{ + { + Name: "hello2.com", + }, + { + SubNames: []string{"hello3.com", "hello4.com", "*.hello5.com"}, + }, + { + Name: "~hello.com", + }, + }, + } + t.Log(config.AllStrictNames()) + t.Log(config.AllFuzzyNames()) +}