diff --git a/internal/caches/storage_file.go b/internal/caches/storage_file.go index 484b71f..cbd390f 100644 --- a/internal/caches/storage_file.go +++ b/internal/caches/storage_file.go @@ -803,9 +803,9 @@ func (this *FileStorage) decodeFile(path string) (*Item, error) { // URL if urlSize > 0 { - data := utils.BytePool1024.Get() + data := utils.BytePool1k.Get() result, ok, err := this.readN(fp, data, int(urlSize)) - utils.BytePool1024.Put(data) + utils.BytePool1k.Put(data) if err != nil { return nil, err } @@ -942,7 +942,8 @@ func (this *FileStorage) hotLoop() { size = len(result) / 10 } - var buf = make([]byte, 32*1024) + var buf = utils.BytePool16k.Get() + defer utils.BytePool16k.Put(buf) for _, item := range result[:size] { reader, err := this.openReader(item.Key, false, false) if err != nil { diff --git a/internal/nodes/client_listener.go b/internal/nodes/client_listener.go index 8aa0d1a..b11017e 100644 --- a/internal/nodes/client_listener.go +++ b/internal/nodes/client_listener.go @@ -37,17 +37,12 @@ func (this *ClientListener) IsTLS() bool { func (this *ClientListener) Accept() (net.Conn, error) { // 限制并发连接数 - var isOk = false var limiter = sharedConnectionsLimiter limiter.Ack() - defer func() { - if !isOk { - limiter.Release() - } - }() conn, err := this.rawListener.Accept() if err != nil { + limiter.Release() return nil, err } @@ -62,11 +57,11 @@ func (this *ClientListener) Accept() (net.Conn, error) { } _ = conn.Close() + limiter.Release() return this.Accept() } } - isOk = true return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil } diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 03a5843..40a4363 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -28,12 +28,6 @@ import ( // 环境变量 var HOSTNAME, _ = os.Hostname() -// byte pool -var bytePool256b = utils.NewBytePool(20480, 256) -var bytePool1k = utils.NewBytePool(20480, 1024) -var bytePool32k = utils.NewBytePool(20480, 32*1024) -var bytePool128k = utils.NewBytePool(20480, 128*1024) - // errors var errWritingToClient = errors.New("writing to client error") @@ -1303,19 +1297,16 @@ func (this *HTTPRequest) addError(err error) { // 计算合适的buffer size func (this *HTTPRequest) bytePool(contentLength int64) *utils.BytePool { - if contentLength <= 0 { - return bytePool1k - } - if contentLength < 1024 { // 1K - return bytePool256b + if contentLength < 8192 { // 8K + return utils.BytePool1k } if contentLength < 32768 { // 32K - return bytePool1k + return utils.BytePool4k } - if contentLength < 1048576 { // 1M - return bytePool32k + if contentLength < 131072 { // 128K + return utils.BytePool16k } - return bytePool128k + return utils.BytePool32k } // 检查是否可以忽略错误 diff --git a/internal/nodes/http_request_cache.go b/internal/nodes/http_request_cache.go index 715efe1..f2b3261 100644 --- a/internal/nodes/http_request_cache.go +++ b/internal/nodes/http_request_cache.go @@ -194,13 +194,14 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } // 准备Buffer - buf := bytePool32k.Get() + var pool = this.bytePool(reader.BodySize()) + var buf = pool.Get() defer func() { - bytePool32k.Put(buf) + pool.Put(buf) }() // 读取Header - headerBuf := []byte{} + var headerBuf = []byte{} err = reader.ReadHeader(buf, func(n int) (goNext bool, err error) { headerBuf = append(headerBuf, buf[:n]...) for { diff --git a/internal/nodes/http_request_page.go b/internal/nodes/http_request_page.go index e7abc54..bfe57e4 100644 --- a/internal/nodes/http_request_page.go +++ b/internal/nodes/http_request_page.go @@ -68,11 +68,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) { this.writer.Prepare(stat.Size(), status) this.writer.WriteHeader(status) } - buf := bytePool1k.Get() + buf := utils.BytePool1k.Get() _, err = utils.CopyWithFilter(this.writer, fp, buf, func(p []byte) []byte { return []byte(this.Format(string(p))) }) - bytePool1k.Put(buf) + utils.BytePool1k.Put(buf) if err != nil { if !this.canIgnore(err) { remotelogs.Warn("HTTP_REQUEST_PAGE", "write to client failed: "+err.Error()) diff --git a/internal/nodes/http_request_shutdown.go b/internal/nodes/http_request_shutdown.go index b9d260a..c767138 100644 --- a/internal/nodes/http_request_shutdown.go +++ b/internal/nodes/http_request_shutdown.go @@ -64,11 +64,11 @@ func (this *HTTPRequest) doShutdown() { this.processResponseHeaders(http.StatusOK) this.writer.WriteHeader(http.StatusOK) } - buf := bytePool1k.Get() + buf := utils.BytePool1k.Get() _, err = utils.CopyWithFilter(this.writer, fp, buf, func(p []byte) []byte { return []byte(this.Format(string(p))) }) - bytePool1k.Put(buf) + utils.BytePool1k.Put(buf) if err != nil { if !this.canIgnore(err) { remotelogs.Warn("HTTP_REQUEST_SHUTDOWN", "write to client failed: "+err.Error()) diff --git a/internal/nodes/http_request_utils.go b/internal/nodes/http_request_utils.go index fa44f1b..955de7c 100644 --- a/internal/nodes/http_request_utils.go +++ b/internal/nodes/http_request_utils.go @@ -2,12 +2,10 @@ package nodes import ( "crypto/rand" - "crypto/tls" "fmt" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/utils" "io" - "net" "net/http" "strconv" "strings" @@ -153,14 +151,5 @@ func httpRequestNextId() string { } // timestamp + requestId + nodeId - return strconv.FormatInt(unixTime, 10) + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1))) + teaconst.NodeIdString -} - -// 检查连接是否为TLS连接 -func httpIsTLSConn(conn net.Conn) bool { - if conn == nil { - return false - } - _, ok := conn.(*tls.Conn) - return ok + return strconv.FormatInt(unixTime, 10) + teaconst.NodeIdString + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1))) } diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index d235904..68bc58f 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -3,6 +3,7 @@ package nodes import ( "errors" "github.com/TeaOSLab/EdgeNode/internal/goman" + "github.com/TeaOSLab/EdgeNode/internal/utils" "io" "net/http" "net/url" @@ -66,7 +67,8 @@ func (this *HTTPRequest) doWebsocket() { }() goman.New(func() { - buf := make([]byte, 4*1024) // TODO 使用内存池 + var buf = utils.BytePool4k.Get() + defer utils.BytePool4k.Put(buf) for { n, err := originConn.Read(buf) if n > 0 { diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index 6f43087..089060b 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -7,6 +7,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/stats" + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/pires/go-proxyproto" "net" "strings" @@ -112,9 +113,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error { // 从源站读取 goman.New(func() { - originBuffer := bytePool32k.Get() + originBuffer := utils.BytePool16k.Get() defer func() { - bytePool32k.Put(originBuffer) + utils.BytePool16k.Put(originBuffer) }() for { n, err := originConn.Read(originBuffer) @@ -138,9 +139,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error { }) // 从客户端读取 - clientBuffer := bytePool32k.Get() + clientBuffer := utils.BytePool16k.Get() defer func() { - bytePool32k.Put(clientBuffer) + utils.BytePool16k.Put(clientBuffer) }() for { n, err := conn.Read(clientBuffer) diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index 448738c..0f5aec7 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -190,9 +190,9 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne } goman.New(func() { - buffer := bytePool32k.Get() + buffer := utils.BytePool4k.Get() defer func() { - bytePool32k.Put(buffer) + utils.BytePool4k.Put(buffer) }() for { diff --git a/internal/utils/byte_pool.go b/internal/utils/byte_pool.go index 60b87d2..7078179 100644 --- a/internal/utils/byte_pool.go +++ b/internal/utils/byte_pool.go @@ -1,16 +1,25 @@ package utils -var BytePool1024 = NewBytePool(20480, 1024) +import ( + "github.com/TeaOSLab/EdgeNode/internal/goman" + "github.com/iwind/TeaGo/Tea" + "time" +) -// pool for get byte slice +var BytePool1k = NewBytePool(20480, 1024) +var BytePool4k = NewBytePool(20480, 4*1024) +var BytePool16k = NewBytePool(40960, 16*1024) +var BytePool32k = NewBytePool(20480, 32*1024) + +// BytePool pool for get byte slice type BytePool struct { - c chan []byte - length int - - lastSize int + c chan []byte + maxSize int + length int + hasNew bool } -// 创建新对象 +// NewBytePool 创建新对象 func NewBytePool(maxSize, length int) *BytePool { if maxSize <= 0 { maxSize = 1024 @@ -18,24 +27,47 @@ func NewBytePool(maxSize, length int) *BytePool { if length <= 0 { length = 128 } - pool := &BytePool{ - c: make(chan []byte, maxSize), - length: length, + var pool = &BytePool{ + c: make(chan []byte, maxSize), + maxSize: maxSize, + length: length, } + + pool.init() + return pool } -// 获取一个新的byte slice +// 初始化 +func (this *BytePool) init() { + var ticker = time.NewTicker(2 * time.Minute) + if Tea.IsTesting() { + ticker = time.NewTicker(5 * time.Second) + } + goman.New(func() { + for range ticker.C { + if this.hasNew { + this.hasNew = false + continue + } + + this.Purge() + } + }) +} + +// Get 获取一个新的byte slice func (this *BytePool) Get() (b []byte) { select { case b = <-this.c: default: b = make([]byte, this.length) + this.hasNew = true } return } -// 放回一个使用过的byte slice +// Put 放回一个使用过的byte slice func (this *BytePool) Put(b []byte) { if cap(b) != this.length { return @@ -47,7 +79,30 @@ func (this *BytePool) Put(b []byte) { } } -// 当前的数量 +// Length 单个字节slice长度 +func (this *BytePool) Length() int { + return this.length +} + +// Size 当前的数量 func (this *BytePool) Size() int { return len(this.c) } + +// Purge 清理 +func (this *BytePool) Purge() { + // 1% + var count = len(this.c) / 100 + if count == 0 { + return + } + +Loop: + for i := 0; i < count; i++ { + select { + case <-this.c: + default: + break Loop + } + } +} diff --git a/internal/utils/byte_pool_test.go b/internal/utils/byte_pool_test.go index 8076065..0d305e6 100644 --- a/internal/utils/byte_pool_test.go +++ b/internal/utils/byte_pool_test.go @@ -27,6 +27,26 @@ func TestNewBytePool(t *testing.T) { a.IsTrue(len(pool.c) == 5) } +func TestBytePool_Memory(t *testing.T) { + var stat1 = &runtime.MemStats{} + runtime.ReadMemStats(stat1) + + var pool = NewBytePool(20480, 32*1024) + for i := 0; i < 20480; i++ { + pool.Put(make([]byte, 32*1024)) + } + + //pool.Purge() + + //time.Sleep(60 * time.Second) + + runtime.GC() + + var stat2 = &runtime.MemStats{} + runtime.ReadMemStats(stat2) + t.Log((stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB,", pool.Size(), "slices") +} + func BenchmarkBytePool_Get(b *testing.B) { runtime.GOMAXPROCS(1) diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index 2add6a5..d3e14c6 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -64,7 +64,6 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque timeout = 60 // 默认封锁60秒 } - SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, group.Id, set.Id) if writer != nil { @@ -99,8 +98,9 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque } } - buf := make([]byte, 1024) + buf := utils.BytePool1k.Get() _, _ = io.CopyBuffer(writer, resp.Body, buf) + utils.BytePool1k.Put(buf) } else { path := this.URL if !filepath.IsAbs(this.URL) { diff --git a/internal/waf/checkpoints/request_headers_test.go b/internal/waf/checkpoints/request_headers_test.go new file mode 100644 index 0000000..83d29b8 --- /dev/null +++ b/internal/waf/checkpoints/request_headers_test.go @@ -0,0 +1,33 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package checkpoints + +import ( + "net/http" + "runtime" + "sort" + "strings" + "testing" +) + +func BenchmarkRequestHeadersCheckpoint_RequestValue(b *testing.B) { + runtime.GOMAXPROCS(1) + + var header = http.Header{ + "Content-Type": []string{"keep-alive"}, + "User-Agent": []string{"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"}, + "Accept-Encoding": []string{"gzip, deflate, br"}, + "Referer": []string{"https://goedge.cn/"}, + } + + for i := 0; i < b.N; i++ { + var headers = []string{} + for k, v := range header { + for _, subV := range v { + headers = append(headers, k+": "+subV) + } + } + sort.Strings(headers) + _ = strings.Join(headers, "\n") + } +}