diff --git a/build/build.sh b/build/build.sh index cd11575..f19dbbd 100755 --- a/build/build.sh +++ b/build/build.sh @@ -40,6 +40,12 @@ function build() { cp -R $ROOT/pages $DIST/ cp -R $ROOT/resources $DIST/ + # we support TOA on linux/amd64 only + if [ $OS == "linux" -a $ARCH == "amd64" ] + then + cp -R $ROOT/edge-toa $DIST + fi + echo "building ..." env GOOS=${OS} GOARCH=${ARCH} go build -o $DIST/bin/${NAME} -ldflags="-s -w" $ROOT/../cmd/edge-node/main.go diff --git a/build/edge-toa/.gitignore b/build/edge-toa/.gitignore new file mode 100644 index 0000000..988b10c --- /dev/null +++ b/build/edge-toa/.gitignore @@ -0,0 +1 @@ +edge-toa \ No newline at end of file diff --git a/internal/const/const.go b/internal/const/const.go index 3884e4d..e57bc85 100644 --- a/internal/const/const.go +++ b/internal/const/const.go @@ -1,7 +1,7 @@ package teaconst const ( - Version = "0.0.3" + Version = "0.0.4" ProductName = "Edge Node" ProcessName = "edge-node" diff --git a/internal/events/events.go b/internal/events/events.go index 2b2f827..166b027 100644 --- a/internal/events/events.go +++ b/internal/events/events.go @@ -3,6 +3,8 @@ package events type Event = string const ( - EventStart Event = "start" // start loading - EventQuit Event = "quit" // quit node gracefully + EventStart Event = "start" // start loading + EventLoaded Event = "loaded" // first load + EventQuit Event = "quit" // quit node gracefully + EventReload Event = "reload" // reload config ) diff --git a/internal/iplibrary/manager.go b/internal/iplibrary/manager.go index 0e7e797..1c90897 100644 --- a/internal/iplibrary/manager.go +++ b/internal/iplibrary/manager.go @@ -17,7 +17,7 @@ var SharedManager = NewManager() var SharedLibrary LibraryInterface func init() { - events.On(events.EventStart, func() { + events.On(events.EventLoaded, func() { // 初始化 library, err := SharedManager.Load() if err != nil { diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index b780288..0db12cc 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeNode/internal/logs" "net" "net/http" "runtime" @@ -36,7 +37,7 @@ func NewHTTPClientPool() *HTTPClientPool { } // 根据地址获取客户端 -func (this *HTTPClientPool) Client(origin *serverconfigs.OriginConfig, originAddr string) (rawClient *http.Client, err error) { +func (this *HTTPClientPool) Client(req *http.Request, origin *serverconfigs.OriginConfig, originAddr string) (rawClient *http.Client, err error) { if origin.Addr == nil { return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } @@ -97,7 +98,34 @@ func (this *HTTPClientPool) Client(origin *serverconfigs.OriginConfig, originAdd transport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // 握手配置 + // 支持TOA的连接 + toaConfig := sharedTOAManager.Config() + if toaConfig != nil && toaConfig.IsOn { + retries := 3 + for i := 1; i <= retries; i++ { + port := int(toaConfig.RandLocalPort()) + // TODO 思考是否支持X-Real-IP/X-Forwarded-IP + err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.RemoteAddr) + if err != nil { + logs.Error("TOA", "add failed: "+err.Error()) + } else { + dialer := net.Dialer{ + Timeout: connectionTimeout, + KeepAlive: 1 * time.Minute, + LocalAddr: &net.TCPAddr{ + Port: port, + }, + } + conn, err := dialer.DialContext(ctx, network, originAddr) + // TODO 需要在合适的时机删除TOA记录 + if err == nil || i == retries { + return conn, err + } + } + } + } + + // 普通的连接 return (&net.Dialer{ Timeout: connectionTimeout, KeepAlive: 1 * time.Minute, diff --git a/internal/nodes/http_client_pool_test.go b/internal/nodes/http_client_pool_test.go index 2207513..cafaaef 100644 --- a/internal/nodes/http_client_pool_test.go +++ b/internal/nodes/http_client_pool_test.go @@ -21,14 +21,14 @@ func TestHTTPClientPool_Client(t *testing.T) { t.Fatal(err) } { - client, err := pool.Client(origin, origin.Addr.PickAddress()) + client, err := pool.Client(nil, origin, origin.Addr.PickAddress()) if err != nil { t.Fatal(err) } t.Log("client:", client) } for i := 0; i < 10; i++ { - client, err := pool.Client(origin, origin.Addr.PickAddress()) + client, err := pool.Client(nil, origin, origin.Addr.PickAddress()) if err != nil { t.Fatal(err) } @@ -53,7 +53,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) { for i := 0; i < 10; i++ { t.Log("get", i) - _, _ = pool.Client(origin, origin.Addr.PickAddress()) + _, _ = pool.Client(nil, origin, origin.Addr.PickAddress()) time.Sleep(1 * time.Second) } } @@ -73,6 +73,6 @@ func BenchmarkHTTPClientPool_Client(b *testing.B) { pool := NewHTTPClientPool() for i := 0; i < b.N; i++ { - _, _ = pool.Client(origin, origin.Addr.PickAddress()) + _, _ = pool.Client(nil, origin, origin.Addr.PickAddress()) } } diff --git a/internal/nodes/http_request_error.go b/internal/nodes/http_request_error.go index bd8c77f..3266b67 100644 --- a/internal/nodes/http_request_error.go +++ b/internal/nodes/http_request_error.go @@ -42,5 +42,5 @@ func (this *HTTPRequest) write502(err error) { } this.processResponseHeaders(statusCode) this.writer.WriteHeader(statusCode) - _, _ = this.writer.Write([]byte(http.StatusText(statusCode))) + _, _ = this.writer.Write([]byte("502 Bad Gateway")) } diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index b4385a4..9dd3a5d 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -140,7 +140,7 @@ func (this *HTTPRequest) doReverseProxy() { } // 获取请求客户端 - client, err := SharedHTTPClientPool.Client(origin, originAddr) + client, err := SharedHTTPClientPool.Client(this.RawReq, origin, originAddr) if err != nil { logs.Error("REQUEST_REVERSE_PROXY", err.Error()) this.write502(err) diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index 7b209b1..0bc316b 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -41,7 +41,7 @@ func (this *HTTPRequest) doWebsocket() { } // TODO 增加N次错误重试,重试的时候需要尝试不同的源站 - originConn, err := OriginConnect(this.origin) + originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr) if err != nil { logs.Error(err) this.write500(err) diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index feec68d..ac814bf 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -54,7 +54,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error { if firstServer.ReverseProxy == nil { return errors.New("no ReverseProxy configured for the server") } - originConn, err := this.connectOrigin(firstServer.ReverseProxy) + originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String()) if err != nil { return err } @@ -106,7 +106,7 @@ func (this *TCPListener) Close() error { return this.Listener.Close() } -func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig) (conn net.Conn, err error) { +func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { if reverseProxy == nil { return nil, errors.New("no reverse proxy config") } @@ -117,7 +117,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC if origin == nil { continue } - conn, err = OriginConnect(origin) + conn, err = OriginConnect(origin, remoteAddr) if err != nil { logs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) continue diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 808f4db..a1796b6 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -92,6 +92,9 @@ func (this *Node) Start() { } sharedNodeConfig = nodeConfig + // 发送事件 + events.Notify(events.EventLoaded) + // 设置rlimit _ = utils.SetRLimit(1024 * 1024) @@ -166,10 +169,16 @@ func (this *Node) syncConfig(isFirstTime bool) error { return errors.New("create rpc client failed: " + err.Error()) } // TODO 这里考虑只同步版本号有变更的 - configResp, err := rpcClient.NodeRPC().FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{}) + configResp, err := rpcClient.NodeRPC().FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{ + Version: lastVersion, + }) if err != nil { return errors.New("read config from rpc failed: " + err.Error()) } + if !configResp.IsChanged { + return nil + } + configJSON := configResp.NodeJSON nodeConfig := &nodeconfigs.NodeConfig{} err = json.Unmarshal(configJSON, nodeConfig) @@ -213,6 +222,9 @@ func (this *Node) syncConfig(isFirstTime bool) error { sharedWAFManager.UpdatePolicies(nodeConfig.AllHTTPFirewallPolicies()) sharedNodeConfig = nodeConfig + // 发送事件 + events.Notify(events.EventReload) + if !isFirstTime { return sharedListenerManager.Start(nodeConfig) } diff --git a/internal/nodes/origin_utils.go b/internal/nodes/origin_utils.go index 7a29934..e5083d1 100644 --- a/internal/nodes/origin_utils.go +++ b/internal/nodes/origin_utils.go @@ -4,15 +4,58 @@ import ( "crypto/tls" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeNode/internal/logs" "net" + "strconv" ) // 连接源站 -func OriginConnect(origin *serverconfigs.OriginConfig) (net.Conn, error) { +func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) { if origin.Addr == nil { return nil, errors.New("origin server address should not be empty") } + // 支持TOA的连接 + toaConfig := sharedTOAManager.Config() + if toaConfig != nil && toaConfig.IsOn { + retries := 3 + for i := 1; i <= retries; i++ { + port := int(toaConfig.RandLocalPort()) + err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + remoteAddr) + if err != nil { + logs.Error("TOA", "add failed: "+err.Error()) + } else { + dialer := net.Dialer{ + Timeout: origin.ConnTimeoutDuration(), + LocalAddr: &net.TCPAddr{ + Port: port, + }, + } + var conn net.Conn + switch origin.Addr.Protocol { + case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP: + // TODO 支持TCP4/TCP6 + // TODO 支持指定特定网卡 + // TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用 + conn, err = dialer.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange) + case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS: + // TODO 支持TCP4/TCP6 + // TODO 支持指定特定网卡 + // TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用 + // TODO 支持使用证书 + conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{ + InsecureSkipVerify: true, + }) + } + + // TODO 需要在合适的时机删除TOA记录 + if err == nil || i == retries { + return conn, err + } + } + } + } + switch origin.Addr.Protocol { case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP: // TODO 支持TCP4/TCP6 diff --git a/internal/nodes/toa_manager.go b/internal/nodes/toa_manager.go new file mode 100644 index 0000000..38f61ea --- /dev/null +++ b/internal/nodes/toa_manager.go @@ -0,0 +1,101 @@ +package nodes + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/logs" + "github.com/iwind/TeaGo/Tea" + "net" + "os" + "os/exec" + "strings" + "time" +) + +var sharedTOAManager = NewTOAManager() + +func init() { + events.On(events.EventReload, func() { + err := sharedTOAManager.Run(sharedNodeConfig.TOA) + if err != nil { + logs.Error("TOA", err.Error()) + } + }) +} + +type TOAManager struct { + config *nodeconfigs.TOAConfig + pid int + conn net.Conn +} + +func NewTOAManager() *TOAManager { + return &TOAManager{} +} + +func (this *TOAManager) Run(config *nodeconfigs.TOAConfig) error { + this.config = config + + if this.pid > 0 { + logs.Println("TOA", "stopping ...") + err := this.Quit() + if err != nil { + logs.Error("TOA", "quit error: "+err.Error()) + } + _ = this.conn.Close() + this.conn = nil + this.pid = 0 + } + + if !config.IsOn { + return nil + } + + binPath := Tea.Root + "/edge-toa/edge-toa" // TODO 可以做成配置 + _, err := os.Stat(binPath) + if err != nil { + return err + } + logs.Println("TOA", "starting ...") + logs.Println("TOA", "args: "+strings.Join(config.AsArgs(), " ")) + cmd := exec.Command(binPath, config.AsArgs()...) + err = cmd.Start() + if err != nil { + return err + } + this.pid = cmd.Process.Pid + + go func() { _ = cmd.Wait() }() + + return nil +} + +func (this *TOAManager) Config() *nodeconfigs.TOAConfig { + return this.config +} + +func (this *TOAManager) Quit() error { + return this.SendMsg("quit:0") +} + +func (this *TOAManager) SendMsg(msg string) error { + if this.config == nil { + return nil + } + + if this.conn != nil { + _, err := this.conn.Write([]byte(msg + "\n")) + if err != nil { + this.conn = nil + } + return err + } + + conn, err := net.DialTimeout("unix", this.config.SockFile(), 1*time.Second) + if err != nil { + return err + } + this.conn = conn + _, err = this.conn.Write([]byte(msg + "\n")) + return err +} diff --git a/internal/nodes/toa_manager_test.go b/internal/nodes/toa_manager_test.go new file mode 100644 index 0000000..26c9359 --- /dev/null +++ b/internal/nodes/toa_manager_test.go @@ -0,0 +1,17 @@ +package nodes + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "testing" +) + +func TestTOAManager_Run(t *testing.T) { + manager := NewTOAManager() + err := manager.Run(&nodeconfigs.TOAConfig{ + IsOn: true, + }) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +}