diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index 197f284..10b222c 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -94,7 +94,7 @@ func (this *ClientConn) Read(b []byte) (n int, err error) { // 设置读超时时间 var autoReadTimeout = globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout - if this.isHTTP && !this.isShortReading && autoReadTimeout { + if this.isHTTP && !this.isWebsocket && !this.isShortReading && autoReadTimeout { this.setHTTPReadTimeout() } @@ -157,7 +157,7 @@ func (this *ClientConn) Write(b []byte) (n int, err error) { } // 延长读超时时间 - if this.isHTTP { + if this.isHTTP && !this.isWebsocket && globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout { this.setHTTPReadTimeout() } @@ -217,7 +217,7 @@ func (this *ClientConn) SetDeadline(t time.Time) error { func (this *ClientConn) SetReadDeadline(t time.Time) error { // 如果开启了HTTP自动读超时选项,则自动控制超时时间 var globalServerConfig = sharedNodeConfig.GlobalServerConfig - if this.isHTTP && globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout { + if this.isHTTP && !this.isWebsocket && globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout { this.isShortReading = false var unixTime = t.Unix() diff --git a/internal/nodes/client_conn_base.go b/internal/nodes/client_conn_base.go index e0cb250..62a0c87 100644 --- a/internal/nodes/client_conn_base.go +++ b/internal/nodes/client_conn_base.go @@ -16,6 +16,8 @@ type BaseClientConn struct { remoteAddr string hasLimit bool + isWebsocket bool + isClosed bool rawIP string @@ -122,3 +124,7 @@ func (this *BaseClientConn) SetLinger(seconds int) error { } return nil } + +func (this *BaseClientConn) SetIsWebsocket(isWebsocket bool) { + this.isWebsocket = isWebsocket +} diff --git a/internal/nodes/client_conn_interface.go b/internal/nodes/client_conn_interface.go index b92e426..72a25b9 100644 --- a/internal/nodes/client_conn_interface.go +++ b/internal/nodes/client_conn_interface.go @@ -23,4 +23,7 @@ type ClientConnInterface interface { // UserId 获取当前连接所属服务的用户ID UserId() int64 + + // SetIsWebsocket 设置是否为Websocket + SetIsWebsocket(isWebsocket bool) } diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index c2fe464..ea04f7f 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -70,6 +70,13 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou this.RawReq.Header.Set("Origin", newRequestOrigin) } + // 获取当前连接 + var requestConn = this.RawReq.Context().Value(HTTPConnContextKey) + if requestConn == nil { + return + } + + // 连接源站 // TODO 增加N次错误重试,重试的时候需要尝试不同的源站 originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost) if err != nil { @@ -102,6 +109,11 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou return } + requestClientConn, ok := requestConn.(ClientConnInterface) + if ok { + requestClientConn.SetIsWebsocket(true) + } + clientConn, _, err := this.writer.Hijack() if err != nil || clientConn == nil { this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)