diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index 68bc58f..3219d71 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -2,7 +2,6 @@ package nodes import ( "errors" - "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/utils" "io" "net/http" @@ -20,7 +19,7 @@ func (this *HTTPRequest) doWebsocket() { // TODO 实现handshakeTimeout // 校验来源 - requestOrigin := this.RawReq.Header.Get("Origin") + var requestOrigin = this.RawReq.Header.Get("Origin") if len(requestOrigin) > 0 { u, err := url.Parse(requestOrigin) if err == nil { @@ -34,7 +33,7 @@ func (this *HTTPRequest) doWebsocket() { // 设置指定的来源域 if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 { - newRequestOrigin := this.web.Websocket.RequestOrigin + var newRequestOrigin = this.web.Websocket.RequestOrigin if this.web.Websocket.RequestOriginHasVariables() { newRequestOrigin = this.Format(newRequestOrigin) } @@ -45,8 +44,21 @@ func (this *HTTPRequest) doWebsocket() { originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr) if err != nil { this.write50x(err, http.StatusBadGateway, false) + + // 增加失败次数 + SharedOriginStateManager.Fail(this.origin, this.reverseProxy, func() { + this.reverseProxy.ResetScheduling() + }) + return } + + if !this.origin.IsOk { + SharedOriginStateManager.Success(this.origin, func() { + this.reverseProxy.ResetScheduling() + }) + } + defer func() { _ = originConn.Close() }() @@ -66,7 +78,7 @@ func (this *HTTPRequest) doWebsocket() { _ = clientConn.Close() }() - goman.New(func() { + go func() { var buf = utils.BytePool4k.Get() defer utils.BytePool4k.Put(buf) for { @@ -84,6 +96,6 @@ func (this *HTTPRequest) doWebsocket() { } _ = clientConn.Close() _ = originConn.Close() - }) + }() _, _ = io.Copy(originConn, clientConn) } diff --git a/internal/nodes/origin_state_manager.go b/internal/nodes/origin_state_manager.go index b4d7fd8..12e3971 100644 --- a/internal/nodes/origin_state_manager.go +++ b/internal/nodes/origin_state_manager.go @@ -45,6 +45,8 @@ func NewOriginStateManager() *OriginStateManager { // Start 启动 func (this *OriginStateManager) Start() { events.OnKey(events.EventReload, this, func() { + // TODO 检查源站是否有变化 + this.locker.Lock() this.stateMap = map[int64]*OriginState{} this.locker.Unlock() @@ -143,7 +145,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse state.UpdatedAt = timestamp if origin.IsOk { - origin.IsOk = state.CountFails > 5 // 超过 N 次之后认为是异常 + origin.IsOk = state.CountFails < 5 // 超过 N 次之后认为是异常 if !origin.IsOk { if callback != nil {