mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-05 17:40:26 +08:00
回源TLS/HTTPS携带ServerName信息
This commit is contained in:
@@ -172,7 +172,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
|
|
||||||
// 判断是否为Websocket请求
|
// 判断是否为Websocket请求
|
||||||
if this.RawReq.Header.Get("Upgrade") == "websocket" {
|
if this.RawReq.Header.Get("Upgrade") == "websocket" {
|
||||||
this.doWebsocket()
|
this.doWebsocket(requestHost)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,13 +196,13 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
// 客户端取消请求,则不提示
|
// 客户端取消请求,则不提示
|
||||||
httpErr, ok := err.(*url.Error)
|
httpErr, ok := err.(*url.Error)
|
||||||
if !ok {
|
if !ok {
|
||||||
SharedOriginStateManager.Fail(origin, this.reverseProxy, func() {
|
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
|
||||||
this.reverseProxy.ResetScheduling()
|
this.reverseProxy.ResetScheduling()
|
||||||
})
|
})
|
||||||
this.write50x(err, http.StatusBadGateway, true)
|
this.write50x(err, http.StatusBadGateway, true)
|
||||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error())
|
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error())
|
||||||
} else if httpErr.Err != context.Canceled {
|
} else if httpErr.Err != context.Canceled {
|
||||||
SharedOriginStateManager.Fail(origin, this.reverseProxy, func() {
|
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
|
||||||
this.reverseProxy.ResetScheduling()
|
this.reverseProxy.ResetScheduling()
|
||||||
})
|
})
|
||||||
if httpErr.Timeout() {
|
if httpErr.Timeout() {
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// 处理Websocket请求
|
// 处理Websocket请求
|
||||||
func (this *HTTPRequest) doWebsocket() {
|
func (this *HTTPRequest) doWebsocket(requestHost string) {
|
||||||
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
|
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
|
||||||
this.writer.WriteHeader(http.StatusForbidden)
|
this.writer.WriteHeader(http.StatusForbidden)
|
||||||
this.addError(errors.New("websocket have not been enabled yet"))
|
this.addError(errors.New("websocket have not been enabled yet"))
|
||||||
@@ -41,12 +41,12 @@ func (this *HTTPRequest) doWebsocket() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
||||||
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr)
|
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr, requestHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
this.write50x(err, http.StatusBadGateway, false)
|
this.write50x(err, http.StatusBadGateway, false)
|
||||||
|
|
||||||
// 增加失败次数
|
// 增加失败次数
|
||||||
SharedOriginStateManager.Fail(this.origin, this.reverseProxy, func() {
|
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
|
||||||
this.reverseProxy.ResetScheduling()
|
this.reverseProxy.ResetScheduling()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ func (this *TCPListener) Close() error {
|
|||||||
return this.Listener.Close()
|
return this.Listener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 连接源站
|
||||||
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||||
if reverseProxy == nil {
|
if reverseProxy == nil {
|
||||||
return nil, errors.New("no reverse proxy config")
|
return nil, errors.New("no reverse proxy config")
|
||||||
@@ -198,7 +199,17 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
|
|||||||
if origin == nil {
|
if origin == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
conn, err = OriginConnect(origin, remoteAddr)
|
|
||||||
|
// 回源主机名
|
||||||
|
var requestHost = ""
|
||||||
|
if len(reverseProxy.RequestHost) > 0 {
|
||||||
|
requestHost = reverseProxy.RequestHost
|
||||||
|
}
|
||||||
|
if len(origin.RequestHost) > 0 {
|
||||||
|
requestHost = origin.RequestHost
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err = OriginConnect(origin, remoteAddr, requestHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
|
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
|
|||||||
if origin == nil {
|
if origin == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
conn, err = OriginConnect(origin, remoteAddr.String())
|
conn, err = OriginConnect(origin, remoteAddr.String(), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
|
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -8,5 +8,6 @@ type OriginState struct {
|
|||||||
CountFails int64
|
CountFails int64
|
||||||
UpdatedAt int64
|
UpdatedAt int64
|
||||||
Config *serverconfigs.OriginConfig
|
Config *serverconfigs.OriginConfig
|
||||||
|
TLSHost string
|
||||||
ReverseProxy *serverconfigs.ReverseProxyConfig
|
ReverseProxy *serverconfigs.ReverseProxyConfig
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func (this *OriginStateManager) Loop() error {
|
|||||||
for _, state := range currentStates {
|
for _, state := range currentStates {
|
||||||
go func(state *OriginState) {
|
go func(state *OriginState) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
conn, err := OriginConnect(state.Config, "")
|
conn, err := OriginConnect(state.Config, "", state.TLSHost)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ func (this *OriginStateManager) Loop() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fail 添加失败的源站
|
// Fail 添加失败的源站
|
||||||
func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) {
|
func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, tlsHost string, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) {
|
||||||
if origin == nil || origin.Id <= 0 {
|
if origin == nil || origin.Id <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -139,6 +139,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
|
|||||||
state.Config.IsOk = true
|
state.Config.IsOk = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state.TLSHost = tlsHost
|
||||||
state.CountFails++
|
state.CountFails++
|
||||||
state.Config = origin
|
state.Config = origin
|
||||||
state.ReverseProxy = reverseProxy
|
state.ReverseProxy = reverseProxy
|
||||||
@@ -157,6 +158,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
|
|||||||
this.stateMap[origin.Id] = &OriginState{
|
this.stateMap[origin.Id] = &OriginState{
|
||||||
CountFails: 1,
|
CountFails: 1,
|
||||||
Config: origin,
|
Config: origin,
|
||||||
|
TLSHost: tlsHost,
|
||||||
ReverseProxy: reverseProxy,
|
ReverseProxy: reverseProxy,
|
||||||
UpdatedAt: timestamp,
|
UpdatedAt: timestamp,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// OriginConnect 连接源站
|
// OriginConnect 连接源站
|
||||||
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) {
|
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string, tlsHost string) (net.Conn, error) {
|
||||||
if origin.Addr == nil {
|
if origin.Addr == nil {
|
||||||
return nil, errors.New("origin server address should not be empty")
|
return nil, errors.New("origin server address should not be empty")
|
||||||
}
|
}
|
||||||
@@ -58,6 +58,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(tlsHost) > 0 {
|
||||||
|
tlsConfig.ServerName = tlsHost
|
||||||
|
}
|
||||||
|
|
||||||
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
|
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
|
||||||
}
|
}
|
||||||
@@ -95,6 +98,9 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(tlsHost) > 0 {
|
||||||
|
tlsConfig.ServerName = tlsHost
|
||||||
|
}
|
||||||
|
|
||||||
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
|
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
|
||||||
case serverconfigs.ProtocolUDP:
|
case serverconfigs.ProtocolUDP:
|
||||||
|
|||||||
Reference in New Issue
Block a user