回源TLS/HTTPS携带ServerName信息

This commit is contained in:
刘祥超
2022-06-27 12:01:33 +08:00
parent f8e155887f
commit b254cfc1a7
7 changed files with 31 additions and 11 deletions

View File

@@ -172,7 +172,7 @@ func (this *HTTPRequest) doReverseProxy() {
// 判断是否为Websocket请求 // 判断是否为Websocket请求
if this.RawReq.Header.Get("Upgrade") == "websocket" { if this.RawReq.Header.Get("Upgrade") == "websocket" {
this.doWebsocket() this.doWebsocket(requestHost)
return return
} }
@@ -196,13 +196,13 @@ func (this *HTTPRequest) doReverseProxy() {
// 客户端取消请求,则不提示 // 客户端取消请求,则不提示
httpErr, ok := err.(*url.Error) httpErr, ok := err.(*url.Error)
if !ok { if !ok {
SharedOriginStateManager.Fail(origin, this.reverseProxy, func() { SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling() this.reverseProxy.ResetScheduling()
}) })
this.write50x(err, http.StatusBadGateway, true) this.write50x(err, http.StatusBadGateway, true)
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error()) remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error())
} else if httpErr.Err != context.Canceled { } else if httpErr.Err != context.Canceled {
SharedOriginStateManager.Fail(origin, this.reverseProxy, func() { SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling() this.reverseProxy.ResetScheduling()
}) })
if httpErr.Timeout() { if httpErr.Timeout() {

View File

@@ -9,7 +9,7 @@ import (
) )
// 处理Websocket请求 // 处理Websocket请求
func (this *HTTPRequest) doWebsocket() { func (this *HTTPRequest) doWebsocket(requestHost string) {
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn { if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
this.writer.WriteHeader(http.StatusForbidden) this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket have not been enabled yet")) this.addError(errors.New("websocket have not been enabled yet"))
@@ -41,12 +41,12 @@ func (this *HTTPRequest) doWebsocket() {
} }
// TODO 增加N次错误重试重试的时候需要尝试不同的源站 // TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr) originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr, requestHost)
if err != nil { if err != nil {
this.write50x(err, http.StatusBadGateway, false) this.write50x(err, http.StatusBadGateway, false)
// 增加失败次数 // 增加失败次数
SharedOriginStateManager.Fail(this.origin, this.reverseProxy, func() { SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling() this.reverseProxy.ResetScheduling()
}) })

View File

@@ -187,6 +187,7 @@ func (this *TCPListener) Close() error {
return this.Listener.Close() return this.Listener.Close()
} }
// 连接源站
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
if reverseProxy == nil { if reverseProxy == nil {
return nil, errors.New("no reverse proxy config") return nil, errors.New("no reverse proxy config")
@@ -198,7 +199,17 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
if origin == nil { if origin == nil {
continue continue
} }
conn, err = OriginConnect(origin, remoteAddr)
// 回源主机名
var requestHost = ""
if len(reverseProxy.RequestHost) > 0 {
requestHost = reverseProxy.RequestHost
}
if len(origin.RequestHost) > 0 {
requestHost = origin.RequestHost
}
conn, err = OriginConnect(origin, remoteAddr, requestHost)
if err != nil { if err != nil {
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil) remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
continue continue

View File

@@ -128,7 +128,7 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
if origin == nil { if origin == nil {
continue continue
} }
conn, err = OriginConnect(origin, remoteAddr.String()) conn, err = OriginConnect(origin, remoteAddr.String(), "")
if err != nil { if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil) remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
continue continue

View File

@@ -8,5 +8,6 @@ type OriginState struct {
CountFails int64 CountFails int64
UpdatedAt int64 UpdatedAt int64
Config *serverconfigs.OriginConfig Config *serverconfigs.OriginConfig
TLSHost string
ReverseProxy *serverconfigs.ReverseProxyConfig ReverseProxy *serverconfigs.ReverseProxyConfig
} }

View File

@@ -102,7 +102,7 @@ func (this *OriginStateManager) Loop() error {
for _, state := range currentStates { for _, state := range currentStates {
go func(state *OriginState) { go func(state *OriginState) {
defer wg.Done() defer wg.Done()
conn, err := OriginConnect(state.Config, "") conn, err := OriginConnect(state.Config, "", state.TLSHost)
if err == nil { if err == nil {
_ = conn.Close() _ = conn.Close()
@@ -125,7 +125,7 @@ func (this *OriginStateManager) Loop() error {
} }
// Fail 添加失败的源站 // Fail 添加失败的源站
func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) { func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, tlsHost string, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) {
if origin == nil || origin.Id <= 0 { if origin == nil || origin.Id <= 0 {
return return
} }
@@ -139,6 +139,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
state.Config.IsOk = true state.Config.IsOk = true
} }
state.TLSHost = tlsHost
state.CountFails++ state.CountFails++
state.Config = origin state.Config = origin
state.ReverseProxy = reverseProxy state.ReverseProxy = reverseProxy
@@ -157,6 +158,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
this.stateMap[origin.Id] = &OriginState{ this.stateMap[origin.Id] = &OriginState{
CountFails: 1, CountFails: 1,
Config: origin, Config: origin,
TLSHost: tlsHost,
ReverseProxy: reverseProxy, ReverseProxy: reverseProxy,
UpdatedAt: timestamp, UpdatedAt: timestamp,
} }

View File

@@ -10,7 +10,7 @@ import (
) )
// OriginConnect 连接源站 // OriginConnect 连接源站
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) { func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string, tlsHost string) (net.Conn, error) {
if origin.Addr == nil { if origin.Addr == nil {
return nil, errors.New("origin server address should not be empty") return nil, errors.New("origin server address should not be empty")
} }
@@ -58,6 +58,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
} }
} }
} }
if len(tlsHost) > 0 {
tlsConfig.ServerName = tlsHost
}
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig) conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
} }
@@ -95,6 +98,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
} }
} }
} }
if len(tlsHost) > 0 {
tlsConfig.ServerName = tlsHost
}
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig) return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
case serverconfigs.ProtocolUDP: case serverconfigs.ProtocolUDP: