TLS连接增加握手超时检查

This commit is contained in:
GoEdgeLab
2021-12-18 19:17:40 +08:00
parent 32c5639452
commit 413c0851a8
8 changed files with 60 additions and 17 deletions

View File

@@ -17,10 +17,13 @@ type ClientConn struct {
once sync.Once once sync.Once
globalLimiter *ratelimit.Counter globalLimiter *ratelimit.Counter
isTLS bool
hasRead bool
BaseClientConn BaseClientConn
} }
func NewClientConn(conn net.Conn, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn { func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
if quickClose { if quickClose {
// TCP // TCP
tcpConn, ok := conn.(*net.TCPConn) tcpConn, ok := conn.(*net.TCPConn)
@@ -30,11 +33,22 @@ func NewClientConn(conn net.Conn, quickClose bool, globalLimiter *ratelimit.Coun
} }
} }
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, globalLimiter: globalLimiter} return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter}
} }
func (this *ClientConn) Read(b []byte) (n int, err error) { func (this *ClientConn) Read(b []byte) (n int, err error) {
if this.isTLS {
if !this.hasRead {
_ = this.rawConn.SetReadDeadline(time.Now().Add(5 * time.Second)) // TODO 握手超时时间可以设置
this.hasRead = true
defer func() {
_ = this.rawConn.SetReadDeadline(time.Time{})
}()
}
}
n, err = this.rawConn.Read(b) n, err = this.rawConn.Read(b)
if n > 0 { if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n)) atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
} }

View File

@@ -16,16 +16,25 @@ var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxCon
// ClientListener 客户端网络监听 // ClientListener 客户端网络监听
type ClientListener struct { type ClientListener struct {
rawListener net.Listener rawListener net.Listener
isTLS bool
quickClose bool quickClose bool
} }
func NewClientListener(listener net.Listener, quickClose bool) net.Listener { func NewClientListener1(listener net.Listener, quickClose bool) *ClientListener {
return &ClientListener{ return &ClientListener{
rawListener: listener, rawListener: listener,
quickClose: quickClose, quickClose: quickClose,
} }
} }
func (this *ClientListener) SetIsTLS(isTLS bool) {
this.isTLS = isTLS
}
func (this *ClientListener) IsTLS() bool {
return this.isTLS
}
func (this *ClientListener) Accept() (net.Conn, error) { func (this *ClientListener) Accept() (net.Conn, error) {
// 限制并发连接数 // 限制并发连接数
var isOk = false var isOk = false
@@ -58,7 +67,7 @@ func (this *ClientListener) Accept() (net.Conn, error) {
} }
isOk = true isOk = true
return NewClientConn(conn, this.quickClose, limiter), nil return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil
} }
func (this *ClientListener) Close() error { func (this *ClientListener) Close() error {

View File

@@ -174,7 +174,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
MaxConnsPerHost: maxConnections, MaxConnsPerHost: maxConnections,
IdleConnTimeout: idleTimeout, IdleConnTimeout: idleTimeout,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 0, // 不限 TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
Proxy: nil, Proxy: nil,
} }

View File

@@ -2,10 +2,12 @@ package nodes
import ( import (
"crypto/rand" "crypto/rand"
"crypto/tls"
"fmt" "fmt"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const" teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"io" "io"
"net"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -153,3 +155,12 @@ func httpRequestNextId() string {
// timestamp + requestId + nodeId // timestamp + requestId + nodeId
return strconv.FormatInt(unixTime, 10) + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1))) + teaconst.NodeIdString return strconv.FormatInt(unixTime, 10) + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1))) + teaconst.NodeIdString
} }
// 检查连接是否为TLS连接
func httpIsTLSConn(conn net.Conn) bool {
if conn == nil {
return false
}
_, ok := conn.(*tls.Conn)
return ok
}

View File

@@ -56,11 +56,11 @@ func (this *Listener) listenTCP() error {
} }
protocol := this.group.Protocol() protocol := this.group.Protocol()
netListener, err := this.createTCPListener() tcpListener, err := this.createTCPListener()
if err != nil { if err != nil {
return err return err
} }
netListener = NewClientListener(netListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily()) var netListener = NewClientListener1(tcpListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily())
events.On(events.EventQuit, func() { events.On(events.EventQuit, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr()) remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
_ = netListener.Close() _ = netListener.Close()
@@ -73,6 +73,7 @@ func (this *Listener) listenTCP() error {
Listener: netListener, Listener: netListener,
} }
case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6: case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6:
netListener.SetIsTLS(true)
this.listener = &HTTPListener{ this.listener = &HTTPListener{
BaseListener: BaseListener{Group: this.group}, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,
@@ -83,6 +84,7 @@ func (this *Listener) listenTCP() error {
Listener: netListener, Listener: netListener,
} }
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6: case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6:
netListener.SetIsTLS(true)
this.listener = &TCPListener{ this.listener = &TCPListener{
BaseListener: BaseListener{Group: this.group}, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,

View File

@@ -47,7 +47,9 @@ func (this *HTTPListener) Serve() error {
this.httpServer = &http.Server{ this.httpServer = &http.Server{
Addr: this.addr, Addr: this.addr,
Handler: this, Handler: this,
ReadTimeout: 1 * time.Hour, // TODO 改成可以配置
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置 ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
WriteTimeout: 1 * time.Hour, // TODO 改成可以配置
IdleTimeout: 75 * time.Second, // TODO 改成可以配置 IdleTimeout: 75 * time.Second, // TODO 改成可以配置
ConnState: func(conn net.Conn, state http.ConnState) { ConnState: func(conn net.Conn, state http.ConnState) {
switch state { switch state {
@@ -60,21 +62,26 @@ func (this *HTTPListener) Serve() error {
metricNewConnMap[conn.RemoteAddr().String()] = zero.New() metricNewConnMap[conn.RemoteAddr().String()] = zero.New()
metricNewConnMapLocker.Unlock() metricNewConnMapLocker.Unlock()
} }
case http.StateActive, http.StateIdle, http.StateHijacked:
// Nothing to do
case http.StateClosed: case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1) atomic.AddInt64(&this.countActiveConnections, -1)
// 移除指标存储连接信息 // 移除指标存储连接信息
// 因为中途配置可能有改变,所以暂时不添加条件
metricNewConnMapLocker.Lock() metricNewConnMapLocker.Lock()
delete(metricNewConnMap, conn.RemoteAddr().String()) delete(metricNewConnMap, conn.RemoteAddr().String())
metricNewConnMapLocker.Unlock() metricNewConnMapLocker.Unlock()
} }
}, },
ConnContext: func(ctx context.Context, c net.Conn) context.Context { ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
tlsConn, ok := c.(*tls.Conn) tlsConn, ok := conn.(*tls.Conn)
if ok { if ok {
c = NewClientTLSConn(tlsConn) conn = NewClientTLSConn(tlsConn)
} }
return context.WithValue(ctx, HTTPConnContextKey, c)
return context.WithValue(ctx, HTTPConnContextKey, conn)
}, },
} }

View File

@@ -134,9 +134,9 @@ func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
return return
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer func() { defer func() {
cancel() cancelFunc()
}() }()
var conn *grpc.ClientConn var conn *grpc.ClientConn
if u.Scheme == "http" { if u.Scheme == "http" {

View File

@@ -13,7 +13,7 @@ import (
var timeoutClientMap = map[time.Duration]*http.Client{} // timeout => Client var timeoutClientMap = map[time.Duration]*http.Client{} // timeout => Client
var timeoutClientLocker = sync.Mutex{} var timeoutClientLocker = sync.Mutex{}
// 导出响应 // DumpResponse 导出响应
func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) { func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) {
header, err = httputil.DumpResponse(resp, false) header, err = httputil.DumpResponse(resp, false)
if err != nil { if err != nil {
@@ -23,7 +23,7 @@ func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) {
return return
} }
// 获取一个新的Client // NewHTTPClient 获取一个新的Client
func NewHTTPClient(timeout time.Duration) *http.Client { func NewHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{ return &http.Client{
Timeout: timeout, Timeout: timeout,
@@ -33,7 +33,7 @@ func NewHTTPClient(timeout time.Duration) *http.Client {
MaxConnsPerHost: 32, MaxConnsPerHost: 32,
IdleConnTimeout: 2 * time.Minute, IdleConnTimeout: 2 * time.Minute,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 0, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
}, },
@@ -41,7 +41,7 @@ func NewHTTPClient(timeout time.Duration) *http.Client {
} }
} }
// 获取一个公用的Client // SharedHttpClient 获取一个公用的Client
func SharedHttpClient(timeout time.Duration) *http.Client { func SharedHttpClient(timeout time.Duration) *http.Client {
timeoutClientLocker.Lock() timeoutClientLocker.Lock()
defer timeoutClientLocker.Unlock() defer timeoutClientLocker.Unlock()