mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 07:40:56 +08:00 
			
		
		
		
	优化连接相关代码
This commit is contained in:
		@@ -1,10 +0,0 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
			
		||||
 | 
			
		||||
package conns
 | 
			
		||||
 | 
			
		||||
import "net"
 | 
			
		||||
 | 
			
		||||
type ConnInfo struct {
 | 
			
		||||
	Conn      net.Conn
 | 
			
		||||
	CreatedAt int64
 | 
			
		||||
}
 | 
			
		||||
@@ -4,22 +4,20 @@ package conns
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var SharedMap = NewMap()
 | 
			
		||||
 | 
			
		||||
type Map struct {
 | 
			
		||||
	m map[string]map[int]*ConnInfo // ip => { port => ConnInfo  }
 | 
			
		||||
	m map[string]map[int]net.Conn // ip => { port => Conn  }
 | 
			
		||||
 | 
			
		||||
	locker sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMap() *Map {
 | 
			
		||||
	return &Map{
 | 
			
		||||
		m: map[string]map[int]*ConnInfo{},
 | 
			
		||||
		m: map[string]map[int]net.Conn{},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -35,20 +33,13 @@ func (this *Map) Add(conn net.Conn) {
 | 
			
		||||
	var ip = tcpAddr.IP.String()
 | 
			
		||||
	var port = tcpAddr.Port
 | 
			
		||||
 | 
			
		||||
	var connInfo = &ConnInfo{
 | 
			
		||||
		Conn:      conn,
 | 
			
		||||
		CreatedAt: time.Now().Unix(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.locker.Lock()
 | 
			
		||||
	defer this.locker.Unlock()
 | 
			
		||||
	connMap, ok := this.m[ip]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		this.m[ip] = map[int]*ConnInfo{
 | 
			
		||||
			port: connInfo,
 | 
			
		||||
		}
 | 
			
		||||
		this.m[ip] = map[int]net.Conn{port: conn}
 | 
			
		||||
	} else {
 | 
			
		||||
		connMap[port] = connInfo
 | 
			
		||||
		connMap[port] = conn
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -93,8 +84,8 @@ func (this *Map) CloseIPConns(ip string) {
 | 
			
		||||
 | 
			
		||||
	// 复制,防止在Close时产生并发冲突
 | 
			
		||||
	if ok {
 | 
			
		||||
		for _, connInfo := range connMap {
 | 
			
		||||
			conns = append(conns, connInfo.Conn)
 | 
			
		||||
		for _, conn := range connMap {
 | 
			
		||||
			conns = append(conns, conn)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -117,22 +108,16 @@ func (this *Map) CloseIPConns(ip string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Map) AllConns() []*ConnInfo {
 | 
			
		||||
func (this *Map) AllConns() []net.Conn {
 | 
			
		||||
	this.locker.RLock()
 | 
			
		||||
	defer this.locker.RUnlock()
 | 
			
		||||
 | 
			
		||||
	var result = []*ConnInfo{}
 | 
			
		||||
	var result = []net.Conn{}
 | 
			
		||||
	for _, m := range this.m {
 | 
			
		||||
		for _, connInfo := range m {
 | 
			
		||||
			result = append(result, connInfo)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 按时间排序
 | 
			
		||||
	sort.Slice(result, func(i, j int) bool {
 | 
			
		||||
		// 创建时间越大,Age越小
 | 
			
		||||
		return result[i].CreatedAt > result[j].CreatedAt
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return result
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
			
		||||
@@ -25,14 +26,19 @@ import (
 | 
			
		||||
type ClientConn struct {
 | 
			
		||||
	BaseClientConn
 | 
			
		||||
 | 
			
		||||
	createdAt int64
 | 
			
		||||
 | 
			
		||||
	isTLS   bool
 | 
			
		||||
	hasDeadline bool
 | 
			
		||||
	hasRead bool
 | 
			
		||||
 | 
			
		||||
	isLO          bool // 是否为环路
 | 
			
		||||
	isInAllowList bool
 | 
			
		||||
 | 
			
		||||
	hasResetSYNFlood bool
 | 
			
		||||
 | 
			
		||||
	lastReadAt  int64
 | 
			
		||||
	lastWriteAt int64
 | 
			
		||||
	lastErr     error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
 | 
			
		||||
@@ -45,6 +51,7 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
 | 
			
		||||
		isTLS:          isTLS,
 | 
			
		||||
		isLO:           isLO,
 | 
			
		||||
		isInAllowList:  isInAllowList,
 | 
			
		||||
		createdAt:      time.Now().Unix(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if quickClose {
 | 
			
		||||
@@ -59,6 +66,14 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
 | 
			
		||||
	this.lastReadAt = time.Now().Unix()
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			this.lastErr = errors.New("read error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// 环路直接读取
 | 
			
		||||
	if this.isLO {
 | 
			
		||||
		n, err = this.rawConn.Read(b)
 | 
			
		||||
@@ -68,26 +83,12 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TLS
 | 
			
		||||
	// TODO L1 -> L2 时,不计算synflood
 | 
			
		||||
	if this.isTLS {
 | 
			
		||||
		if !this.hasDeadline {
 | 
			
		||||
			_ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置
 | 
			
		||||
			this.hasDeadline = true
 | 
			
		||||
			defer func() {
 | 
			
		||||
				_ = this.rawConn.SetReadDeadline(time.Time{})
 | 
			
		||||
			}()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 开始读取
 | 
			
		||||
	n, err = this.rawConn.Read(b)
 | 
			
		||||
	if n > 0 {
 | 
			
		||||
		atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
 | 
			
		||||
		if !this.hasRead {
 | 
			
		||||
		this.hasRead = true
 | 
			
		||||
	}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检测是否为握手错误
 | 
			
		||||
	var isHandshakeError = err != nil && os.IsTimeout(err) && !this.hasRead
 | 
			
		||||
@@ -115,6 +116,14 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) Write(b []byte) (n int, err error) {
 | 
			
		||||
	this.lastWriteAt = time.Now().Unix()
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			this.lastErr = errors.New("write error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// 设置超时时间
 | 
			
		||||
	// TODO L2 -> L1 写入时不限制时间
 | 
			
		||||
	var timeoutSeconds = len(b) / 4096
 | 
			
		||||
@@ -136,8 +145,6 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
 | 
			
		||||
 | 
			
		||||
	// 如果是写入超时,则立即关闭连接
 | 
			
		||||
	if err != nil && os.IsTimeout(err) {
 | 
			
		||||
		//logs.Println(this.RemoteAddr(), timeoutSeconds, "seconds", n, "bytes")
 | 
			
		||||
 | 
			
		||||
		// TODO 考虑对多次慢连接的IP做出惩罚
 | 
			
		||||
		conn, ok := this.rawConn.(LingerConn)
 | 
			
		||||
		if ok {
 | 
			
		||||
@@ -183,6 +190,22 @@ func (this *ClientConn) SetWriteDeadline(t time.Time) error {
 | 
			
		||||
	return this.rawConn.SetWriteDeadline(t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) CreatedAt() int64 {
 | 
			
		||||
	return this.createdAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) LastReadAt() int64 {
 | 
			
		||||
	return this.lastReadAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) LastWriteAt() int64 {
 | 
			
		||||
	return this.lastWriteAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) LastErr() error {
 | 
			
		||||
	return this.lastErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) resetSYNFlood() {
 | 
			
		||||
	ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP())
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -36,7 +36,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
 | 
			
		||||
	return &tls.Config{
 | 
			
		||||
		Certificates: nil,
 | 
			
		||||
		GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
 | 
			
		||||
			tlsPolicy, _, err := this.matchSSL(clientInfo.ServerName)
 | 
			
		||||
			tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
@@ -50,7 +50,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
 | 
			
		||||
			return tlsPolicy.TLSConfig(), nil
 | 
			
		||||
		},
 | 
			
		||||
		GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
 | 
			
		||||
			tlsPolicy, cert, err := this.matchSSL(clientInfo.ServerName)
 | 
			
		||||
			tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
@@ -182,3 +182,18 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
 | 
			
		||||
 | 
			
		||||
	return nil, name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从Hello信息中获取服务名称
 | 
			
		||||
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
 | 
			
		||||
	var serverName = clientInfo.ServerName
 | 
			
		||||
	if len(serverName) == 0 {
 | 
			
		||||
		var localAddr = clientInfo.Conn.LocalAddr()
 | 
			
		||||
		if localAddr != nil {
 | 
			
		||||
			tcpAddr, ok := localAddr.(*net.TCPAddr)
 | 
			
		||||
			if ok {
 | 
			
		||||
				serverName = tcpAddr.IP.String()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return serverName
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -51,8 +51,6 @@ func (this *HTTPListener) Serve() error {
 | 
			
		||||
			switch state {
 | 
			
		||||
			case http.StateNew:
 | 
			
		||||
				atomic.AddInt64(&this.countActiveConnections, 1)
 | 
			
		||||
			case http.StateActive, http.StateIdle, http.StateHijacked:
 | 
			
		||||
				// Nothing to do
 | 
			
		||||
			case http.StateClosed:
 | 
			
		||||
				atomic.AddInt64(&this.countActiveConnections, -1)
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
@@ -881,12 +881,49 @@ func (this *Node) listenSock() error {
 | 
			
		||||
			case "conns":
 | 
			
		||||
				var connMaps = []maps.Map{}
 | 
			
		||||
				var connMap = conns.SharedMap.AllConns()
 | 
			
		||||
				for _, connInfo := range connMap {
 | 
			
		||||
				for _, conn := range connMap {
 | 
			
		||||
					var createdAt int64
 | 
			
		||||
					var lastReadAt int64
 | 
			
		||||
					var lastWriteAt int64
 | 
			
		||||
					var lastErrString = ""
 | 
			
		||||
					clientConn, ok := conn.(*ClientConn)
 | 
			
		||||
					if ok {
 | 
			
		||||
						createdAt = clientConn.CreatedAt()
 | 
			
		||||
						lastReadAt = clientConn.LastReadAt()
 | 
			
		||||
						lastWriteAt = clientConn.LastWriteAt()
 | 
			
		||||
 | 
			
		||||
						var lastErr = clientConn.LastErr()
 | 
			
		||||
						if lastErr != nil {
 | 
			
		||||
							lastErrString = lastErr.Error()
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					var age int64
 | 
			
		||||
					var lastReadAge int64
 | 
			
		||||
					var lastWriteAge int64
 | 
			
		||||
					var currentTime = time.Now().Unix()
 | 
			
		||||
					if createdAt > 0 {
 | 
			
		||||
						age = currentTime - createdAt
 | 
			
		||||
					}
 | 
			
		||||
					if lastReadAt > 0 {
 | 
			
		||||
						lastReadAge = currentTime - lastReadAt
 | 
			
		||||
					}
 | 
			
		||||
					if lastWriteAt > 0 {
 | 
			
		||||
						lastWriteAge = currentTime - lastWriteAt
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					connMaps = append(connMaps, maps.Map{
 | 
			
		||||
						"addr": connInfo.Conn.RemoteAddr().String(),
 | 
			
		||||
						"age":  time.Now().Unix() - connInfo.CreatedAt,
 | 
			
		||||
						"addr":     conn.RemoteAddr().String(),
 | 
			
		||||
						"age":      age,
 | 
			
		||||
						"readAge":  lastReadAge,
 | 
			
		||||
						"writeAge": lastWriteAge,
 | 
			
		||||
						"lastErr":  lastErrString,
 | 
			
		||||
					})
 | 
			
		||||
				}
 | 
			
		||||
				sort.Slice(connMaps, func(i, j int) bool {
 | 
			
		||||
					var m1 = connMaps[i]
 | 
			
		||||
					var m2 = connMaps[j]
 | 
			
		||||
					return m1.GetInt64("age") < m2.GetInt64("age")
 | 
			
		||||
				})
 | 
			
		||||
 | 
			
		||||
				_ = cmd.Reply(&gosock.Command{
 | 
			
		||||
					Params: map[string]interface{}{
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user