mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-18 03:35:10 +08:00
支持UDP代理
This commit is contained in:
@@ -43,8 +43,19 @@ func (this *Listener) Listen() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
protocol := this.group.Protocol()
|
protocol := this.group.Protocol()
|
||||||
|
if protocol.IsUDPFamily() {
|
||||||
|
return this.listenUDP()
|
||||||
|
}
|
||||||
|
return this.listenTCP()
|
||||||
|
}
|
||||||
|
|
||||||
netListener, err := this.createListener()
|
func (this *Listener) listenTCP() error {
|
||||||
|
if this.group == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
protocol := this.group.Protocol()
|
||||||
|
|
||||||
|
netListener, err := this.createTCPListener()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -80,11 +91,6 @@ func (this *Listener) Listen() error {
|
|||||||
BaseListener: BaseListener{Group: this.group},
|
BaseListener: BaseListener{Group: this.group},
|
||||||
Listener: netListener,
|
Listener: netListener,
|
||||||
}
|
}
|
||||||
case serverconfigs.ProtocolUDP:
|
|
||||||
this.listener = &UDPListener{
|
|
||||||
BaseListener: BaseListener{Group: this.group},
|
|
||||||
Listener: netListener,
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return errors.New("unknown protocol '" + protocol.String() + "'")
|
return errors.New("unknown protocol '" + protocol.String() + "'")
|
||||||
}
|
}
|
||||||
@@ -108,6 +114,31 @@ func (this *Listener) Listen() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (this *Listener) listenUDP() error {
|
||||||
|
listener, err := this.createUDPListener()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
events.On(events.EventQuit, func() {
|
||||||
|
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
|
||||||
|
_ = listener.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
this.listener = &UDPListener{
|
||||||
|
BaseListener: BaseListener{Group: this.group},
|
||||||
|
Listener: listener,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := this.listener.Serve()
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("LISTENER", err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (this *Listener) Close() error {
|
func (this *Listener) Close() error {
|
||||||
if this.listener == nil {
|
if this.listener == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -115,8 +146,8 @@ func (this *Listener) Close() error {
|
|||||||
return this.listener.Close()
|
return this.listener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建监听器
|
// 创建TCP监听器
|
||||||
func (this *Listener) createListener() (net.Listener, error) {
|
func (this *Listener) createTCPListener() (net.Listener, error) {
|
||||||
listenConfig := net.ListenConfig{
|
listenConfig := net.ListenConfig{
|
||||||
Control: nil,
|
Control: nil,
|
||||||
KeepAlive: 0,
|
KeepAlive: 0,
|
||||||
@@ -131,3 +162,13 @@ func (this *Listener) createListener() (net.Listener, error) {
|
|||||||
|
|
||||||
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
|
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建UDP监听器
|
||||||
|
func (this *Listener) createUDPListener() (*net.UDPConn, error) {
|
||||||
|
// TODO 将来支持udp4/udp6
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.ListenUDP("udp", addr)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,28 +1,201 @@
|
|||||||
package nodes
|
package nodes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UDPListener struct {
|
type UDPListener struct {
|
||||||
BaseListener
|
BaseListener
|
||||||
|
|
||||||
Listener net.Listener
|
Listener *net.UDPConn
|
||||||
|
|
||||||
|
connMap map[string]*UDPConn
|
||||||
|
connLocker sync.Mutex
|
||||||
|
connTicker *utils.Ticker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *UDPListener) Serve() error {
|
func (this *UDPListener) Serve() error {
|
||||||
// TODO
|
firstServer := this.Group.FirstServer()
|
||||||
// TODO 注意管理 CountActiveConnections
|
if firstServer == nil {
|
||||||
return nil
|
return errors.New("no server available")
|
||||||
|
}
|
||||||
|
if firstServer.ReverseProxy == nil {
|
||||||
|
return errors.New("no ReverseProxy configured for the server")
|
||||||
|
}
|
||||||
|
|
||||||
|
this.connMap = map[string]*UDPConn{}
|
||||||
|
this.connTicker = utils.NewTicker(1 * time.Minute)
|
||||||
|
go func() {
|
||||||
|
for this.connTicker.Next() {
|
||||||
|
this.gcConns()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var buffer = make([]byte, 4*1024)
|
||||||
|
for {
|
||||||
|
n, addr, _ := this.Listener.ReadFrom(buffer)
|
||||||
|
if n > 0 {
|
||||||
|
this.connLocker.Lock()
|
||||||
|
conn, ok := this.connMap[addr.String()]
|
||||||
|
this.connLocker.Unlock()
|
||||||
|
if ok && !conn.isOk {
|
||||||
|
_ = conn.Close()
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
originConn, err := this.connectOrigin(firstServer.ReverseProxy, "")
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if originConn == nil {
|
||||||
|
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
conn = NewUDPConn(firstServer.Id, addr, this.Listener, originConn.(*net.UDPConn))
|
||||||
|
this.connLocker.Lock()
|
||||||
|
this.connMap[addr.String()] = conn
|
||||||
|
this.connLocker.Unlock()
|
||||||
|
}
|
||||||
|
_, _ = conn.Write(buffer[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *UDPListener) Close() error {
|
func (this *UDPListener) Close() error {
|
||||||
// TODO
|
if this.connTicker != nil {
|
||||||
return nil
|
this.connTicker.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭所有连接
|
||||||
|
this.connLocker.Lock()
|
||||||
|
for _, conn := range this.connMap {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
this.connLocker.Unlock()
|
||||||
|
|
||||||
|
return this.Listener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *UDPListener) Reload(group *serverconfigs.ServerGroup) {
|
func (this *UDPListener) Reload(group *serverconfigs.ServerGroup) {
|
||||||
this.Group = group
|
this.Group = group
|
||||||
this.Reset()
|
this.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||||
|
if reverseProxy == nil {
|
||||||
|
return nil, errors.New("no reverse proxy config")
|
||||||
|
}
|
||||||
|
|
||||||
|
retries := 3
|
||||||
|
for i := 0; i < retries; i++ {
|
||||||
|
origin := reverseProxy.NextOrigin(nil)
|
||||||
|
if origin == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
conn, err = OriginConnect(origin, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = errors.New("no origin can be used")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 回收连接
|
||||||
|
func (this *UDPListener) gcConns() {
|
||||||
|
this.connLocker.Lock()
|
||||||
|
closingConns := []*UDPConn{}
|
||||||
|
for addr, conn := range this.connMap {
|
||||||
|
if !conn.IsOk() {
|
||||||
|
closingConns = append(closingConns, conn)
|
||||||
|
delete(this.connMap, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.connLocker.Unlock()
|
||||||
|
|
||||||
|
for _, conn := range closingConns {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPConn 自定义的UDP连接管理
|
||||||
|
type UDPConn struct {
|
||||||
|
addr net.Addr
|
||||||
|
proxyConn net.Conn
|
||||||
|
serverConn net.Conn
|
||||||
|
activatedAt int64
|
||||||
|
isOk bool
|
||||||
|
isClosed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
|
||||||
|
conn := &UDPConn{
|
||||||
|
addr: addr,
|
||||||
|
proxyConn: proxyConn,
|
||||||
|
serverConn: serverConn,
|
||||||
|
activatedAt: time.Now().Unix(),
|
||||||
|
isOk: true,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
buffer := bytePool32k.Get()
|
||||||
|
defer func() {
|
||||||
|
bytePool32k.Put(buffer)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := serverConn.Read(buffer)
|
||||||
|
if n > 0 {
|
||||||
|
conn.activatedAt = time.Now().Unix()
|
||||||
|
_, writingErr := proxyConn.WriteTo(buffer[:n], addr)
|
||||||
|
if writingErr != nil {
|
||||||
|
conn.isOk = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录流量
|
||||||
|
stats.SharedTrafficStatManager.Add(serverId, int64(n))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
conn.isOk = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *UDPConn) Write(b []byte) (n int, err error) {
|
||||||
|
this.activatedAt = time.Now().Unix()
|
||||||
|
n, err = this.serverConn.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
this.isOk = false
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *UDPConn) Close() error {
|
||||||
|
this.isOk = false
|
||||||
|
if this.isClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
this.isClosed = true
|
||||||
|
return this.serverConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *UDPConn) IsOk() bool {
|
||||||
|
if !this.isOk {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().Unix()-this.activatedAt < 30 // 如果超过 N 秒没有活动我们认为是超时
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 连接源站
|
// OriginConnect 连接源站
|
||||||
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) {
|
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) {
|
||||||
if origin.Addr == nil {
|
if origin.Addr == nil {
|
||||||
return nil, errors.New("origin server address should not be empty")
|
return nil, errors.New("origin server address should not be empty")
|
||||||
@@ -70,9 +70,15 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
|
|||||||
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
|
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
})
|
})
|
||||||
|
case serverconfigs.ProtocolUDP:
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", origin.Addr.Host+":"+origin.Addr.PortRange)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.DialUDP("udp", nil, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
|
// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
|
||||||
|
|
||||||
return nil, errors.New("invalid scheme '" + origin.Addr.Protocol.String() + "'")
|
return nil, errors.New("invalid origin scheme '" + origin.Addr.Protocol.String() + "'")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user