找不到匹配的域名时可以指定默认域名、指定不匹配时的处理动作

This commit is contained in:
刘祥超
2020-11-18 12:17:19 +08:00
parent 58cfb281f9
commit a5f3b98bb0
2 changed files with 56 additions and 14 deletions

View File

@@ -5,6 +5,7 @@ import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
http2 "golang.org/x/net/http2"
"sync"
@@ -150,6 +151,42 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
// 根据域名来查找匹配的域名
func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
serverConfig, serverName = this.findNamedServerMatched(name)
if serverConfig != nil {
return
}
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
if sharedNodeConfig.GlobalConfig != nil &&
len(sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain) > 0 &&
(!matchDomainStrictly || lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name)) {
defaultDomain := sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain
serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
if serverConfig != nil {
return
}
}
if matchDomainStrictly && !lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name) {
return
}
// 如果没有找到,则匹配到第一个
this.serversLocker.RLock()
defer this.serversLocker.RUnlock()
group := this.Group
currentServers := group.Servers
countServers := len(currentServers)
if countServers == 0 {
return nil, ""
}
return currentServers[0], name
}
// 严格查找域名
func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
group := this.Group
if group == nil {
return nil, ""
@@ -177,7 +214,7 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
maxNamedServers := 100_0000
// 是否严格匹配域名
matchDomainStrictly := (group.IsHTTP() || group.IsHTTPS()) && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
// 如果只有一个server则默认为这个
if countServers == 1 && !matchDomainStrictly {
@@ -214,11 +251,5 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
}
}
// 找不到而且域名严格匹配模式下不返回Server
if matchDomainStrictly {
return nil, name
}
// 如果没有找到,则匹配到第一个
return currentServers[0], name
}

View File

@@ -128,6 +128,16 @@ func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http
if server == nil {
// 严格匹配域名模式下,我们拒绝用户访问
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
httpAllConfig := sharedNodeConfig.GlobalConfig.HTTPAll
mismatchAction := httpAllConfig.DomainMismatchAction
if mismatchAction != nil && mismatchAction.Code == "page" {
if mismatchAction.Options != nil {
http.Error(rawWriter, mismatchAction.Options.GetString("contentHTML"), mismatchAction.Options.GetInt("statusCode"))
} else {
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
}
return
} else {
hijacker, ok := rawWriter.(http.Hijacker)
if ok {
conn, _, _ := hijacker.Hijack()
@@ -137,6 +147,7 @@ func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http
}
}
}
}
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
return