支持UDP代理

This commit is contained in:
刘祥超
2021-06-07 15:45:47 +08:00
parent a49b724745
commit f461760158
3 changed files with 236 additions and 16 deletions

View File

@@ -43,8 +43,19 @@ func (this *Listener) Listen() error {
return nil return nil
} }
protocol := this.group.Protocol() protocol := this.group.Protocol()
if protocol.IsUDPFamily() {
return this.listenUDP()
}
return this.listenTCP()
}
netListener, err := this.createListener() func (this *Listener) listenTCP() error {
if this.group == nil {
return nil
}
protocol := this.group.Protocol()
netListener, err := this.createTCPListener()
if err != nil { if err != nil {
return err return err
} }
@@ -80,11 +91,6 @@ func (this *Listener) Listen() error {
BaseListener: BaseListener{Group: this.group}, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,
} }
case serverconfigs.ProtocolUDP:
this.listener = &UDPListener{
BaseListener: BaseListener{Group: this.group},
Listener: netListener,
}
default: default:
return errors.New("unknown protocol '" + protocol.String() + "'") return errors.New("unknown protocol '" + protocol.String() + "'")
} }
@@ -108,6 +114,31 @@ func (this *Listener) Listen() error {
return nil return nil
} }
func (this *Listener) listenUDP() error {
listener, err := this.createUDPListener()
if err != nil {
return err
}
events.On(events.EventQuit, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
_ = listener.Close()
})
this.listener = &UDPListener{
BaseListener: BaseListener{Group: this.group},
Listener: listener,
}
go func() {
err := this.listener.Serve()
if err != nil {
remotelogs.Error("LISTENER", err.Error())
}
}()
return nil
}
func (this *Listener) Close() error { func (this *Listener) Close() error {
if this.listener == nil { if this.listener == nil {
return nil return nil
@@ -115,8 +146,8 @@ func (this *Listener) Close() error {
return this.listener.Close() return this.listener.Close()
} }
// 创建监听器 // 创建TCP监听器
func (this *Listener) createListener() (net.Listener, error) { func (this *Listener) createTCPListener() (net.Listener, error) {
listenConfig := net.ListenConfig{ listenConfig := net.ListenConfig{
Control: nil, Control: nil,
KeepAlive: 0, KeepAlive: 0,
@@ -131,3 +162,13 @@ func (this *Listener) createListener() (net.Listener, error) {
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr()) return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
} }
// 创建UDP监听器
func (this *Listener) createUDPListener() (*net.UDPConn, error) {
// TODO 将来支持udp4/udp6
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
if err != nil {
return nil, err
}
return net.ListenUDP("udp", addr)
}

View File

@@ -1,28 +1,201 @@
package nodes package nodes
import ( import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"net" "net"
"sync"
"time"
) )
type UDPListener struct { type UDPListener struct {
BaseListener BaseListener
Listener net.Listener Listener *net.UDPConn
connMap map[string]*UDPConn
connLocker sync.Mutex
connTicker *utils.Ticker
} }
func (this *UDPListener) Serve() error { func (this *UDPListener) Serve() error {
// TODO firstServer := this.Group.FirstServer()
// TODO 注意管理 CountActiveConnections if firstServer == nil {
return nil return errors.New("no server available")
}
if firstServer.ReverseProxy == nil {
return errors.New("no ReverseProxy configured for the server")
}
this.connMap = map[string]*UDPConn{}
this.connTicker = utils.NewTicker(1 * time.Minute)
go func() {
for this.connTicker.Next() {
this.gcConns()
}
}()
var buffer = make([]byte, 4*1024)
for {
n, addr, _ := this.Listener.ReadFrom(buffer)
if n > 0 {
this.connLocker.Lock()
conn, ok := this.connMap[addr.String()]
this.connLocker.Unlock()
if ok && !conn.isOk {
_ = conn.Close()
ok = false
}
if !ok {
originConn, err := this.connectOrigin(firstServer.ReverseProxy, "")
if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
continue
}
if originConn == nil {
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
continue
}
conn = NewUDPConn(firstServer.Id, addr, this.Listener, originConn.(*net.UDPConn))
this.connLocker.Lock()
this.connMap[addr.String()] = conn
this.connLocker.Unlock()
}
_, _ = conn.Write(buffer[:n])
}
}
} }
func (this *UDPListener) Close() error { func (this *UDPListener) Close() error {
// TODO if this.connTicker != nil {
return nil this.connTicker.Stop()
}
// 关闭所有连接
this.connLocker.Lock()
for _, conn := range this.connMap {
_ = conn.Close()
}
this.connLocker.Unlock()
return this.Listener.Close()
} }
func (this *UDPListener) Reload(group *serverconfigs.ServerGroup) { func (this *UDPListener) Reload(group *serverconfigs.ServerGroup) {
this.Group = group this.Group = group
this.Reset() this.Reset()
} }
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
if reverseProxy == nil {
return nil, errors.New("no reverse proxy config")
}
retries := 3
for i := 0; i < retries; i++ {
origin := reverseProxy.NextOrigin(nil)
if origin == nil {
continue
}
conn, err = OriginConnect(origin, remoteAddr)
if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
continue
} else {
return
}
}
err = errors.New("no origin can be used")
return
}
// 回收连接
func (this *UDPListener) gcConns() {
this.connLocker.Lock()
closingConns := []*UDPConn{}
for addr, conn := range this.connMap {
if !conn.IsOk() {
closingConns = append(closingConns, conn)
delete(this.connMap, addr)
}
}
this.connLocker.Unlock()
for _, conn := range closingConns {
_ = conn.Close()
}
}
// UDPConn 自定义的UDP连接管理
type UDPConn struct {
addr net.Addr
proxyConn net.Conn
serverConn net.Conn
activatedAt int64
isOk bool
isClosed bool
}
func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
conn := &UDPConn{
addr: addr,
proxyConn: proxyConn,
serverConn: serverConn,
activatedAt: time.Now().Unix(),
isOk: true,
}
go func() {
buffer := bytePool32k.Get()
defer func() {
bytePool32k.Put(buffer)
}()
for {
n, err := serverConn.Read(buffer)
if n > 0 {
conn.activatedAt = time.Now().Unix()
_, writingErr := proxyConn.WriteTo(buffer[:n], addr)
if writingErr != nil {
conn.isOk = false
break
}
// 记录流量
stats.SharedTrafficStatManager.Add(serverId, int64(n))
}
if err != nil {
conn.isOk = false
break
}
}
}()
return conn
}
func (this *UDPConn) Write(b []byte) (n int, err error) {
this.activatedAt = time.Now().Unix()
n, err = this.serverConn.Write(b)
if err != nil {
this.isOk = false
}
return
}
func (this *UDPConn) Close() error {
this.isOk = false
if this.isClosed {
return nil
}
this.isClosed = true
return this.serverConn.Close()
}
func (this *UDPConn) IsOk() bool {
if !this.isOk {
return false
}
return time.Now().Unix()-this.activatedAt < 30 // 如果超过 N 秒没有活动我们认为是超时
}

View File

@@ -9,7 +9,7 @@ import (
"strconv" "strconv"
) )
// 连接源站 // OriginConnect 连接源站
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) { func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) {
if origin.Addr == nil { if origin.Addr == nil {
return nil, errors.New("origin server address should not be empty") return nil, errors.New("origin server address should not be empty")
@@ -70,9 +70,15 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{ return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
}) })
case serverconfigs.ProtocolUDP:
addr, err := net.ResolveUDPAddr("udp", origin.Addr.Host+":"+origin.Addr.PortRange)
if err != nil {
return nil, err
}
return net.DialUDP("udp", nil, addr)
} }
// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据 // TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
return nil, errors.New("invalid scheme '" + origin.Addr.Protocol.String() + "'") return nil, errors.New("invalid origin scheme '" + origin.Addr.Protocol.String() + "'")
} }