mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-18 11:40:24 +08:00
支持UDP代理
This commit is contained in:
@@ -43,8 +43,19 @@ func (this *Listener) Listen() error {
|
||||
return nil
|
||||
}
|
||||
protocol := this.group.Protocol()
|
||||
if protocol.IsUDPFamily() {
|
||||
return this.listenUDP()
|
||||
}
|
||||
return this.listenTCP()
|
||||
}
|
||||
|
||||
netListener, err := this.createListener()
|
||||
func (this *Listener) listenTCP() error {
|
||||
if this.group == nil {
|
||||
return nil
|
||||
}
|
||||
protocol := this.group.Protocol()
|
||||
|
||||
netListener, err := this.createTCPListener()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -80,11 +91,6 @@ func (this *Listener) Listen() error {
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: netListener,
|
||||
}
|
||||
case serverconfigs.ProtocolUDP:
|
||||
this.listener = &UDPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: netListener,
|
||||
}
|
||||
default:
|
||||
return errors.New("unknown protocol '" + protocol.String() + "'")
|
||||
}
|
||||
@@ -108,6 +114,31 @@ func (this *Listener) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Listener) listenUDP() error {
|
||||
listener, err := this.createUDPListener()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
events.On(events.EventQuit, func() {
|
||||
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
|
||||
_ = listener.Close()
|
||||
})
|
||||
|
||||
this.listener = &UDPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: listener,
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := this.listener.Serve()
|
||||
if err != nil {
|
||||
remotelogs.Error("LISTENER", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Listener) Close() error {
|
||||
if this.listener == nil {
|
||||
return nil
|
||||
@@ -115,8 +146,8 @@ func (this *Listener) Close() error {
|
||||
return this.listener.Close()
|
||||
}
|
||||
|
||||
// 创建监听器
|
||||
func (this *Listener) createListener() (net.Listener, error) {
|
||||
// 创建TCP监听器
|
||||
func (this *Listener) createTCPListener() (net.Listener, error) {
|
||||
listenConfig := net.ListenConfig{
|
||||
Control: nil,
|
||||
KeepAlive: 0,
|
||||
@@ -131,3 +162,13 @@ func (this *Listener) createListener() (net.Listener, error) {
|
||||
|
||||
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
|
||||
}
|
||||
|
||||
// 创建UDP监听器
|
||||
func (this *Listener) createUDPListener() (*net.UDPConn, error) {
|
||||
// TODO 将来支持udp4/udp6
|
||||
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUDP("udp", addr)
|
||||
}
|
||||
|
||||
@@ -1,28 +1,201 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UDPListener struct {
|
||||
BaseListener
|
||||
|
||||
Listener net.Listener
|
||||
Listener *net.UDPConn
|
||||
|
||||
connMap map[string]*UDPConn
|
||||
connLocker sync.Mutex
|
||||
connTicker *utils.Ticker
|
||||
}
|
||||
|
||||
func (this *UDPListener) Serve() error {
|
||||
// TODO
|
||||
// TODO 注意管理 CountActiveConnections
|
||||
return nil
|
||||
firstServer := this.Group.FirstServer()
|
||||
if firstServer == nil {
|
||||
return errors.New("no server available")
|
||||
}
|
||||
if firstServer.ReverseProxy == nil {
|
||||
return errors.New("no ReverseProxy configured for the server")
|
||||
}
|
||||
|
||||
this.connMap = map[string]*UDPConn{}
|
||||
this.connTicker = utils.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for this.connTicker.Next() {
|
||||
this.gcConns()
|
||||
}
|
||||
}()
|
||||
|
||||
var buffer = make([]byte, 4*1024)
|
||||
for {
|
||||
n, addr, _ := this.Listener.ReadFrom(buffer)
|
||||
if n > 0 {
|
||||
this.connLocker.Lock()
|
||||
conn, ok := this.connMap[addr.String()]
|
||||
this.connLocker.Unlock()
|
||||
if ok && !conn.isOk {
|
||||
_ = conn.Close()
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
originConn, err := this.connectOrigin(firstServer.ReverseProxy, "")
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
||||
continue
|
||||
}
|
||||
if originConn == nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
|
||||
continue
|
||||
}
|
||||
conn = NewUDPConn(firstServer.Id, addr, this.Listener, originConn.(*net.UDPConn))
|
||||
this.connLocker.Lock()
|
||||
this.connMap[addr.String()] = conn
|
||||
this.connLocker.Unlock()
|
||||
}
|
||||
_, _ = conn.Write(buffer[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *UDPListener) Close() error {
|
||||
// TODO
|
||||
return nil
|
||||
if this.connTicker != nil {
|
||||
this.connTicker.Stop()
|
||||
}
|
||||
|
||||
// 关闭所有连接
|
||||
this.connLocker.Lock()
|
||||
for _, conn := range this.connMap {
|
||||
_ = conn.Close()
|
||||
}
|
||||
this.connLocker.Unlock()
|
||||
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *UDPListener) Reload(group *serverconfigs.ServerGroup) {
|
||||
this.Group = group
|
||||
this.Reset()
|
||||
}
|
||||
|
||||
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
|
||||
retries := 3
|
||||
for i := 0; i < retries; i++ {
|
||||
origin := reverseProxy.NextOrigin(nil)
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
conn, err = OriginConnect(origin, remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
continue
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = errors.New("no origin can be used")
|
||||
return
|
||||
}
|
||||
|
||||
// 回收连接
|
||||
func (this *UDPListener) gcConns() {
|
||||
this.connLocker.Lock()
|
||||
closingConns := []*UDPConn{}
|
||||
for addr, conn := range this.connMap {
|
||||
if !conn.IsOk() {
|
||||
closingConns = append(closingConns, conn)
|
||||
delete(this.connMap, addr)
|
||||
}
|
||||
}
|
||||
this.connLocker.Unlock()
|
||||
|
||||
for _, conn := range closingConns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn 自定义的UDP连接管理
|
||||
type UDPConn struct {
|
||||
addr net.Addr
|
||||
proxyConn net.Conn
|
||||
serverConn net.Conn
|
||||
activatedAt int64
|
||||
isOk bool
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
|
||||
conn := &UDPConn{
|
||||
addr: addr,
|
||||
proxyConn: proxyConn,
|
||||
serverConn: serverConn,
|
||||
activatedAt: time.Now().Unix(),
|
||||
isOk: true,
|
||||
}
|
||||
go func() {
|
||||
buffer := bytePool32k.Get()
|
||||
defer func() {
|
||||
bytePool32k.Put(buffer)
|
||||
}()
|
||||
|
||||
for {
|
||||
n, err := serverConn.Read(buffer)
|
||||
if n > 0 {
|
||||
conn.activatedAt = time.Now().Unix()
|
||||
_, writingErr := proxyConn.WriteTo(buffer[:n], addr)
|
||||
if writingErr != nil {
|
||||
conn.isOk = false
|
||||
break
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
stats.SharedTrafficStatManager.Add(serverId, int64(n))
|
||||
}
|
||||
if err != nil {
|
||||
conn.isOk = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
return conn
|
||||
}
|
||||
|
||||
func (this *UDPConn) Write(b []byte) (n int, err error) {
|
||||
this.activatedAt = time.Now().Unix()
|
||||
n, err = this.serverConn.Write(b)
|
||||
if err != nil {
|
||||
this.isOk = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *UDPConn) Close() error {
|
||||
this.isOk = false
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
this.isClosed = true
|
||||
return this.serverConn.Close()
|
||||
}
|
||||
|
||||
func (this *UDPConn) IsOk() bool {
|
||||
if !this.isOk {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix()-this.activatedAt < 30 // 如果超过 N 秒没有活动我们认为是超时
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// 连接源站
|
||||
// OriginConnect 连接源站
|
||||
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) {
|
||||
if origin.Addr == nil {
|
||||
return nil, errors.New("origin server address should not be empty")
|
||||
@@ -70,9 +70,15 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
|
||||
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
case serverconfigs.ProtocolUDP:
|
||||
addr, err := net.ResolveUDPAddr("udp", origin.Addr.Host+":"+origin.Addr.PortRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.DialUDP("udp", nil, addr)
|
||||
}
|
||||
|
||||
// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
|
||||
|
||||
return nil, errors.New("invalid scheme '" + origin.Addr.Protocol.String() + "'")
|
||||
return nil, errors.New("invalid origin scheme '" + origin.Addr.Protocol.String() + "'")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user