diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 393e78d..7d789ac 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -172,7 +172,7 @@ func (this *HTTPRequest) doReverseProxy() { // 判断是否为Websocket请求 if this.RawReq.Header.Get("Upgrade") == "websocket" { - this.doWebsocket() + this.doWebsocket(requestHost) return } @@ -196,13 +196,13 @@ func (this *HTTPRequest) doReverseProxy() { // 客户端取消请求,则不提示 httpErr, ok := err.(*url.Error) if !ok { - SharedOriginStateManager.Fail(origin, this.reverseProxy, func() { + SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() { this.reverseProxy.ResetScheduling() }) this.write50x(err, http.StatusBadGateway, true) remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error()) } else if httpErr.Err != context.Canceled { - SharedOriginStateManager.Fail(origin, this.reverseProxy, func() { + SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() { this.reverseProxy.ResetScheduling() }) if httpErr.Timeout() { diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index 3219d71..b159035 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -9,7 +9,7 @@ import ( ) // 处理Websocket请求 -func (this *HTTPRequest) doWebsocket() { +func (this *HTTPRequest) doWebsocket(requestHost string) { if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn { this.writer.WriteHeader(http.StatusForbidden) this.addError(errors.New("websocket have not been enabled yet")) @@ -41,12 +41,12 @@ func (this *HTTPRequest) doWebsocket() { } // TODO 增加N次错误重试,重试的时候需要尝试不同的源站 - originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr) + originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr, requestHost) if err != nil { this.write50x(err, http.StatusBadGateway, false) // 增加失败次数 - SharedOriginStateManager.Fail(this.origin, this.reverseProxy, func() { + SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() { this.reverseProxy.ResetScheduling() }) diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index 8ed6a89..19c6807 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -187,6 +187,7 @@ func (this *TCPListener) Close() error { return this.Listener.Close() } +// 连接源站 func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { if reverseProxy == nil { return nil, errors.New("no reverse proxy config") @@ -198,7 +199,17 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi if origin == nil { continue } - conn, err = OriginConnect(origin, remoteAddr) + + // 回源主机名 + var requestHost = "" + if len(reverseProxy.RequestHost) > 0 { + requestHost = reverseProxy.RequestHost + } + if len(origin.RequestHost) > 0 { + requestHost = origin.RequestHost + } + + conn, err = OriginConnect(origin, remoteAddr, requestHost) if err != nil { remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil) continue diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index 21ff047..f361054 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -128,7 +128,7 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi if origin == nil { continue } - conn, err = OriginConnect(origin, remoteAddr.String()) + conn, err = OriginConnect(origin, remoteAddr.String(), "") if err != nil { remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil) continue diff --git a/internal/nodes/origin_state.go b/internal/nodes/origin_state.go index 17dfa8a..e3180d8 100644 --- a/internal/nodes/origin_state.go +++ b/internal/nodes/origin_state.go @@ -8,5 +8,6 @@ type OriginState struct { CountFails int64 UpdatedAt int64 Config *serverconfigs.OriginConfig + TLSHost string ReverseProxy *serverconfigs.ReverseProxyConfig } diff --git a/internal/nodes/origin_state_manager.go b/internal/nodes/origin_state_manager.go index 12e3971..9f07106 100644 --- a/internal/nodes/origin_state_manager.go +++ b/internal/nodes/origin_state_manager.go @@ -102,7 +102,7 @@ func (this *OriginStateManager) Loop() error { for _, state := range currentStates { go func(state *OriginState) { defer wg.Done() - conn, err := OriginConnect(state.Config, "") + conn, err := OriginConnect(state.Config, "", state.TLSHost) if err == nil { _ = conn.Close() @@ -125,7 +125,7 @@ func (this *OriginStateManager) Loop() error { } // Fail 添加失败的源站 -func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) { +func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, tlsHost string, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) { if origin == nil || origin.Id <= 0 { return } @@ -139,6 +139,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse state.Config.IsOk = true } + state.TLSHost = tlsHost state.CountFails++ state.Config = origin state.ReverseProxy = reverseProxy @@ -157,6 +158,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse this.stateMap[origin.Id] = &OriginState{ CountFails: 1, Config: origin, + TLSHost: tlsHost, ReverseProxy: reverseProxy, UpdatedAt: timestamp, } diff --git a/internal/nodes/origin_utils.go b/internal/nodes/origin_utils.go index 497fc60..9537012 100644 --- a/internal/nodes/origin_utils.go +++ b/internal/nodes/origin_utils.go @@ -10,7 +10,7 @@ import ( ) // OriginConnect 连接源站 -func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) { +func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string, tlsHost string) (net.Conn, error) { if origin.Addr == nil { return nil, errors.New("origin server address should not be empty") } @@ -58,6 +58,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C } } } + if len(tlsHost) > 0 { + tlsConfig.ServerName = tlsHost + } conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig) } @@ -95,6 +98,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C } } } + if len(tlsHost) > 0 { + tlsConfig.ServerName = tlsHost + } return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig) case serverconfigs.ProtocolUDP: