TLS连接增加握手超时检查

This commit is contained in:
GoEdgeLab
2021-12-18 19:17:40 +08:00
parent 32c5639452
commit 413c0851a8
8 changed files with 60 additions and 17 deletions

View File

@@ -17,10 +17,13 @@ type ClientConn struct {
once sync.Once
globalLimiter *ratelimit.Counter
isTLS bool
hasRead bool
BaseClientConn
}
func NewClientConn(conn net.Conn, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
if quickClose {
// TCP
tcpConn, ok := conn.(*net.TCPConn)
@@ -30,11 +33,22 @@ func NewClientConn(conn net.Conn, quickClose bool, globalLimiter *ratelimit.Coun
}
}
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, globalLimiter: globalLimiter}
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter}
}
func (this *ClientConn) Read(b []byte) (n int, err error) {
if this.isTLS {
if !this.hasRead {
_ = this.rawConn.SetReadDeadline(time.Now().Add(5 * time.Second)) // TODO 握手超时时间可以设置
this.hasRead = true
defer func() {
_ = this.rawConn.SetReadDeadline(time.Time{})
}()
}
}
n, err = this.rawConn.Read(b)
if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
}

View File

@@ -16,16 +16,25 @@ var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxCon
// ClientListener 客户端网络监听
type ClientListener struct {
rawListener net.Listener
isTLS bool
quickClose bool
}
func NewClientListener(listener net.Listener, quickClose bool) net.Listener {
func NewClientListener1(listener net.Listener, quickClose bool) *ClientListener {
return &ClientListener{
rawListener: listener,
quickClose: quickClose,
}
}
func (this *ClientListener) SetIsTLS(isTLS bool) {
this.isTLS = isTLS
}
func (this *ClientListener) IsTLS() bool {
return this.isTLS
}
func (this *ClientListener) Accept() (net.Conn, error) {
// 限制并发连接数
var isOk = false
@@ -58,7 +67,7 @@ func (this *ClientListener) Accept() (net.Conn, error) {
}
isOk = true
return NewClientConn(conn, this.quickClose, limiter), nil
return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil
}
func (this *ClientListener) Close() error {

View File

@@ -174,7 +174,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
MaxConnsPerHost: maxConnections,
IdleConnTimeout: idleTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 0, // 不限
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
Proxy: nil,
}

View File

@@ -2,10 +2,12 @@ package nodes
import (
"crypto/rand"
"crypto/tls"
"fmt"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"io"
"net"
"net/http"
"strconv"
"strings"
@@ -153,3 +155,12 @@ func httpRequestNextId() string {
// timestamp + requestId + nodeId
return strconv.FormatInt(unixTime, 10) + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1))) + teaconst.NodeIdString
}
// 检查连接是否为TLS连接
func httpIsTLSConn(conn net.Conn) bool {
if conn == nil {
return false
}
_, ok := conn.(*tls.Conn)
return ok
}

View File

@@ -56,11 +56,11 @@ func (this *Listener) listenTCP() error {
}
protocol := this.group.Protocol()
netListener, err := this.createTCPListener()
tcpListener, err := this.createTCPListener()
if err != nil {
return err
}
netListener = NewClientListener(netListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily())
var netListener = NewClientListener1(tcpListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily())
events.On(events.EventQuit, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
_ = netListener.Close()
@@ -73,6 +73,7 @@ func (this *Listener) listenTCP() error {
Listener: netListener,
}
case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6:
netListener.SetIsTLS(true)
this.listener = &HTTPListener{
BaseListener: BaseListener{Group: this.group},
Listener: netListener,
@@ -83,6 +84,7 @@ func (this *Listener) listenTCP() error {
Listener: netListener,
}
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6:
netListener.SetIsTLS(true)
this.listener = &TCPListener{
BaseListener: BaseListener{Group: this.group},
Listener: netListener,

View File

@@ -47,7 +47,9 @@ func (this *HTTPListener) Serve() error {
this.httpServer = &http.Server{
Addr: this.addr,
Handler: this,
ReadTimeout: 1 * time.Hour, // TODO 改成可以配置
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
WriteTimeout: 1 * time.Hour, // TODO 改成可以配置
IdleTimeout: 75 * time.Second, // TODO 改成可以配置
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
@@ -60,21 +62,26 @@ func (this *HTTPListener) Serve() error {
metricNewConnMap[conn.RemoteAddr().String()] = zero.New()
metricNewConnMapLocker.Unlock()
}
case http.StateActive, http.StateIdle, http.StateHijacked:
// Nothing to do
case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1)
// 移除指标存储连接信息
// 因为中途配置可能有改变,所以暂时不添加条件
metricNewConnMapLocker.Lock()
delete(metricNewConnMap, conn.RemoteAddr().String())
metricNewConnMapLocker.Unlock()
}
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
tlsConn, ok := c.(*tls.Conn)
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
tlsConn, ok := conn.(*tls.Conn)
if ok {
c = NewClientTLSConn(tlsConn)
conn = NewClientTLSConn(tlsConn)
}
return context.WithValue(ctx, HTTPConnContextKey, c)
return context.WithValue(ctx, HTTPConnContextKey, conn)
},
}

View File

@@ -134,9 +134,9 @@ func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer func() {
cancel()
cancelFunc()
}()
var conn *grpc.ClientConn
if u.Scheme == "http" {

View File

@@ -13,7 +13,7 @@ import (
var timeoutClientMap = map[time.Duration]*http.Client{} // timeout => Client
var timeoutClientLocker = sync.Mutex{}
// 导出响应
// DumpResponse 导出响应
func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) {
header, err = httputil.DumpResponse(resp, false)
if err != nil {
@@ -23,7 +23,7 @@ func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) {
return
}
// 获取一个新的Client
// NewHTTPClient 获取一个新的Client
func NewHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
@@ -33,7 +33,7 @@ func NewHTTPClient(timeout time.Duration) *http.Client {
MaxConnsPerHost: 32,
IdleConnTimeout: 2 * time.Minute,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 0,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
@@ -41,7 +41,7 @@ func NewHTTPClient(timeout time.Duration) *http.Client {
}
}
// 获取一个公用的Client
// SharedHttpClient 获取一个公用的Client
func SharedHttpClient(timeout time.Duration) *http.Client {
timeoutClientLocker.Lock()
defer timeoutClientLocker.Unlock()