实现源站端口跟随功能

This commit is contained in:
GoEdgeLab
2022-06-29 21:58:41 +08:00
parent 18ebd8c712
commit 146a947d0b
10 changed files with 96 additions and 41 deletions

3
dist/.gitignore vendored
View File

@@ -1 +1,2 @@
*.zip *.zip
edge-node

View File

@@ -7,6 +7,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@@ -124,6 +125,18 @@ func (this *HTTPRequest) doReverseProxy() {
if origin.Addr.HostHasVariables() { if origin.Addr.HostHasVariables() {
originAddr = this.Format(originAddr) originAddr = this.Format(originAddr)
} }
// 端口跟随
if origin.FollowPort {
var originHostIndex = strings.Index(originAddr, ":")
if originHostIndex < 0 {
var originErr = errors.New("invalid origin address '" + originAddr + "', lacking port")
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
this.write50x(originErr, http.StatusBadGateway, true)
return
}
originAddr = originAddr[:originHostIndex+1] + types.String(this.requestServerPort())
}
this.originAddr = originAddr this.originAddr = originAddr
// RequestHost // RequestHost

View File

@@ -41,7 +41,7 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
} }
// TODO 增加N次错误重试重试的时候需要尝试不同的源站 // TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr, requestHost) originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil { if err != nil {
this.write50x(err, http.StatusBadGateway, false) this.write50x(err, http.StatusBadGateway, false)

View File

@@ -153,7 +153,7 @@ func (this *Listener) Close() error {
// 创建TCP监听器 // 创建TCP监听器
func (this *Listener) createTCPListener() (net.Listener, error) { func (this *Listener) createTCPListener() (net.Listener, error) {
listenConfig := net.ListenConfig{ var listenConfig = net.ListenConfig{
Control: nil, Control: nil,
KeepAlive: 0, KeepAlive: 0,
} }

View File

@@ -107,7 +107,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
} }
// 证书是否匹配 // 证书是否匹配
sslConfig := server.SSLPolicy() var sslConfig = server.SSLPolicy()
cert, ok := sslConfig.MatchDomain(domain) cert, ok := sslConfig.MatchDomain(domain)
if ok { if ok {
return sslConfig, cert, nil return sslConfig, cert, nil
@@ -127,7 +127,7 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
return return
} }
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly var matchDomainStrictly = sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
if sharedNodeConfig.GlobalConfig != nil && if sharedNodeConfig.GlobalConfig != nil &&
len(sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain) > 0 && len(sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain) > 0 &&
@@ -144,9 +144,9 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
} }
// 如果没有找到,则匹配到第一个 // 如果没有找到,则匹配到第一个
group := this.Group var group = this.Group
currentServers := group.Servers() var currentServers = group.Servers()
countServers := len(currentServers) var countServers = len(currentServers)
if countServers == 0 { if countServers == 0 {
return nil, "" return nil, ""
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"github.com/pires/go-proxyproto" "github.com/pires/go-proxyproto"
"net" "net"
"strings" "strings"
@@ -18,6 +19,8 @@ type TCPListener struct {
BaseListener BaseListener
Listener net.Listener Listener net.Listener
port int
} }
func (this *TCPListener) Serve() error { func (this *TCPListener) Serve() error {
@@ -26,6 +29,14 @@ func (this *TCPListener) Serve() error {
listener = tls.NewListener(listener, this.buildTLSConfig()) listener = tls.NewListener(listener, this.buildTLSConfig())
} }
// 获取分组端口
var groupAddr = this.Group.Addr()
var portIndex = strings.LastIndex(groupAddr, ":")
if portIndex >= 0 {
var port = groupAddr[portIndex+1:]
this.port = types.Int(port)
}
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
@@ -52,6 +63,7 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
} }
func (this *TCPListener) handleConn(conn net.Conn) error { func (this *TCPListener) handleConn(conn net.Conn) error {
var server = this.Group.FirstServer() var server = this.Group.FirstServer()
if server == nil { if server == nil {
return errors.New("no server available") return errors.New("no server available")
@@ -193,9 +205,10 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
return nil, errors.New("no reverse proxy config") return nil, errors.New("no reverse proxy config")
} }
retries := 3 var retries = 3
var addr string
for i := 0; i < retries; i++ { for i := 0; i < retries; i++ {
origin := reverseProxy.NextOrigin(nil) var origin = reverseProxy.NextOrigin(nil)
if origin == nil { if origin == nil {
continue continue
} }
@@ -209,15 +222,15 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
requestHost = origin.RequestHost requestHost = origin.RequestHost
} }
conn, err = OriginConnect(origin, remoteAddr, requestHost) conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost)
if err != nil { if err != nil {
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil) remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
continue continue
} else { } else {
return return
} }
} }
err = errors.New("no origin can be used") err = errors.New("no origin server can be used")
return return
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"github.com/pires/go-proxyproto" "github.com/pires/go-proxyproto"
"net" "net"
"strings" "strings"
@@ -25,11 +26,21 @@ type UDPListener struct {
reverseProxy *serverconfigs.ReverseProxyConfig reverseProxy *serverconfigs.ReverseProxyConfig
port int
isClosed bool isClosed bool
} }
func (this *UDPListener) Serve() error { func (this *UDPListener) Serve() error {
firstServer := this.Group.FirstServer() // 获取分组端口
var groupAddr = this.Group.Addr()
var portIndex = strings.LastIndex(groupAddr, ":")
if portIndex >= 0 {
var port = groupAddr[portIndex+1:]
this.port = types.Int(port)
}
var firstServer = this.Group.FirstServer()
if firstServer == nil { if firstServer == nil {
return errors.New("no server available") return errors.New("no server available")
} }
@@ -110,7 +121,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Reset() this.Reset()
// 重置配置 // 重置配置
firstServer := this.Group.FirstServer() var firstServer = this.Group.FirstServer()
if firstServer == nil { if firstServer == nil {
return return
} }
@@ -122,15 +133,16 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
return nil, errors.New("no reverse proxy config") return nil, errors.New("no reverse proxy config")
} }
retries := 3 var retries = 3
var addr string
for i := 0; i < retries; i++ { for i := 0; i < retries; i++ {
origin := reverseProxy.NextOrigin(nil) var origin = reverseProxy.NextOrigin(nil)
if origin == nil { if origin == nil {
continue continue
} }
conn, err = OriginConnect(origin, remoteAddr.String(), "") conn, addr, err = OriginConnect(origin, this.port, remoteAddr.String(), "")
if err != nil { if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error(), "", nil) remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
continue continue
} else { } else {
// PROXY Protocol // PROXY Protocol
@@ -159,7 +171,7 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
return return
} }
} }
err = errors.New("no origin can be used") err = errors.New("no origin server can be used")
return return
} }

View File

@@ -8,6 +8,7 @@ type OriginState struct {
CountFails int64 CountFails int64
UpdatedAt int64 UpdatedAt int64
Config *serverconfigs.OriginConfig Config *serverconfigs.OriginConfig
Addr string
TLSHost string TLSHost string
ReverseProxy *serverconfigs.ReverseProxyConfig ReverseProxy *serverconfigs.ReverseProxyConfig
} }

View File

@@ -99,7 +99,7 @@ func (this *OriginStateManager) Loop() error {
for _, state := range currentStates { for _, state := range currentStates {
go func(state *OriginState) { go func(state *OriginState) {
defer wg.Done() defer wg.Done()
conn, err := OriginConnect(state.Config, "", state.TLSHost) conn, _, err := OriginConnect(state.Config, 0, "", state.TLSHost)
if err == nil { if err == nil {
_ = conn.Close() _ = conn.Close()

View File

@@ -3,47 +3,54 @@ package nodes
import ( import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/types"
"net" "net"
"strconv" "strconv"
) )
// OriginConnect 连接源站 // OriginConnect 连接源站
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string, tlsHost string) (net.Conn, error) { func OriginConnect(origin *serverconfigs.OriginConfig, serverPort int, remoteAddr string, tlsHost string) (originConn net.Conn, originAddr string, err error) {
if origin.Addr == nil { if origin.Addr == nil {
return nil, errors.New("origin server address should not be empty") return nil, "", errors.New("origin server address should not be empty")
} }
// 支持TOA的连接 // 支持TOA的连接
// 这个条件很重要如果没有传递remoteAddr表示不使用TOA // 这个条件很重要如果没有传递remoteAddr表示不使用TOA
if len(remoteAddr) > 0 { if len(remoteAddr) > 0 {
toaConfig := sharedTOAManager.Config() var toaConfig = sharedTOAManager.Config()
if toaConfig != nil && toaConfig.IsOn { if toaConfig != nil && toaConfig.IsOn {
retries := 3 var retries = 3
for i := 1; i <= retries; i++ { for i := 1; i <= retries; i++ {
port := int(toaConfig.RandLocalPort()) var port = int(toaConfig.RandLocalPort())
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + remoteAddr) err = sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + remoteAddr)
if err != nil { if err != nil {
remotelogs.Error("TOA", "add failed: "+err.Error()) remotelogs.Error("TOA", "add failed: "+err.Error())
} else { } else {
dialer := net.Dialer{ var dialer = net.Dialer{
Timeout: origin.ConnTimeoutDuration(), Timeout: origin.ConnTimeoutDuration(),
LocalAddr: &net.TCPAddr{ LocalAddr: &net.TCPAddr{
Port: port, Port: port,
}, },
} }
originAddr = origin.Addr.PickAddress()
// 端口跟随
if origin.FollowPort && serverPort > 0 {
originAddr = configutils.QuoteIP(origin.Addr.Host) + ":" + types.String(serverPort)
}
var conn net.Conn var conn net.Conn
switch origin.Addr.Protocol { switch origin.Addr.Protocol {
case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP: case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
// TODO 支持TCP4/TCP6 // TODO 支持TCP4/TCP6
// TODO 支持指定特定网卡 // TODO 支持指定特定网卡
// TODO Addr支持端口范围如果有多个端口时随机一个端口使用 conn, err = dialer.Dial("tcp", originAddr)
conn, err = dialer.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange)
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS: case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS:
// TODO 支持TCP4/TCP6 // TODO 支持TCP4/TCP6
// TODO 支持指定特定网卡 // TODO 支持指定特定网卡
// TODO Addr支持端口范围如果有多个端口时随机一个端口使用
var tlsConfig = &tls.Config{ var tlsConfig = &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@@ -62,28 +69,34 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string, tlsHos
tlsConfig.ServerName = tlsHost tlsConfig.ServerName = tlsHost
} }
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig) conn, err = tls.DialWithDialer(&dialer, "tcp", originAddr, tlsConfig)
} }
// TODO 需要在合适的时机删除TOA记录 // TODO 需要在合适的时机删除TOA记录
if err == nil || i == retries { if err == nil || i == retries {
return conn, err return conn, originAddr, err
} }
} }
} }
} }
} }
originAddr = origin.Addr.PickAddress()
// 端口跟随
if origin.FollowPort && serverPort > 0 {
originAddr = configutils.QuoteIP(origin.Addr.Host) + ":" + types.String(serverPort)
}
switch origin.Addr.Protocol { switch origin.Addr.Protocol {
case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP: case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
// TODO 支持TCP4/TCP6 // TODO 支持TCP4/TCP6
// TODO 支持指定特定网卡 // TODO 支持指定特定网卡
// TODO Addr支持端口范围如果有多个端口时随机一个端口使用 originConn, err = net.DialTimeout("tcp", originAddr, origin.ConnTimeoutDuration())
return net.DialTimeout("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, origin.ConnTimeoutDuration()) return originConn, originAddr, err
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS: case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS:
// TODO 支持TCP4/TCP6 // TODO 支持TCP4/TCP6
// TODO 支持指定特定网卡 // TODO 支持指定特定网卡
// TODO Addr支持端口范围如果有多个端口时随机一个端口使用
var tlsConfig = &tls.Config{ var tlsConfig = &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@@ -102,16 +115,18 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string, tlsHos
tlsConfig.ServerName = tlsHost tlsConfig.ServerName = tlsHost
} }
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig) originConn, err = tls.Dial("tcp", originAddr, tlsConfig)
return originConn, originAddr, err
case serverconfigs.ProtocolUDP: case serverconfigs.ProtocolUDP:
addr, err := net.ResolveUDPAddr("udp", origin.Addr.Host+":"+origin.Addr.PortRange) addr, err := net.ResolveUDPAddr("udp", originAddr)
if err != nil { if err != nil {
return nil, err return nil, originAddr, err
} }
return net.DialUDP("udp", nil, addr) originConn, err = net.DialUDP("udp", nil, addr)
return originConn, originAddr, err
} }
// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据 // TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
return nil, errors.New("invalid origin scheme '" + origin.Addr.Protocol.String() + "'") return nil, originAddr, errors.New("invalid origin scheme '" + origin.Addr.Protocol.String() + "'")
} }