From 0877fd5632002a09c6a06ea7b9e6882b7e85fd06 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Tue, 12 Oct 2021 20:20:06 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81PROXY=20Protocol/=E4=BF=AE?= =?UTF-8?q?=E5=A4=8DUDP=E6=BA=90=E7=AB=99=E6=97=A0=E6=B3=95=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 1 + go.sum | 2 + internal/nodes/http_client_pool.go | 41 ++++++++++++++++-- internal/nodes/http_client_pool_test.go | 8 ++-- internal/nodes/http_request_reverse_proxy.go | 2 +- internal/nodes/listener_tcp.go | 28 ++++++++++++ internal/nodes/listener_test.go | 2 +- internal/nodes/listener_udp.go | 45 +++++++++++++++++--- 8 files changed, 115 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index afecc99..1c6eba0 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/lionsoul2014/ip2region v2.2.0-release+incompatible github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/mssola/user_agent v0.5.2 + github.com/pires/go-proxyproto v0.6.1 github.com/shirou/gopsutil v3.21.5+incompatible github.com/tklauser/go-sysconf v0.3.6 // indirect golang.org/x/image v0.0.0-20190802002840-cff245a6509b diff --git a/go.sum b/go.sum index 2169f29..017e22d 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1Cpa github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/opentracing/opentracing-go v1.1.1-0.20190913142402-a7454ce5950e/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/pires/go-proxyproto v0.6.1 h1:EBupykFmo22SDjv4fQVQd2J9NOoLPmyZA/15ldOGkPw= +github.com/pires/go-proxyproto v0.6.1/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index 3222262..5952724 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -6,10 +6,12 @@ import ( "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/pires/go-proxyproto" "net" "net/http" "runtime" "strconv" + "strings" "sync" "time" ) @@ -37,7 +39,7 @@ func NewHTTPClientPool() *HTTPClientPool { } // Client 根据地址获取客户端 -func (this *HTTPClientPool) Client(req *http.Request, origin *serverconfigs.OriginConfig, originAddr string) (rawClient *http.Client, err error) { +func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig, originAddr string, proxyProtocol *serverconfigs.ProxyProtocolConfig) (rawClient *http.Client, err error) { if origin.Addr == nil { return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } @@ -105,7 +107,7 @@ func (this *HTTPClientPool) Client(req *http.Request, origin *serverconfigs.Orig for i := 1; i <= retries; i++ { port := int(toaConfig.RandLocalPort()) // TODO 思考是否支持X-Real-IP/X-Forwarded-IP - err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.RemoteAddr) + err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.requestRemoteAddr(true)) if err != nil { remotelogs.Error("TOA", "add failed: "+err.Error()) } else { @@ -126,10 +128,43 @@ func (this *HTTPClientPool) Client(req *http.Request, origin *serverconfigs.Orig } // 普通的连接 - return (&net.Dialer{ + conn, err := (&net.Dialer{ Timeout: connectionTimeout, KeepAlive: 1 * time.Minute, }).DialContext(ctx, network, originAddr) + if err != nil { + return nil, err + } + + if proxyProtocol != nil && proxyProtocol.IsOn && (proxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || proxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) { + var remoteAddr = req.requestRemoteAddr(true) + var transportProtocol = proxyproto.TCPv4 + if strings.Contains(remoteAddr, ":") { + transportProtocol = proxyproto.TCPv6 + } + var destAddr = conn.RemoteAddr() + var reqConn = req.RawReq.Context().Value(HTTPConnContextKey) + if reqConn != nil { + destAddr = reqConn.(net.Conn).LocalAddr() + } + header := proxyproto.Header{ + Version: byte(proxyProtocol.Version), + Command: proxyproto.PROXY, + TransportProtocol: transportProtocol, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(remoteAddr), + Port: req.requestRemotePort(), + }, + DestinationAddr: destAddr, + } + _, err = header.WriteTo(conn) + if err != nil { + _ = conn.Close() + return nil, err + } + } + + return conn, nil }, MaxIdleConns: 0, MaxIdleConnsPerHost: idleConns, diff --git a/internal/nodes/http_client_pool_test.go b/internal/nodes/http_client_pool_test.go index cafaaef..556dd0b 100644 --- a/internal/nodes/http_client_pool_test.go +++ b/internal/nodes/http_client_pool_test.go @@ -21,14 +21,14 @@ func TestHTTPClientPool_Client(t *testing.T) { t.Fatal(err) } { - client, err := pool.Client(nil, origin, origin.Addr.PickAddress()) + client, err := pool.Client(nil, origin, origin.Addr.PickAddress(), nil) if err != nil { t.Fatal(err) } t.Log("client:", client) } for i := 0; i < 10; i++ { - client, err := pool.Client(nil, origin, origin.Addr.PickAddress()) + client, err := pool.Client(nil, origin, origin.Addr.PickAddress(), nil) if err != nil { t.Fatal(err) } @@ -53,7 +53,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) { for i := 0; i < 10; i++ { t.Log("get", i) - _, _ = pool.Client(nil, origin, origin.Addr.PickAddress()) + _, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil) time.Sleep(1 * time.Second) } } @@ -73,6 +73,6 @@ func BenchmarkHTTPClientPool_Client(b *testing.B) { pool := NewHTTPClientPool() for i := 0; i < b.N; i++ { - _, _ = pool.Client(nil, origin, origin.Addr.PickAddress()) + _, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil) } } diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index d41a642..5c85cb7 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -145,7 +145,7 @@ func (this *HTTPRequest) doReverseProxy() { } // 获取请求客户端 - client, err := SharedHTTPClientPool.Client(this.RawReq, origin, originAddr) + client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol) if err != nil { remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error()) this.write50x(err, http.StatusBadGateway) diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index ba2c46f..d7a3ed0 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -6,7 +6,9 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/stats" + "github.com/pires/go-proxyproto" "net" + "strings" "sync/atomic" ) @@ -83,6 +85,31 @@ func (this *TCPListener) handleConn(conn net.Conn) error { _ = originConn.Close() } + // PROXY Protocol + if firstServer.ReverseProxy != nil && + firstServer.ReverseProxy.ProxyProtocol != nil && + firstServer.ReverseProxy.ProxyProtocol.IsOn && + (firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) { + var remoteAddr = conn.RemoteAddr() + var transportProtocol = proxyproto.TCPv4 + if strings.Contains(remoteAddr.String(), "[") { + transportProtocol = proxyproto.TCPv6 + } + header := proxyproto.Header{ + Version: byte(firstServer.ReverseProxy.ProxyProtocol.Version), + Command: proxyproto.PROXY, + TransportProtocol: transportProtocol, + SourceAddr: remoteAddr, + DestinationAddr: conn.LocalAddr(), + } + _, err = header.WriteTo(originConn) + if err != nil { + closer() + return err + } + } + + // 从源站读取 go func() { originBuffer := bytePool32k.Get() defer func() { @@ -107,6 +134,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error { } }() + // 从客户端读取 clientBuffer := bytePool32k.Get() defer func() { bytePool32k.Put(clientBuffer) diff --git a/internal/nodes/listener_test.go b/internal/nodes/listener_test.go index 54d8e9d..545d347 100644 --- a/internal/nodes/listener_test.go +++ b/internal/nodes/listener_test.go @@ -8,7 +8,7 @@ import ( func TestListener_Listen(t *testing.T) { listener := NewListener() - group := serverconfigs.NewServerGroup("http://:1234") + group := serverconfigs.NewServerAddressGroup("https://:1234") listener.Reload(group) err := listener.Listen() diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index 2644e7b..963e774 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -6,7 +6,9 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/pires/go-proxyproto" "net" + "strings" "sync" "time" ) @@ -19,6 +21,8 @@ type UDPListener struct { connMap map[string]*UDPConn connLocker sync.Mutex connTicker *utils.Ticker + + reverseProxy *serverconfigs.ReverseProxyConfig } func (this *UDPListener) Serve() error { @@ -26,8 +30,9 @@ func (this *UDPListener) Serve() error { if firstServer == nil { return errors.New("no server available") } - if firstServer.ReverseProxy == nil { - return errors.New("no ReverseProxy configured for the server") + this.reverseProxy = firstServer.ReverseProxy + if this.reverseProxy == nil { + return errors.New("no ReverseProxy configured for the server '" + firstServer.Name + "'") } this.connMap = map[string]*UDPConn{} @@ -50,7 +55,7 @@ func (this *UDPListener) Serve() error { ok = false } if !ok { - originConn, err := this.connectOrigin(firstServer.ReverseProxy, "") + originConn, err := this.connectOrigin(this.reverseProxy, addr) if err != nil { remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error()) continue @@ -87,9 +92,16 @@ func (this *UDPListener) Close() error { func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) { this.Group = group this.Reset() + + // 重置配置 + firstServer := this.Group.FirstServer() + if firstServer == nil { + return + } + this.reverseProxy = firstServer.ReverseProxy } -func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { +func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) { if reverseProxy == nil { return nil, errors.New("no reverse proxy config") } @@ -100,11 +112,34 @@ func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC if origin == nil { continue } - conn, err = OriginConnect(origin, remoteAddr) + conn, err = OriginConnect(origin, remoteAddr.String()) if err != nil { remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) continue } else { + // PROXY Protocol + if reverseProxy != nil && + reverseProxy.ProxyProtocol != nil && + reverseProxy.ProxyProtocol.IsOn && + (reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) { + var transportProtocol = proxyproto.UDPv4 + if strings.Contains(remoteAddr.String(), "[") { + transportProtocol = proxyproto.UDPv6 + } + header := proxyproto.Header{ + Version: byte(reverseProxy.ProxyProtocol.Version), + Command: proxyproto.PROXY, + TransportProtocol: transportProtocol, + SourceAddr: remoteAddr, + DestinationAddr: this.Listener.LocalAddr(), + } + _, err = header.WriteTo(conn) + if err != nil { + _ = conn.Close() + return nil, err + } + } + return } }