diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index 5ce86c9..2e2215f 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -17,10 +17,13 @@ type ClientConn struct { once sync.Once globalLimiter *ratelimit.Counter + isTLS bool + hasRead bool + BaseClientConn } -func NewClientConn(conn net.Conn, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn { +func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn { if quickClose { // TCP tcpConn, ok := conn.(*net.TCPConn) @@ -30,11 +33,22 @@ func NewClientConn(conn net.Conn, quickClose bool, globalLimiter *ratelimit.Coun } } - return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, globalLimiter: globalLimiter} + return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter} } func (this *ClientConn) Read(b []byte) (n int, err error) { + if this.isTLS { + if !this.hasRead { + _ = this.rawConn.SetReadDeadline(time.Now().Add(5 * time.Second)) // TODO 握手超时时间可以设置 + this.hasRead = true + defer func() { + _ = this.rawConn.SetReadDeadline(time.Time{}) + }() + } + } + n, err = this.rawConn.Read(b) + if n > 0 { atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n)) } diff --git a/internal/nodes/client_listener.go b/internal/nodes/client_listener.go index f607ffc..8aa0d1a 100644 --- a/internal/nodes/client_listener.go +++ b/internal/nodes/client_listener.go @@ -16,16 +16,25 @@ var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxCon // ClientListener 客户端网络监听 type ClientListener struct { rawListener net.Listener + isTLS bool quickClose bool } -func NewClientListener(listener net.Listener, quickClose bool) net.Listener { +func NewClientListener1(listener net.Listener, quickClose bool) *ClientListener { return &ClientListener{ rawListener: listener, quickClose: quickClose, } } +func (this *ClientListener) SetIsTLS(isTLS bool) { + this.isTLS = isTLS +} + +func (this *ClientListener) IsTLS() bool { + return this.isTLS +} + func (this *ClientListener) Accept() (net.Conn, error) { // 限制并发连接数 var isOk = false @@ -58,7 +67,7 @@ func (this *ClientListener) Accept() (net.Conn, error) { } isOk = true - return NewClientConn(conn, this.quickClose, limiter), nil + return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil } func (this *ClientListener) Close() error { diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index 8b442b3..06cd64e 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -174,7 +174,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi MaxConnsPerHost: maxConnections, IdleConnTimeout: idleTimeout, ExpectContinueTimeout: 1 * time.Second, - TLSHandshakeTimeout: 0, // 不限 + TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, Proxy: nil, } diff --git a/internal/nodes/http_request_utils.go b/internal/nodes/http_request_utils.go index 5a86b1d..fa44f1b 100644 --- a/internal/nodes/http_request_utils.go +++ b/internal/nodes/http_request_utils.go @@ -2,10 +2,12 @@ package nodes import ( "crypto/rand" + "crypto/tls" "fmt" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/utils" "io" + "net" "net/http" "strconv" "strings" @@ -153,3 +155,12 @@ func httpRequestNextId() string { // timestamp + requestId + nodeId return strconv.FormatInt(unixTime, 10) + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1))) + teaconst.NodeIdString } + +// 检查连接是否为TLS连接 +func httpIsTLSConn(conn net.Conn) bool { + if conn == nil { + return false + } + _, ok := conn.(*tls.Conn) + return ok +} diff --git a/internal/nodes/listener.go b/internal/nodes/listener.go index 9d57f81..b74dc81 100644 --- a/internal/nodes/listener.go +++ b/internal/nodes/listener.go @@ -56,11 +56,11 @@ func (this *Listener) listenTCP() error { } protocol := this.group.Protocol() - netListener, err := this.createTCPListener() + tcpListener, err := this.createTCPListener() if err != nil { return err } - netListener = NewClientListener(netListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily()) + var netListener = NewClientListener1(tcpListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily()) events.On(events.EventQuit, func() { remotelogs.Println("LISTENER", "quit "+this.group.FullAddr()) _ = netListener.Close() @@ -73,6 +73,7 @@ func (this *Listener) listenTCP() error { Listener: netListener, } case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6: + netListener.SetIsTLS(true) this.listener = &HTTPListener{ BaseListener: BaseListener{Group: this.group}, Listener: netListener, @@ -83,6 +84,7 @@ func (this *Listener) listenTCP() error { Listener: netListener, } case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6: + netListener.SetIsTLS(true) this.listener = &TCPListener{ BaseListener: BaseListener{Group: this.group}, Listener: netListener, diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index b53be84..0a5e036 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -47,7 +47,9 @@ func (this *HTTPListener) Serve() error { this.httpServer = &http.Server{ Addr: this.addr, Handler: this, + ReadTimeout: 1 * time.Hour, // TODO 改成可以配置 ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置 + WriteTimeout: 1 * time.Hour, // TODO 改成可以配置 IdleTimeout: 75 * time.Second, // TODO 改成可以配置 ConnState: func(conn net.Conn, state http.ConnState) { switch state { @@ -60,21 +62,26 @@ func (this *HTTPListener) Serve() error { metricNewConnMap[conn.RemoteAddr().String()] = zero.New() metricNewConnMapLocker.Unlock() } + case http.StateActive, http.StateIdle, http.StateHijacked: + // Nothing to do case http.StateClosed: atomic.AddInt64(&this.countActiveConnections, -1) // 移除指标存储连接信息 + // 因为中途配置可能有改变,所以暂时不添加条件 metricNewConnMapLocker.Lock() delete(metricNewConnMap, conn.RemoteAddr().String()) metricNewConnMapLocker.Unlock() } }, - ConnContext: func(ctx context.Context, c net.Conn) context.Context { - tlsConn, ok := c.(*tls.Conn) + ConnContext: func(ctx context.Context, conn net.Conn) context.Context { + tlsConn, ok := conn.(*tls.Conn) + if ok { - c = NewClientTLSConn(tlsConn) + conn = NewClientTLSConn(tlsConn) } - return context.WithValue(ctx, HTTPConnContextKey, c) + + return context.WithValue(ctx, HTTPConnContextKey, conn) }, } diff --git a/internal/nodes/task_sync_api_nodes.go b/internal/nodes/task_sync_api_nodes.go index 4ce8c52..448c3ba 100644 --- a/internal/nodes/task_sync_api_nodes.go +++ b/internal/nodes/task_sync_api_nodes.go @@ -134,9 +134,9 @@ func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool { return } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second) defer func() { - cancel() + cancelFunc() }() var conn *grpc.ClientConn if u.Scheme == "http" { diff --git a/internal/utils/http.go b/internal/utils/http.go index d581952..c7253fd 100644 --- a/internal/utils/http.go +++ b/internal/utils/http.go @@ -13,7 +13,7 @@ import ( var timeoutClientMap = map[time.Duration]*http.Client{} // timeout => Client var timeoutClientLocker = sync.Mutex{} -// 导出响应 +// DumpResponse 导出响应 func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) { header, err = httputil.DumpResponse(resp, false) if err != nil { @@ -23,7 +23,7 @@ func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) { return } -// 获取一个新的Client +// NewHTTPClient 获取一个新的Client func NewHTTPClient(timeout time.Duration) *http.Client { return &http.Client{ Timeout: timeout, @@ -33,7 +33,7 @@ func NewHTTPClient(timeout time.Duration) *http.Client { MaxConnsPerHost: 32, IdleConnTimeout: 2 * time.Minute, ExpectContinueTimeout: 1 * time.Second, - TLSHandshakeTimeout: 0, + TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -41,7 +41,7 @@ func NewHTTPClient(timeout time.Duration) *http.Client { } } -// 获取一个公用的Client +// SharedHttpClient 获取一个公用的Client func SharedHttpClient(timeout time.Duration) *http.Client { timeoutClientLocker.Lock() defer timeoutClientLocker.Unlock()