优化连接相关代码

This commit is contained in:
刘祥超
2022-12-21 15:59:07 +08:00
parent 1a200918a8
commit c45f7adf04
6 changed files with 108 additions and 60 deletions

View File

@@ -3,6 +3,7 @@
package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
@@ -25,14 +26,19 @@ import (
type ClientConn struct {
BaseClientConn
isTLS bool
hasDeadline bool
hasRead bool
createdAt int64
isTLS bool
hasRead bool
isLO bool // 是否为环路
isInAllowList bool
hasResetSYNFlood bool
lastReadAt int64
lastWriteAt int64
lastErr error
}
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
@@ -45,6 +51,7 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
isTLS: isTLS,
isLO: isLO,
isInAllowList: isInAllowList,
createdAt: time.Now().Unix(),
}
if quickClose {
@@ -59,6 +66,14 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
}
func (this *ClientConn) Read(b []byte) (n int, err error) {
this.lastReadAt = time.Now().Unix()
defer func() {
if err != nil {
this.lastErr = errors.New("read error: " + err.Error())
}
}()
// 环路直接读取
if this.isLO {
n, err = this.rawConn.Read(b)
@@ -68,25 +83,11 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
return
}
// TLS
// TODO L1 -> L2 时不计算synflood
if this.isTLS {
if !this.hasDeadline {
_ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置
this.hasDeadline = true
defer func() {
_ = this.rawConn.SetReadDeadline(time.Time{})
}()
}
}
// 开始读取
n, err = this.rawConn.Read(b)
if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
if !this.hasRead {
this.hasRead = true
}
this.hasRead = true
}
// 检测是否为握手错误
@@ -115,6 +116,14 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
}
func (this *ClientConn) Write(b []byte) (n int, err error) {
this.lastWriteAt = time.Now().Unix()
defer func() {
if err != nil {
this.lastErr = errors.New("write error: " + err.Error())
}
}()
// 设置超时时间
// TODO L2 -> L1 写入时不限制时间
var timeoutSeconds = len(b) / 4096
@@ -136,8 +145,6 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
// 如果是写入超时,则立即关闭连接
if err != nil && os.IsTimeout(err) {
//logs.Println(this.RemoteAddr(), timeoutSeconds, "seconds", n, "bytes")
// TODO 考虑对多次慢连接的IP做出惩罚
conn, ok := this.rawConn.(LingerConn)
if ok {
@@ -183,6 +190,22 @@ func (this *ClientConn) SetWriteDeadline(t time.Time) error {
return this.rawConn.SetWriteDeadline(t)
}
func (this *ClientConn) CreatedAt() int64 {
return this.createdAt
}
func (this *ClientConn) LastReadAt() int64 {
return this.lastReadAt
}
func (this *ClientConn) LastWriteAt() int64 {
return this.lastWriteAt
}
func (this *ClientConn) LastErr() error {
return this.lastErr
}
func (this *ClientConn) resetSYNFlood() {
ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP())
}

View File

@@ -36,7 +36,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
tlsPolicy, _, err := this.matchSSL(clientInfo.ServerName)
tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
if err != nil {
return nil, err
}
@@ -50,7 +50,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return tlsPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
tlsPolicy, cert, err := this.matchSSL(clientInfo.ServerName)
tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
if err != nil {
return nil, err
}
@@ -182,3 +182,18 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
return nil, name
}
// 从Hello信息中获取服务名称
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
var serverName = clientInfo.ServerName
if len(serverName) == 0 {
var localAddr = clientInfo.Conn.LocalAddr()
if localAddr != nil {
tcpAddr, ok := localAddr.(*net.TCPAddr)
if ok {
serverName = tcpAddr.IP.String()
}
}
}
return serverName
}

View File

@@ -51,8 +51,6 @@ func (this *HTTPListener) Serve() error {
switch state {
case http.StateNew:
atomic.AddInt64(&this.countActiveConnections, 1)
case http.StateActive, http.StateIdle, http.StateHijacked:
// Nothing to do
case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1)
}

View File

@@ -881,12 +881,49 @@ func (this *Node) listenSock() error {
case "conns":
var connMaps = []maps.Map{}
var connMap = conns.SharedMap.AllConns()
for _, connInfo := range connMap {
for _, conn := range connMap {
var createdAt int64
var lastReadAt int64
var lastWriteAt int64
var lastErrString = ""
clientConn, ok := conn.(*ClientConn)
if ok {
createdAt = clientConn.CreatedAt()
lastReadAt = clientConn.LastReadAt()
lastWriteAt = clientConn.LastWriteAt()
var lastErr = clientConn.LastErr()
if lastErr != nil {
lastErrString = lastErr.Error()
}
}
var age int64
var lastReadAge int64
var lastWriteAge int64
var currentTime = time.Now().Unix()
if createdAt > 0 {
age = currentTime - createdAt
}
if lastReadAt > 0 {
lastReadAge = currentTime - lastReadAt
}
if lastWriteAt > 0 {
lastWriteAge = currentTime - lastWriteAt
}
connMaps = append(connMaps, maps.Map{
"addr": connInfo.Conn.RemoteAddr().String(),
"age": time.Now().Unix() - connInfo.CreatedAt,
"addr": conn.RemoteAddr().String(),
"age": age,
"readAge": lastReadAge,
"writeAge": lastWriteAge,
"lastErr": lastErrString,
})
}
sort.Slice(connMaps, func(i, j int) bool {
var m1 = connMaps[i]
var m2 = connMaps[j]
return m1.GetInt64("age") < m2.GetInt64("age")
})
_ = cmd.Reply(&gosock.Command{
Params: map[string]interface{}{