From 158cb258f6049826c5e4d0bb6626e053e405caa8 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Tue, 15 Feb 2022 14:55:49 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E5=AF=B9HTTP=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E7=9A=84=E5=A4=84=E7=90=86=E6=96=B9=E6=B3=95=EF=BC=9A?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E3=80=81=E5=8E=8B=E7=BC=A9=E3=80=81WebP?= =?UTF-8?q?=E3=80=81=E9=99=90=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/caches/reader_file.go | 26 +- internal/nodes/http_request.go | 6 +- internal/nodes/http_request_cache.go | 36 +- internal/nodes/http_request_fastcgi.go | 2 +- internal/nodes/http_request_page.go | 8 +- internal/nodes/http_request_reverse_proxy.go | 4 +- internal/nodes/http_request_root.go | 2 +- internal/nodes/http_request_url.go | 4 +- internal/nodes/http_writer.go | 911 +++++++++--------- internal/nodes/http_writer_rate.go | 102 -- .../utils/readers/bytes_counter_reader.go | 26 + internal/utils/readers/filter_reader.go | 34 + internal/utils/readers/filter_reader_test.go | 41 + internal/utils/readers/tee_reader.go | 52 + internal/utils/readers/tee_reader_closer.go | 58 ++ .../utils/writers/bytes_counter_writer.go | 28 + internal/utils/writers/rate_limit_writer.go | 87 ++ .../utils/writers/rate_limit_writer_test.go | 41 + internal/utils/writers/tee_writer_closer.go | 51 + 19 files changed, 903 insertions(+), 616 deletions(-) delete mode 100644 internal/nodes/http_writer_rate.go create mode 100644 internal/utils/readers/bytes_counter_reader.go create mode 100644 internal/utils/readers/filter_reader.go create mode 100644 internal/utils/readers/filter_reader_test.go create mode 100644 internal/utils/readers/tee_reader.go create mode 100644 internal/utils/readers/tee_reader_closer.go create mode 100644 internal/utils/writers/bytes_counter_writer.go create mode 100644 internal/utils/writers/rate_limit_writer.go create mode 100644 internal/utils/writers/rate_limit_writer_test.go create mode 100644 internal/utils/writers/tee_writer_closer.go diff --git a/internal/caches/reader_file.go b/internal/caches/reader_file.go index f4442ae..5317ff8 100644 --- a/internal/caches/reader_file.go +++ b/internal/caches/reader_file.go @@ -278,17 +278,27 @@ func (this *FileReader) Read(buf []byte) (n int, err error) { }() // 直接返回从Header中剩余的 - if this.bodyBufLen > 0 && len(buf) >= this.bodyBufLen { - copy(buf, this.bodyBuf) - isOk = true - n = this.bodyBufLen + if this.bodyBufLen > 0 { + var bufLen = len(buf) + if bufLen < this.bodyBufLen { + this.bodyBufLen -= bufLen + copy(buf, this.bodyBuf[:bufLen]) + this.bodyBuf = this.bodyBuf[bufLen:] - if this.bodySize <= int64(this.bodyBufLen) { - err = io.EOF - return + n = bufLen + } else { + copy(buf, this.bodyBuf) + this.bodyBuf = nil + + if this.bodySize <= int64(this.bodyBufLen) { + err = io.EOF + } + + n = this.bodyBufLen + this.bodyBufLen = 0 } - this.bodyBufLen = 0 + isOk = true return } diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 44df236..c268ffd 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -315,12 +315,12 @@ func (this *HTTPRequest) doEnd() { // TODO 增加Header统计,考虑从Conn中读取 if this.ReqServer != nil { if this.isCached { - stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) + stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes(), this.writer.SentBodyBytes(), 1, 1, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) } else { if this.isAttack { - stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) + stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes(), 0, 1, 0, 1, this.writer.SentBodyBytes(), this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) } else { - stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.sentBodyBytes, 0, 1, 0, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) + stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes(), 0, 1, 0, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) } } } diff --git a/internal/nodes/http_request_cache.go b/internal/nodes/http_request_cache.go index 1d21cf2..b63ee81 100644 --- a/internal/nodes/http_request_cache.go +++ b/internal/nodes/http_request_cache.go @@ -5,7 +5,6 @@ import ( "errors" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeNode/internal/caches" - "github.com/TeaOSLab/EdgeNode/internal/compressions" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" @@ -162,11 +161,15 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { var err error // 是否优先检查WebP + var isWebP = false if this.web.WebP != nil && this.web.WebP.IsOn && this.web.WebP.MatchRequest(filepath.Ext(this.Path()), this.Format) && this.web.WebP.MatchAccept(this.requestHeader("Accept")) { reader, _ = storage.OpenReader(key+webpSuffix, useStale) + if reader != nil { + isWebP = true + } } // 检查正常的文件 @@ -189,8 +192,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { return } } + defer func() { - _ = reader.Close() + if !this.writer.DelayRead() { + _ = reader.Close() + } }() if useStale { @@ -257,7 +263,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { var eTag = "" var lastModifiedAt = reader.LastModified() if lastModifiedAt > 0 { - eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\"" + if isWebP { + eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "_webp" + "\"" + } else { + eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\"" + } respHeader.Del("Etag") respHeader["ETag"] = []string{eTag} } @@ -439,25 +449,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { return true } } else { // 没有Range - var body io.Reader = reader - var contentEncoding = this.writer.Header().Get("Content-Encoding") - if len(contentEncoding) > 0 && !httpAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"), contentEncoding) { - decompressReader, err := compressions.NewReader(body, contentEncoding) - if err == nil { - body = decompressReader - defer func() { - _ = decompressReader.Close() - }() - - this.writer.Header().Del("Content-Encoding") - this.writer.Header().Del("Content-Length") - } - } - - this.writer.PrepareCompression(reader.BodySize()) + var resp = &http.Response{Body: reader} + this.writer.Prepare(resp, reader.BodySize(), reader.Status(), false) this.writer.WriteHeader(reader.Status()) - _, err = io.CopyBuffer(this.writer, body, buf) + _, err = io.CopyBuffer(this.writer, resp.Body, buf) if err == io.EOF { err = nil } diff --git a/internal/nodes/http_request_fastcgi.go b/internal/nodes/http_request_fastcgi.go index 557a92b..dd22f79 100644 --- a/internal/nodes/http_request_fastcgi.go +++ b/internal/nodes/http_request_fastcgi.go @@ -190,7 +190,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) { this.processResponseHeaders(resp.StatusCode) // 准备 - this.writer.Prepare(resp.ContentLength, resp.StatusCode) + this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true) // 设置响应代码 this.writer.WriteHeader(resp.StatusCode) diff --git a/internal/nodes/http_request_page.go b/internal/nodes/http_request_page.go index bfe57e4..6380788 100644 --- a/internal/nodes/http_request_page.go +++ b/internal/nodes/http_request_page.go @@ -61,11 +61,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) { if page.NewStatus > 0 { // 自定义响应Headers this.processResponseHeaders(page.NewStatus) - this.writer.Prepare(stat.Size(), page.NewStatus) + this.writer.Prepare(nil, stat.Size(), page.NewStatus, true) this.writer.WriteHeader(page.NewStatus) } else { this.processResponseHeaders(status) - this.writer.Prepare(stat.Size(), status) + this.writer.Prepare(nil, stat.Size(), status, true) this.writer.WriteHeader(status) } buf := utils.BytePool1k.Get() @@ -100,11 +100,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) { if page.NewStatus > 0 { // 自定义响应Headers this.processResponseHeaders(page.NewStatus) - this.writer.Prepare(int64(len(content)), page.NewStatus) + this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true) this.writer.WriteHeader(page.NewStatus) } else { this.processResponseHeaders(status) - this.writer.Prepare(int64(len(content)), status) + this.writer.Prepare(nil, int64(len(content)), status, true) this.writer.WriteHeader(status) } diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 0b3f745..c4feeba 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -285,10 +285,10 @@ func (this *HTTPRequest) doReverseProxy() { this.processResponseHeaders(resp.StatusCode) // 是否需要刷新 - shouldAutoFlush := this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream" + var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream" // 准备 - delayHeaders := this.writer.Prepare(resp.ContentLength, resp.StatusCode) + var delayHeaders = this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true) // 设置响应代码 if !delayHeaders { diff --git a/internal/nodes/http_request_root.go b/internal/nodes/http_request_root.go index a2bf24b..f22493b 100644 --- a/internal/nodes/http_request_root.go +++ b/internal/nodes/http_request_root.go @@ -302,7 +302,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { this.cacheRef = nil // 不支持缓存 } - this.writer.Prepare(fileSize, http.StatusOK) + this.writer.Prepare(nil, fileSize, http.StatusOK, true) pool := this.bytePool(fileSize) buf := pool.Get() diff --git a/internal/nodes/http_request_url.go b/internal/nodes/http_request_url.go index 3617771..f4a3996 100644 --- a/internal/nodes/http_request_url.go +++ b/internal/nodes/http_request_url.go @@ -54,9 +54,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod } this.writer.AddHeaders(resp.Header) if statusCode <= 0 { - this.writer.Prepare(resp.ContentLength, resp.StatusCode) + this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true) } else { - this.writer.Prepare(resp.ContentLength, statusCode) + this.writer.Prepare(resp, resp.ContentLength, statusCode, true) } // 设置响应代码 diff --git a/internal/nodes/http_writer.go b/internal/nodes/http_writer.go index 60521cb..90fb313 100644 --- a/internal/nodes/http_writer.go +++ b/internal/nodes/http_writer.go @@ -5,15 +5,14 @@ package nodes import ( "bufio" "bytes" - "compress/flate" - "compress/gzip" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/caches" "github.com/TeaOSLab/EdgeNode/internal/compressions" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/utils" - "github.com/andybalholm/brotli" + "github.com/TeaOSLab/EdgeNode/internal/utils/readers" + "github.com/TeaOSLab/EdgeNode/internal/utils/writers" _ "github.com/biessek/golang-ico" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" @@ -26,6 +25,7 @@ import ( _ "image/jpeg" _ "image/png" "io" + "io/ioutil" "net" "net/http" "os" @@ -34,77 +34,75 @@ import ( "sync/atomic" ) -// 限制WebP能够同时使用的Buffer内存使用量 -const webpMaxBufferSize int64 = 1_000_000_000 +// webp相关配置 const webpSuffix = "@GOEDGE_WEBP" +var webpMaxBufferSize int64 = 1_000_000_000 var webpTotalBufferSize int64 = 0 -var webpBufferPool = utils.NewBufferPool(1024) + +func init() { + var systemMemory = utils.SystemMemoryGB() / 8 + if systemMemory > 0 { + webpMaxBufferSize = int64(systemMemory) * 1024 * 1024 * 1024 + } +} // HTTPWriter 响应Writer type HTTPWriter struct { - req *HTTPRequest - writer http.ResponseWriter + req *HTTPRequest + rawWriter http.ResponseWriter + + rawReader io.ReadCloser + delayRead bool + + counterWriter *writers.BytesCounterWriter + writer io.WriteCloser size int64 - webpIsEncoding bool - webpBuffer *bytes.Buffer - webpIsWriting bool - webpOriginContentType string - webpOriginEncoding string // gzip - - compressionConfig *serverconfigs.HTTPCompressionConfig - compressionWriter compressions.Writer - compressionType serverconfigs.HTTPCompressionType - statusCode int sentBodyBytes int64 - bodyCopying bool - body []byte - compressionBodyBuffer *bytes.Buffer // 当使用压缩时使用 - compressionBodyWriter compressions.Writer // 当使用压缩时使用 - - cacheWriter caches.Writer // 缓存写入 - cacheStorage caches.StorageInterface - isOk bool // 是否完全成功 isFinished bool // 是否已完成 + + // WebP + webpIsEncoding bool + webpOriginContentType string + + // Compression + compressionConfig *serverconfigs.HTTPCompressionConfig + + // Cache + cacheStorage caches.StorageInterface + cacheWriter caches.Writer + cacheIsFinished bool } // NewHTTPWriter 包装对象 func NewHTTPWriter(req *HTTPRequest, httpResponseWriter http.ResponseWriter) *HTTPWriter { + var counterWriter = writers.NewBytesCounterWriter(httpResponseWriter) return &HTTPWriter{ - req: req, - writer: httpResponseWriter, + req: req, + rawWriter: httpResponseWriter, + writer: counterWriter, + counterWriter: counterWriter, } } -// SetCompression 设置内容压缩配置 -func (this *HTTPWriter) SetCompression(config *serverconfigs.HTTPCompressionConfig) { - this.compressionConfig = config -} - // Prepare 准备输出 -// 缓存不调用此函数 -func (this *HTTPWriter) Prepare(size int64, status int) (delayHeaders bool) { +func (this *HTTPWriter) Prepare(resp *http.Response, size int64, status int, enableCache bool) (delayHeaders bool) { this.size = size this.statusCode = status - if status == http.StatusOK { - this.prepareWebP(size) + if resp != nil { + this.rawReader = resp.Body - if this.webpIsEncoding { - delayHeaders = true + if enableCache { + this.PrepareCache(resp, size) } - } - - this.prepareCache(size) - - // 在WebP模式下,压缩暂不可用 - if !this.webpIsEncoding { - this.PrepareCompression(size) + this.PrepareWebP(resp, size) + this.PrepareCompression(resp, size) } // 是否限速写入 @@ -112,38 +110,308 @@ func (this *HTTPWriter) Prepare(size int64, status int) (delayHeaders bool) { this.req.web.RequestLimit != nil && this.req.web.RequestLimit.IsOn && this.req.web.RequestLimit.OutBandwidthPerConnBytes() > 0 { - this.writer = NewHTTPRateWriter(this.writer, this.req.web.RequestLimit.OutBandwidthPerConnBytes()) + this.writer = writers.NewRateLimitWriter(this.writer, this.req.web.RequestLimit.OutBandwidthPerConnBytes()) } return } +// PrepareCache 准备缓存 +func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { + if resp == nil { + return + } + + var cachePolicy = this.req.ReqServer.HTTPCachePolicy + if cachePolicy == nil || !cachePolicy.IsOn { + return + } + + var cacheRef = this.req.cacheRef + if cacheRef == nil || !cacheRef.IsOn { + return + } + + var addStatusHeader = this.req.web != nil && this.req.web.Cache != nil && this.req.web.Cache.AddStatusHeader + + // 不支持Range + if len(this.Header().Get("Content-Range")) > 0 { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, not supported Content-Range") + } + return + } + + // 如果允许 ChunkedEncoding,就无需尺寸的判断,因为此时的 size 为 -1 + if !cacheRef.AllowChunkedEncoding && size < 0 { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, ChunkedEncoding") + } + return + } + if size >= 0 && ((cacheRef.MaxSizeBytes() > 0 && size > cacheRef.MaxSizeBytes()) || + (cachePolicy.MaxSizeBytes() > 0 && size > cachePolicy.MaxSizeBytes()) || (cacheRef.MinSizeBytes() > size)) { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, Content-Length") + } + return + } + + // 检查状态 + if len(cacheRef.Status) > 0 && !lists.ContainsInt(cacheRef.Status, this.StatusCode()) { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, Status: "+types.String(this.StatusCode())) + } + return + } + + // Cache-Control + if len(cacheRef.SkipResponseCacheControlValues) > 0 { + var cacheControl = this.Header().Get("Cache-Control") + if len(cacheControl) > 0 { + values := strings.Split(cacheControl, ",") + for _, value := range values { + if cacheRef.ContainsCacheControl(strings.TrimSpace(value)) { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, Cache-Control: "+cacheControl) + } + return + } + } + } + } + + // Set-Cookie + if cacheRef.SkipResponseSetCookie && len(this.Header().Get("Set-Cookie")) > 0 { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, Set-Cookie") + } + return + } + + // 校验其他条件 + if cacheRef.Conds != nil && cacheRef.Conds.HasResponseConds() && !cacheRef.Conds.MatchResponse(this.req.Format) { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, ResponseConds") + } + return + } + + // 打开缓存写入 + var storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id) + if storage == nil { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, Storage") + } + return + } + + this.req.varMapping["cache.status"] = "UPDATING" + if addStatusHeader { + this.Header().Set("X-Cache", "UPDATING") + } + + this.cacheStorage = storage + life := cacheRef.LifeSeconds() + + if life <= 0 { + life = 60 + } + + // 支持源站设置的max-age + if this.req.web.Cache != nil && this.req.web.Cache.EnableCacheControlMaxAge { + var cacheControl = this.Header().Get("Cache-Control") + var pieces = strings.Split(cacheControl, ";") + for _, piece := range pieces { + var eqIndex = strings.Index(piece, "=") + if eqIndex > 0 && piece[:eqIndex] == "max-age" { + var maxAge = types.Int64(piece[eqIndex+1:]) + if maxAge > 0 { + life = maxAge + } + } + } + } + + var expiredAt = utils.UnixTime() + life + var cacheKey = this.req.cacheKey + cacheWriter, err := storage.OpenWriter(cacheKey, expiredAt, this.StatusCode()) + if err != nil { + if !caches.CanIgnoreErr(err) { + remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error()) + } + return + } + this.cacheWriter = cacheWriter + + // 写入Header + for k, v := range this.Header() { + for _, v1 := range v { + _, err = cacheWriter.WriteHeader([]byte(k + ":" + v1 + "\n")) + if err != nil { + remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error()) + _ = this.cacheWriter.Discard() + this.cacheWriter = nil + return + } + } + } + + var cacheReader = readers.NewTeeReaderCloser(resp.Body, this.cacheWriter) + resp.Body = cacheReader + this.rawReader = cacheReader + + cacheReader.OnFail(func(err error) { + _ = this.cacheWriter.Discard() + this.cacheWriter = nil + }) + cacheReader.OnEOF(func() { + this.cacheIsFinished = true + }) +} + +// PrepareWebP 准备WebP +func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) { + if resp == nil { + return + } + + var contentType = this.Header().Get("Content-Type") + + if this.req.web != nil && + this.req.web.WebP != nil && + this.req.web.WebP.IsOn && + this.req.web.WebP.MatchResponse(contentType, size, filepath.Ext(this.req.Path()), this.req.Format) && + this.req.web.WebP.MatchAccept(this.req.requestHeader("Accept")) { + // 如果已经是WebP不再重复处理 + // TODO 考虑是否需要很严格的匹配 + if strings.Contains(contentType, "image/webp") { + return + } + + // 检查内存 + if atomic.LoadInt64(&webpTotalBufferSize) >= webpMaxBufferSize { + return + } + + var contentEncoding = resp.Header.Get("Content-Encoding") + switch contentEncoding { + case "gzip", "deflate", "br": + reader, err := compressions.NewReader(resp.Body, contentEncoding) + if err != nil { + return + } + this.Header().Del("Content-Encoding") + this.rawReader = reader + case "": // 空 + default: + return + } + + this.webpOriginContentType = contentType + this.webpIsEncoding = true + resp.Body = ioutil.NopCloser(&bytes.Buffer{}) + this.delayRead = true + + this.Header().Del("Content-Length") + this.Header().Set("Content-Type", "image/webp") + } +} + +// PrepareCompression 准备压缩 +func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) { + if this.compressionConfig == nil || !this.compressionConfig.IsOn || this.compressionConfig.Level <= 0 { + return + } + + // 如果已经有编码则不处理 + var contentEncoding = this.rawWriter.Header().Get("Content-Encoding") + if len(contentEncoding) > 0 && (!this.compressionConfig.DecompressData || !lists.ContainsString([]string{"gzip", "deflate", "br"}, contentEncoding)) { + return + } + + // 尺寸和类型 + if !this.compressionConfig.MatchResponse(this.Header().Get("Content-Type"), size, filepath.Ext(this.req.Path()), this.req.Format) { + return + } + + // 判断Accept是否支持压缩 + compressionType, compressionEncoding, ok := this.compressionConfig.MatchAcceptEncoding(this.req.RawReq.Header.Get("Accept-Encoding")) + if !ok { + return + } + + // 压缩前后如果编码一致,则不处理 + if compressionEncoding == contentEncoding { + return + } + + if len(contentEncoding) > 0 && resp != nil { + if !this.compressionConfig.DecompressData { + return + } + + reader, err := compressions.NewReader(resp.Body, contentEncoding) + if err != nil { + return + } + resp.Body = reader + } + + // compression writer + var err error = nil + compressionWriter, err := compressions.NewWriter(this.writer, compressionType, int(this.compressionConfig.Level)) + if err != nil { + remotelogs.Error("HTTP_WRITER", err.Error()) + return + } + this.writer = compressionWriter + + header := this.rawWriter.Header() + header.Set("Content-Encoding", compressionEncoding) + header.Set("Vary", "Accept-Encoding") + header.Del("Content-Length") +} + +// SetCompression 设置内容压缩配置 +func (this *HTTPWriter) SetCompression(config *serverconfigs.HTTPCompressionConfig) { + this.compressionConfig = config +} + // Raw 包装前的原始的Writer func (this *HTTPWriter) Raw() http.ResponseWriter { - return this.writer + return this.rawWriter } // Header 获取Header func (this *HTTPWriter) Header() http.Header { - if this.writer == nil { + if this.rawWriter == nil { return http.Header{} } - return this.writer.Header() + return this.rawWriter.Header() } // DeleteHeader 删除Header func (this *HTTPWriter) DeleteHeader(name string) { - this.writer.Header().Del(name) + this.rawWriter.Header().Del(name) } // SetHeader 设置Header func (this *HTTPWriter) SetHeader(name string, values []string) { - this.writer.Header()[name] = values + this.rawWriter.Header()[name] = values } // AddHeaders 添加一组Header func (this *HTTPWriter) AddHeaders(header http.Header) { - if this.writer == nil { + if this.rawWriter == nil { return } for key, value := range header { @@ -151,52 +419,17 @@ func (this *HTTPWriter) AddHeaders(header http.Header) { continue } for _, v := range value { - this.writer.Header().Add(key, v) + this.rawWriter.Header().Add(key, v) } } } // Write 写入数据 func (this *HTTPWriter) Write(data []byte) (n int, err error) { - n = len(data) - - if this.writer != nil { - if this.webpIsEncoding && !this.webpIsWriting { - this.webpBuffer.Write(data) - } else { - // 写入压缩 - var n1 int - if this.compressionWriter != nil { - n1, err = this.compressionWriter.Write(data) - } else { - n1, err = this.writer.Write(data) - } - if n1 > 0 { - this.sentBodyBytes += int64(n1) - } - - // 写入缓存 - if this.cacheWriter != nil { - _, err = this.cacheWriter.Write(data) - if err != nil { - _ = this.cacheWriter.Discard() - this.cacheWriter = nil - remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error()) - } - } - - if this.bodyCopying { - if this.compressionBodyWriter != nil { - _, err := this.compressionBodyWriter.Write(data) - if err != nil { - remotelogs.Error("HTTP_WRITER", err.Error()) - } - } else { - this.body = append(this.body, data...) - } - } - } + if this.webpIsEncoding { + return } + n, err = this.writer.Write(data) return } @@ -213,8 +446,8 @@ func (this *HTTPWriter) SentBodyBytes() int64 { // WriteHeader 写入状态码 func (this *HTTPWriter) WriteHeader(statusCode int) { - if this.writer != nil { - this.writer.WriteHeader(statusCode) + if this.rawWriter != nil { + this.rawWriter.WriteHeader(statusCode) } this.statusCode = statusCode } @@ -267,24 +500,9 @@ func (this *HTTPWriter) StatusCode() int { return this.statusCode } -// SetBodyCopying 设置拷贝Body数据 -func (this *HTTPWriter) SetBodyCopying(b bool) { - this.bodyCopying = b -} - -// BodyIsCopying 判断是否在拷贝Body数据 -func (this *HTTPWriter) BodyIsCopying() bool { - return this.bodyCopying -} - -// Body 读取拷贝的Body数据 -func (this *HTTPWriter) Body() []byte { - return this.body -} - // HeaderData 读取Header二进制数据 func (this *HTTPWriter) HeaderData() []byte { - if this.writer == nil { + if this.rawWriter == nil { return nil } @@ -311,152 +529,136 @@ func (this *HTTPWriter) SetOk() { // Close 关闭 func (this *HTTPWriter) Close() { + // 处理WebP if this.webpIsEncoding { - defer func() { - atomic.AddInt64(&webpTotalBufferSize, -this.size*32) - webpBufferPool.Put(this.webpBuffer) - }() - } + var webpCacheWriter caches.Writer - // webp writer - if this.isOk && this.webpIsEncoding { - var bufferLen = int64(this.webpBuffer.Len()) - atomic.AddInt64(&webpTotalBufferSize, bufferLen*4) + // 准备WebP Cache + if this.cacheWriter != nil { + var cacheKey = this.cacheWriter.Key() + webpSuffix + + webpCacheWriter, _ = this.cacheStorage.OpenWriter(cacheKey, this.cacheWriter.ExpiredAt(), this.StatusCode()) + if webpCacheWriter != nil { + // 写入Header + for k, v := range this.Header() { + for _, v1 := range v { + _, err := webpCacheWriter.WriteHeader([]byte(k + ":" + v1 + "\n")) + if err != nil { + remotelogs.Error("HTTP_WRITER", "write webp cache failed: "+err.Error()) + _ = webpCacheWriter.Discard() + webpCacheWriter = nil + break + } + } + } + + if webpCacheWriter != nil { + var teeWriter = writers.NewTeeWriterCloser(this.writer, webpCacheWriter) + teeWriter.OnFail(func(err error) { + _ = webpCacheWriter.Discard() + webpCacheWriter = nil + }) + this.writer = teeWriter + } + } + } + + var reader = readers.NewBytesCounterReader(this.rawReader) - // 需要把字节读取出来做备份,防止在image.Decode()过程中丢失 - var imageBytes = this.webpBuffer.Bytes() var imageData image.Image var gifImage *gif.GIF var isGif = strings.Contains(this.webpOriginContentType, "image/gif") - var err error - if this.webpOriginEncoding == "gzip" { - this.Header().Del("Content-Encoding") - var reader *gzip.Reader - reader, err = gzip.NewReader(this.webpBuffer) - if err == nil { - defer func() { - _ = reader.Close() - }() - if isGif { - gifImage, err = gif.DecodeAll(reader) - } else { - imageData, _, err = image.Decode(reader) - } - } - } else if this.webpOriginEncoding == "deflate" { - this.Header().Del("Content-Encoding") - var reader io.ReadCloser - reader = flate.NewReader(this.webpBuffer) - defer func() { - _ = reader.Close() - }() - if isGif { - gifImage, err = gif.DecodeAll(reader) - } else { - imageData, _, err = image.Decode(reader) - } - } else if this.webpOriginEncoding == "br" { - this.Header().Del("Content-Encoding") - var reader *brotli.Reader - reader = brotli.NewReader(this.webpBuffer) - if isGif { - gifImage, err = gif.DecodeAll(reader) - } else { - imageData, _, err = image.Decode(reader) - } + if isGif { + gifImage, err = gif.DecodeAll(reader) } else { - if isGif { - gifImage, err = gif.DecodeAll(this.webpBuffer) - } else { - imageData, _, err = image.Decode(this.webpBuffer) - } + imageData, _, err = image.Decode(reader) } + if err != nil { - this.Header().Set("Content-Type", this.webpOriginContentType) - this.WriteHeader(http.StatusOK) - _, _ = this.writer.Write(imageBytes) + return + } - // 处理缓存 - if this.cacheWriter != nil { - _ = this.cacheWriter.Discard() - } - this.cacheWriter = nil - } else { - var f = types.Float32(this.req.web.WebP.Quality) - if f > 100 { - f = 100 - } - this.webpIsWriting = true + var totalBytes = reader.TotalBytes() + atomic.AddInt64(&webpTotalBufferSize, totalBytes) + defer func() { + atomic.AddInt64(&webpTotalBufferSize, -totalBytes) + }() - if imageData != nil { - err = gowebp.Encode(this, imageData, &gowebp.Options{ - Lossless: false, - Quality: f, - Exact: true, - }) - } else if gifImage != nil { - anim := gowebp.NewWebpAnimation(gifImage.Config.Width, gifImage.Config.Height, gifImage.LoopCount) - anim.WebPAnimEncoderOptions.SetKmin(9) - anim.WebPAnimEncoderOptions.SetKmax(17) - defer anim.ReleaseMemory() - webpConfig := gowebp.NewWebpConfig() - //webpConfig.SetLossless(1) - webpConfig.SetQuality(f) + var f = types.Float32(this.req.web.WebP.Quality) + if f > 100 { + f = 100 + } - timeline := 0 + if imageData != nil { + err = gowebp.Encode(this.writer, imageData, &gowebp.Options{ + Lossless: false, + Quality: f, + Exact: true, + }) + } else if gifImage != nil { + anim := gowebp.NewWebpAnimation(gifImage.Config.Width, gifImage.Config.Height, gifImage.LoopCount) + anim.WebPAnimEncoderOptions.SetKmin(9) + anim.WebPAnimEncoderOptions.SetKmax(17) + defer anim.ReleaseMemory() + webpConfig := gowebp.NewWebpConfig() + //webpConfig.SetLossless(1) + webpConfig.SetQuality(f) - for i, img := range gifImage.Image { - err = anim.AddFrame(img, timeline, webpConfig) - if err != nil { - break - } - timeline += gifImage.Delay[i] * 10 + timeline := 0 + + for i, img := range gifImage.Image { + err = anim.AddFrame(img, timeline, webpConfig) + if err != nil { + break } + timeline += gifImage.Delay[i] * 10 + } + if err == nil { + err = anim.AddFrame(nil, timeline, webpConfig) + if err == nil { - err = anim.AddFrame(nil, timeline, webpConfig) - - if err == nil { - err = anim.Encode(this) - } + err = anim.Encode(this.writer) } } + } + + if err != nil && !this.req.canIgnore(err) { + remotelogs.Error("HTTP_WRITER", "'"+this.req.URL()+"' encode webp failed: "+err.Error()) + } + + if err == nil && webpCacheWriter != nil { + err = webpCacheWriter.Close() if err != nil { - if !this.req.canIgnore(err) { - remotelogs.Error("HTTP_WRITER", "encode webp failed: "+err.Error()) - } - - this.Header().Set("Content-Type", this.webpOriginContentType) - this.WriteHeader(http.StatusOK) - _, _ = this.writer.Write(imageBytes) - - // 处理缓存 - if this.cacheWriter != nil { - _ = this.cacheWriter.Discard() - } - this.cacheWriter = nil + _ = webpCacheWriter.Discard() + } else { + this.cacheStorage.AddToList(&caches.Item{ + Type: webpCacheWriter.ItemType(), + Key: webpCacheWriter.Key(), + ExpiredAt: webpCacheWriter.ExpiredAt(), + StaleAt: webpCacheWriter.ExpiredAt() + int64(this.calculateStaleLife()), + HeaderSize: webpCacheWriter.HeaderSize(), + BodySize: webpCacheWriter.BodySize(), + Host: this.req.ReqHost, + ServerId: this.req.ReqServer.Id, + }) } } - - atomic.AddInt64(&webpTotalBufferSize, -bufferLen*4) - this.webpBuffer.Reset() } - // compression writer - if this.compressionWriter != nil { - if this.bodyCopying && this.compressionBodyWriter != nil { - _ = this.compressionBodyWriter.Close() - this.body = this.compressionBodyBuffer.Bytes() - } - _ = this.compressionWriter.Close() - this.compressionWriter = nil + if this.writer != nil { + _ = this.writer.Close() } - // cache writer + if this.rawReader != nil { + _ = this.rawReader.Close() + } + + // 缓存 if this.cacheWriter != nil { - if this.isOk { + if this.isOk && this.cacheIsFinished { // 对比Content-Length - contentLengthString := this.Header().Get("Content-Length") + var contentLengthString = this.Header().Get("Content-Length") if len(contentLengthString) > 0 { contentLength := types.Int64(contentLengthString) if contentLength != this.cacheWriter.BodySize() { @@ -485,11 +687,13 @@ func (this *HTTPWriter) Close() { _ = this.cacheWriter.Discard() } } + + this.sentBodyBytes = this.counterWriter.TotalBytes() } // Hijack Hijack func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) { - hijack, ok := this.writer.(http.Hijacker) + hijack, ok := this.rawWriter.(http.Hijacker) if ok { return hijack.Hijack() } @@ -498,254 +702,15 @@ func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err erro // Flush Flush func (this *HTTPWriter) Flush() { - flusher, ok := this.writer.(http.Flusher) + flusher, ok := this.rawWriter.(http.Flusher) if ok { flusher.Flush() } } -// 准备Webp -func (this *HTTPWriter) prepareWebP(size int64) { - if this.req.web != nil && - this.req.web.WebP != nil && - this.req.web.WebP.IsOn && - this.req.web.WebP.MatchResponse(this.Header().Get("Content-Type"), size, filepath.Ext(this.req.Path()), this.req.Format) && - this.req.web.WebP.MatchAccept(this.req.requestHeader("Accept")) && - atomic.LoadInt64(&webpTotalBufferSize) < webpMaxBufferSize { - - var contentEncoding = this.writer.Header().Get("Content-Encoding") - switch contentEncoding { - case "gzip", "deflate", "br": - this.webpOriginEncoding = contentEncoding - case "": // 空 - default: - return - } - - this.webpIsEncoding = true - this.webpOriginContentType = this.Header().Get("Content-Type") - this.webpBuffer = webpBufferPool.Get() - - this.Header().Del("Content-Length") - this.Header().Set("Content-Type", "image/webp") - - atomic.AddInt64(&webpTotalBufferSize, size*32) - } -} - -// PrepareCompression 准备压缩 -func (this *HTTPWriter) PrepareCompression(size int64) { - if this.compressionConfig == nil || !this.compressionConfig.IsOn || this.compressionConfig.Level <= 0 { - return - } - - // 如果已经有编码则不处理 - var contentEncoding = this.writer.Header().Get("Content-Encoding") - if len(contentEncoding) > 0 && (!this.compressionConfig.DecompressData || !lists.ContainsString([]string{"gzip", "deflate", "br"}, contentEncoding)) { - return - } - - // 尺寸和类型 - if !this.compressionConfig.MatchResponse(this.Header().Get("Content-Type"), size, filepath.Ext(this.req.Path()), this.req.Format) { - return - } - - // 判断Accept是否支持压缩 - compressionType, compressionEncoding, ok := this.compressionConfig.MatchAcceptEncoding(this.req.RawReq.Header.Get("Accept-Encoding")) - if !ok { - return - } - - // 压缩前后如果编码一致,则不处理 - if compressionEncoding == contentEncoding { - return - } - - this.compressionType = compressionType - - // compression writer - var err error = nil - this.compressionWriter, err = compressions.NewWriter(this.writer, compressionType, int(this.compressionConfig.Level)) - if err != nil { - remotelogs.Error("HTTP_WRITER", err.Error()) - return - } - - // convert between encodings - if len(contentEncoding) > 0 { - this.compressionWriter, err = compressions.NewEncodingWriter(contentEncoding, this.compressionWriter) - if err != nil { - remotelogs.Error("HTTP_WRITER", err.Error()) - return - } - } - - // body copy - if this.bodyCopying { - this.compressionBodyBuffer = bytes.NewBuffer([]byte{}) - this.compressionBodyWriter, err = compressions.NewWriter(this.compressionBodyBuffer, compressionType, int(this.compressionConfig.Level)) - if err != nil { - remotelogs.Error("HTTP_WRITER", err.Error()) - } - } - - header := this.writer.Header() - header.Set("Content-Encoding", compressionEncoding) - header.Set("Vary", "Accept-Encoding") - header.Del("Content-Length") -} - -// 准备缓存 -func (this *HTTPWriter) prepareCache(size int64) { - if this.writer == nil { - return - } - - cachePolicy := this.req.ReqServer.HTTPCachePolicy - if cachePolicy == nil || !cachePolicy.IsOn { - return - } - - cacheRef := this.req.cacheRef - if cacheRef == nil || !cacheRef.IsOn { - return - } - - var addStatusHeader = this.req.web != nil && this.req.web.Cache != nil && this.req.web.Cache.AddStatusHeader - - // 不支持Range - if len(this.Header().Get("Content-Range")) > 0 { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, not supported Content-Range") - } - return - } - - // 如果允许 ChunkedEncoding,就无需尺寸的判断,因为此时的 size 为 -1 - if !cacheRef.AllowChunkedEncoding && size < 0 { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, ChunkedEncoding") - } - return - } - if size >= 0 && ((cacheRef.MaxSizeBytes() > 0 && size > cacheRef.MaxSizeBytes()) || - (cachePolicy.MaxSizeBytes() > 0 && size > cachePolicy.MaxSizeBytes()) || (cacheRef.MinSizeBytes() > size)) { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, Content-Length") - } - return - } - - // 检查状态 - if len(cacheRef.Status) > 0 && !lists.ContainsInt(cacheRef.Status, this.StatusCode()) { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, Status: "+types.String(this.StatusCode())) - } - return - } - - // Cache-Control - if len(cacheRef.SkipResponseCacheControlValues) > 0 { - cacheControl := this.writer.Header().Get("Cache-Control") - if len(cacheControl) > 0 { - values := strings.Split(cacheControl, ",") - for _, value := range values { - if cacheRef.ContainsCacheControl(strings.TrimSpace(value)) { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, Cache-Control: "+cacheControl) - } - return - } - } - } - } - - // Set-Cookie - if cacheRef.SkipResponseSetCookie && len(this.writer.Header().Get("Set-Cookie")) > 0 { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, Set-Cookie") - } - return - } - - // 校验其他条件 - if cacheRef.Conds != nil && cacheRef.Conds.HasResponseConds() && !cacheRef.Conds.MatchResponse(this.req.Format) { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, ResponseConds") - } - return - } - - // 打开缓存写入 - storage := caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id) - if storage == nil { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, Storage") - } - return - } - - this.req.varMapping["cache.status"] = "UPDATING" - if addStatusHeader { - this.Header().Set("X-Cache", "UPDATING") - } - - this.cacheStorage = storage - life := cacheRef.LifeSeconds() - - if life <= 0 { - life = 60 - } - - // 支持源站设置的max-age - if this.req.web.Cache != nil && this.req.web.Cache.EnableCacheControlMaxAge { - var cacheControl = this.Header().Get("Cache-Control") - var pieces = strings.Split(cacheControl, ";") - for _, piece := range pieces { - var eqIndex = strings.Index(piece, "=") - if eqIndex > 0 && piece[:eqIndex] == "max-age" { - var maxAge = types.Int64(piece[eqIndex+1:]) - if maxAge > 0 { - life = maxAge - } - } - } - } - - expiredAt := utils.UnixTime() + life - var cacheKey = this.req.cacheKey - if this.webpIsEncoding { - cacheKey += webpSuffix - } - cacheWriter, err := storage.OpenWriter(cacheKey, expiredAt, this.StatusCode()) - if err != nil { - if !caches.CanIgnoreErr(err) { - remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error()) - } - return - } - this.cacheWriter = cacheWriter - - // 写入Header - for k, v := range this.Header() { - for _, v1 := range v { - _, err = cacheWriter.WriteHeader([]byte(k + ":" + v1 + "\n")) - if err != nil { - remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error()) - _ = this.cacheWriter.Discard() - this.cacheWriter = nil - return - } - } - } +// DelayRead 是否延迟读取Reader +func (this *HTTPWriter) DelayRead() bool { + return this.delayRead } // 计算stale时长 diff --git a/internal/nodes/http_writer_rate.go b/internal/nodes/http_writer_rate.go deleted file mode 100644 index 1449dcb..0000000 --- a/internal/nodes/http_writer_rate.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. - -package nodes - -import ( - "bufio" - "github.com/iwind/TeaGo/types" - "net" - "net/http" - "time" -) - -// HTTPRateWriter 限速写入 -type HTTPRateWriter struct { - parentWriter http.ResponseWriter - - rateBytes int - lastBytes int - timeCost time.Duration -} - -func NewHTTPRateWriter(writer http.ResponseWriter, rateBytes int64) http.ResponseWriter { - return &HTTPRateWriter{ - parentWriter: writer, - rateBytes: types.Int(rateBytes), - } -} - -func (this *HTTPRateWriter) Header() http.Header { - return this.parentWriter.Header() -} - -func (this *HTTPRateWriter) Write(data []byte) (int, error) { - if len(data) == 0 { - return 0, nil - } - - var left = this.rateBytes - this.lastBytes - - if left <= 0 { - if this.timeCost > 0 && this.timeCost < 1*time.Second { - time.Sleep(1*time.Second - this.timeCost) - } - - this.lastBytes = 0 - this.timeCost = 0 - return this.Write(data) - } - - var n = len(data) - - // n <= left - if n <= left { - this.lastBytes += n - - var before = time.Now() - defer func() { - this.timeCost += time.Since(before) - }() - return this.parentWriter.Write(data) - } - - // n > left - var before = time.Now() - result, err := this.parentWriter.Write(data[:left]) - this.timeCost += time.Since(before) - - if err != nil { - return result, err - } - this.lastBytes += left - - return this.Write(data[left:]) -} - -func (this *HTTPRateWriter) WriteHeader(statusCode int) { - this.parentWriter.WriteHeader(statusCode) -} - -// Hijack Hijack -func (this *HTTPRateWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) { - if this.parentWriter == nil { - return - } - hijack, ok := this.parentWriter.(http.Hijacker) - if ok { - return hijack.Hijack() - } - return -} - -// Flush Flush -func (this *HTTPRateWriter) Flush() { - if this.parentWriter == nil { - return - } - flusher, ok := this.parentWriter.(http.Flusher) - if ok { - flusher.Flush() - return - } -} diff --git a/internal/utils/readers/bytes_counter_reader.go b/internal/utils/readers/bytes_counter_reader.go new file mode 100644 index 0000000..cfa276f --- /dev/null +++ b/internal/utils/readers/bytes_counter_reader.go @@ -0,0 +1,26 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers + +import "io" + +type BytesCounterReader struct { + rawReader io.Reader + count int64 +} + +func NewBytesCounterReader(rawReader io.Reader) *BytesCounterReader { + return &BytesCounterReader{ + rawReader: rawReader, + } +} + +func (this *BytesCounterReader) Read(p []byte) (n int, err error) { + n, err = this.rawReader.Read(p) + this.count += int64(n) + return +} + +func (this *BytesCounterReader) TotalBytes() int64 { + return this.count +} diff --git a/internal/utils/readers/filter_reader.go b/internal/utils/readers/filter_reader.go new file mode 100644 index 0000000..4e462e1 --- /dev/null +++ b/internal/utils/readers/filter_reader.go @@ -0,0 +1,34 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers + +import "io" + +type FilterFunc = func(p []byte, err error) error + +type FilterReader struct { + rawReader io.Reader + filters []FilterFunc +} + +func NewFilterReader(rawReader io.Reader) *FilterReader { + return &FilterReader{ + rawReader: rawReader, + } +} + +func (this *FilterReader) Add(filter FilterFunc) { + this.filters = append(this.filters, filter) +} + +func (this *FilterReader) Read(p []byte) (n int, err error) { + n, err = this.rawReader.Read(p) + for _, filter := range this.filters { + filterErr := filter(p[:n], err) + if filterErr != nil { + err = filterErr + return + } + } + return +} diff --git a/internal/utils/readers/filter_reader_test.go b/internal/utils/readers/filter_reader_test.go new file mode 100644 index 0000000..87808b2 --- /dev/null +++ b/internal/utils/readers/filter_reader_test.go @@ -0,0 +1,41 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers_test + +import ( + "bytes" + "errors" + "github.com/TeaOSLab/EdgeNode/internal/utils/readers" + "testing" +) + +func TestNewFilterReader(t *testing.T) { + var reader = readers.NewFilterReader(bytes.NewBufferString("0123456789")) + reader.Add(func(p []byte, err error) error { + t.Log("filter1:", string(p), err) + return nil + }) + reader.Add(func(p []byte, err error) error { + t.Log("filter2:", string(p), err) + if string(p) == "345" { + return errors.New("end") + } + return nil + }) + reader.Add(func(p []byte, err error) error { + t.Log("filter3:", string(p), err) + return nil + }) + + var buf = make([]byte, 3) + for { + n, err := reader.Read(buf) + if n > 0 { + t.Log(string(buf[:n])) + } + if err != nil { + t.Log(err) + break + } + } +} diff --git a/internal/utils/readers/tee_reader.go b/internal/utils/readers/tee_reader.go new file mode 100644 index 0000000..9abe601 --- /dev/null +++ b/internal/utils/readers/tee_reader.go @@ -0,0 +1,52 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers + +import ( + "io" +) + +type TeeReader struct { + r io.Reader + w io.Writer + + onFail func(err error) + onEOF func() +} + +func NewTeeReader(reader io.Reader, writer io.Writer) *TeeReader { + return &TeeReader{ + r: reader, + w: writer, + } +} + +func (this *TeeReader) Read(p []byte) (n int, err error) { + n, err = this.r.Read(p) + if n > 0 { + _, wErr := this.w.Write(p[:n]) + if err == nil && wErr != nil { + err = wErr + } + } + if err != nil { + if err == io.EOF { + if this.onEOF != nil { + this.onEOF() + } + } else { + if this.onFail != nil { + this.onFail(err) + } + } + } + return +} + +func (this *TeeReader) OnFail(onFail func(err error)) { + this.onFail = onFail +} + +func (this *TeeReader) OnEOF(onEOF func()) { + this.onEOF = onEOF +} diff --git a/internal/utils/readers/tee_reader_closer.go b/internal/utils/readers/tee_reader_closer.go new file mode 100644 index 0000000..87e529c --- /dev/null +++ b/internal/utils/readers/tee_reader_closer.go @@ -0,0 +1,58 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers + +import "io" + +type TeeReaderCloser struct { + r io.Reader + w io.Writer + + onFail func(err error) + onEOF func() +} + +func NewTeeReaderCloser(reader io.Reader, writer io.Writer) *TeeReaderCloser { + return &TeeReaderCloser{ + r: reader, + w: writer, + } +} + +func (this *TeeReaderCloser) Read(p []byte) (n int, err error) { + n, err = this.r.Read(p) + if n > 0 { + _, wErr := this.w.Write(p[:n]) + if err == nil && wErr != nil { + err = wErr + } + } + if err != nil { + if err == io.EOF { + if this.onEOF != nil { + this.onEOF() + } + } else { + if this.onFail != nil { + this.onFail(err) + } + } + } + return +} + +func (this *TeeReaderCloser) Close() error { + r, ok := this.r.(io.Closer) + if ok { + return r.Close() + } + return nil +} + +func (this *TeeReaderCloser) OnFail(onFail func(err error)) { + this.onFail = onFail +} + +func (this *TeeReaderCloser) OnEOF(onEOF func()) { + this.onEOF = onEOF +} diff --git a/internal/utils/writers/bytes_counter_writer.go b/internal/utils/writers/bytes_counter_writer.go new file mode 100644 index 0000000..9a5bdaf --- /dev/null +++ b/internal/utils/writers/bytes_counter_writer.go @@ -0,0 +1,28 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package writers + +import "io" + +type BytesCounterWriter struct { + writer io.Writer + count int64 +} + +func NewBytesCounterWriter(rawWriter io.Writer) *BytesCounterWriter { + return &BytesCounterWriter{writer: rawWriter} +} + +func (this *BytesCounterWriter) Write(p []byte) (n int, err error) { + n, err = this.writer.Write(p) + this.count += int64(n) + return +} + +func (this *BytesCounterWriter) Close() error { + return nil +} + +func (this *BytesCounterWriter) TotalBytes() int64 { + return this.count +} diff --git a/internal/utils/writers/rate_limit_writer.go b/internal/utils/writers/rate_limit_writer.go new file mode 100644 index 0000000..154cc9b --- /dev/null +++ b/internal/utils/writers/rate_limit_writer.go @@ -0,0 +1,87 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package writers + +import ( + "github.com/iwind/TeaGo/types" + "io" + "time" +) + +// RateLimitWriter 限速写入 +type RateLimitWriter struct { + rawWriter io.WriteCloser + + rateBytes int + + written int + before time.Time +} + +func NewRateLimitWriter(rawWriter io.WriteCloser, rateBytes int64) io.WriteCloser { + return &RateLimitWriter{ + rawWriter: rawWriter, + rateBytes: types.Int(rateBytes), + before: time.Now(), + } +} + +func (this *RateLimitWriter) Write(p []byte) (n int, err error) { + if this.rateBytes <= 0 { + return this.write(p) + } + + var size = len(p) + if size == 0 { + return 0, nil + } + + if size <= this.rateBytes { + return this.write(p) + } + + for { + size = len(p) + + var limit = this.rateBytes + if limit > size { + limit = size + } + n1, wErr := this.write(p[:limit]) + n += n1 + if wErr != nil { + return n, wErr + } + + if size > limit { + p = p[limit:] + } else { + break + } + } + + return +} + +func (this *RateLimitWriter) Close() error { + return this.rawWriter.Close() +} + +func (this *RateLimitWriter) write(p []byte) (n int, err error) { + n, err = this.rawWriter.Write(p) + + if err == nil { + this.written += n + + if this.written >= this.rateBytes { + var duration = 1*time.Second - time.Now().Sub(this.before) + if duration > 0 { + time.Sleep(duration) + } + this.before = time.Now() + this.written = 0 + } + } + + return +} diff --git a/internal/utils/writers/rate_limit_writer_test.go b/internal/utils/writers/rate_limit_writer_test.go new file mode 100644 index 0000000..dce5827 --- /dev/null +++ b/internal/utils/writers/rate_limit_writer_test.go @@ -0,0 +1,41 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package writers + +import ( + "sync" + "testing" + "time" +) + +func TestSleep(t *testing.T) { + var count = 2000 + var wg = sync.WaitGroup{} + wg.Add(count) + var before = time.Now() + for i := 0; i < count; i++ { + go func() { + defer wg.Done() + time.Sleep(1 * time.Second) + }() + } + wg.Wait() + t.Log(time.Since(before).Seconds()*1000, "ms") +} + +func TestTimeout(t *testing.T) { + var count = 2000 + var wg = sync.WaitGroup{} + wg.Add(count) + var before = time.Now() + for i := 0; i < count; i++ { + go func() { + defer wg.Done() + + var timeout = time.NewTimer(1 * time.Second) + <-timeout.C + }() + } + wg.Wait() + t.Log(time.Since(before).Seconds()*1000, "ms") +} diff --git a/internal/utils/writers/tee_writer_closer.go b/internal/utils/writers/tee_writer_closer.go new file mode 100644 index 0000000..b5d908b --- /dev/null +++ b/internal/utils/writers/tee_writer_closer.go @@ -0,0 +1,51 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package writers + +import "io" + +type TeeWriterCloser struct { + primaryW io.WriteCloser + secondaryW io.WriteCloser + + onFail func(err error) +} + +func NewTeeWriterCloser(primaryW io.WriteCloser, secondaryW io.WriteCloser) *TeeWriterCloser { + return &TeeWriterCloser{ + primaryW: primaryW, + secondaryW: secondaryW, + } +} + +func (this *TeeWriterCloser) Write(p []byte) (n int, err error) { + { + n, err = this.primaryW.Write(p) + + if err != nil { + if this.onFail != nil { + this.onFail(err) + } + } + } + + { + _, err2 := this.secondaryW.Write(p) + if err2 != nil { + if this.onFail != nil { + this.onFail(err2) + } + } + } + + return +} + +func (this *TeeWriterCloser) Close() error { + // 这里不关闭secondary + return this.primaryW.Close() +} + +func (this *TeeWriterCloser) OnFail(onFail func(err error)) { + this.onFail = onFail +}