优化连接相关代码

This commit is contained in:
GoEdgeLab
2022-12-21 15:59:07 +08:00
parent a8111c76f6
commit de411e2209
6 changed files with 108 additions and 60 deletions

View File

@@ -1,10 +0,0 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package conns
import "net"
type ConnInfo struct {
Conn net.Conn
CreatedAt int64
}

View File

@@ -4,22 +4,20 @@ package conns
import (
"net"
"sort"
"sync"
"time"
)
var SharedMap = NewMap()
type Map struct {
m map[string]map[int]*ConnInfo // ip => { port => ConnInfo }
m map[string]map[int]net.Conn // ip => { port => Conn }
locker sync.RWMutex
}
func NewMap() *Map {
return &Map{
m: map[string]map[int]*ConnInfo{},
m: map[string]map[int]net.Conn{},
}
}
@@ -35,20 +33,13 @@ func (this *Map) Add(conn net.Conn) {
var ip = tcpAddr.IP.String()
var port = tcpAddr.Port
var connInfo = &ConnInfo{
Conn: conn,
CreatedAt: time.Now().Unix(),
}
this.locker.Lock()
defer this.locker.Unlock()
connMap, ok := this.m[ip]
if !ok {
this.m[ip] = map[int]*ConnInfo{
port: connInfo,
}
this.m[ip] = map[int]net.Conn{port: conn}
} else {
connMap[port] = connInfo
connMap[port] = conn
}
}
@@ -93,8 +84,8 @@ func (this *Map) CloseIPConns(ip string) {
// 复制防止在Close时产生并发冲突
if ok {
for _, connInfo := range connMap {
conns = append(conns, connInfo.Conn)
for _, conn := range connMap {
conns = append(conns, conn)
}
}
@@ -117,22 +108,16 @@ func (this *Map) CloseIPConns(ip string) {
}
}
func (this *Map) AllConns() []*ConnInfo {
func (this *Map) AllConns() []net.Conn {
this.locker.RLock()
defer this.locker.RUnlock()
var result = []*ConnInfo{}
var result = []net.Conn{}
for _, m := range this.m {
for _, connInfo := range m {
result = append(result, connInfo)
}
}
// 按时间排序
sort.Slice(result, func(i, j int) bool {
// 创建时间越大Age越小
return result[i].CreatedAt > result[j].CreatedAt
})
return result
}

View File

@@ -3,6 +3,7 @@
package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
@@ -25,14 +26,19 @@ import (
type ClientConn struct {
BaseClientConn
isTLS bool
hasDeadline bool
hasRead bool
createdAt int64
isTLS bool
hasRead bool
isLO bool // 是否为环路
isInAllowList bool
hasResetSYNFlood bool
lastReadAt int64
lastWriteAt int64
lastErr error
}
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
@@ -45,6 +51,7 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
isTLS: isTLS,
isLO: isLO,
isInAllowList: isInAllowList,
createdAt: time.Now().Unix(),
}
if quickClose {
@@ -59,6 +66,14 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
}
func (this *ClientConn) Read(b []byte) (n int, err error) {
this.lastReadAt = time.Now().Unix()
defer func() {
if err != nil {
this.lastErr = errors.New("read error: " + err.Error())
}
}()
// 环路直接读取
if this.isLO {
n, err = this.rawConn.Read(b)
@@ -68,25 +83,11 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
return
}
// TLS
// TODO L1 -> L2 时不计算synflood
if this.isTLS {
if !this.hasDeadline {
_ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置
this.hasDeadline = true
defer func() {
_ = this.rawConn.SetReadDeadline(time.Time{})
}()
}
}
// 开始读取
n, err = this.rawConn.Read(b)
if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
if !this.hasRead {
this.hasRead = true
}
this.hasRead = true
}
// 检测是否为握手错误
@@ -115,6 +116,14 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
}
func (this *ClientConn) Write(b []byte) (n int, err error) {
this.lastWriteAt = time.Now().Unix()
defer func() {
if err != nil {
this.lastErr = errors.New("write error: " + err.Error())
}
}()
// 设置超时时间
// TODO L2 -> L1 写入时不限制时间
var timeoutSeconds = len(b) / 4096
@@ -136,8 +145,6 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
// 如果是写入超时,则立即关闭连接
if err != nil && os.IsTimeout(err) {
//logs.Println(this.RemoteAddr(), timeoutSeconds, "seconds", n, "bytes")
// TODO 考虑对多次慢连接的IP做出惩罚
conn, ok := this.rawConn.(LingerConn)
if ok {
@@ -183,6 +190,22 @@ func (this *ClientConn) SetWriteDeadline(t time.Time) error {
return this.rawConn.SetWriteDeadline(t)
}
func (this *ClientConn) CreatedAt() int64 {
return this.createdAt
}
func (this *ClientConn) LastReadAt() int64 {
return this.lastReadAt
}
func (this *ClientConn) LastWriteAt() int64 {
return this.lastWriteAt
}
func (this *ClientConn) LastErr() error {
return this.lastErr
}
func (this *ClientConn) resetSYNFlood() {
ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP())
}

View File

@@ -36,7 +36,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
tlsPolicy, _, err := this.matchSSL(clientInfo.ServerName)
tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
if err != nil {
return nil, err
}
@@ -50,7 +50,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return tlsPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
tlsPolicy, cert, err := this.matchSSL(clientInfo.ServerName)
tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
if err != nil {
return nil, err
}
@@ -182,3 +182,18 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
return nil, name
}
// 从Hello信息中获取服务名称
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
var serverName = clientInfo.ServerName
if len(serverName) == 0 {
var localAddr = clientInfo.Conn.LocalAddr()
if localAddr != nil {
tcpAddr, ok := localAddr.(*net.TCPAddr)
if ok {
serverName = tcpAddr.IP.String()
}
}
}
return serverName
}

View File

@@ -51,8 +51,6 @@ func (this *HTTPListener) Serve() error {
switch state {
case http.StateNew:
atomic.AddInt64(&this.countActiveConnections, 1)
case http.StateActive, http.StateIdle, http.StateHijacked:
// Nothing to do
case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1)
}

View File

@@ -881,12 +881,49 @@ func (this *Node) listenSock() error {
case "conns":
var connMaps = []maps.Map{}
var connMap = conns.SharedMap.AllConns()
for _, connInfo := range connMap {
for _, conn := range connMap {
var createdAt int64
var lastReadAt int64
var lastWriteAt int64
var lastErrString = ""
clientConn, ok := conn.(*ClientConn)
if ok {
createdAt = clientConn.CreatedAt()
lastReadAt = clientConn.LastReadAt()
lastWriteAt = clientConn.LastWriteAt()
var lastErr = clientConn.LastErr()
if lastErr != nil {
lastErrString = lastErr.Error()
}
}
var age int64
var lastReadAge int64
var lastWriteAge int64
var currentTime = time.Now().Unix()
if createdAt > 0 {
age = currentTime - createdAt
}
if lastReadAt > 0 {
lastReadAge = currentTime - lastReadAt
}
if lastWriteAt > 0 {
lastWriteAge = currentTime - lastWriteAt
}
connMaps = append(connMaps, maps.Map{
"addr": connInfo.Conn.RemoteAddr().String(),
"age": time.Now().Unix() - connInfo.CreatedAt,
"addr": conn.RemoteAddr().String(),
"age": age,
"readAge": lastReadAge,
"writeAge": lastWriteAge,
"lastErr": lastErrString,
})
}
sort.Slice(connMaps, func(i, j int) bool {
var m1 = connMaps[i]
var m2 = connMaps[j]
return m1.GetInt64("age") < m2.GetInt64("age")
})
_ = cmd.Reply(&gosock.Command{
Params: map[string]interface{}{