mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	实现请求连接数等限制
This commit is contained in:
		@@ -108,6 +108,20 @@ func main() {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						app.On("conns", func() {
 | 
				
			||||||
 | 
							var sock = gosock.NewTmpSock(teaconst.ProcessName)
 | 
				
			||||||
 | 
							reply, err := sock.Send(&gosock.Command{Code: "conns"})
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								fmt.Println("[ERROR]" + err.Error())
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								resultJSON, err := json.MarshalIndent(reply.Params, "", "  ")
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									fmt.Println("[ERROR]" + err.Error())
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									fmt.Println(string(resultJSON))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
	app.Run(func() {
 | 
						app.Run(func() {
 | 
				
			||||||
		node := nodes.NewNode()
 | 
							node := nodes.NewNode()
 | 
				
			||||||
		node.Start()
 | 
							node.Start()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,63 +5,32 @@ package nodes
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
				
			||||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
						teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/monitor"
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
 | 
				
			||||||
	"github.com/iwind/TeaGo/maps"
 | 
					 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"sync/atomic"
 | 
						"sync/atomic"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 发送监控流量
 | 
					 | 
				
			||||||
func init() {
 | 
					 | 
				
			||||||
	events.On(events.EventStart, func() {
 | 
					 | 
				
			||||||
		ticker := time.NewTicker(1 * time.Minute)
 | 
					 | 
				
			||||||
		goman.New(func() {
 | 
					 | 
				
			||||||
			for range ticker.C {
 | 
					 | 
				
			||||||
				// 加入到数据队列中
 | 
					 | 
				
			||||||
				if teaconst.InTrafficBytes > 0 {
 | 
					 | 
				
			||||||
					monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficIn, maps.Map{
 | 
					 | 
				
			||||||
						"total": teaconst.InTrafficBytes,
 | 
					 | 
				
			||||||
					})
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if teaconst.OutTrafficBytes > 0 {
 | 
					 | 
				
			||||||
					monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficOut, maps.Map{
 | 
					 | 
				
			||||||
						"total": teaconst.OutTrafficBytes,
 | 
					 | 
				
			||||||
					})
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				// 重置数据
 | 
					 | 
				
			||||||
				atomic.StoreUint64(&teaconst.InTrafficBytes, 0)
 | 
					 | 
				
			||||||
				atomic.StoreUint64(&teaconst.OutTrafficBytes, 0)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ClientConn 客户端连接
 | 
					// ClientConn 客户端连接
 | 
				
			||||||
type ClientConn struct {
 | 
					type ClientConn struct {
 | 
				
			||||||
	rawConn  net.Conn
 | 
						once          sync.Once
 | 
				
			||||||
	isClosed bool
 | 
						globalLimiter *ratelimit.Counter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	once    sync.Once
 | 
						BaseClientConn
 | 
				
			||||||
	limiter *ratelimit.Counter
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewClientConn(conn net.Conn, quickClose bool, limiter *ratelimit.Counter) net.Conn {
 | 
					func NewClientConn(conn net.Conn, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
 | 
				
			||||||
	if quickClose {
 | 
						if quickClose {
 | 
				
			||||||
		// TCP
 | 
							// TCP
 | 
				
			||||||
		tcpConn, ok := conn.(*net.TCPConn)
 | 
							tcpConn, ok := conn.(*net.TCPConn)
 | 
				
			||||||
		if ok {
 | 
							if ok {
 | 
				
			||||||
			// TODO 可以设置此值
 | 
								// TODO 可以在配置中设置此值
 | 
				
			||||||
			_ = tcpConn.SetLinger(nodeconfigs.DefaultTCPLinger)
 | 
								_ = tcpConn.SetLinger(nodeconfigs.DefaultTCPLinger)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &ClientConn{rawConn: conn, limiter: limiter}
 | 
						return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, globalLimiter: globalLimiter}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
 | 
					func (this *ClientConn) Read(b []byte) (n int, err error) {
 | 
				
			||||||
@@ -82,11 +51,17 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (this *ClientConn) Close() error {
 | 
					func (this *ClientConn) Close() error {
 | 
				
			||||||
	this.isClosed = true
 | 
						this.isClosed = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 全局并发数限制
 | 
				
			||||||
	this.once.Do(func() {
 | 
						this.once.Do(func() {
 | 
				
			||||||
		if this.limiter != nil {
 | 
							if this.globalLimiter != nil {
 | 
				
			||||||
			this.limiter.Release()
 | 
								this.globalLimiter.Release()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 单个服务并发数限制
 | 
				
			||||||
 | 
						sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return this.rawConn.Close()
 | 
						return this.rawConn.Close()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -109,7 +84,3 @@ func (this *ClientConn) SetReadDeadline(t time.Time) error {
 | 
				
			|||||||
func (this *ClientConn) SetWriteDeadline(t time.Time) error {
 | 
					func (this *ClientConn) SetWriteDeadline(t time.Time) error {
 | 
				
			||||||
	return this.rawConn.SetWriteDeadline(t)
 | 
						return this.rawConn.SetWriteDeadline(t)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (this *ClientConn) IsClosed() bool {
 | 
					 | 
				
			||||||
	return this.isClosed
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										38
									
								
								internal/nodes/client_conn_base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								internal/nodes/client_conn_base.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "net"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type BaseClientConn struct {
 | 
				
			||||||
 | 
						rawConn net.Conn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						isBound    bool
 | 
				
			||||||
 | 
						serverId   int64
 | 
				
			||||||
 | 
						remoteAddr string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						isClosed bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *BaseClientConn) IsClosed() bool {
 | 
				
			||||||
 | 
						return this.isClosed
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IsBound 是否已绑定服务
 | 
				
			||||||
 | 
					func (this *BaseClientConn) IsBound() bool {
 | 
				
			||||||
 | 
						return this.isBound
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Bind 绑定服务
 | 
				
			||||||
 | 
					func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
 | 
				
			||||||
 | 
						if this.isBound {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.isBound = true
 | 
				
			||||||
 | 
						this.serverId = serverId
 | 
				
			||||||
 | 
						this.remoteAddr = remoteAddr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 检查是否可以连接
 | 
				
			||||||
 | 
						return sharedClientConnLimiter.Add(this.rawConn.RemoteAddr().String(), serverId, remoteAddr, maxConnsPerServer, maxConnsPerIP)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1,7 +0,0 @@
 | 
				
			|||||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package nodes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ClientConnCloser interface {
 | 
					 | 
				
			||||||
	IsClosed() bool
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										14
									
								
								internal/nodes/client_conn_interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								internal/nodes/client_conn_interface.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ClientConnInterface interface {
 | 
				
			||||||
 | 
						// IsClosed 是否已关闭
 | 
				
			||||||
 | 
						IsClosed() bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// IsBound 是否已绑定服务
 | 
				
			||||||
 | 
						IsBound() bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Bind 绑定服务
 | 
				
			||||||
 | 
						Bind(serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										130
									
								
								internal/nodes/client_conn_limiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								internal/nodes/client_conn_limiter.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,130 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/zero"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var sharedClientConnLimiter = NewClientConnLimiter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ClientConnRemoteAddr 客户端地址定义
 | 
				
			||||||
 | 
					type ClientConnRemoteAddr struct {
 | 
				
			||||||
 | 
						remoteAddr string
 | 
				
			||||||
 | 
						serverId   int64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ClientConnLimiter 客户端连接数限制
 | 
				
			||||||
 | 
					type ClientConnLimiter struct {
 | 
				
			||||||
 | 
						remoteAddrMap map[string]*ClientConnRemoteAddr // raw remote addr => remoteAddr
 | 
				
			||||||
 | 
						ipConns       map[string]map[string]zero.Zero  // remoteAddr => { raw remote addr => Zero }
 | 
				
			||||||
 | 
						serverConns   map[int64]map[string]zero.Zero   // serverId => { remoteAddr => Zero }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						locker sync.Mutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewClientConnLimiter() *ClientConnLimiter {
 | 
				
			||||||
 | 
						return &ClientConnLimiter{
 | 
				
			||||||
 | 
							remoteAddrMap: map[string]*ClientConnRemoteAddr{},
 | 
				
			||||||
 | 
							ipConns:       map[string]map[string]zero.Zero{},
 | 
				
			||||||
 | 
							serverConns:   map[int64]map[string]zero.Zero{},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add 添加新连接
 | 
				
			||||||
 | 
					// 返回值为true的时候表示允许添加;否则表示不允许添加
 | 
				
			||||||
 | 
					func (this *ClientConnLimiter) Add(rawRemoteAddr string, serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
 | 
				
			||||||
 | 
						if maxConnsPerServer <= 0 || maxConnsPerIP <= 0 || len(remoteAddr) == 0 || serverId <= 0 {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 检查服务连接数
 | 
				
			||||||
 | 
						var serverMap = this.serverConns[serverId]
 | 
				
			||||||
 | 
						if maxConnsPerServer > 0 {
 | 
				
			||||||
 | 
							if serverMap == nil {
 | 
				
			||||||
 | 
								serverMap = map[string]zero.Zero{}
 | 
				
			||||||
 | 
								this.serverConns[serverId] = serverMap
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if maxConnsPerServer <= len(serverMap) {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 检查IP连接数
 | 
				
			||||||
 | 
						var ipMap = this.ipConns[remoteAddr]
 | 
				
			||||||
 | 
						if maxConnsPerIP > 0 {
 | 
				
			||||||
 | 
							if ipMap == nil {
 | 
				
			||||||
 | 
								ipMap = map[string]zero.Zero{}
 | 
				
			||||||
 | 
								this.ipConns[remoteAddr] = ipMap
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if maxConnsPerIP > 0 && maxConnsPerIP <= len(ipMap) {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.remoteAddrMap[rawRemoteAddr] = &ClientConnRemoteAddr{
 | 
				
			||||||
 | 
							remoteAddr: remoteAddr,
 | 
				
			||||||
 | 
							serverId:   serverId,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if maxConnsPerServer > 0 {
 | 
				
			||||||
 | 
							serverMap[rawRemoteAddr] = zero.New()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if maxConnsPerIP > 0 {
 | 
				
			||||||
 | 
							ipMap[rawRemoteAddr] = zero.New()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Remove 删除连接
 | 
				
			||||||
 | 
					func (this *ClientConnLimiter) Remove(rawRemoteAddr string) {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						addr, ok := this.remoteAddrMap[rawRemoteAddr]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						delete(this.remoteAddrMap, rawRemoteAddr)
 | 
				
			||||||
 | 
						delete(this.ipConns[addr.remoteAddr], rawRemoteAddr)
 | 
				
			||||||
 | 
						delete(this.serverConns[addr.serverId], rawRemoteAddr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(this.ipConns[addr.remoteAddr]) == 0 {
 | 
				
			||||||
 | 
							delete(this.ipConns, addr.remoteAddr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(this.serverConns[addr.serverId]) == 0 {
 | 
				
			||||||
 | 
							delete(this.serverConns, addr.serverId)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Conns 获取连接信息
 | 
				
			||||||
 | 
					// 用于调试
 | 
				
			||||||
 | 
					func (this *ClientConnLimiter) Conns() (ipConns map[string][]string, serverConns map[int64][]string) {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ipConns = map[string][]string{}    // ip => [addr1, addr2, ...]
 | 
				
			||||||
 | 
						serverConns = map[int64][]string{} // serverId => [addr1, addr2, ...]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for ip, m := range this.ipConns {
 | 
				
			||||||
 | 
							for addr := range m {
 | 
				
			||||||
 | 
								ipConns[ip] = append(ipConns[ip], addr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for serverId, m := range this.serverConns {
 | 
				
			||||||
 | 
							for addr := range m {
 | 
				
			||||||
 | 
								serverConns[serverId] = append(serverConns[serverId], addr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										38
									
								
								internal/nodes/client_conn_limiter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								internal/nodes/client_conn_limiter_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestClientConnLimiter_Add(t *testing.T) {
 | 
				
			||||||
 | 
						var limiter = NewClientConnLimiter()
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							b := limiter.Add("127.0.0.1:1234", 1, "192.168.1.100", 10, 5)
 | 
				
			||||||
 | 
							t.Log(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							b := limiter.Add("127.0.0.1:1235", 1, "192.168.1.100", 10, 5)
 | 
				
			||||||
 | 
							t.Log(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							b := limiter.Add("127.0.0.1:1236", 1, "192.168.1.100", 10, 5)
 | 
				
			||||||
 | 
							t.Log(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							b := limiter.Add("127.0.0.1:1237", 1, "192.168.1.101", 10, 5)
 | 
				
			||||||
 | 
							t.Log(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							b := limiter.Add("127.0.0.1:1238", 1, "192.168.1.100", 5, 5)
 | 
				
			||||||
 | 
							t.Log(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						limiter.Remove("127.0.0.1:1238")
 | 
				
			||||||
 | 
						limiter.Remove("127.0.0.1:1239")
 | 
				
			||||||
 | 
						limiter.Remove("127.0.0.1:1237")
 | 
				
			||||||
 | 
						logs.PrintAsJSON(limiter.remoteAddrMap, t)
 | 
				
			||||||
 | 
						logs.PrintAsJSON(limiter.ipConns, t)
 | 
				
			||||||
 | 
						logs.PrintAsJSON(limiter.serverConns, t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										40
									
								
								internal/nodes/client_conn_traffic.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								internal/nodes/client_conn_traffic.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,40 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
				
			||||||
 | 
						teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/monitor"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 发送监控流量
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						events.On(events.EventStart, func() {
 | 
				
			||||||
 | 
							ticker := time.NewTicker(1 * time.Minute)
 | 
				
			||||||
 | 
							goman.New(func() {
 | 
				
			||||||
 | 
								for range ticker.C {
 | 
				
			||||||
 | 
									// 加入到数据队列中
 | 
				
			||||||
 | 
									if teaconst.InTrafficBytes > 0 {
 | 
				
			||||||
 | 
										monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficIn, maps.Map{
 | 
				
			||||||
 | 
											"total": teaconst.InTrafficBytes,
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if teaconst.OutTrafficBytes > 0 {
 | 
				
			||||||
 | 
										monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficOut, maps.Map{
 | 
				
			||||||
 | 
											"total": teaconst.OutTrafficBytes,
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// 重置数据
 | 
				
			||||||
 | 
									atomic.StoreUint64(&teaconst.InTrafficBytes, 0)
 | 
				
			||||||
 | 
									atomic.StoreUint64(&teaconst.OutTrafficBytes, 0)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -11,7 +11,7 @@ func isClientConnClosed(conn net.Conn) bool {
 | 
				
			|||||||
	if conn == nil {
 | 
						if conn == nil {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	clientConn, ok := conn.(ClientConnCloser)
 | 
						clientConn, ok := conn.(ClientConnInterface)
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		return clientConn.IsClosed()
 | 
							return clientConn.IsClosed()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,13 +8,13 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ClientTLSConn TLS连接封装
 | 
				
			||||||
type ClientTLSConn struct {
 | 
					type ClientTLSConn struct {
 | 
				
			||||||
	rawConn  *tls.Conn
 | 
						BaseClientConn
 | 
				
			||||||
	isClosed bool
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewClientTLSConn(conn *tls.Conn) net.Conn {
 | 
					func NewClientTLSConn(conn *tls.Conn) net.Conn {
 | 
				
			||||||
	return &ClientTLSConn{rawConn: conn}
 | 
						return &ClientTLSConn{BaseClientConn{rawConn: conn}}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ClientTLSConn) Read(b []byte) (n int, err error) {
 | 
					func (this *ClientTLSConn) Read(b []byte) (n int, err error) {
 | 
				
			||||||
@@ -29,6 +29,10 @@ func (this *ClientTLSConn) Write(b []byte) (n int, err error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (this *ClientTLSConn) Close() error {
 | 
					func (this *ClientTLSConn) Close() error {
 | 
				
			||||||
	this.isClosed = true
 | 
						this.isClosed = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 单个服务并发数限制
 | 
				
			||||||
 | 
						sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return this.rawConn.Close()
 | 
						return this.rawConn.Close()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -51,7 +55,3 @@ func (this *ClientTLSConn) SetReadDeadline(t time.Time) error {
 | 
				
			|||||||
func (this *ClientTLSConn) SetWriteDeadline(t time.Time) error {
 | 
					func (this *ClientTLSConn) SetWriteDeadline(t time.Time) error {
 | 
				
			||||||
	return this.rawConn.SetWriteDeadline(t)
 | 
						return this.rawConn.SetWriteDeadline(t)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (this *ClientTLSConn) IsClosed() bool {
 | 
					 | 
				
			||||||
	return this.isClosed
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -207,6 +207,14 @@ func (this *HTTPRequest) Do() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// 开始调用
 | 
					// 开始调用
 | 
				
			||||||
func (this *HTTPRequest) doBegin() {
 | 
					func (this *HTTPRequest) doBegin() {
 | 
				
			||||||
 | 
						// 处理request limit
 | 
				
			||||||
 | 
						if this.web.RequestLimit != nil &&
 | 
				
			||||||
 | 
							this.web.RequestLimit.IsOn {
 | 
				
			||||||
 | 
							if this.doRequestLimit() {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 处理requestBody
 | 
						// 处理requestBody
 | 
				
			||||||
	if this.RawReq.ContentLength > 0 &&
 | 
						if this.RawReq.ContentLength > 0 &&
 | 
				
			||||||
		this.web.AccessLogRef != nil &&
 | 
							this.web.AccessLogRef != nil &&
 | 
				
			||||||
@@ -441,6 +449,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
 | 
				
			|||||||
		this.web.Auth = web.Auth
 | 
							this.web.Auth = web.Auth
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// request limit
 | 
				
			||||||
 | 
						if web.RequestLimit != nil && (web.RequestLimit.IsPrior || isTop) {
 | 
				
			||||||
 | 
							this.web.RequestLimit = web.RequestLimit
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 重写规则
 | 
						// 重写规则
 | 
				
			||||||
	if len(web.RewriteRefs) > 0 {
 | 
						if len(web.RewriteRefs) > 0 {
 | 
				
			||||||
		for index, ref := range web.RewriteRefs {
 | 
							for index, ref := range web.RewriteRefs {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,16 @@ func (this *HTTPRequest) write404() {
 | 
				
			|||||||
	_, _ = this.writer.Write([]byte("404 page not found: '" + this.requestFullURL() + "'" + " (Request Id: " + this.requestId + ")"))
 | 
						_, _ = this.writer.Write([]byte("404 page not found: '" + this.requestFullURL() + "'" + " (Request Id: " + this.requestId + ")"))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *HTTPRequest) writeCode(code int) {
 | 
				
			||||||
 | 
						if this.doPage(code) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.processResponseHeaders(code)
 | 
				
			||||||
 | 
						this.writer.WriteHeader(code)
 | 
				
			||||||
 | 
						_, _ = this.writer.Write([]byte(types.String(code) + " " + http.StatusText(code) + ": '" + this.requestFullURL() + "'" + " (Request Id: " + this.requestId + ")"))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *HTTPRequest) write50x(err error, statusCode int) {
 | 
					func (this *HTTPRequest) write50x(err error, statusCode int) {
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		this.addError(err)
 | 
							this.addError(err)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										32
									
								
								internal/nodes/http_request_limit.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								internal/nodes/http_request_limit.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
 | 
				
			||||||
 | 
						// 检查请求Body尺寸
 | 
				
			||||||
 | 
						// TODO 处理分片提交的内容
 | 
				
			||||||
 | 
						if this.web.RequestLimit.MaxBodyBytes() > 0 &&
 | 
				
			||||||
 | 
							this.RawReq.ContentLength > this.web.RequestLimit.MaxBodyBytes() {
 | 
				
			||||||
 | 
							this.writeCode(http.StatusRequestEntityTooLarge)
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 设置连接相关参数
 | 
				
			||||||
 | 
						if this.web.RequestLimit.MaxConns > 0 || this.web.RequestLimit.MaxConnsPerIP > 0 {
 | 
				
			||||||
 | 
							requestConn := this.RawReq.Context().Value(HTTPConnContextKey)
 | 
				
			||||||
 | 
							if requestConn != nil {
 | 
				
			||||||
 | 
								clientConn, ok := requestConn.(ClientConnInterface)
 | 
				
			||||||
 | 
								if ok && !clientConn.IsBound() {
 | 
				
			||||||
 | 
									if !clientConn.Bind(this.Server.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
 | 
				
			||||||
 | 
										this.writeCode(http.StatusTooManyRequests)
 | 
				
			||||||
 | 
										this.closeConn()
 | 
				
			||||||
 | 
										return true
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -76,22 +76,6 @@ func NewHTTPWriter(req *HTTPRequest, httpResponseWriter http.ResponseWriter) *HT
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Reset 重置
 | 
					 | 
				
			||||||
func (this *HTTPWriter) Reset(httpResponseWriter http.ResponseWriter) {
 | 
					 | 
				
			||||||
	this.writer = httpResponseWriter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	this.compressionConfig = nil
 | 
					 | 
				
			||||||
	this.compressionWriter = nil
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	this.statusCode = 0
 | 
					 | 
				
			||||||
	this.sentBodyBytes = 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	this.bodyCopying = false
 | 
					 | 
				
			||||||
	this.body = nil
 | 
					 | 
				
			||||||
	this.compressionBodyBuffer = nil
 | 
					 | 
				
			||||||
	this.compressionBodyWriter = nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// SetCompression 设置内容压缩配置
 | 
					// SetCompression 设置内容压缩配置
 | 
				
			||||||
func (this *HTTPWriter) SetCompression(config *serverconfigs.HTTPCompressionConfig) {
 | 
					func (this *HTTPWriter) SetCompression(config *serverconfigs.HTTPCompressionConfig) {
 | 
				
			||||||
	this.compressionConfig = config
 | 
						this.compressionConfig = config
 | 
				
			||||||
@@ -118,6 +102,14 @@ func (this *HTTPWriter) Prepare(size int64, status int) (delayHeaders bool) {
 | 
				
			|||||||
		this.PrepareCompression(size)
 | 
							this.PrepareCompression(size)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 是否限速写入
 | 
				
			||||||
 | 
						if this.req.web != nil &&
 | 
				
			||||||
 | 
							this.req.web.RequestLimit != nil &&
 | 
				
			||||||
 | 
							this.req.web.RequestLimit.IsOn &&
 | 
				
			||||||
 | 
							this.req.web.RequestLimit.OutBandwidthPerConnBytes() > 0 {
 | 
				
			||||||
 | 
							this.writer = NewHTTPRateWriter(this.writer, this.req.web.RequestLimit.OutBandwidthPerConnBytes())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										90
									
								
								internal/nodes/http_writer_rate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								internal/nodes/http_writer_rate.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bufio"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HTTPRateWriter 限速写入
 | 
				
			||||||
 | 
					type HTTPRateWriter struct {
 | 
				
			||||||
 | 
						parentWriter http.ResponseWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rateBytes int
 | 
				
			||||||
 | 
						lastBytes int
 | 
				
			||||||
 | 
						timeCost  time.Duration
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewHTTPRateWriter(writer http.ResponseWriter, rateBytes int64) http.ResponseWriter {
 | 
				
			||||||
 | 
						return &HTTPRateWriter{
 | 
				
			||||||
 | 
							parentWriter: writer,
 | 
				
			||||||
 | 
							rateBytes:    types.Int(rateBytes),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *HTTPRateWriter) Header() http.Header {
 | 
				
			||||||
 | 
						return this.parentWriter.Header()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *HTTPRateWriter) Write(data []byte) (int, error) {
 | 
				
			||||||
 | 
						if len(data) == 0 {
 | 
				
			||||||
 | 
							return 0, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var left = this.rateBytes - this.lastBytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if left <= 0 {
 | 
				
			||||||
 | 
							if this.timeCost > 0 && this.timeCost < 1*time.Second {
 | 
				
			||||||
 | 
								time.Sleep(1*time.Second - this.timeCost)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							this.lastBytes = 0
 | 
				
			||||||
 | 
							this.timeCost = 0
 | 
				
			||||||
 | 
							return this.Write(data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var n = len(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// n <= left
 | 
				
			||||||
 | 
						if n <= left {
 | 
				
			||||||
 | 
							this.lastBytes += n
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							var before = time.Now()
 | 
				
			||||||
 | 
							defer func() {
 | 
				
			||||||
 | 
								this.timeCost += time.Since(before)
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
							return this.parentWriter.Write(data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// n > left
 | 
				
			||||||
 | 
						var before = time.Now()
 | 
				
			||||||
 | 
						result, err := this.parentWriter.Write(data[:left])
 | 
				
			||||||
 | 
						this.timeCost += time.Since(before)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return result, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.lastBytes += left
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return this.Write(data[left:])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *HTTPRateWriter) WriteHeader(statusCode int) {
 | 
				
			||||||
 | 
						this.parentWriter.WriteHeader(statusCode)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Hijack Hijack
 | 
				
			||||||
 | 
					func (this *HTTPRateWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
 | 
				
			||||||
 | 
						if this.parentWriter == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						hijack, ok := this.parentWriter.(http.Hijacker)
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							return hijack.Hijack()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -621,6 +621,15 @@ func (this *Node) listenSock() error {
 | 
				
			|||||||
						"result": result,
 | 
											"result": result,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
 | 
								case "conns":
 | 
				
			||||||
 | 
									ipConns, serverConns := sharedClientConnLimiter.Conns()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									_ = cmd.Reply(&gosock.Command{
 | 
				
			||||||
 | 
										Params: map[string]interface{}{
 | 
				
			||||||
 | 
											"ipConns":     ipConns,
 | 
				
			||||||
 | 
											"serverConns": serverConns,
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										64
									
								
								internal/utils/expires/id_key_map.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								internal/utils/expires/id_key_map.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,64 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package expires
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "sync"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type IdKeyMap struct {
 | 
				
			||||||
 | 
						idKeys map[int64]string // id => key
 | 
				
			||||||
 | 
						keyIds map[string]int64 // key => id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						locker sync.Mutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewIdKeyMap() *IdKeyMap {
 | 
				
			||||||
 | 
						return &IdKeyMap{
 | 
				
			||||||
 | 
							idKeys: map[int64]string{},
 | 
				
			||||||
 | 
							keyIds: map[string]int64{},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IdKeyMap) Add(id int64, key string) {
 | 
				
			||||||
 | 
						oldKey, ok := this.idKeys[id]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							delete(this.keyIds, oldKey)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oldId, ok := this.keyIds[key]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							delete(this.idKeys, oldId)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.idKeys[id] = key
 | 
				
			||||||
 | 
						this.keyIds[key] = id
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IdKeyMap) Key(id int64) (key string, ok bool) {
 | 
				
			||||||
 | 
						key, ok = this.idKeys[id]
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IdKeyMap) Id(key string) (id int64, ok bool) {
 | 
				
			||||||
 | 
						id, ok = this.keyIds[key]
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IdKeyMap) DeleteId(id int64) {
 | 
				
			||||||
 | 
						key, ok := this.idKeys[id]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							delete(this.keyIds, key)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						delete(this.idKeys, id)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IdKeyMap) DeleteKey(key string) {
 | 
				
			||||||
 | 
						id, ok := this.keyIds[key]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							delete(this.idKeys, id)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						delete(this.keyIds, key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IdKeyMap) Len() int {
 | 
				
			||||||
 | 
						return len(this.idKeys)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										46
									
								
								internal/utils/expires/id_key_map_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								internal/utils/expires/id_key_map_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package expires
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestNewIdKeyMap(t *testing.T) {
 | 
				
			||||||
 | 
						var a = assert.NewAssertion(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var m = NewIdKeyMap()
 | 
				
			||||||
 | 
						m.Add(1, "1")
 | 
				
			||||||
 | 
						m.Add(1, "2")
 | 
				
			||||||
 | 
						m.Add(100, "100")
 | 
				
			||||||
 | 
						logs.PrintAsJSON(m.idKeys, t)
 | 
				
			||||||
 | 
						logs.PrintAsJSON(m.keyIds, t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							k, ok := m.Key(1)
 | 
				
			||||||
 | 
							a.IsTrue(ok)
 | 
				
			||||||
 | 
							a.IsTrue(k == "2")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							_, ok := m.Key(2)
 | 
				
			||||||
 | 
							a.IsFalse(ok)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m.DeleteKey("2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							_, ok := m.Key(1)
 | 
				
			||||||
 | 
							a.IsFalse(ok)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						logs.PrintAsJSON(m.idKeys, t)
 | 
				
			||||||
 | 
						logs.PrintAsJSON(m.keyIds, t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m.DeleteId(100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						logs.PrintAsJSON(m.idKeys, t)
 | 
				
			||||||
 | 
						logs.PrintAsJSON(m.keyIds, t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -27,14 +27,19 @@ func NewList() *List {
 | 
				
			|||||||
	return list
 | 
						return list
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add 添加条目
 | 
				
			||||||
 | 
					// 如果条目已经存在,则覆盖
 | 
				
			||||||
func (this *List) Add(itemId int64, expiresAt int64) {
 | 
					func (this *List) Add(itemId int64, expiresAt int64) {
 | 
				
			||||||
	this.locker.Lock()
 | 
						this.locker.Lock()
 | 
				
			||||||
	defer this.locker.Unlock()
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 是否已经存在
 | 
						// 是否已经存在
 | 
				
			||||||
	_, ok := this.itemsMap[itemId]
 | 
						oldExpiresAt, ok := this.itemsMap[itemId]
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		this.removeItem(itemId)
 | 
							if oldExpiresAt == expiresAt {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							delete(this.expireMap, oldExpiresAt)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	expireItemMap, ok := this.expireMap[expiresAt]
 | 
						expireItemMap, ok := this.expireMap[expiresAt]
 | 
				
			||||||
@@ -68,8 +73,9 @@ func (this *List) GC(timestamp int64, callback func(itemId int64)) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *List) OnGC(callback func(itemId int64)) {
 | 
					func (this *List) OnGC(callback func(itemId int64)) *List {
 | 
				
			||||||
	this.gcCallback = callback
 | 
						this.gcCallback = callback
 | 
				
			||||||
 | 
						return this
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *List) removeItem(itemId int64) {
 | 
					func (this *List) removeItem(itemId int64) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,9 +1,11 @@
 | 
				
			|||||||
package expires
 | 
					package expires
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
	"github.com/iwind/TeaGo/logs"
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
	timeutil "github.com/iwind/TeaGo/utils/time"
 | 
						timeutil "github.com/iwind/TeaGo/utils/time"
 | 
				
			||||||
	"math"
 | 
						"math"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -24,12 +26,19 @@ func TestList_Add(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestList_Add_Overwrite(t *testing.T) {
 | 
					func TestList_Add_Overwrite(t *testing.T) {
 | 
				
			||||||
 | 
						var timestamp = time.Now().Unix()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	list := NewList()
 | 
						list := NewList()
 | 
				
			||||||
	list.Add(1, time.Now().Unix()+1)
 | 
						list.Add(1, timestamp+1)
 | 
				
			||||||
	list.Add(1, time.Now().Unix()+1)
 | 
						list.Add(1, timestamp+1)
 | 
				
			||||||
	list.Add(1, time.Now().Unix()+2)
 | 
						list.Add(1, timestamp+2)
 | 
				
			||||||
	logs.PrintAsJSON(list.expireMap, t)
 | 
						logs.PrintAsJSON(list.expireMap, t)
 | 
				
			||||||
	logs.PrintAsJSON(list.itemsMap, t)
 | 
						logs.PrintAsJSON(list.itemsMap, t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var a = assert.NewAssertion(t)
 | 
				
			||||||
 | 
						a.IsTrue(len(list.itemsMap) == 1)
 | 
				
			||||||
 | 
						a.IsTrue(len(list.expireMap) == 1)
 | 
				
			||||||
 | 
						a.IsTrue(list.itemsMap[1] == timestamp+2)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestList_Remove(t *testing.T) {
 | 
					func TestList_Remove(t *testing.T) {
 | 
				
			||||||
@@ -77,7 +86,10 @@ func TestList_Start_GC(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestList_ManyItems(t *testing.T) {
 | 
					func TestList_ManyItems(t *testing.T) {
 | 
				
			||||||
	list := NewList()
 | 
						list := NewList()
 | 
				
			||||||
	for i := 0; i < 100_000; i++ {
 | 
						for i := 0; i < 1_000; i++ {
 | 
				
			||||||
 | 
							list.Add(int64(i), time.Now().Unix())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i := 0; i < 1_000; i++ {
 | 
				
			||||||
		list.Add(int64(i), time.Now().Unix()+1)
 | 
							list.Add(int64(i), time.Now().Unix()+1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -87,35 +99,69 @@ func TestList_ManyItems(t *testing.T) {
 | 
				
			|||||||
		count++
 | 
							count++
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	t.Log("gc", count, "items")
 | 
						t.Log("gc", count, "items")
 | 
				
			||||||
	t.Log(time.Since(now).Seconds()*1000, "ms")
 | 
						t.Log(time.Now().Sub(now))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestList_Map_Performance(t *testing.T) {
 | 
					func TestList_Map_Performance(t *testing.T) {
 | 
				
			||||||
	t.Log("max uint32", math.MaxUint32)
 | 
						t.Log("max uint32", math.MaxUint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var timestamp = time.Now().Unix()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		m := map[int64]int64{}
 | 
							m := map[int64]int64{}
 | 
				
			||||||
		for i := 0; i < 1_000_000; i++ {
 | 
							for i := 0; i < 1_000_000; i++ {
 | 
				
			||||||
			m[int64(i)] = time.Now().Unix()
 | 
								m[int64(i)] = timestamp
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		now := time.Now()
 | 
							now := time.Now()
 | 
				
			||||||
		for i := 0; i < 100_000; i++ {
 | 
							for i := 0; i < 100_000; i++ {
 | 
				
			||||||
			delete(m, int64(i))
 | 
								delete(m, int64(i))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		t.Log(time.Since(now).Seconds()*1000, "ms")
 | 
							t.Log(time.Now().Sub(now))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							m := map[uint64]int64{}
 | 
				
			||||||
 | 
							for i := 0; i < 1_000_000; i++ {
 | 
				
			||||||
 | 
								m[uint64(i)] = timestamp
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							now := time.Now()
 | 
				
			||||||
 | 
							for i := 0; i < 100_000; i++ {
 | 
				
			||||||
 | 
								delete(m, uint64(i))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log(time.Now().Sub(now))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		m := map[uint32]int64{}
 | 
							m := map[uint32]int64{}
 | 
				
			||||||
		for i := 0; i < 1_000_000; i++ {
 | 
							for i := 0; i < 1_000_000; i++ {
 | 
				
			||||||
			m[uint32(i)] = time.Now().Unix()
 | 
								m[uint32(i)] = timestamp
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		now := time.Now()
 | 
							now := time.Now()
 | 
				
			||||||
		for i := 0; i < 100_000; i++ {
 | 
							for i := 0; i < 100_000; i++ {
 | 
				
			||||||
			delete(m, uint32(i))
 | 
								delete(m, uint32(i))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		t.Log(time.Since(now).Seconds()*1000, "ms")
 | 
							t.Log(time.Now().Sub(now))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Benchmark_Map_Uint64(b *testing.B) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
						var timestamp = uint64(time.Now().Unix())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var i uint64
 | 
				
			||||||
 | 
						var count uint64 = 1_000_000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m := map[uint64]uint64{}
 | 
				
			||||||
 | 
						for i = 0; i < count; i++ {
 | 
				
			||||||
 | 
							m[i] = timestamp
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for n := 0; n < b.N; n++ {
 | 
				
			||||||
 | 
							for i = 0; i < count; i++ {
 | 
				
			||||||
 | 
								_ = m[i]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user