mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	回源TLS/HTTPS携带ServerName信息
This commit is contained in:
		@@ -172,7 +172,7 @@ func (this *HTTPRequest) doReverseProxy() {
 | 
			
		||||
 | 
			
		||||
	// 判断是否为Websocket请求
 | 
			
		||||
	if this.RawReq.Header.Get("Upgrade") == "websocket" {
 | 
			
		||||
		this.doWebsocket()
 | 
			
		||||
		this.doWebsocket(requestHost)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -196,13 +196,13 @@ func (this *HTTPRequest) doReverseProxy() {
 | 
			
		||||
		// 客户端取消请求,则不提示
 | 
			
		||||
		httpErr, ok := err.(*url.Error)
 | 
			
		||||
		if !ok {
 | 
			
		||||
			SharedOriginStateManager.Fail(origin, this.reverseProxy, func() {
 | 
			
		||||
			SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
 | 
			
		||||
				this.reverseProxy.ResetScheduling()
 | 
			
		||||
			})
 | 
			
		||||
			this.write50x(err, http.StatusBadGateway, true)
 | 
			
		||||
			remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error())
 | 
			
		||||
		} else if httpErr.Err != context.Canceled {
 | 
			
		||||
			SharedOriginStateManager.Fail(origin, this.reverseProxy, func() {
 | 
			
		||||
			SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
 | 
			
		||||
				this.reverseProxy.ResetScheduling()
 | 
			
		||||
			})
 | 
			
		||||
			if httpErr.Timeout() {
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 处理Websocket请求
 | 
			
		||||
func (this *HTTPRequest) doWebsocket() {
 | 
			
		||||
func (this *HTTPRequest) doWebsocket(requestHost string) {
 | 
			
		||||
	if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
 | 
			
		||||
		this.writer.WriteHeader(http.StatusForbidden)
 | 
			
		||||
		this.addError(errors.New("websocket have not been enabled yet"))
 | 
			
		||||
@@ -41,12 +41,12 @@ func (this *HTTPRequest) doWebsocket() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
 | 
			
		||||
	originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr)
 | 
			
		||||
	originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr, requestHost)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		this.write50x(err, http.StatusBadGateway, false)
 | 
			
		||||
 | 
			
		||||
		// 增加失败次数
 | 
			
		||||
		SharedOriginStateManager.Fail(this.origin, this.reverseProxy, func() {
 | 
			
		||||
		SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
 | 
			
		||||
			this.reverseProxy.ResetScheduling()
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -187,6 +187,7 @@ func (this *TCPListener) Close() error {
 | 
			
		||||
	return this.Listener.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 连接源站
 | 
			
		||||
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
 | 
			
		||||
	if reverseProxy == nil {
 | 
			
		||||
		return nil, errors.New("no reverse proxy config")
 | 
			
		||||
@@ -198,7 +199,17 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
 | 
			
		||||
		if origin == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		conn, err = OriginConnect(origin, remoteAddr)
 | 
			
		||||
 | 
			
		||||
		// 回源主机名
 | 
			
		||||
		var requestHost = ""
 | 
			
		||||
		if len(reverseProxy.RequestHost) > 0 {
 | 
			
		||||
			requestHost = reverseProxy.RequestHost
 | 
			
		||||
		}
 | 
			
		||||
		if len(origin.RequestHost) > 0 {
 | 
			
		||||
			requestHost = origin.RequestHost
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		conn, err = OriginConnect(origin, remoteAddr, requestHost)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
 | 
			
		||||
			continue
 | 
			
		||||
 
 | 
			
		||||
@@ -128,7 +128,7 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
 | 
			
		||||
		if origin == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		conn, err = OriginConnect(origin, remoteAddr.String())
 | 
			
		||||
		conn, err = OriginConnect(origin, remoteAddr.String(), "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
 | 
			
		||||
			continue
 | 
			
		||||
 
 | 
			
		||||
@@ -8,5 +8,6 @@ type OriginState struct {
 | 
			
		||||
	CountFails   int64
 | 
			
		||||
	UpdatedAt    int64
 | 
			
		||||
	Config       *serverconfigs.OriginConfig
 | 
			
		||||
	TLSHost      string
 | 
			
		||||
	ReverseProxy *serverconfigs.ReverseProxyConfig
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -102,7 +102,7 @@ func (this *OriginStateManager) Loop() error {
 | 
			
		||||
	for _, state := range currentStates {
 | 
			
		||||
		go func(state *OriginState) {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			conn, err := OriginConnect(state.Config, "")
 | 
			
		||||
			conn, err := OriginConnect(state.Config, "", state.TLSHost)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				_ = conn.Close()
 | 
			
		||||
 | 
			
		||||
@@ -125,7 +125,7 @@ func (this *OriginStateManager) Loop() error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fail 添加失败的源站
 | 
			
		||||
func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) {
 | 
			
		||||
func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, tlsHost string, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) {
 | 
			
		||||
	if origin == nil || origin.Id <= 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -139,6 +139,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
 | 
			
		||||
			state.Config.IsOk = true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		state.TLSHost = tlsHost
 | 
			
		||||
		state.CountFails++
 | 
			
		||||
		state.Config = origin
 | 
			
		||||
		state.ReverseProxy = reverseProxy
 | 
			
		||||
@@ -157,6 +158,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
 | 
			
		||||
		this.stateMap[origin.Id] = &OriginState{
 | 
			
		||||
			CountFails:   1,
 | 
			
		||||
			Config:       origin,
 | 
			
		||||
			TLSHost:      tlsHost,
 | 
			
		||||
			ReverseProxy: reverseProxy,
 | 
			
		||||
			UpdatedAt:    timestamp,
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,7 +10,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// OriginConnect 连接源站
 | 
			
		||||
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) {
 | 
			
		||||
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string, tlsHost string) (net.Conn, error) {
 | 
			
		||||
	if origin.Addr == nil {
 | 
			
		||||
		return nil, errors.New("origin server address should not be empty")
 | 
			
		||||
	}
 | 
			
		||||
@@ -58,6 +58,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						if len(tlsHost) > 0 {
 | 
			
		||||
							tlsConfig.ServerName = tlsHost
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
 | 
			
		||||
					}
 | 
			
		||||
@@ -95,6 +98,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if len(tlsHost) > 0 {
 | 
			
		||||
			tlsConfig.ServerName = tlsHost
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
 | 
			
		||||
	case serverconfigs.ProtocolUDP:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user