diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index dedbae6..910cf7e 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" http2 "golang.org/x/net/http2" "sync" @@ -150,6 +151,42 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C // 根据域名来查找匹配的域名 func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { + serverConfig, serverName = this.findNamedServerMatched(name) + if serverConfig != nil { + return + } + + matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly + + if sharedNodeConfig.GlobalConfig != nil && + len(sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain) > 0 && + (!matchDomainStrictly || lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name)) { + defaultDomain := sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain + serverConfig, serverName = this.findNamedServerMatched(defaultDomain) + if serverConfig != nil { + return + } + } + + if matchDomainStrictly && !lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name) { + return + } + + // 如果没有找到,则匹配到第一个 + this.serversLocker.RLock() + defer this.serversLocker.RUnlock() + + group := this.Group + currentServers := group.Servers + countServers := len(currentServers) + if countServers == 0 { + return nil, "" + } + return currentServers[0], name +} + +// 严格查找域名 +func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { group := this.Group if group == nil { return nil, "" @@ -177,7 +214,7 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf maxNamedServers := 100_0000 // 是否严格匹配域名 - matchDomainStrictly := (group.IsHTTP() || group.IsHTTPS()) && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly + matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly // 如果只有一个server,则默认为这个 if countServers == 1 && !matchDomainStrictly { @@ -214,11 +251,5 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf } } - // 找不到而且域名严格匹配模式下不返回Server - if matchDomainStrictly { - return nil, name - } - - // 如果没有找到,则匹配到第一个 - return currentServers[0], name + return nil, name } diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index 19d1b07..96e2e5b 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -128,12 +128,23 @@ func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http if server == nil { // 严格匹配域名模式下,我们拒绝用户访问 if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { - hijacker, ok := rawWriter.(http.Hijacker) - if ok { - conn, _, _ := hijacker.Hijack() - if conn != nil { - _ = conn.Close() - return + httpAllConfig := sharedNodeConfig.GlobalConfig.HTTPAll + mismatchAction := httpAllConfig.DomainMismatchAction + if mismatchAction != nil && mismatchAction.Code == "page" { + if mismatchAction.Options != nil { + http.Error(rawWriter, mismatchAction.Options.GetString("contentHTML"), mismatchAction.Options.GetInt("statusCode")) + } else { + http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound) + } + return + } else { + hijacker, ok := rawWriter.(http.Hijacker) + if ok { + conn, _, _ := hijacker.Hijack() + if conn != nil { + _ = conn.Close() + return + } } } }