diff --git a/internal/conns/info.go b/internal/conns/info.go deleted file mode 100644 index 10f35ed..0000000 --- a/internal/conns/info.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . - -package conns - -import "net" - -type ConnInfo struct { - Conn net.Conn - CreatedAt int64 -} diff --git a/internal/conns/map.go b/internal/conns/map.go index de657ad..f02f3c8 100644 --- a/internal/conns/map.go +++ b/internal/conns/map.go @@ -4,22 +4,20 @@ package conns import ( "net" - "sort" "sync" - "time" ) var SharedMap = NewMap() type Map struct { - m map[string]map[int]*ConnInfo // ip => { port => ConnInfo } + m map[string]map[int]net.Conn // ip => { port => Conn } locker sync.RWMutex } func NewMap() *Map { return &Map{ - m: map[string]map[int]*ConnInfo{}, + m: map[string]map[int]net.Conn{}, } } @@ -35,20 +33,13 @@ func (this *Map) Add(conn net.Conn) { var ip = tcpAddr.IP.String() var port = tcpAddr.Port - var connInfo = &ConnInfo{ - Conn: conn, - CreatedAt: time.Now().Unix(), - } - this.locker.Lock() defer this.locker.Unlock() connMap, ok := this.m[ip] if !ok { - this.m[ip] = map[int]*ConnInfo{ - port: connInfo, - } + this.m[ip] = map[int]net.Conn{port: conn} } else { - connMap[port] = connInfo + connMap[port] = conn } } @@ -93,8 +84,8 @@ func (this *Map) CloseIPConns(ip string) { // 复制,防止在Close时产生并发冲突 if ok { - for _, connInfo := range connMap { - conns = append(conns, connInfo.Conn) + for _, conn := range connMap { + conns = append(conns, conn) } } @@ -117,22 +108,16 @@ func (this *Map) CloseIPConns(ip string) { } } -func (this *Map) AllConns() []*ConnInfo { +func (this *Map) AllConns() []net.Conn { this.locker.RLock() defer this.locker.RUnlock() - var result = []*ConnInfo{} + var result = []net.Conn{} for _, m := range this.m { for _, connInfo := range m { result = append(result, connInfo) } } - // 按时间排序 - sort.Slice(result, func(i, j int) bool { - // 创建时间越大,Age越小 - return result[i].CreatedAt > result[j].CreatedAt - }) - return result } diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index 1c11a34..9916aac 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -3,6 +3,7 @@ package nodes import ( + "errors" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/conns" @@ -25,14 +26,19 @@ import ( type ClientConn struct { BaseClientConn - isTLS bool - hasDeadline bool - hasRead bool + createdAt int64 + + isTLS bool + hasRead bool isLO bool // 是否为环路 isInAllowList bool hasResetSYNFlood bool + + lastReadAt int64 + lastWriteAt int64 + lastErr error } func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn { @@ -45,6 +51,7 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList isTLS: isTLS, isLO: isLO, isInAllowList: isInAllowList, + createdAt: time.Now().Unix(), } if quickClose { @@ -59,6 +66,14 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList } func (this *ClientConn) Read(b []byte) (n int, err error) { + this.lastReadAt = time.Now().Unix() + + defer func() { + if err != nil { + this.lastErr = errors.New("read error: " + err.Error()) + } + }() + // 环路直接读取 if this.isLO { n, err = this.rawConn.Read(b) @@ -68,25 +83,11 @@ func (this *ClientConn) Read(b []byte) (n int, err error) { return } - // TLS - // TODO L1 -> L2 时,不计算synflood - if this.isTLS { - if !this.hasDeadline { - _ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置 - this.hasDeadline = true - defer func() { - _ = this.rawConn.SetReadDeadline(time.Time{}) - }() - } - } - // 开始读取 n, err = this.rawConn.Read(b) if n > 0 { atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n)) - if !this.hasRead { - this.hasRead = true - } + this.hasRead = true } // 检测是否为握手错误 @@ -115,6 +116,14 @@ func (this *ClientConn) Read(b []byte) (n int, err error) { } func (this *ClientConn) Write(b []byte) (n int, err error) { + this.lastWriteAt = time.Now().Unix() + + defer func() { + if err != nil { + this.lastErr = errors.New("write error: " + err.Error()) + } + }() + // 设置超时时间 // TODO L2 -> L1 写入时不限制时间 var timeoutSeconds = len(b) / 4096 @@ -136,8 +145,6 @@ func (this *ClientConn) Write(b []byte) (n int, err error) { // 如果是写入超时,则立即关闭连接 if err != nil && os.IsTimeout(err) { - //logs.Println(this.RemoteAddr(), timeoutSeconds, "seconds", n, "bytes") - // TODO 考虑对多次慢连接的IP做出惩罚 conn, ok := this.rawConn.(LingerConn) if ok { @@ -183,6 +190,22 @@ func (this *ClientConn) SetWriteDeadline(t time.Time) error { return this.rawConn.SetWriteDeadline(t) } +func (this *ClientConn) CreatedAt() int64 { + return this.createdAt +} + +func (this *ClientConn) LastReadAt() int64 { + return this.lastReadAt +} + +func (this *ClientConn) LastWriteAt() int64 { + return this.lastWriteAt +} + +func (this *ClientConn) LastErr() error { + return this.lastErr +} + func (this *ClientConn) resetSYNFlood() { ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP()) } diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index 3bf8fdf..41ccdec 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -36,7 +36,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { return &tls.Config{ Certificates: nil, GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) { - tlsPolicy, _, err := this.matchSSL(clientInfo.ServerName) + tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo)) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { return tlsPolicy.TLSConfig(), nil }, GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) { - tlsPolicy, cert, err := this.matchSSL(clientInfo.ServerName) + tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo)) if err != nil { return nil, err } @@ -182,3 +182,18 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser return nil, name } + +// 从Hello信息中获取服务名称 +func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string { + var serverName = clientInfo.ServerName + if len(serverName) == 0 { + var localAddr = clientInfo.Conn.LocalAddr() + if localAddr != nil { + tcpAddr, ok := localAddr.(*net.TCPAddr) + if ok { + serverName = tcpAddr.IP.String() + } + } + } + return serverName +} diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index a9a3729..7297af5 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -51,8 +51,6 @@ func (this *HTTPListener) Serve() error { switch state { case http.StateNew: atomic.AddInt64(&this.countActiveConnections, 1) - case http.StateActive, http.StateIdle, http.StateHijacked: - // Nothing to do case http.StateClosed: atomic.AddInt64(&this.countActiveConnections, -1) } diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 775fa7c..28f10b6 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -881,12 +881,49 @@ func (this *Node) listenSock() error { case "conns": var connMaps = []maps.Map{} var connMap = conns.SharedMap.AllConns() - for _, connInfo := range connMap { + for _, conn := range connMap { + var createdAt int64 + var lastReadAt int64 + var lastWriteAt int64 + var lastErrString = "" + clientConn, ok := conn.(*ClientConn) + if ok { + createdAt = clientConn.CreatedAt() + lastReadAt = clientConn.LastReadAt() + lastWriteAt = clientConn.LastWriteAt() + + var lastErr = clientConn.LastErr() + if lastErr != nil { + lastErrString = lastErr.Error() + } + } + var age int64 + var lastReadAge int64 + var lastWriteAge int64 + var currentTime = time.Now().Unix() + if createdAt > 0 { + age = currentTime - createdAt + } + if lastReadAt > 0 { + lastReadAge = currentTime - lastReadAt + } + if lastWriteAt > 0 { + lastWriteAge = currentTime - lastWriteAt + } + connMaps = append(connMaps, maps.Map{ - "addr": connInfo.Conn.RemoteAddr().String(), - "age": time.Now().Unix() - connInfo.CreatedAt, + "addr": conn.RemoteAddr().String(), + "age": age, + "readAge": lastReadAge, + "writeAge": lastWriteAge, + "lastErr": lastErrString, }) } + sort.Slice(connMaps, func(i, j int) bool { + var m1 = connMaps[i] + var m2 = connMaps[j] + return m1.GetInt64("age") < m2.GetInt64("age") + }) _ = cmd.Reply(&gosock.Command{ Params: map[string]interface{}{