mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 15:00:26 +08:00
[反向代理]实验性添加TOA支持
This commit is contained in:
@@ -40,6 +40,12 @@ function build() {
|
||||
cp -R $ROOT/pages $DIST/
|
||||
cp -R $ROOT/resources $DIST/
|
||||
|
||||
# we support TOA on linux/amd64 only
|
||||
if [ $OS == "linux" -a $ARCH == "amd64" ]
|
||||
then
|
||||
cp -R $ROOT/edge-toa $DIST
|
||||
fi
|
||||
|
||||
echo "building ..."
|
||||
env GOOS=${OS} GOARCH=${ARCH} go build -o $DIST/bin/${NAME} -ldflags="-s -w" $ROOT/../cmd/edge-node/main.go
|
||||
|
||||
|
||||
1
build/edge-toa/.gitignore
vendored
Normal file
1
build/edge-toa/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
edge-toa
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.0.3"
|
||||
Version = "0.0.4"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -3,6 +3,8 @@ package events
|
||||
type Event = string
|
||||
|
||||
const (
|
||||
EventStart Event = "start" // start loading
|
||||
EventQuit Event = "quit" // quit node gracefully
|
||||
EventStart Event = "start" // start loading
|
||||
EventLoaded Event = "loaded" // first load
|
||||
EventQuit Event = "quit" // quit node gracefully
|
||||
EventReload Event = "reload" // reload config
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ var SharedManager = NewManager()
|
||||
var SharedLibrary LibraryInterface
|
||||
|
||||
func init() {
|
||||
events.On(events.EventStart, func() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
// 初始化
|
||||
library, err := SharedManager.Load()
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/logs"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
@@ -36,7 +37,7 @@ func NewHTTPClientPool() *HTTPClientPool {
|
||||
}
|
||||
|
||||
// 根据地址获取客户端
|
||||
func (this *HTTPClientPool) Client(origin *serverconfigs.OriginConfig, originAddr string) (rawClient *http.Client, err error) {
|
||||
func (this *HTTPClientPool) Client(req *http.Request, origin *serverconfigs.OriginConfig, originAddr string) (rawClient *http.Client, err error) {
|
||||
if origin.Addr == nil {
|
||||
return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
|
||||
}
|
||||
@@ -97,7 +98,34 @@ func (this *HTTPClientPool) Client(origin *serverconfigs.OriginConfig, originAdd
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 握手配置
|
||||
// 支持TOA的连接
|
||||
toaConfig := sharedTOAManager.Config()
|
||||
if toaConfig != nil && toaConfig.IsOn {
|
||||
retries := 3
|
||||
for i := 1; i <= retries; i++ {
|
||||
port := int(toaConfig.RandLocalPort())
|
||||
// TODO 思考是否支持X-Real-IP/X-Forwarded-IP
|
||||
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.RemoteAddr)
|
||||
if err != nil {
|
||||
logs.Error("TOA", "add failed: "+err.Error())
|
||||
} else {
|
||||
dialer := net.Dialer{
|
||||
Timeout: connectionTimeout,
|
||||
KeepAlive: 1 * time.Minute,
|
||||
LocalAddr: &net.TCPAddr{
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, network, originAddr)
|
||||
// TODO 需要在合适的时机删除TOA记录
|
||||
if err == nil || i == retries {
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 普通的连接
|
||||
return (&net.Dialer{
|
||||
Timeout: connectionTimeout,
|
||||
KeepAlive: 1 * time.Minute,
|
||||
|
||||
@@ -21,14 +21,14 @@ func TestHTTPClientPool_Client(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
{
|
||||
client, err := pool.Client(origin, origin.Addr.PickAddress())
|
||||
client, err := pool.Client(nil, origin, origin.Addr.PickAddress())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("client:", client)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
client, err := pool.Client(origin, origin.Addr.PickAddress())
|
||||
client, err := pool.Client(nil, origin, origin.Addr.PickAddress())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Log("get", i)
|
||||
_, _ = pool.Client(origin, origin.Addr.PickAddress())
|
||||
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress())
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
@@ -73,6 +73,6 @@ func BenchmarkHTTPClientPool_Client(b *testing.B) {
|
||||
|
||||
pool := NewHTTPClientPool()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = pool.Client(origin, origin.Addr.PickAddress())
|
||||
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,5 +42,5 @@ func (this *HTTPRequest) write502(err error) {
|
||||
}
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.Write([]byte(http.StatusText(statusCode)))
|
||||
_, _ = this.writer.Write([]byte("502 Bad Gateway"))
|
||||
}
|
||||
|
||||
@@ -140,7 +140,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
}
|
||||
|
||||
// 获取请求客户端
|
||||
client, err := SharedHTTPClientPool.Client(origin, originAddr)
|
||||
client, err := SharedHTTPClientPool.Client(this.RawReq, origin, originAddr)
|
||||
if err != nil {
|
||||
logs.Error("REQUEST_REVERSE_PROXY", err.Error())
|
||||
this.write502(err)
|
||||
|
||||
@@ -41,7 +41,7 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
}
|
||||
|
||||
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
||||
originConn, err := OriginConnect(this.origin)
|
||||
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.write500(err)
|
||||
|
||||
@@ -54,7 +54,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
if firstServer.ReverseProxy == nil {
|
||||
return errors.New("no ReverseProxy configured for the server")
|
||||
}
|
||||
originConn, err := this.connectOrigin(firstServer.ReverseProxy)
|
||||
originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -106,7 +106,7 @@ func (this *TCPListener) Close() error {
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig) (conn net.Conn, err error) {
|
||||
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
conn, err = OriginConnect(origin)
|
||||
conn, err = OriginConnect(origin, remoteAddr)
|
||||
if err != nil {
|
||||
logs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
continue
|
||||
|
||||
@@ -92,6 +92,9 @@ func (this *Node) Start() {
|
||||
}
|
||||
sharedNodeConfig = nodeConfig
|
||||
|
||||
// 发送事件
|
||||
events.Notify(events.EventLoaded)
|
||||
|
||||
// 设置rlimit
|
||||
_ = utils.SetRLimit(1024 * 1024)
|
||||
|
||||
@@ -166,10 +169,16 @@ func (this *Node) syncConfig(isFirstTime bool) error {
|
||||
return errors.New("create rpc client failed: " + err.Error())
|
||||
}
|
||||
// TODO 这里考虑只同步版本号有变更的
|
||||
configResp, err := rpcClient.NodeRPC().FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{})
|
||||
configResp, err := rpcClient.NodeRPC().FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{
|
||||
Version: lastVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("read config from rpc failed: " + err.Error())
|
||||
}
|
||||
if !configResp.IsChanged {
|
||||
return nil
|
||||
}
|
||||
|
||||
configJSON := configResp.NodeJSON
|
||||
nodeConfig := &nodeconfigs.NodeConfig{}
|
||||
err = json.Unmarshal(configJSON, nodeConfig)
|
||||
@@ -213,6 +222,9 @@ func (this *Node) syncConfig(isFirstTime bool) error {
|
||||
sharedWAFManager.UpdatePolicies(nodeConfig.AllHTTPFirewallPolicies())
|
||||
sharedNodeConfig = nodeConfig
|
||||
|
||||
// 发送事件
|
||||
events.Notify(events.EventReload)
|
||||
|
||||
if !isFirstTime {
|
||||
return sharedListenerManager.Start(nodeConfig)
|
||||
}
|
||||
|
||||
@@ -4,15 +4,58 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/logs"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// 连接源站
|
||||
func OriginConnect(origin *serverconfigs.OriginConfig) (net.Conn, error) {
|
||||
func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.Conn, error) {
|
||||
if origin.Addr == nil {
|
||||
return nil, errors.New("origin server address should not be empty")
|
||||
}
|
||||
|
||||
// 支持TOA的连接
|
||||
toaConfig := sharedTOAManager.Config()
|
||||
if toaConfig != nil && toaConfig.IsOn {
|
||||
retries := 3
|
||||
for i := 1; i <= retries; i++ {
|
||||
port := int(toaConfig.RandLocalPort())
|
||||
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + remoteAddr)
|
||||
if err != nil {
|
||||
logs.Error("TOA", "add failed: "+err.Error())
|
||||
} else {
|
||||
dialer := net.Dialer{
|
||||
Timeout: origin.ConnTimeoutDuration(),
|
||||
LocalAddr: &net.TCPAddr{
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
var conn net.Conn
|
||||
switch origin.Addr.Protocol {
|
||||
case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
conn, err = dialer.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange)
|
||||
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
// TODO 支持使用证书
|
||||
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
|
||||
// TODO 需要在合适的时机删除TOA记录
|
||||
if err == nil || i == retries {
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch origin.Addr.Protocol {
|
||||
case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
|
||||
// TODO 支持TCP4/TCP6
|
||||
|
||||
101
internal/nodes/toa_manager.go
Normal file
101
internal/nodes/toa_manager.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/logs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedTOAManager = NewTOAManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventReload, func() {
|
||||
err := sharedTOAManager.Run(sharedNodeConfig.TOA)
|
||||
if err != nil {
|
||||
logs.Error("TOA", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type TOAManager struct {
|
||||
config *nodeconfigs.TOAConfig
|
||||
pid int
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func NewTOAManager() *TOAManager {
|
||||
return &TOAManager{}
|
||||
}
|
||||
|
||||
func (this *TOAManager) Run(config *nodeconfigs.TOAConfig) error {
|
||||
this.config = config
|
||||
|
||||
if this.pid > 0 {
|
||||
logs.Println("TOA", "stopping ...")
|
||||
err := this.Quit()
|
||||
if err != nil {
|
||||
logs.Error("TOA", "quit error: "+err.Error())
|
||||
}
|
||||
_ = this.conn.Close()
|
||||
this.conn = nil
|
||||
this.pid = 0
|
||||
}
|
||||
|
||||
if !config.IsOn {
|
||||
return nil
|
||||
}
|
||||
|
||||
binPath := Tea.Root + "/edge-toa/edge-toa" // TODO 可以做成配置
|
||||
_, err := os.Stat(binPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logs.Println("TOA", "starting ...")
|
||||
logs.Println("TOA", "args: "+strings.Join(config.AsArgs(), " "))
|
||||
cmd := exec.Command(binPath, config.AsArgs()...)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.pid = cmd.Process.Pid
|
||||
|
||||
go func() { _ = cmd.Wait() }()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TOAManager) Config() *nodeconfigs.TOAConfig {
|
||||
return this.config
|
||||
}
|
||||
|
||||
func (this *TOAManager) Quit() error {
|
||||
return this.SendMsg("quit:0")
|
||||
}
|
||||
|
||||
func (this *TOAManager) SendMsg(msg string) error {
|
||||
if this.config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if this.conn != nil {
|
||||
_, err := this.conn.Write([]byte(msg + "\n"))
|
||||
if err != nil {
|
||||
this.conn = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("unix", this.config.SockFile(), 1*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.conn = conn
|
||||
_, err = this.conn.Write([]byte(msg + "\n"))
|
||||
return err
|
||||
}
|
||||
17
internal/nodes/toa_manager_test.go
Normal file
17
internal/nodes/toa_manager_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTOAManager_Run(t *testing.T) {
|
||||
manager := NewTOAManager()
|
||||
err := manager.Run(&nodeconfigs.TOAConfig{
|
||||
IsOn: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
Reference in New Issue
Block a user