diff --git a/internal/nodes/listener.go b/internal/nodes/listener.go index 22bc49c..de1021d 100644 --- a/internal/nodes/listener.go +++ b/internal/nodes/listener.go @@ -7,7 +7,10 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "net" + "strings" "sync" ) @@ -118,18 +121,64 @@ func (this *Listener) listenTCP() error { } func (this *Listener) listenUDP() error { - listener, err := this.createUDPListener() + var addr = this.group.Addr() + + var ipv4PacketListener *ipv4.PacketConn + var ipv6PacketListener *ipv6.PacketConn + + host, _, err := net.SplitHostPort(addr) if err != nil { return err } + + if len(host) == 0 { + // ipv4 + ipv4Listener, err := this.createUDPIPv4Listener() + if err == nil { + ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener) + } else { + remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error()) + } + + // ipv6 + ipv6Listener, err := this.createUDPIPv6Listener() + if err == nil { + ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener) + } else { + remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error()) + } + } else if strings.Contains(host, ":") { // ipv6 + ipv6Listener, err := this.createUDPIPv6Listener() + if err == nil { + ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener) + } else { + remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error()) + } + } else { // ipv4 + ipv4Listener, err := this.createUDPIPv4Listener() + if err == nil { + ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener) + } else { + remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error()) + } + } + events.OnKey(events.EventQuit, this, func() { remotelogs.Println("LISTENER", "quit "+this.group.FullAddr()) - _ = listener.Close() + + if ipv4PacketListener != nil { + _ = ipv4PacketListener.Close() + } + + if ipv6PacketListener != nil { + _ = ipv6PacketListener.Close() + } }) this.listener = &UDPListener{ BaseListener: BaseListener{Group: this.group}, - Listener: listener, + IPv4Listener: ipv4PacketListener, + IPv6Listener: ipv6PacketListener, } goman.New(func() { @@ -168,12 +217,20 @@ func (this *Listener) createTCPListener() (net.Listener, error) { return listenConfig.Listen(context.Background(), "tcp", this.group.Addr()) } -// 创建UDP监听器 -func (this *Listener) createUDPListener() (*net.UDPConn, error) { - // TODO 将来支持udp4/udp6 +// 创建UDP IPv4监听器 +func (this *Listener) createUDPIPv4Listener() (*net.UDPConn, error) { addr, err := net.ResolveUDPAddr("udp", this.group.Addr()) if err != nil { return nil, err } - return net.ListenUDP("udp", addr) + return net.ListenUDP("udp4", addr) +} + +// 创建UDP监听器 +func (this *Listener) createUDPIPv6Listener() (*net.UDPConn, error) { + addr, err := net.ResolveUDPAddr("udp", this.group.Addr()) + if err != nil { + return nil, err + } + return net.ListenUDP("udp6", addr) } diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index aedf025..89a46a7 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -9,6 +9,8 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/iwind/TeaGo/types" "github.com/pires/go-proxyproto" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "net" "strings" "sync" @@ -19,10 +21,57 @@ const ( UDPConnLifeSeconds = 30 ) +type UDPPacketListener interface { + ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) + WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) + LocalAddr() net.Addr +} + +type UDPIPv4Listener struct { + rawListener *ipv4.PacketConn +} + +func NewUDPIPv4Listener(rawListener *ipv4.PacketConn) *UDPIPv4Listener { + return &UDPIPv4Listener{rawListener: rawListener} +} + +func (this *UDPIPv4Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) { + return this.rawListener.ReadFrom(b) +} + +func (this *UDPIPv4Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) { + return this.rawListener.WriteTo(b, cm.(*ipv4.ControlMessage), dst) +} + +func (this *UDPIPv4Listener) LocalAddr() net.Addr { + return this.rawListener.LocalAddr() +} + +type UDPIPv6Listener struct { + rawListener *ipv6.PacketConn +} + +func NewUDPIPv6Listener(rawListener *ipv6.PacketConn) *UDPIPv6Listener { + return &UDPIPv6Listener{rawListener: rawListener} +} + +func (this *UDPIPv6Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) { + return this.rawListener.ReadFrom(b) +} + +func (this *UDPIPv6Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) { + return this.rawListener.WriteTo(b, cm.(*ipv6.ControlMessage), dst) +} + +func (this *UDPIPv6Listener) LocalAddr() net.Addr { + return this.rawListener.LocalAddr() +} + type UDPListener struct { BaseListener - Listener *net.UDPConn + IPv4Listener *ipv4.PacketConn + IPv6Listener *ipv6.PacketConn connMap map[string]*UDPConn connLocker sync.Mutex @@ -36,6 +85,60 @@ type UDPListener struct { } func (this *UDPListener) Serve() error { + if this.Group == nil { + return nil + } + var server = this.Group.FirstServer() + if server == nil { + return nil + } + var serverId = server.Id + + var wg = &sync.WaitGroup{} + wg.Add(2) // 2 = ipv4 + ipv6 + + go func() { + defer wg.Done() + + if this.IPv4Listener != nil { + err := this.IPv4Listener.SetControlMessage(ipv4.FlagDst, true) + if err != nil { + remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil) + return + } + + err = this.servePacketListener(NewUDPIPv4Listener(this.IPv4Listener)) + if err != nil { + remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil) + return + } + } + }() + + go func() { + defer wg.Done() + + if this.IPv6Listener != nil { + err := this.IPv6Listener.SetControlMessage(ipv6.FlagDst, true) + if err != nil { + remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil) + return + } + + err = this.servePacketListener(NewUDPIPv6Listener(this.IPv6Listener)) + if err != nil { + remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil) + return + } + } + }() + + wg.Wait() + + return nil +} + +func (this *UDPListener) servePacketListener(listener UDPPacketListener) error { // 获取分组端口 var groupAddr = this.Group.Addr() var portIndex = strings.LastIndex(groupAddr, ":") @@ -67,7 +170,7 @@ func (this *UDPListener) Serve() error { return nil } - n, addr, err := this.Listener.ReadFrom(buffer) + n, cm, clientAddr, err := listener.ReadFrom(buffer) if err != nil { if this.isClosed { return nil @@ -77,14 +180,14 @@ func (this *UDPListener) Serve() error { if n > 0 { this.connLocker.Lock() - conn, ok := this.connMap[addr.String()] + conn, ok := this.connMap[clientAddr.String()] this.connLocker.Unlock() if ok && !conn.IsOk() { _ = conn.Close() ok = false } if !ok { - originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr) + originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, listener.LocalAddr(), clientAddr) if err != nil { remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error()) continue @@ -93,9 +196,9 @@ func (this *UDPListener) Serve() error { remotelogs.Error("UDP_LISTENER", "unable to find a origin server") continue } - conn = NewUDPConn(firstServer, addr, this.Listener, originConn.(*net.UDPConn)) + conn = NewUDPConn(firstServer, clientAddr, listener, cm, originConn.(*net.UDPConn)) this.connLocker.Lock() - this.connMap[addr.String()] = conn + this.connMap[clientAddr.String()] = conn this.connLocker.Unlock() } _, _ = conn.Write(buffer[:n]) @@ -117,7 +220,26 @@ func (this *UDPListener) Close() error { } this.connLocker.Unlock() - return this.Listener.Close() + var errorStrings = []string{} + if this.IPv4Listener != nil { + err := this.IPv4Listener.Close() + if err != nil { + errorStrings = append(errorStrings, err.Error()) + } + } + + if this.IPv6Listener != nil { + err := this.IPv6Listener.Close() + if err != nil { + errorStrings = append(errorStrings, err.Error()) + } + } + + if len(errorStrings) > 0 { + return errors.New(errorStrings[0]) + } + + return nil } func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) { @@ -132,7 +254,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) { this.reverseProxy = firstServer.ReverseProxy } -func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) { +func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, localAddr net.Addr, remoteAddr net.Addr) (conn net.Conn, err error) { if reverseProxy == nil { return nil, errors.New("no reverse proxy config") } @@ -181,12 +303,12 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi if strings.Contains(remoteAddr.String(), "[") { transportProtocol = proxyproto.UDPv6 } - header := proxyproto.Header{ + var header = proxyproto.Header{ Version: byte(reverseProxy.ProxyProtocol.Version), Command: proxyproto.PROXY, TransportProtocol: transportProtocol, SourceAddr: remoteAddr, - DestinationAddr: this.Listener.LocalAddr(), + DestinationAddr: localAddr, } _, err = header.WriteTo(conn) if err != nil { @@ -224,21 +346,21 @@ func (this *UDPListener) gcConns() { // UDPConn 自定义的UDP连接管理 type UDPConn struct { - addr net.Addr - proxyConn net.Conn - serverConn net.Conn - activatedAt int64 - isOk bool - isClosed bool + addr net.Addr + proxyListener UDPPacketListener + serverConn net.Conn + activatedAt int64 + isOk bool + isClosed bool } -func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn { +func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener UDPPacketListener, cm any, serverConn *net.UDPConn) *UDPConn { var conn = &UDPConn{ - addr: addr, - proxyConn: proxyConn, - serverConn: serverConn, - activatedAt: time.Now().Unix(), - isOk: true, + addr: addr, + proxyListener: proxyListener, + serverConn: serverConn, + activatedAt: time.Now().Unix(), + isOk: true, } // 统计 @@ -246,6 +368,14 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) } + // 处理ControlMessage + switch controlMessage := cm.(type) { + case *ipv4.ControlMessage: + controlMessage.Src = controlMessage.Dst + case *ipv6.ControlMessage: + controlMessage.Src = controlMessage.Dst + } + goman.New(func() { var buffer = utils.BytePool4k.Get() defer func() { @@ -256,7 +386,8 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne n, err := serverConn.Read(buffer) if n > 0 { conn.activatedAt = time.Now().Unix() - _, writingErr := proxyConn.WriteTo(buffer[:n], addr) + + _, writingErr := proxyListener.WriteTo(buffer[:n], cm, addr) if writingErr != nil { conn.isOk = false break