diff --git a/internal/nodes/http_request_page.go b/internal/nodes/http_request_page.go index 1e3ce9e..e578a5f 100644 --- a/internal/nodes/http_request_page.go +++ b/internal/nodes/http_request_page.go @@ -9,11 +9,8 @@ import ( "github.com/iwind/TeaGo/logs" "net/http" "os" - "regexp" ) -var urlPrefixRegexp = regexp.MustCompile("^(?i)(http|https|ftp)://") - // 请求特殊页面 func (this *HTTPRequest) doPage(status int) (shouldStop bool) { if len(this.web.Pages) == 0 { @@ -49,7 +46,7 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta for _, page := range pages { if page.Match(status) { if len(page.BodyType) == 0 || page.BodyType == shared.BodyTypeURL { - if urlPrefixRegexp.MatchString(page.URL) { + if urlSchemeRegexp.MatchString(page.URL) { var newStatus = page.NewStatus if newStatus <= 0 { newStatus = status diff --git a/internal/nodes/http_request_referers.go b/internal/nodes/http_request_referers.go index cc18c1f..9fc77d7 100644 --- a/internal/nodes/http_request_referers.go +++ b/internal/nodes/http_request_referers.go @@ -12,13 +12,29 @@ func (this *HTTPRequest) doCheckReferers() (shouldStop bool) { return } + var origin = this.RawReq.Header.Get("Origin") + const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效 + // 处理用到Origin的特殊功能 + if this.web.Referers.CheckOrigin && len(origin) > 0 { + // 处理Websocket + if this.web.Websocket != nil && this.web.Websocket.IsOn && this.RawReq.Header.Get("Upgrade") == "websocket" { + originHost, _ := httpParseHost(origin) + if len(originHost) > 0 && this.web.Websocket.MatchOrigin(originHost) { + return + } + } + } + var refererURL = this.RawReq.Header.Get("Referer") if len(refererURL) == 0 && this.web.Referers.CheckOrigin { - var origin = this.RawReq.Header.Get("Origin") if len(origin) > 0 && origin != "null" { - refererURL = "https://" + origin // 因为Origin都只有域名部分,所以为了下面的URL 分析需要加上https:// + if urlSchemeRegexp.MatchString(origin) { + refererURL = origin + } else { + refererURL = "https://" + origin + } } } diff --git a/internal/nodes/http_request_shutdown.go b/internal/nodes/http_request_shutdown.go index 27219e3..693aff8 100644 --- a/internal/nodes/http_request_shutdown.go +++ b/internal/nodes/http_request_shutdown.go @@ -19,7 +19,7 @@ func (this *HTTPRequest) doShutdown() { if len(shutdown.BodyType) == 0 || shutdown.BodyType == shared.BodyTypeURL { // URL - if urlPrefixRegexp.MatchString(shutdown.URL) { + if urlSchemeRegexp.MatchString(shutdown.URL) { this.doURL(http.MethodGet, shutdown.URL, "", shutdown.Status, true) return } diff --git a/internal/nodes/http_request_utils.go b/internal/nodes/http_request_utils.go index 4c7b6e1..192d706 100644 --- a/internal/nodes/http_request_utils.go +++ b/internal/nodes/http_request_utils.go @@ -9,6 +9,7 @@ import ( "github.com/iwind/TeaGo/types" "io" "net/http" + "net/url" "regexp" "strconv" "strings" @@ -22,6 +23,9 @@ var spiderRegexp = regexp.MustCompile(`(?i)(python|pycurl|http-client|httpclient // 内容范围正则,其中的每个括号里的内容都在被引用,不能轻易修改 var contentRangeRegexp = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)`) +// URL协议前缀 +var urlSchemeRegexp = regexp.MustCompile("^(?i)(http|https|ftp)://") + // 分解Range func httpRequestParseRangeHeader(rangeValue string) (result []rangeutils.Range, ok bool) { // 参考RFC:https://tools.ietf.org/html/rfc7233 @@ -222,3 +226,16 @@ func httpRedirect(writer http.ResponseWriter, req *http.Request, url string, cod http.Redirect(writer, req, url, code) } + +// 分析URL中的Host部分 +func httpParseHost(urlString string) (host string, err error) { + if !urlSchemeRegexp.MatchString(urlString) { + urlString = "https://" + urlString + } + + u, err := url.Parse(urlString) + if err != nil && u != nil { + return "", err + } + return u.Host, nil +} diff --git a/internal/nodes/http_request_utils_test.go b/internal/nodes/http_request_utils_test.go index ec337ba..33b785a 100644 --- a/internal/nodes/http_request_utils_test.go +++ b/internal/nodes/http_request_utils_test.go @@ -145,6 +145,23 @@ func TestHTTPRequest_httpRequestNextId_Concurrent(t *testing.T) { a.IsTrue(countDuplicated == 0) } +func TestHTTPParseURL(t *testing.T) { + for _, s := range []string{ + "", + "null", + "example.com", + "https://example.com", + "https://example.com/hello", + } { + host, err := httpParseHost(s) + if err == nil { + t.Log(s, "=>", host) + } else { + t.Log(s, "=>") + } + } +} + func BenchmarkHTTPRequest_httpRequestNextId(b *testing.B) { runtime.GOMAXPROCS(1)