改进在自动读超时模式下的Websocket连接

This commit is contained in:
GoEdgeLab
2023-01-09 12:36:33 +08:00
parent 6e852a167a
commit 24bf452ea5
4 changed files with 24 additions and 3 deletions

View File

@@ -94,7 +94,7 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
// 设置读超时时间 // 设置读超时时间
var autoReadTimeout = globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout var autoReadTimeout = globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout
if this.isHTTP && !this.isShortReading && autoReadTimeout { if this.isHTTP && !this.isWebsocket && !this.isShortReading && autoReadTimeout {
this.setHTTPReadTimeout() this.setHTTPReadTimeout()
} }
@@ -157,7 +157,7 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
} }
// 延长读超时时间 // 延长读超时时间
if this.isHTTP { if this.isHTTP && !this.isWebsocket && globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout {
this.setHTTPReadTimeout() this.setHTTPReadTimeout()
} }
@@ -217,7 +217,7 @@ func (this *ClientConn) SetDeadline(t time.Time) error {
func (this *ClientConn) SetReadDeadline(t time.Time) error { func (this *ClientConn) SetReadDeadline(t time.Time) error {
// 如果开启了HTTP自动读超时选项则自动控制超时时间 // 如果开启了HTTP自动读超时选项则自动控制超时时间
var globalServerConfig = sharedNodeConfig.GlobalServerConfig var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if this.isHTTP && globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout { if this.isHTTP && !this.isWebsocket && globalServerConfig != nil && globalServerConfig.Performance.AutoReadTimeout {
this.isShortReading = false this.isShortReading = false
var unixTime = t.Unix() var unixTime = t.Unix()

View File

@@ -16,6 +16,8 @@ type BaseClientConn struct {
remoteAddr string remoteAddr string
hasLimit bool hasLimit bool
isWebsocket bool
isClosed bool isClosed bool
rawIP string rawIP string
@@ -122,3 +124,7 @@ func (this *BaseClientConn) SetLinger(seconds int) error {
} }
return nil return nil
} }
func (this *BaseClientConn) SetIsWebsocket(isWebsocket bool) {
this.isWebsocket = isWebsocket
}

View File

@@ -23,4 +23,7 @@ type ClientConnInterface interface {
// UserId 获取当前连接所属服务的用户ID // UserId 获取当前连接所属服务的用户ID
UserId() int64 UserId() int64
// SetIsWebsocket 设置是否为Websocket
SetIsWebsocket(isWebsocket bool)
} }

View File

@@ -70,6 +70,13 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
this.RawReq.Header.Set("Origin", newRequestOrigin) this.RawReq.Header.Set("Origin", newRequestOrigin)
} }
// 获取当前连接
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
// 连接源站
// TODO 增加N次错误重试重试的时候需要尝试不同的源站 // TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost) originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil { if err != nil {
@@ -102,6 +109,11 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
return return
} }
requestClientConn, ok := requestConn.(ClientConnInterface)
if ok {
requestClientConn.SetIsWebsocket(true)
}
clientConn, _, err := this.writer.Hijack() clientConn, _, err := this.writer.Hijack()
if err != nil || clientConn == nil { if err != nil || clientConn == nil {
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false) this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)