diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index a737050..f4374b3 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -102,6 +102,8 @@ type HTTPRequest struct { disableLog bool // 是否在当前请求中关闭Log forceLog bool // 是否强制记录日志 + isHijacked bool + // script相关操作 isDone bool } diff --git a/internal/nodes/http_writer.go b/internal/nodes/http_writer.go index ac14905..a91870f 100644 --- a/internal/nodes/http_writer.go +++ b/internal/nodes/http_writer.go @@ -339,7 +339,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { cacheWriter, err := storage.OpenWriter(cacheKey, expiresAt, this.StatusCode(), this.calculateHeaderLength(), totalSize, cacheRef.MaxSizeBytes(), this.isPartial) if err != nil { - if err == caches.ErrEntityTooLarge && addStatusHeader { + if errors.Is(err, caches.ErrEntityTooLarge) && addStatusHeader { this.Header().Set("X-Cache", "BYPASS, entity too large") } @@ -968,6 +968,7 @@ func (this *HTTPWriter) Close() { func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) { hijack, ok := this.rawWriter.(http.Hijacker) if ok { + this.req.isHijacked = true return hijack.Hijack() } return diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index 1db1936..22e1853 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -3,6 +3,7 @@ package nodes import ( "context" "crypto/tls" + "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/iwind/TeaGo/Tea" "io" @@ -52,6 +53,8 @@ func (this *HTTPListener) Serve() error { atomic.AddInt64(&this.countActiveConnections, 1) case http.StateClosed: atomic.AddInt64(&this.countActiveConnections, -1) + default: + // do nothing } }, ConnContext: func(ctx context.Context, conn net.Conn) context.Context { @@ -74,7 +77,7 @@ func (this *HTTPListener) Serve() error { // HTTP协议 if this.isHTTP { err := this.httpServer.Serve(this.Listener) - if err != nil && err != http.ErrServerClosed { + if err != nil && !errors.Is(err, http.ErrServerClosed) { return err } } @@ -84,7 +87,7 @@ func (this *HTTPListener) Serve() error { this.httpServer.TLSConfig = this.buildTLSConfig() err := this.httpServer.ServeTLS(this.Listener, "", "") - if err != nil && err != http.ErrServerClosed { + if err != nil && !errors.Is(err, http.ErrServerClosed) { return err } } @@ -180,10 +183,12 @@ func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawRe } // 绑定连接 + var clientConn ClientConnInterface if server != nil && server.Id > 0 { var requestConn = rawReq.Context().Value(HTTPConnContextKey) if requestConn != nil { - clientConn, ok := requestConn.(ClientConnInterface) + var ok bool + clientConn, ok = requestConn.(ClientConnInterface) if ok { var goNext = clientConn.SetServerId(server.Id) if !goNext { @@ -224,6 +229,14 @@ func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawRe nodeConfig: sharedNodeConfig, } req.Do() + + // fix hijacked connection state + if req.isHijacked && clientConn != nil && this.httpServer.ConnState != nil { + netConn, ok := clientConn.(net.Conn) + if ok { + this.httpServer.ConnState(netConn, http.StateClosed) + } + } } // 检查host是否为IP