mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
优化连接相关代码
This commit is contained in:
@@ -1,10 +0,0 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package conns
|
||||
|
||||
import "net"
|
||||
|
||||
type ConnInfo struct {
|
||||
Conn net.Conn
|
||||
CreatedAt int64
|
||||
}
|
||||
@@ -4,22 +4,20 @@ package conns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedMap = NewMap()
|
||||
|
||||
type Map struct {
|
||||
m map[string]map[int]*ConnInfo // ip => { port => ConnInfo }
|
||||
m map[string]map[int]net.Conn // ip => { port => Conn }
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMap() *Map {
|
||||
return &Map{
|
||||
m: map[string]map[int]*ConnInfo{},
|
||||
m: map[string]map[int]net.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,20 +33,13 @@ func (this *Map) Add(conn net.Conn) {
|
||||
var ip = tcpAddr.IP.String()
|
||||
var port = tcpAddr.Port
|
||||
|
||||
var connInfo = &ConnInfo{
|
||||
Conn: conn,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
connMap, ok := this.m[ip]
|
||||
if !ok {
|
||||
this.m[ip] = map[int]*ConnInfo{
|
||||
port: connInfo,
|
||||
}
|
||||
this.m[ip] = map[int]net.Conn{port: conn}
|
||||
} else {
|
||||
connMap[port] = connInfo
|
||||
connMap[port] = conn
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,8 +84,8 @@ func (this *Map) CloseIPConns(ip string) {
|
||||
|
||||
// 复制,防止在Close时产生并发冲突
|
||||
if ok {
|
||||
for _, connInfo := range connMap {
|
||||
conns = append(conns, connInfo.Conn)
|
||||
for _, conn := range connMap {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,22 +108,16 @@ func (this *Map) CloseIPConns(ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Map) AllConns() []*ConnInfo {
|
||||
func (this *Map) AllConns() []net.Conn {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
var result = []*ConnInfo{}
|
||||
var result = []net.Conn{}
|
||||
for _, m := range this.m {
|
||||
for _, connInfo := range m {
|
||||
result = append(result, connInfo)
|
||||
}
|
||||
}
|
||||
|
||||
// 按时间排序
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
// 创建时间越大,Age越小
|
||||
return result[i].CreatedAt > result[j].CreatedAt
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
@@ -25,14 +26,19 @@ import (
|
||||
type ClientConn struct {
|
||||
BaseClientConn
|
||||
|
||||
isTLS bool
|
||||
hasDeadline bool
|
||||
hasRead bool
|
||||
createdAt int64
|
||||
|
||||
isTLS bool
|
||||
hasRead bool
|
||||
|
||||
isLO bool // 是否为环路
|
||||
isInAllowList bool
|
||||
|
||||
hasResetSYNFlood bool
|
||||
|
||||
lastReadAt int64
|
||||
lastWriteAt int64
|
||||
lastErr error
|
||||
}
|
||||
|
||||
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
|
||||
@@ -45,6 +51,7 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
|
||||
isTLS: isTLS,
|
||||
isLO: isLO,
|
||||
isInAllowList: isInAllowList,
|
||||
createdAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
if quickClose {
|
||||
@@ -59,6 +66,14 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
|
||||
}
|
||||
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
this.lastReadAt = time.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
this.lastErr = errors.New("read error: " + err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
// 环路直接读取
|
||||
if this.isLO {
|
||||
n, err = this.rawConn.Read(b)
|
||||
@@ -68,25 +83,11 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// TLS
|
||||
// TODO L1 -> L2 时,不计算synflood
|
||||
if this.isTLS {
|
||||
if !this.hasDeadline {
|
||||
_ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置
|
||||
this.hasDeadline = true
|
||||
defer func() {
|
||||
_ = this.rawConn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// 开始读取
|
||||
n, err = this.rawConn.Read(b)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
|
||||
if !this.hasRead {
|
||||
this.hasRead = true
|
||||
}
|
||||
this.hasRead = true
|
||||
}
|
||||
|
||||
// 检测是否为握手错误
|
||||
@@ -115,6 +116,14 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
this.lastWriteAt = time.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
this.lastErr = errors.New("write error: " + err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
// 设置超时时间
|
||||
// TODO L2 -> L1 写入时不限制时间
|
||||
var timeoutSeconds = len(b) / 4096
|
||||
@@ -136,8 +145,6 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
|
||||
// 如果是写入超时,则立即关闭连接
|
||||
if err != nil && os.IsTimeout(err) {
|
||||
//logs.Println(this.RemoteAddr(), timeoutSeconds, "seconds", n, "bytes")
|
||||
|
||||
// TODO 考虑对多次慢连接的IP做出惩罚
|
||||
conn, ok := this.rawConn.(LingerConn)
|
||||
if ok {
|
||||
@@ -183,6 +190,22 @@ func (this *ClientConn) SetWriteDeadline(t time.Time) error {
|
||||
return this.rawConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientConn) CreatedAt() int64 {
|
||||
return this.createdAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastReadAt() int64 {
|
||||
return this.lastReadAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastWriteAt() int64 {
|
||||
return this.lastWriteAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastErr() error {
|
||||
return this.lastErr
|
||||
}
|
||||
|
||||
func (this *ClientConn) resetSYNFlood() {
|
||||
ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP())
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
Certificates: nil,
|
||||
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
|
||||
tlsPolicy, _, err := this.matchSSL(clientInfo.ServerName)
|
||||
tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return tlsPolicy.TLSConfig(), nil
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
tlsPolicy, cert, err := this.matchSSL(clientInfo.ServerName)
|
||||
tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -182,3 +182,18 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
||||
|
||||
return nil, name
|
||||
}
|
||||
|
||||
// 从Hello信息中获取服务名称
|
||||
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
|
||||
var serverName = clientInfo.ServerName
|
||||
if len(serverName) == 0 {
|
||||
var localAddr = clientInfo.Conn.LocalAddr()
|
||||
if localAddr != nil {
|
||||
tcpAddr, ok := localAddr.(*net.TCPAddr)
|
||||
if ok {
|
||||
serverName = tcpAddr.IP.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
return serverName
|
||||
}
|
||||
|
||||
@@ -51,8 +51,6 @@ func (this *HTTPListener) Serve() error {
|
||||
switch state {
|
||||
case http.StateNew:
|
||||
atomic.AddInt64(&this.countActiveConnections, 1)
|
||||
case http.StateActive, http.StateIdle, http.StateHijacked:
|
||||
// Nothing to do
|
||||
case http.StateClosed:
|
||||
atomic.AddInt64(&this.countActiveConnections, -1)
|
||||
}
|
||||
|
||||
@@ -881,12 +881,49 @@ func (this *Node) listenSock() error {
|
||||
case "conns":
|
||||
var connMaps = []maps.Map{}
|
||||
var connMap = conns.SharedMap.AllConns()
|
||||
for _, connInfo := range connMap {
|
||||
for _, conn := range connMap {
|
||||
var createdAt int64
|
||||
var lastReadAt int64
|
||||
var lastWriteAt int64
|
||||
var lastErrString = ""
|
||||
clientConn, ok := conn.(*ClientConn)
|
||||
if ok {
|
||||
createdAt = clientConn.CreatedAt()
|
||||
lastReadAt = clientConn.LastReadAt()
|
||||
lastWriteAt = clientConn.LastWriteAt()
|
||||
|
||||
var lastErr = clientConn.LastErr()
|
||||
if lastErr != nil {
|
||||
lastErrString = lastErr.Error()
|
||||
}
|
||||
}
|
||||
var age int64
|
||||
var lastReadAge int64
|
||||
var lastWriteAge int64
|
||||
var currentTime = time.Now().Unix()
|
||||
if createdAt > 0 {
|
||||
age = currentTime - createdAt
|
||||
}
|
||||
if lastReadAt > 0 {
|
||||
lastReadAge = currentTime - lastReadAt
|
||||
}
|
||||
if lastWriteAt > 0 {
|
||||
lastWriteAge = currentTime - lastWriteAt
|
||||
}
|
||||
|
||||
connMaps = append(connMaps, maps.Map{
|
||||
"addr": connInfo.Conn.RemoteAddr().String(),
|
||||
"age": time.Now().Unix() - connInfo.CreatedAt,
|
||||
"addr": conn.RemoteAddr().String(),
|
||||
"age": age,
|
||||
"readAge": lastReadAge,
|
||||
"writeAge": lastWriteAge,
|
||||
"lastErr": lastErrString,
|
||||
})
|
||||
}
|
||||
sort.Slice(connMaps, func(i, j int) bool {
|
||||
var m1 = connMaps[i]
|
||||
var m2 = connMaps[j]
|
||||
return m1.GetInt64("age") < m2.GetInt64("age")
|
||||
})
|
||||
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Params: map[string]interface{}{
|
||||
|
||||
Reference in New Issue
Block a user