From a5f3b98bb039d86c3bb3c36f3cf40bfd8e65b16f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Wed, 18 Nov 2020 12:17:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=89=BE=E4=B8=8D=E5=88=B0=E5=8C=B9=E9=85=8D?= =?UTF-8?q?=E7=9A=84=E5=9F=9F=E5=90=8D=E6=97=B6=E5=8F=AF=E4=BB=A5=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=E9=BB=98=E8=AE=A4=E5=9F=9F=E5=90=8D=E3=80=81=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=E4=B8=8D=E5=8C=B9=E9=85=8D=E6=97=B6=E7=9A=84=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=8A=A8=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/listener_base.go | 47 +++++++++++++++++++++++++++------ internal/nodes/listener_http.go | 23 +++++++++++----- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index dedbae6..910cf7e 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" http2 "golang.org/x/net/http2" "sync" @@ -150,6 +151,42 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C // 根据域名来查找匹配的域名 func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { + serverConfig, serverName = this.findNamedServerMatched(name) + if serverConfig != nil { + return + } + + matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly + + if sharedNodeConfig.GlobalConfig != nil && + len(sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain) > 0 && + (!matchDomainStrictly || lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name)) { + defaultDomain := sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain + serverConfig, serverName = this.findNamedServerMatched(defaultDomain) + if serverConfig != nil { + return + } + } + + if matchDomainStrictly && !lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name) { + return + } + + // 如果没有找到,则匹配到第一个 + this.serversLocker.RLock() + defer this.serversLocker.RUnlock() + + group := this.Group + currentServers := group.Servers + countServers := len(currentServers) + if countServers == 0 { + return nil, "" + } + return currentServers[0], name +} + +// 严格查找域名 +func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { group := this.Group if group == nil { return nil, "" @@ -177,7 +214,7 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf maxNamedServers := 100_0000 // 是否严格匹配域名 - matchDomainStrictly := (group.IsHTTP() || group.IsHTTPS()) && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly + matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly // 如果只有一个server,则默认为这个 if countServers == 1 && !matchDomainStrictly { @@ -214,11 +251,5 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf } } - // 找不到而且域名严格匹配模式下不返回Server - if matchDomainStrictly { - return nil, name - } - - // 如果没有找到,则匹配到第一个 - return currentServers[0], name + return nil, name } diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index 19d1b07..96e2e5b 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -128,12 +128,23 @@ func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http if server == nil { // 严格匹配域名模式下,我们拒绝用户访问 if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { - hijacker, ok := rawWriter.(http.Hijacker) - if ok { - conn, _, _ := hijacker.Hijack() - if conn != nil { - _ = conn.Close() - return + httpAllConfig := sharedNodeConfig.GlobalConfig.HTTPAll + mismatchAction := httpAllConfig.DomainMismatchAction + if mismatchAction != nil && mismatchAction.Code == "page" { + if mismatchAction.Options != nil { + http.Error(rawWriter, mismatchAction.Options.GetString("contentHTML"), mismatchAction.Options.GetInt("statusCode")) + } else { + http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound) + } + return + } else { + hijacker, ok := rawWriter.(http.Hijacker) + if ok { + conn, _, _ := hijacker.Hijack() + if conn != nil { + _ = conn.Close() + return + } } } }