修复Websocket连接无法报告连接关闭的问题

This commit is contained in:
刘祥超
2024-01-16 09:24:51 +08:00
parent d694319191
commit c94e93859f
3 changed files with 20 additions and 4 deletions

View File

@@ -102,6 +102,8 @@ type HTTPRequest struct {
disableLog bool // 是否在当前请求中关闭Log disableLog bool // 是否在当前请求中关闭Log
forceLog bool // 是否强制记录日志 forceLog bool // 是否强制记录日志
isHijacked bool
// script相关操作 // script相关操作
isDone bool isDone bool
} }

View File

@@ -339,7 +339,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
cacheWriter, err := storage.OpenWriter(cacheKey, expiresAt, this.StatusCode(), this.calculateHeaderLength(), totalSize, cacheRef.MaxSizeBytes(), this.isPartial) cacheWriter, err := storage.OpenWriter(cacheKey, expiresAt, this.StatusCode(), this.calculateHeaderLength(), totalSize, cacheRef.MaxSizeBytes(), this.isPartial)
if err != nil { if err != nil {
if err == caches.ErrEntityTooLarge && addStatusHeader { if errors.Is(err, caches.ErrEntityTooLarge) && addStatusHeader {
this.Header().Set("X-Cache", "BYPASS, entity too large") this.Header().Set("X-Cache", "BYPASS, entity too large")
} }
@@ -968,6 +968,7 @@ func (this *HTTPWriter) Close() {
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) { func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
hijack, ok := this.rawWriter.(http.Hijacker) hijack, ok := this.rawWriter.(http.Hijacker)
if ok { if ok {
this.req.isHijacked = true
return hijack.Hijack() return hijack.Hijack()
} }
return return

View File

@@ -3,6 +3,7 @@ package nodes
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"io" "io"
@@ -52,6 +53,8 @@ func (this *HTTPListener) Serve() error {
atomic.AddInt64(&this.countActiveConnections, 1) atomic.AddInt64(&this.countActiveConnections, 1)
case http.StateClosed: case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1) atomic.AddInt64(&this.countActiveConnections, -1)
default:
// do nothing
} }
}, },
ConnContext: func(ctx context.Context, conn net.Conn) context.Context { ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
@@ -74,7 +77,7 @@ func (this *HTTPListener) Serve() error {
// HTTP协议 // HTTP协议
if this.isHTTP { if this.isHTTP {
err := this.httpServer.Serve(this.Listener) err := this.httpServer.Serve(this.Listener)
if err != nil && err != http.ErrServerClosed { if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err return err
} }
} }
@@ -84,7 +87,7 @@ func (this *HTTPListener) Serve() error {
this.httpServer.TLSConfig = this.buildTLSConfig() this.httpServer.TLSConfig = this.buildTLSConfig()
err := this.httpServer.ServeTLS(this.Listener, "", "") err := this.httpServer.ServeTLS(this.Listener, "", "")
if err != nil && err != http.ErrServerClosed { if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err return err
} }
} }
@@ -180,10 +183,12 @@ func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawRe
} }
// 绑定连接 // 绑定连接
var clientConn ClientConnInterface
if server != nil && server.Id > 0 { if server != nil && server.Id > 0 {
var requestConn = rawReq.Context().Value(HTTPConnContextKey) var requestConn = rawReq.Context().Value(HTTPConnContextKey)
if requestConn != nil { if requestConn != nil {
clientConn, ok := requestConn.(ClientConnInterface) var ok bool
clientConn, ok = requestConn.(ClientConnInterface)
if ok { if ok {
var goNext = clientConn.SetServerId(server.Id) var goNext = clientConn.SetServerId(server.Id)
if !goNext { if !goNext {
@@ -224,6 +229,14 @@ func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawRe
nodeConfig: sharedNodeConfig, nodeConfig: sharedNodeConfig,
} }
req.Do() req.Do()
// fix hijacked connection state
if req.isHijacked && clientConn != nil && this.httpServer.ConnState != nil {
netConn, ok := clientConn.(net.Conn)
if ok {
this.httpServer.ConnState(netConn, http.StateClosed)
}
}
} }
// 检查host是否为IP // 检查host是否为IP