diff --git a/internal/nodes/http_client.go b/internal/nodes/http_client.go new file mode 100644 index 0000000..ab4848b --- /dev/null +++ b/internal/nodes/http_client.go @@ -0,0 +1,40 @@ +package nodes + +import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" + "net/http" +) + +// HTTP客户端 +type HTTPClient struct { + rawClient *http.Client + accessAt int64 +} + +// 获取新客户端对象 +func NewHTTPClient(rawClient *http.Client) *HTTPClient { + return &HTTPClient{ + rawClient: rawClient, + accessAt: utils.UnixTime(), + } +} + +// 获取原始客户端对象 +func (this *HTTPClient) RawClient() *http.Client { + return this.rawClient +} + +// 更新访问时间 +func (this *HTTPClient) UpdateAccessTime() { + this.accessAt = utils.UnixTime() +} + +// 获取访问时间 +func (this *HTTPClient) AccessTime() int64 { + return this.accessAt +} + +// 关闭 +func (this *HTTPClient) Close() { + this.rawClient.CloseIdleConnections() +} diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go new file mode 100644 index 0000000..b2ae636 --- /dev/null +++ b/internal/nodes/http_client_pool.go @@ -0,0 +1,149 @@ +package nodes + +import ( + "context" + "crypto/tls" + "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "net" + "net/http" + "runtime" + "strconv" + "sync" + "time" +) + +// HTTP客户端池单例 +var SharedHTTPClientPool = NewHTTPClientPool() + +// 客户端池 +type HTTPClientPool struct { + clientExpiredDuration time.Duration + clientsMap map[string]*HTTPClient // backend key => client + locker sync.Mutex +} + +// 获取新对象 +func NewHTTPClientPool() *HTTPClientPool { + pool := &HTTPClientPool{ + clientExpiredDuration: 3600 * time.Second, + clientsMap: map[string]*HTTPClient{}, + } + + go pool.cleanClients() + + return pool +} + +// 根据地址获取客户端 +func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig) (rawClient *http.Client, realAddr string, err error) { + if origin.Addr == nil { + return nil, "", errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") + } + + key := origin.UniqueKey() + originAddr := origin.Addr.PickAddress() + if origin.Addr.HostHasVariables() { + originAddr = req.Format(originAddr) + } + key += "@" + originAddr + + this.locker.Lock() + defer this.locker.Unlock() + + client, found := this.clientsMap[key] + if found { + client.UpdateAccessTime() + return client.RawClient(), originAddr, nil + } + + maxConnections := origin.MaxConns + connectionTimeout := origin.ConnTimeoutDuration() + readTimeout := origin.ReadTimeoutDuration() + idleTimeout := origin.IdleTimeoutDuration() + idleConns := origin.MaxIdleConns + + // 超时时间 + if connectionTimeout <= 0 { + connectionTimeout = 15 * time.Second + } + + if idleTimeout <= 0 { + idleTimeout = 2 * time.Minute + } + + numberCPU := runtime.NumCPU() + if numberCPU < 8 { + numberCPU = 8 + } + if maxConnections <= 0 { + maxConnections = numberCPU * 2 + } + + if idleConns <= 0 { + idleConns = numberCPU + } + //logs.Println("[ORIGIN]max connections:", maxConnections) + + // TLS通讯 + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + if origin.Cert != nil { + obj := origin.Cert.CertObject() + if obj != nil { + tlsConfig.InsecureSkipVerify = false + tlsConfig.Certificates = []tls.Certificate{*obj} + if len(origin.Cert.ServerName) > 0 { + tlsConfig.ServerName = origin.Cert.ServerName + } + } + } + + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // 握手配置 + return (&net.Dialer{ + Timeout: connectionTimeout, + KeepAlive: 1 * time.Minute, + }).DialContext(ctx, network, originAddr) + }, + MaxIdleConns: 0, + MaxIdleConnsPerHost: idleConns, + MaxConnsPerHost: maxConnections, + IdleConnTimeout: idleTimeout, + ExpectContinueTimeout: 1 * time.Second, + TLSHandshakeTimeout: 0, // 不限 + TLSClientConfig: tlsConfig, + Proxy: nil, + } + + rawClient = &http.Client{ + Timeout: readTimeout, + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + this.clientsMap[key] = NewHTTPClient(rawClient) + + return rawClient, originAddr, nil +} + +// 清理不使用的Client +func (this *HTTPClientPool) cleanClients() { + ticker := time.NewTicker(this.clientExpiredDuration) + for range ticker.C { + currentAt := time.Now().Unix() + + this.locker.Lock() + for k, client := range this.clientsMap { + if client.AccessTime() < currentAt+86400 { // 超过 N 秒没有调用就关闭 + delete(this.clientsMap, k) + client.Close() + } + } + this.locker.Unlock() + } +} diff --git a/internal/nodes/http_client_pool_test.go b/internal/nodes/http_client_pool_test.go new file mode 100644 index 0000000..01d8bd9 --- /dev/null +++ b/internal/nodes/http_client_pool_test.go @@ -0,0 +1,78 @@ +package nodes + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "runtime" + "testing" + "time" +) + +func TestHTTPClientPool_Client(t *testing.T) { + pool := NewHTTPClientPool() + + { + origin := &serverconfigs.OriginConfig{ + Id: 1, + Version: 2, + Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"}, + } + err := origin.Init() + if err != nil { + t.Fatal(err) + } + { + client, addr, err := pool.Client(nil, origin) + if err != nil { + t.Fatal(err) + } + t.Log("addr:", addr, "client:", client) + } + for i := 0; i < 10; i++ { + client, addr, err := pool.Client(nil, origin) + if err != nil { + t.Fatal(err) + } + t.Log("addr:", addr, "client:", client) + } + } +} + +func TestHTTPClientPool_cleanClients(t *testing.T) { + origin := &serverconfigs.OriginConfig{ + Id: 1, + Version: 2, + Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"}, + } + err := origin.Init() + if err != nil { + t.Fatal(err) + } + + pool := NewHTTPClientPool() + pool.clientExpiredDuration = 2 * time.Second + + for i := 0; i < 10; i++ { + t.Log("get", i) + _, _, _ = pool.Client(nil, origin) + time.Sleep(1 * time.Second) + } +} + +func BenchmarkHTTPClientPool_Client(b *testing.B) { + runtime.GOMAXPROCS(1) + + origin := &serverconfigs.OriginConfig{ + Id: 1, + Version: 2, + Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"}, + } + err := origin.Init() + if err != nil { + b.Fatal(err) + } + + pool := NewHTTPClientPool() + for i := 0; i < b.N; i++ { + _, _, _ = pool.Client(nil, origin) + } +} diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index acca511..7d5c496 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -7,6 +7,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/types" "net" "net/http" @@ -51,6 +52,7 @@ type HTTPRequest struct { requestCost float64 // 请求耗时 filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空 origin *serverconfigs.OriginConfig // 源站 + originAddr string // 源站实际地址 errors []string // 错误信息 } @@ -78,7 +80,7 @@ func (this *HTTPRequest) Do() { // Web配置 err := this.configureWeb(this.Server.Web, true, 0) if err != nil { - this.write500() + this.write500(err) this.doEnd() return } @@ -232,6 +234,7 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo if !location.IsOn { continue } + logs.Println("rawPath:", rawPath, "location:", location.Pattern) // TODO if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched { if len(varMapping) > 0 { this.addVarMapping(varMapping) @@ -398,9 +401,9 @@ func (this *HTTPRequest) Format(source string) string { if this.origin != nil { switch suffix { case "address", "addr": - return this.origin.RealAddr() + return this.originAddr case "host": - addr := this.origin.RealAddr() + addr := this.originAddr index := strings.Index(addr, ":") if index > -1 { return addr[:index] @@ -674,7 +677,9 @@ func (this *HTTPRequest) requestServerPort() int { // 设置代理相关头部信息 // 参考:https://tools.ietf.org/html/rfc7239 func (this *HTTPRequest) setForwardHeaders(header http.Header) { - delete(header, "Connection") + if this.RawReq.Header.Get("Connection") == "close" { + this.RawReq.Header.Set("Connection", "keep-alive") + } remoteAddr := this.RawReq.RemoteAddr host, _, err := net.SplitHostPort(remoteAddr) @@ -728,6 +733,8 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) { // 处理自定义Request Header func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) { + this.fixRequestHeader(reqHeader) + if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn { // 删除某些Header for name := range reqHeader { @@ -742,12 +749,17 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) { continue } oldValues, _ := this.RawReq.Header[header.Name] + newHeaderValue := header.Value // 因为我们不能修改header,所以在这里使用新变量 if header.HasVariables() { - oldValues = append(oldValues, this.Format(header.Value)) - } else { - oldValues = append(oldValues, header.Value) + newHeaderValue = this.Format(header.Value) } + oldValues = append(oldValues, newHeaderValue) reqHeader[header.Name] = oldValues + + // 支持修改Host + if header.Name == "Host" && len(header.Value) > 0 { + this.RawReq.Host = newHeaderValue + } } // Set @@ -755,10 +767,15 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) { if !header.IsOn { continue } + newHeaderValue := header.Value // 因为我们不能修改header,所以在这里使用新变量 if header.HasVariables() { - reqHeader[header.Name] = []string{this.Format(header.Value)} - } else { - reqHeader[header.Name] = []string{header.Value} + newHeaderValue = this.Format(header.Value) + } + reqHeader[header.Name] = []string{newHeaderValue} + + // 支持修改Host + if header.Name == "Host" && len(header.Value) > 0 { + this.RawReq.Host = newHeaderValue } } diff --git a/internal/nodes/http_request_error.go b/internal/nodes/http_request_error.go index 2a98c39..0e32343 100644 --- a/internal/nodes/http_request_error.go +++ b/internal/nodes/http_request_error.go @@ -17,7 +17,11 @@ func (this *HTTPRequest) write404() { _, _ = this.writer.Write([]byte(msg)) } -func (this *HTTPRequest) write500() { +func (this *HTTPRequest) write500(err error) { + if err != nil { + this.addError(err) + } + statusCode := http.StatusInternalServerError if this.doPage(statusCode) { return diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index ff30b41..a475b42 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -1,7 +1,14 @@ package nodes import ( + "context" + "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/logs" + "io" + "net/url" + "strconv" "strings" ) @@ -11,9 +18,46 @@ func (this *HTTPRequest) doReverseProxy() { return } + // 对URL的处理 + stripPrefix := this.reverseProxy.StripPrefix + requestURI := this.reverseProxy.RequestURI + requestURIHasVariables := this.reverseProxy.RequestURIHasVariables() + requestHost := this.reverseProxy.RequestHost + requestHostHasVariables := this.reverseProxy.RequestHostHasVariables() + + // 源站 + requestCall := shared.NewRequestCall() + origin := this.reverseProxy.NextOrigin(requestCall) + if origin == nil { + err := errors.New(this.requestPath() + ": no available backends for reverse proxy") + logs.Error(err) + this.write500(err) + return + } + this.origin = origin // 设置全局变量是为了日志等处理 + if len(origin.StripPrefix) > 0 { + stripPrefix = origin.StripPrefix + } + if len(origin.RequestURI) > 0 { + requestURI = origin.RequestURI + requestURIHasVariables = origin.RequestURIHasVariables() + } + if len(origin.RequestHost) > 0 { + requestHost = origin.RequestHost + requestHostHasVariables = origin.RequestHostHasVariables() + } + + // 处理Scheme + if origin.Addr == nil { + err := errors.New(this.requestPath() + ": origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address") + logs.Error(err) + this.write500(err) + return + } + this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme() + // StripPrefix - if len(this.reverseProxy.StripPrefix) > 0 { - stripPrefix := this.reverseProxy.StripPrefix + if len(stripPrefix) > 0 { if stripPrefix[0] != '/' { stripPrefix = "/" + stripPrefix } @@ -24,11 +68,11 @@ func (this *HTTPRequest) doReverseProxy() { } // RequestURI - if len(this.reverseProxy.RequestURI) > 0 { - if this.reverseProxy.RequestURIHasVariables() { - this.uri = this.Format(this.reverseProxy.RequestURI) + if len(requestURI) > 0 { + if requestURIHasVariables { + this.uri = this.Format(requestURI) } else { - this.uri = this.reverseProxy.RequestURI + this.uri = requestURI } if len(this.uri) == 0 || this.uri[0] != '/' { this.uri = "/" + this.uri @@ -47,6 +91,18 @@ func (this *HTTPRequest) doReverseProxy() { this.uri = utils.CleanPath(this.uri) } + // RequestHost + if len(requestHost) > 0 { + if requestHostHasVariables { + this.RawReq.Host = this.Format(requestHost) + } else { + this.RawReq.Host = this.reverseProxy.RequestHost + } + this.RawReq.URL.Host = this.RawReq.Host + } else { + this.RawReq.URL.Host = this.Host + } + // 重组请求URL questionMark := strings.Index(this.uri, "?") if questionMark > -1 { @@ -56,16 +112,11 @@ func (this *HTTPRequest) doReverseProxy() { this.RawReq.URL.Path = this.uri this.RawReq.URL.RawQuery = "" } + this.RawReq.RequestURI = "" - // RequestHost - if len(this.reverseProxy.RequestHost) > 0 { - if this.reverseProxy.RequestHostHasVariables() { - this.RawReq.Host = this.Format(this.reverseProxy.RequestHost) - } else { - this.RawReq.Host = this.reverseProxy.RequestHost - } - this.RawReq.URL.Host = this.RawReq.Host - } + // 处理Header + this.setForwardHeaders(this.RawReq.Header) + this.processRequestHeaders(this.RawReq.Header) // 判断是否为Websocket请求 if this.RawReq.Header.Get("Upgrade") == "websocket" { @@ -73,6 +124,110 @@ func (this *HTTPRequest) doReverseProxy() { return } - // 普通HTTP请求 + // 获取请求客户端 + client, addr, err := SharedHTTPClientPool.Client(this, origin) + if err != nil { + logs.Error(err) + this.write500(err) + return + } + + this.originAddr = addr + + // 开始请求 + resp, err := client.Do(this.RawReq) + if err != nil { + // 客户端取消请求,则不提示 + httpErr, ok := err.(*url.Error) + if !ok || httpErr.Err != context.Canceled { + // TODO 如果超过最大失败次数,则下线 + + this.write500(err) + logs.Println("[proxy]'" + this.RawReq.URL.String() + "': " + err.Error()) + } else { + // 是否为客户端方面的错误 + isClientError := false + if ok { + if httpErr.Err == context.Canceled { + isClientError = true + this.addError(errors.New(httpErr.Op + " " + httpErr.URL + ": client closed the connection")) + this.writer.WriteHeader(499) // 仿照nginx + } + } + + if !isClientError { + this.write500(err) + } + } + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + return + } + + // WAF对出站进行检查 // TODO + + // TODO 清除源站错误次数 + + // 特殊页面 + // TODO + + // 设置Charset + // TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集 + if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 { + contentTypes, ok := resp.Header["Content-Type"] + if ok && len(contentTypes) > 0 { + contentType := contentTypes[0] + if _, found := textMimeMap[contentType]; found { + resp.Header["Content-Type"][0] = contentType + "; charset=" + this.web.Charset.Charset + } + } + } + + // 响应Header + this.writer.AddHeaders(resp.Header) + this.processResponseHeaders(resp.StatusCode) + + // 是否需要刷新 + shouldFlush := this.RawReq.Header.Get("Accept") == "text/event-stream" + + // 准备 + this.writer.Prepare(resp.ContentLength) + + // 设置响应代码 + this.writer.WriteHeader(resp.StatusCode) + + // 输出到客户端 + pool := this.bytePool(resp.ContentLength) + buf := pool.Get() + if shouldFlush { + for { + n, readErr := resp.Body.Read(buf) + if n > 0 { + _, err = this.writer.Write(buf[:n]) + this.writer.Flush() + if err != nil { + break + } + } + if readErr != nil { + err = readErr + break + } + } + } else { + _, err = io.CopyBuffer(this.writer, resp.Body, buf) + } + pool.Put(buf) + + err1 := resp.Body.Close() + if err1 != nil { + logs.Error(err1) + } + + if err != nil { + logs.Error(err) + this.addError(err) + } } diff --git a/internal/nodes/http_request_root.go b/internal/nodes/http_request_root.go index 0b6c1af..faa6125 100644 --- a/internal/nodes/http_request_root.go +++ b/internal/nodes/http_request_root.go @@ -107,9 +107,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { } return } else { - this.write500() + this.write500(err) logs.Error(err) - this.addError(err) return true } } @@ -137,9 +136,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { } return } else { - this.write500() + this.write500(err) logs.Error(err) - this.addError(err) return true } } @@ -220,9 +218,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { reader, err := os.OpenFile(filePath, os.O_RDONLY, 0444) if err != nil { - this.write500() + this.write500(err) logs.Error(err) - this.addError(err) return true } diff --git a/internal/nodes/http_request_test.go b/internal/nodes/http_request_test.go index 93e8fda..e5669a2 100644 --- a/internal/nodes/http_request_test.go +++ b/internal/nodes/http_request_test.go @@ -16,7 +16,7 @@ func TestHTTPRequest_RedirectToHTTPS(t *testing.T) { }, }, } - req.Run() + req.Do() a.IsBool(req.web.RedirectToHttps.IsOn == false) } { @@ -29,7 +29,7 @@ func TestHTTPRequest_RedirectToHTTPS(t *testing.T) { }, }, } - req.Run() + req.Do() a.IsBool(req.web.RedirectToHttps.IsOn == true) } } diff --git a/internal/nodes/http_request_url.go b/internal/nodes/http_request_url.go index 7e44ad2..29e0fcc 100644 --- a/internal/nodes/http_request_url.go +++ b/internal/nodes/http_request_url.go @@ -35,8 +35,7 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod resp, err := client.Do(req) if err != nil { logs.Error(errors.New(req.URL.String() + ": " + err.Error())) - this.addError(err) - this.write500() + this.write500(err) return } defer func() { diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index b217d44..d6a3991 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -1,8 +1,6 @@ package nodes import ( - "errors" - "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/iwind/TeaGo/logs" "io" "net/http" @@ -30,20 +28,6 @@ func (this *HTTPRequest) doWebsocket() { } } - requestCall := shared.NewRequestCall() - origin := this.reverseProxy.NextOrigin(requestCall) - if origin == nil { - err := errors.New(this.requestPath() + ": no available backends for websocket") - logs.Error(err) - this.addError(err) - this.write500() - return - } - - // 处理Header - this.processRequestHeaders(this.RawReq.Header) - this.fixRequestHeader(this.RawReq.Header) // 处理 Websocket -> WebSocket - // 设置指定的来源域 if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 { newRequestOrigin := this.web.Websocket.RequestOrigin @@ -54,11 +38,10 @@ func (this *HTTPRequest) doWebsocket() { } // TODO 增加N次错误重试,重试的时候需要尝试不同的源站 - originConn, err := OriginConnect(origin) + originConn, err := OriginConnect(this.origin) if err != nil { logs.Error(err) - this.addError(err) - this.write500() + this.write500(err) return } defer func() { @@ -68,16 +51,14 @@ func (this *HTTPRequest) doWebsocket() { err = this.RawReq.Write(originConn) if err != nil { logs.Error(err) - this.addError(err) - this.write500() + this.write500(err) return } clientConn, _, err := this.writer.Hijack() if err != nil { logs.Error(err) - this.addError(err) - this.write500() + this.write500(err) return } defer func() { diff --git a/internal/nodes/http_writer.go b/internal/nodes/http_writer.go index 767c142..ca76b72 100644 --- a/internal/nodes/http_writer.go +++ b/internal/nodes/http_writer.go @@ -118,6 +118,9 @@ func (this *HTTPWriter) AddHeaders(header http.Header) { return } for key, value := range header { + if key == "Connection" { + continue + } for _, v := range value { this.writer.Header().Add(key, v) } diff --git a/internal/utils/time.go b/internal/utils/time.go new file mode 100644 index 0000000..aa0595e --- /dev/null +++ b/internal/utils/time.go @@ -0,0 +1,26 @@ +package utils + +import ( + "time" +) + +var unixTime = time.Now().Unix() +var unixTimerIsReady = false + +func init() { + ticker := time.NewTicker(500 * time.Millisecond) + go func() { + for range ticker.C { + unixTimerIsReady = true + unixTime = time.Now().Unix() + } + }() +} + +// 最快获取时间戳的方式,通常用在不需要特别精确时间戳的场景 +func UnixTime() int64 { + if unixTimerIsReady { + return unixTime + } + return time.Now().Unix() +} diff --git a/internal/utils/time_test.go b/internal/utils/time_test.go new file mode 100644 index 0000000..7df1064 --- /dev/null +++ b/internal/utils/time_test.go @@ -0,0 +1,13 @@ +package utils + +import ( + "testing" + "time" +) + +func TestUnixTime(t *testing.T) { + for i := 0; i < 5; i++ { + t.Log(UnixTime(), "real:", time.Now().Unix()) + time.Sleep(1 * time.Second) + } +}