diff --git a/internal/caches/partial_ranges.go b/internal/caches/partial_ranges.go index d5f94f8..1623bb0 100644 --- a/internal/caches/partial_ranges.go +++ b/internal/caches/partial_ranges.go @@ -4,28 +4,37 @@ package caches import ( "encoding/json" + "errors" + "io/ioutil" ) // PartialRanges 内容分区范围定义 type PartialRanges struct { - ranges [][2]int64 + Ranges [][2]int64 `json:"ranges"` } // NewPartialRanges 获取新对象 func NewPartialRanges() *PartialRanges { - return &PartialRanges{ranges: [][2]int64{}} + return &PartialRanges{Ranges: [][2]int64{}} } // NewPartialRangesFromJSON 从JSON中解析范围 func NewPartialRangesFromJSON(data []byte) (*PartialRanges, error) { - var rs = [][2]int64{} + var rs = NewPartialRanges() err := json.Unmarshal(data, &rs) if err != nil { return nil, err } - var r = NewPartialRanges() - r.ranges = rs - return r, nil + + return rs, nil +} + +func NewPartialRangesFromFile(path string) (*PartialRanges, error) { + data, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + return NewPartialRangesFromJSON(data) } // Add 添加新范围 @@ -36,46 +45,41 @@ func (this *PartialRanges) Add(begin int64, end int64) { var nr = [2]int64{begin, end} - var count = len(this.ranges) + var count = len(this.Ranges) if count == 0 { - this.ranges = [][2]int64{nr} + this.Ranges = [][2]int64{nr} return } // insert // TODO 将来使用二分法改进 var index = -1 - for i, r := range this.ranges { + for i, r := range this.Ranges { if r[0] > begin || (r[0] == begin && r[1] >= end) { index = i - this.ranges = append(this.ranges, [2]int64{}) - copy(this.ranges[index+1:], this.ranges[index:]) - this.ranges[index] = nr + this.Ranges = append(this.Ranges, [2]int64{}) + copy(this.Ranges[index+1:], this.Ranges[index:]) + this.Ranges[index] = nr break } } if index == -1 { index = count - this.ranges = append(this.ranges, nr) + this.Ranges = append(this.Ranges, nr) } this.merge(index) } -// Ranges 获取所有范围 -func (this *PartialRanges) Ranges() [][2]int64 { - return this.ranges -} - // Contains 检查是否包含某个范围 func (this *PartialRanges) Contains(begin int64, end int64) bool { - if len(this.ranges) == 0 { - return true + if len(this.Ranges) == 0 { + return false } // TODO 使用二分法查找改进性能 - for _, r2 := range this.ranges { + for _, r2 := range this.Ranges { if r2[0] <= begin && r2[1] >= end { return true } @@ -86,21 +90,46 @@ func (this *PartialRanges) Contains(begin int64, end int64) bool { // AsJSON 转换为JSON func (this *PartialRanges) AsJSON() ([]byte, error) { - return json.Marshal(this.ranges) + return json.Marshal(this) +} + +// WriteToFile 写入到文件中 +func (this *PartialRanges) WriteToFile(path string) error { + data, err := this.AsJSON() + if err != nil { + return errors.New("convert to json failed: " + err.Error()) + } + return ioutil.WriteFile(path, data, 0666) +} + +// ReadFromFile 从文件中读取 +func (this *PartialRanges) ReadFromFile(path string) (*PartialRanges, error) { + data, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + return NewPartialRangesFromJSON(data) +} + +func (this *PartialRanges) Max() int64 { + if len(this.Ranges) > 0 { + return this.Ranges[len(this.Ranges)-1][1] + } + return 0 } func (this *PartialRanges) merge(index int) { // forward var lastIndex = index for i := index; i >= 1; i-- { - var curr = this.ranges[i] - var prev = this.ranges[i-1] + var curr = this.Ranges[i] + var prev = this.Ranges[i-1] var w1 = this.w(curr) var w2 = this.w(prev) if w1+w2 >= this.max(curr[1], prev[1])-this.min(curr[0], prev[0])-1 { prev = [2]int64{this.min(curr[0], prev[0]), this.max(curr[1], prev[1])} - this.ranges[i-1] = prev - this.ranges = append(this.ranges[:i], this.ranges[i+1:]...) + this.Ranges[i-1] = prev + this.Ranges = append(this.Ranges[:i], this.Ranges[i+1:]...) lastIndex = i - 1 } else { break @@ -109,15 +138,15 @@ func (this *PartialRanges) merge(index int) { // backward index = lastIndex - for index < len(this.ranges)-1 { - var curr = this.ranges[index] - var next = this.ranges[index+1] + for index < len(this.Ranges)-1 { + var curr = this.Ranges[index] + var next = this.Ranges[index+1] var w1 = this.w(curr) var w2 = this.w(next) if w1+w2 >= this.max(curr[1], next[1])-this.min(curr[0], next[0])-1 { curr = [2]int64{this.min(curr[0], next[0]), this.max(curr[1], next[1])} - this.ranges = append(this.ranges[:index], this.ranges[index+1:]...) - this.ranges[index] = curr + this.Ranges = append(this.Ranges[:index], this.Ranges[index+1:]...) + this.Ranges[index] = curr } else { break } diff --git a/internal/caches/partial_ranges_test.go b/internal/caches/partial_ranges_test.go index 5476970..900d9b2 100644 --- a/internal/caches/partial_ranges_test.go +++ b/internal/caches/partial_ranges_test.go @@ -21,7 +21,8 @@ func TestNewPartialRanges(t *testing.T) { r.Add(200, 1000) r.Add(200, 10040) - logs.PrintAsJSON(r.Ranges()) + logs.PrintAsJSON(r.Ranges, t) + t.Log("max:", r.Max()) } func TestNewPartialRanges1(t *testing.T) { @@ -35,7 +36,7 @@ func TestNewPartialRanges1(t *testing.T) { r.Add(200, 300) r.Add(1, 1000) - var rs = r.Ranges() + var rs = r.Ranges logs.PrintAsJSON(rs, t) a.IsTrue(len(rs) == 1) if len(rs) == 1 { @@ -56,7 +57,7 @@ func TestNewPartialRanges2(t *testing.T) { r.Add(303, 304) r.Add(250, 400) - var rs = r.Ranges() + var rs = r.Ranges logs.PrintAsJSON(rs, t) } @@ -68,7 +69,7 @@ func TestNewPartialRanges3(t *testing.T) { r.Add(200, 300) r.Add(250, 400) - var rs = r.Ranges() + var rs = r.Ranges logs.PrintAsJSON(rs, t) } @@ -83,7 +84,7 @@ func TestNewPartialRanges4(t *testing.T) { r.Add(410, 415) r.Add(400, 409) - var rs = r.Ranges() + var rs = r.Ranges logs.PrintAsJSON(rs, t) t.Log(r.Contains(400, 416)) } @@ -93,7 +94,7 @@ func TestNewPartialRanges5(t *testing.T) { for j := 0; j < 1000; j++ { r.Add(int64(j), int64(j+100)) } - logs.PrintAsJSON(r.Ranges(), t) + logs.PrintAsJSON(r.Ranges, t) } func TestNewPartialRanges_AsJSON(t *testing.T) { @@ -111,7 +112,7 @@ func TestNewPartialRanges_AsJSON(t *testing.T) { if err != nil { t.Fatal(err) } - t.Log(r2.Ranges()) + t.Log(r2.Ranges) } func BenchmarkNewPartialRanges(b *testing.B) { diff --git a/internal/caches/reader.go b/internal/caches/reader.go index 4ae9f21..d387cd1 100644 --- a/internal/caches/reader.go +++ b/internal/caches/reader.go @@ -1,5 +1,7 @@ package caches +import "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" + type ReaderFunc func(n int) (goNext bool, err error) type Reader interface { @@ -36,6 +38,9 @@ type Reader interface { // BodySize Body Size BodySize() int64 + // ContainsRange 是否包含某个区间内容 + ContainsRange(r rangeutils.Range) bool + // Close 关闭 Close() error } diff --git a/internal/caches/reader_file.go b/internal/caches/reader_file.go index 6fec531..122a23c 100644 --- a/internal/caches/reader_file.go +++ b/internal/caches/reader_file.go @@ -3,6 +3,7 @@ package caches import ( "encoding/binary" "errors" + rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" "github.com/iwind/TeaGo/types" "io" "os" @@ -332,6 +333,11 @@ func (this *FileReader) ReadBodyRange(buf []byte, start int64, end int64, callba return nil } +// ContainsRange 是否包含某些区间内容 +func (this *FileReader) ContainsRange(r rangeutils.Range) bool { + return true +} + func (this *FileReader) Close() error { if this.openFileCache != nil { if this.isClosed { diff --git a/internal/caches/reader_memory.go b/internal/caches/reader_memory.go index 76998b5..bb2ca5a 100644 --- a/internal/caches/reader_memory.go +++ b/internal/caches/reader_memory.go @@ -2,6 +2,7 @@ package caches import ( "errors" + rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" "io" ) @@ -197,6 +198,11 @@ func (this *MemoryReader) ReadBodyRange(buf []byte, start int64, end int64, call return nil } +// ContainsRange 是否包含某些区间内容 +func (this *MemoryReader) ContainsRange(r rangeutils.Range) bool { + return true +} + func (this *MemoryReader) Close() error { return nil } diff --git a/internal/caches/reader_partial_file.go b/internal/caches/reader_partial_file.go new file mode 100644 index 0000000..70be30f --- /dev/null +++ b/internal/caches/reader_partial_file.go @@ -0,0 +1,142 @@ +package caches + +import ( + "encoding/binary" + "errors" + rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" + "github.com/iwind/TeaGo/types" + "io" + "os" + "strings" +) + +type PartialFileReader struct { + *FileReader + + ranges *PartialRanges + rangePath string +} + +func NewPartialFileReader(fp *os.File) *PartialFileReader { + // range path + var path = fp.Name() + var dotIndex = strings.LastIndex(path, ".") + var rangePath = "" + if dotIndex < 0 { + rangePath = path + "@ranges.cache" + } else { + rangePath = path[:dotIndex] + "@ranges" + path[dotIndex:] + } + + return &PartialFileReader{ + FileReader: NewFileReader(fp), + rangePath: rangePath, + } +} + +func (this *PartialFileReader) Init() error { + return this.InitAutoDiscard(true) +} + +func (this *PartialFileReader) InitAutoDiscard(autoDiscard bool) error { + if this.openFile != nil { + this.meta = this.openFile.meta + this.header = this.openFile.header + } + + isOk := false + + if autoDiscard { + defer func() { + if !isOk { + _ = this.discard() + } + }() + } + + // 读取Range + ranges, err := NewPartialRangesFromFile(this.rangePath) + if err != nil { + return errors.New("read ranges failed: " + err.Error()) + } + this.ranges = ranges + + var buf = this.meta + if len(buf) == 0 { + buf = make([]byte, SizeMeta) + ok, err := this.readToBuff(this.fp, buf) + if err != nil { + return err + } + if !ok { + return ErrNotFound + } + this.meta = buf + } + + this.expiresAt = int64(binary.BigEndian.Uint32(buf[:SizeExpiresAt])) + + status := types.Int(string(buf[SizeExpiresAt : SizeExpiresAt+SizeStatus])) + if status < 100 || status > 999 { + return errors.New("invalid status") + } + this.status = status + + // URL + urlLength := binary.BigEndian.Uint32(buf[SizeExpiresAt+SizeStatus : SizeExpiresAt+SizeStatus+SizeURLLength]) + + // header + headerSize := int(binary.BigEndian.Uint32(buf[SizeExpiresAt+SizeStatus+SizeURLLength : SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength])) + if headerSize == 0 { + return nil + } + this.headerSize = headerSize + this.headerOffset = int64(SizeMeta) + int64(urlLength) + + // body + this.bodyOffset = this.headerOffset + int64(headerSize) + bodySize := int(binary.BigEndian.Uint64(buf[SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength : SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength+SizeBodyLength])) + if bodySize == 0 { + isOk = true + return nil + } + this.bodySize = int64(bodySize) + + // read header + if this.openFileCache != nil && len(this.header) == 0 { + if headerSize > 0 && headerSize <= 512 { + this.header = make([]byte, headerSize) + _, err := this.fp.Seek(this.headerOffset, io.SeekStart) + if err != nil { + return err + } + _, err = this.readToBuff(this.fp, this.header) + if err != nil { + return err + } + } + } + + isOk = true + + return nil +} + +// ContainsRange 是否包含某些区间内容 +// 这里的 r 是已经经过格式化的 +func (this *PartialFileReader) ContainsRange(r rangeutils.Range) bool { + return this.ranges.Contains(r.Start(), r.End()) +} + +// MaxLength 获取区间最大长度 +func (this *PartialFileReader) MaxLength() int64 { + if this.bodySize > 0 { + return this.bodySize + } + return this.ranges.Max() + 1 +} + +func (this *PartialFileReader) discard() error { + _ = os.Remove(this.rangePath) + return this.FileReader.discard() +} diff --git a/internal/caches/storage_file.go b/internal/caches/storage_file.go index b6c5ac5..01f988f 100644 --- a/internal/caches/storage_file.go +++ b/internal/caches/storage_file.go @@ -216,19 +216,24 @@ func (this *FileStorage) Init() error { return nil } -func (this *FileStorage) OpenReader(key string, useStale bool) (Reader, error) { - return this.openReader(key, true, useStale) +func (this *FileStorage) OpenReader(key string, useStale bool, isPartial bool) (Reader, error) { + return this.openReader(key, true, useStale, isPartial) } -func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool) (Reader, error) { +func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool, isPartial bool) (Reader, error) { // 使用陈旧缓存的时候,我们认为是短暂的,只需要从文件里检查即可 if useStale { allowMemory = false } + // 区间缓存只存在文件中 + if isPartial { + allowMemory = false + } + // 先尝试内存缓存 if allowMemory && this.memoryStorage != nil { - reader, err := this.memoryStorage.OpenReader(key, useStale) + reader, err := this.memoryStorage.OpenReader(key, useStale, isPartial) if err == nil { return reader, err } @@ -273,9 +278,18 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool) } }() - var reader = NewFileReader(fp) - reader.openFile = openFile - reader.openFileCache = this.openFileCache + var reader Reader + if isPartial { + var partialFileReader = NewPartialFileReader(fp) + partialFileReader.openFile = openFile + partialFileReader.openFileCache = this.openFileCache + reader = partialFileReader + } else { + var fileReader = NewFileReader(fp) + fileReader.openFile = openFile + fileReader.openFileCache = this.openFileCache + reader = fileReader + } err = reader.Init() if err != nil { return nil, err @@ -284,40 +298,7 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool) // 增加点击量 // 1/1000采样 if allowMemory { - var rate = this.policy.PersistenceHitSampleRate - if rate <= 0 { - rate = 1000 - } - if this.lastHotSize == 0 { - // 自动降低采样率来增加热点数据的缓存几率 - rate = rate / 10 - } - if rands.Int(0, rate) == 0 { - var hitErr = this.list.IncreaseHit(hash) - if hitErr != nil { - // 此错误可以忽略 - remotelogs.Error("CACHE", "increase hit failed: "+hitErr.Error()) - } - - // 增加到热点 - // 这里不收录缓存尺寸过大的文件 - if this.memoryStorage != nil && reader.BodySize() > 0 && reader.BodySize() < 128*1024*1024 { - this.hotMapLocker.Lock() - hotItem, ok := this.hotMap[key] - if ok { - hotItem.Hits++ - hotItem.ExpiresAt = reader.expiresAt - } else if len(this.hotMap) < HotItemSize { // 控制数量 - this.hotMap[key] = &HotItem{ - Key: key, - ExpiresAt: reader.ExpiresAt(), - Status: reader.Status(), - Hits: 1, - } - } - this.hotMapLocker.Unlock() - } - } + this.increaseHit(key, hash, reader) } isOk = true @@ -380,7 +361,8 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int, siz } // 检查缓存是否已经生成 - var cachePath = dir + "/" + hash + ".cache" + var cachePathName = dir + "/" + hash + var cachePath = cachePathName + ".cache" stat, err := os.Stat(cachePath) if err == nil && time.Now().Sub(stat.ModTime()) <= 1*time.Second { // 防止并发连续写入 @@ -388,7 +370,7 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int, siz } var tmpPath = cachePath + ".tmp" if isPartial { - tmpPath = cachePath + tmpPath = cachePathName + ".cache" } // 先删除 @@ -502,7 +484,12 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int, siz isOk = true if isPartial { - return NewPartialFileWriter(writer, key, expiredAt, isNewCreated, isPartial, partialBodyOffset, func() { + ranges, err := NewPartialRangesFromFile(cachePathName + "@ranges.cache") + if err != nil { + ranges = NewPartialRanges() + } + + return NewPartialFileWriter(writer, key, expiredAt, isNewCreated, isPartial, partialBodyOffset, ranges, func() { sharedWritingFileKeyLocker.Lock() delete(sharedWritingFileKeyMap, key) sharedWritingFileKeyLocker.Unlock() @@ -923,7 +910,7 @@ func (this *FileStorage) hotLoop() { var buf = utils.BytePool16k.Get() defer utils.BytePool16k.Put(buf) for _, item := range result[:size] { - reader, err := this.openReader(item.Key, false, false) + reader, err := this.openReader(item.Key, false, false, false) if err != nil { continue } @@ -1025,3 +1012,41 @@ func (this *FileStorage) cleanDeletedDirs(dir string) error { } return nil } + +// 增加某个Key的点击量 +func (this *FileStorage) increaseHit(key string, hash string, reader Reader) { + var rate = this.policy.PersistenceHitSampleRate + if rate <= 0 { + rate = 1000 + } + if this.lastHotSize == 0 { + // 自动降低采样率来增加热点数据的缓存几率 + rate = rate / 10 + } + if rands.Int(0, rate) == 0 { + var hitErr = this.list.IncreaseHit(hash) + if hitErr != nil { + // 此错误可以忽略 + remotelogs.Error("CACHE", "increase hit failed: "+hitErr.Error()) + } + + // 增加到热点 + // 这里不收录缓存尺寸过大的文件 + if this.memoryStorage != nil && reader.BodySize() > 0 && reader.BodySize() < 128*1024*1024 { + this.hotMapLocker.Lock() + hotItem, ok := this.hotMap[key] + if ok { + hotItem.Hits++ + hotItem.ExpiresAt = reader.ExpiresAt() + } else if len(this.hotMap) < HotItemSize { // 控制数量 + this.hotMap[key] = &HotItem{ + Key: key, + ExpiresAt: reader.ExpiresAt(), + Status: reader.Status(), + Hits: 1, + } + } + this.hotMapLocker.Unlock() + } + } +} diff --git a/internal/caches/storage_file_test.go b/internal/caches/storage_file_test.go index 38f3fc0..d8e2749 100644 --- a/internal/caches/storage_file_test.go +++ b/internal/caches/storage_file_test.go @@ -110,7 +110,7 @@ func TestFileStorage_OpenWriter_Partial(t *testing.T) { t.Fatal(err) } - err = writer.WriteAt([]byte("Hello, World"), 0) + err = writer.WriteAt(0, []byte("Hello, World")) if err != nil { t.Fatal(err) } @@ -311,7 +311,7 @@ func TestFileStorage_Read(t *testing.T) { t.Fatal(err) } now := time.Now() - reader, err := storage.OpenReader("my-key", false) + reader, err := storage.OpenReader("my-key", false, false) if err != nil { t.Fatal(err) } @@ -347,7 +347,7 @@ func TestFileStorage_Read_HTTP_Response(t *testing.T) { t.Fatal(err) } now := time.Now() - reader, err := storage.OpenReader("my-http-response", false) + reader, err := storage.OpenReader("my-http-response", false, false) if err != nil { t.Fatal(err) } @@ -401,7 +401,7 @@ func TestFileStorage_Read_NotFound(t *testing.T) { } now := time.Now() buf := make([]byte, 6) - reader, err := storage.OpenReader("my-key-10000", false) + reader, err := storage.OpenReader("my-key-10000", false, false) if err != nil { if err == ErrNotFound { t.Log("cache not fund") @@ -543,7 +543,7 @@ func BenchmarkFileStorage_Read(b *testing.B) { b.Fatal(err) } for i := 0; i < b.N; i++ { - reader, err := storage.OpenReader("my-key", false) + reader, err := storage.OpenReader("my-key", false, false) if err != nil { b.Fatal(err) } diff --git a/internal/caches/storage_interface.go b/internal/caches/storage_interface.go index 2ed8647..98d4f0b 100644 --- a/internal/caches/storage_interface.go +++ b/internal/caches/storage_interface.go @@ -10,7 +10,7 @@ type StorageInterface interface { Init() error // OpenReader 读取缓存 - OpenReader(key string, useStale bool) (reader Reader, err error) + OpenReader(key string, useStale bool, isPartial bool) (reader Reader, err error) // OpenWriter 打开缓存写入器等待写入 OpenWriter(key string, expiredAt int64, status int, size int64, isPartial bool) (Writer, error) diff --git a/internal/caches/storage_memory.go b/internal/caches/storage_memory.go index 9700fed..5b98339 100644 --- a/internal/caches/storage_memory.go +++ b/internal/caches/storage_memory.go @@ -105,7 +105,7 @@ func (this *MemoryStorage) Init() error { } // OpenReader 读取缓存 -func (this *MemoryStorage) OpenReader(key string, useStale bool) (Reader, error) { +func (this *MemoryStorage) OpenReader(key string, useStale bool, isPartial bool) (Reader, error) { hash := this.hash(key) this.locker.RLock() diff --git a/internal/caches/storage_memory_test.go b/internal/caches/storage_memory_test.go index fe6da81..38c5c1d 100644 --- a/internal/caches/storage_memory_test.go +++ b/internal/caches/storage_memory_test.go @@ -25,7 +25,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) { t.Log(storage.valuesMap) { - reader, err := storage.OpenReader("abc", false) + reader, err := storage.OpenReader("abc", false, false) if err != nil { if err == ErrNotFound { t.Log("not found: abc") @@ -52,7 +52,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) { } { - _, err := storage.OpenReader("abc 2", false) + _, err := storage.OpenReader("abc 2", false, false) if err != nil { if err == ErrNotFound { t.Log("not found: abc2") @@ -68,7 +68,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) { } _, _ = writer.Write([]byte("Hello123")) { - reader, err := storage.OpenReader("abc", false) + reader, err := storage.OpenReader("abc", false, false) if err != nil { if err == ErrNotFound { t.Log("not found: abc") @@ -97,7 +97,7 @@ func TestMemoryStorage_OpenReaderLock(t *testing.T) { IsDone: true, }, } - _, _ = storage.OpenReader("test", false) + _, _ = storage.OpenReader("test", false, false) } func TestMemoryStorage_Delete(t *testing.T) { diff --git a/internal/caches/writer.go b/internal/caches/writer.go index 40e66bc..975764c 100644 --- a/internal/caches/writer.go +++ b/internal/caches/writer.go @@ -9,7 +9,7 @@ type Writer interface { Write(data []byte) (n int, err error) // WriteAt 在指定位置写入数据 - WriteAt(data []byte, offset int64) error + WriteAt(offset int64, data []byte) error // HeaderSize 写入的Header数据大小 HeaderSize() int64 diff --git a/internal/caches/writer_file.go b/internal/caches/writer_file.go index 550643a..5a34c68 100644 --- a/internal/caches/writer_file.go +++ b/internal/caches/writer_file.go @@ -67,7 +67,7 @@ func (this *FileWriter) Write(data []byte) (n int, err error) { } // WriteAt 在指定位置写入数据 -func (this *FileWriter) WriteAt(data []byte, offset int64) error { +func (this *FileWriter) WriteAt(offset int64, data []byte) error { _ = data _ = offset return errors.New("not supported") diff --git a/internal/caches/writer_memory.go b/internal/caches/writer_memory.go index e2e9771..9b675f8 100644 --- a/internal/caches/writer_memory.go +++ b/internal/caches/writer_memory.go @@ -57,7 +57,7 @@ func (this *MemoryWriter) Write(data []byte) (n int, err error) { } // WriteAt 在指定位置写入数据 -func (this *MemoryWriter) WriteAt(b []byte, offset int64) error { +func (this *MemoryWriter) WriteAt(offset int64, b []byte) error { _ = b _ = offset return errors.New("not supported") diff --git a/internal/caches/writer_partial_file.go b/internal/caches/writer_partial_file.go index 24821e9..2262151 100644 --- a/internal/caches/writer_partial_file.go +++ b/internal/caches/writer_partial_file.go @@ -23,9 +23,22 @@ type PartialFileWriter struct { isNew bool isPartial bool bodyOffset int64 + + ranges *PartialRanges + rangePath string } -func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, isNew bool, isPartial bool, bodyOffset int64, endFunc func()) *PartialFileWriter { +func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, isNew bool, isPartial bool, bodyOffset int64, ranges *PartialRanges, endFunc func()) *PartialFileWriter { + var path = rawWriter.Name() + // ranges路径 + var dotIndex = strings.LastIndex(path, ".") + var rangePath = "" + if dotIndex < 0 { + rangePath = path + "@ranges.cache" + } else { + rangePath = path[:dotIndex] + "@ranges" + path[dotIndex:] + } + return &PartialFileWriter{ key: key, rawWriter: rawWriter, @@ -34,6 +47,8 @@ func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, isNew isNew: isNew, isPartial: isPartial, bodyOffset: bodyOffset, + ranges: ranges, + rangePath: rangePath, } } @@ -50,6 +65,21 @@ func (this *PartialFileWriter) WriteHeader(data []byte) (n int, err error) { return } +func (this *PartialFileWriter) AppendHeader(data []byte) error { + _, err := this.rawWriter.Write(data) + if err != nil { + _ = this.Discard() + } else { + var c = len(data) + this.headerSize += int64(c) + err = this.WriteHeaderLength(int(this.headerSize)) + if err != nil { + _ = this.Discard() + } + } + return err +} + // WriteHeaderLength 写入Header长度数据 func (this *PartialFileWriter) WriteHeaderLength(headerLength int) error { bytes4 := make([]byte, 4) @@ -78,12 +108,34 @@ func (this *PartialFileWriter) Write(data []byte) (n int, err error) { } // WriteAt 在指定位置写入数据 -func (this *PartialFileWriter) WriteAt(data []byte, offset int64) error { +func (this *PartialFileWriter) WriteAt(offset int64, data []byte) error { + var c = int64(len(data)) + if c == 0 { + return nil + } + var end = offset + c - 1 + + // 是否已包含在内 + if this.ranges.Contains(offset, end) { + return nil + } + if this.bodyOffset == 0 { this.bodyOffset = SizeMeta + int64(len(this.key)) + this.headerSize } _, err := this.rawWriter.WriteAt(data, this.bodyOffset+offset) - return err + if err != nil { + return err + } + + this.ranges.Add(offset, end) + + return nil +} + +// SetBodyLength 设置内容总长度 +func (this *PartialFileWriter) SetBodyLength(bodyLength int64) { + this.bodySize = bodyLength } // WriteBodyLength 写入Body长度数据 @@ -109,31 +161,30 @@ func (this *PartialFileWriter) Close() error { this.endFunc() }) - var path = this.rawWriter.Name() + err := this.ranges.WriteToFile(this.rangePath) + if err != nil { + return err + } + // 关闭当前writer if this.isNew { - err := this.WriteHeaderLength(types.Int(this.headerSize)) + err = this.WriteHeaderLength(types.Int(this.headerSize)) if err != nil { _ = this.rawWriter.Close() - _ = os.Remove(path) + this.remove() return err } err = this.WriteBodyLength(this.bodySize) if err != nil { _ = this.rawWriter.Close() - _ = os.Remove(path) + this.remove() return err } } - err := this.rawWriter.Close() + err = this.rawWriter.Close() if err != nil { - _ = os.Remove(path) - } else if !this.isPartial { - err = os.Rename(path, strings.Replace(path, ".tmp", "", 1)) - if err != nil { - _ = os.Remove(path) - } + this.remove() } return err @@ -147,6 +198,8 @@ func (this *PartialFileWriter) Discard() error { _ = this.rawWriter.Close() + _ = os.Remove(this.rangePath) + err := os.Remove(this.rawWriter.Name()) return err } @@ -171,3 +224,12 @@ func (this *PartialFileWriter) Key() string { func (this *PartialFileWriter) ItemType() ItemType { return ItemTypeFile } + +func (this *PartialFileWriter) IsNew() bool { + return this.isNew && len(this.ranges.Ranges) == 0 +} + +func (this *PartialFileWriter) remove() { + _ = os.Remove(this.rawWriter.Name()) + _ = os.Remove(this.rangePath) +} diff --git a/internal/caches/writer_partial_file_test.go b/internal/caches/writer_partial_file_test.go index 0058d33..af50569 100644 --- a/internal/caches/writer_partial_file_test.go +++ b/internal/caches/writer_partial_file_test.go @@ -11,8 +11,8 @@ import ( "time" ) -func TestPartialFileWriter_SeekOffset(t *testing.T) { - var path = "/tmp/test@partial.cache" +func TestPartialFileWriter_Write(t *testing.T) { + var path = "/tmp/test_partial.cache" _ = os.Remove(path) var reader = func() { @@ -27,7 +27,8 @@ func TestPartialFileWriter_SeekOffset(t *testing.T) { if err != nil { t.Fatal(err) } - var writer = caches.NewPartialFileWriter(fp, "test", time.Now().Unix()+86500, true, true, 0, func() { + var ranges = caches.NewPartialRanges() + var writer = caches.NewPartialFileWriter(fp, "test", time.Now().Unix()+86500, true, true, 0, ranges, func() { t.Log("end") }) _, err = writer.WriteHeader([]byte("header")) @@ -36,7 +37,7 @@ func TestPartialFileWriter_SeekOffset(t *testing.T) { } // 移动位置 - err = writer.WriteAt([]byte("HELLO"), 100) + err = writer.WriteAt(100, []byte("HELLO")) if err != nil { t.Fatal(err) } diff --git a/internal/nodes/api_stream.go b/internal/nodes/api_stream.go index 1d23f16..9c51db3 100644 --- a/internal/nodes/api_stream.go +++ b/internal/nodes/api_stream.go @@ -240,7 +240,7 @@ func (this *APIStream) handleReadCache(message *pb.NodeStreamMessage) error { }() } - reader, err := storage.OpenReader(msg.Key, false) + reader, err := storage.OpenReader(msg.Key, false, false) if err != nil { if err == caches.ErrNotFound { this.replyFail(message.RequestId, "key not found") @@ -351,7 +351,11 @@ func (this *APIStream) handlePurgeCache(message *pb.NodeStreamMessage) error { if msg.Type == "file" { var keys = msg.Keys for _, key := range keys { - keys = append(keys, key+webpCacheSuffix, key+cacheMethodSuffix+"HEAD") + keys = append(keys, + key+cacheMethodSuffix+"HEAD", + key+webpCacheSuffix, + key+cachePartialSuffix, + ) // TODO 根据实际缓存的内容进行组合 for _, encoding := range compressions.AllEncodings() { keys = append(keys, key+compressionCacheSuffix+encoding) diff --git a/internal/nodes/http_request_cache.go b/internal/nodes/http_request_cache.go index d1d8dff..2da9018 100644 --- a/internal/nodes/http_request_cache.go +++ b/internal/nodes/http_request_cache.go @@ -10,6 +10,8 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/utils" + rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" + "github.com/iwind/TeaGo/types" "io" "net/http" "path/filepath" @@ -36,6 +38,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { return } + // 添加 X-Cache Header var addStatusHeader = this.web.Cache.AddStatusHeader if addStatusHeader { defer func() { @@ -137,7 +140,12 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { if this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey { this.varMapping["cache.status"] = "PURGE" - var subKeys = []string{key, key + cacheMethodSuffix + "HEAD"} + var subKeys = []string{ + key, + key + cacheMethodSuffix + "HEAD", + key + webpCacheSuffix, + key + cachePartialSuffix, + } // TODO 根据实际缓存的内容进行组合 for _, encoding := range compressions.AllEncodings() { subKeys = append(subKeys, key+compressionCacheSuffix+encoding) @@ -180,10 +188,14 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { var reader caches.Reader var err error + var rangeHeader = this.RawReq.Header.Get("Range") + var isPartialRequest = len(rangeHeader) > 0 + // 检查是否支持WebP var webPIsEnabled = false var isHeadMethod = method == http.MethodHead - if !isHeadMethod && + if !isPartialRequest && + !isHeadMethod && this.web.WebP != nil && this.web.WebP.IsOn && this.web.WebP.MatchRequest(filepath.Ext(this.Path()), this.Format) && @@ -192,13 +204,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } // 检查压缩缓存 - if !isHeadMethod && reader == nil { + if !isPartialRequest && !isHeadMethod && reader == nil { if this.web.Compression != nil && this.web.Compression.IsOn { _, encoding, ok := this.web.Compression.MatchAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding")) if ok { // 检查支持WebP的压缩缓存 if webPIsEnabled { - reader, _ = storage.OpenReader(key+webpCacheSuffix+compressionCacheSuffix+encoding, useStale) + reader, _ = storage.OpenReader(key+webpCacheSuffix+compressionCacheSuffix+encoding, useStale, false) if reader != nil { tags = append(tags, "webp", encoding) } @@ -206,7 +218,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { // 检查普通缓存 if reader == nil { - reader, _ = storage.OpenReader(key+compressionCacheSuffix+encoding, useStale) + reader, _ = storage.OpenReader(key+compressionCacheSuffix+encoding, useStale, false) if reader != nil { tags = append(tags, encoding) } @@ -216,8 +228,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } // 检查WebP - if !isHeadMethod && reader == nil && webPIsEnabled { - reader, _ = storage.OpenReader(key+webpCacheSuffix, useStale) + if !isPartialRequest && + !isHeadMethod && + reader == nil && + webPIsEnabled { + reader, _ = storage.OpenReader(key+webpCacheSuffix, useStale, false) if reader != nil { this.writer.cacheReaderSuffix = webpCacheSuffix tags = append(tags, "webp") @@ -225,8 +240,18 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } // 检查正常的文件 + var isPartialCache = false if reader == nil { - reader, err = storage.OpenReader(key, useStale) + reader, err = storage.OpenReader(key, useStale, false) + if err != nil && this.cacheRef.AllowPartialContent { + pReader := this.tryPartialReader(storage, key, useStale, rangeHeader) + if pReader != nil { + isPartialCache = true + reader = pReader + err = nil + } + } + if err != nil { if err == caches.ErrNotFound { // cache相关变量 @@ -260,7 +285,16 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } // 准备Buffer - var pool = this.bytePool(reader.BodySize()) + var fileSize = reader.BodySize() + var totalSizeString = types.String(fileSize) + if isPartialCache { + fileSize = reader.(*caches.PartialFileReader).MaxLength() + if totalSizeString == "0" { + totalSizeString = "*" + } + } + + var pool = this.bytePool(fileSize) var buf = pool.Get() defer func() { pool.Put(buf) @@ -323,7 +357,9 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\"" } respHeader.Del("Etag") - respHeader["ETag"] = []string{eTag} + if !isPartialCache { + respHeader["ETag"] = []string{eTag} + } } // 支持 Last-Modified @@ -331,11 +367,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { var modifiedTime = "" if lastModifiedAt > 0 { modifiedTime = time.Unix(utils.GMTUnixTime(lastModifiedAt), 0).Format("Mon, 02 Jan 2006 15:04:05") + " GMT" - respHeader.Set("Last-Modified", modifiedTime) + if !isPartialCache { + respHeader.Set("Last-Modified", modifiedTime) + } } // 支持 If-None-Match - if len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag { + if !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag { // 自定义Header this.processResponseHeaders(http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified) @@ -346,7 +384,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } // 支持 If-Modified-Since - if len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime { + if !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime { // 自定义Header this.processResponseHeaders(http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified) @@ -364,69 +402,55 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { this.writer.WriteHeader(reader.Status()) } else { ifRangeHeaders, ok := this.RawReq.Header["If-Range"] - supportRange := true + var supportRange = true if ok { supportRange = false for _, v := range ifRangeHeaders { if v == this.writer.Header().Get("ETag") || v == this.writer.Header().Get("Last-Modified") { supportRange = true + break } } } // 支持Range - rangeSet := [][]int64{} + var ranges = []rangeutils.Range{} if supportRange { - fileSize := reader.BodySize() - contentRange := this.RawReq.Header.Get("Range") - if len(contentRange) > 0 { + if len(rangeHeader) > 0 { if fileSize == 0 { this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } - set, ok := httpRequestParseContentRange(contentRange) + set, ok := httpRequestParseRangeHeader(rangeHeader) if !ok { this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } if len(set) > 0 { - rangeSet = set - for _, arr := range rangeSet { - if arr[0] == -1 { - arr[0] = fileSize + arr[1] - arr[1] = fileSize - 1 - - if arr[0] < 0 { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) - this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) - return true - } - } - if arr[1] < 0 { - arr[1] = fileSize - 1 - } - if arr[1] >= fileSize { - arr[1] = fileSize - 1 - } - if arr[0] > arr[1] { + ranges = set + for k, r := range ranges { + r2, ok := r.Convert(fileSize) + if !ok { this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } + + ranges[k] = r2 } } } } - if len(rangeSet) == 1 { - respHeader.Set("Content-Range", "bytes "+strconv.FormatInt(rangeSet[0][0], 10)+"-"+strconv.FormatInt(rangeSet[0][1], 10)+"/"+strconv.FormatInt(reader.BodySize(), 10)) - respHeader.Set("Content-Length", strconv.FormatInt(rangeSet[0][1]-rangeSet[0][0]+1, 10)) + if len(ranges) == 1 { + respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(totalSizeString)) + respHeader.Set("Content-Length", strconv.FormatInt(ranges[0].Length(), 10)) this.writer.WriteHeader(http.StatusPartialContent) - err = reader.ReadBodyRange(buf, rangeSet[0][0], rangeSet[0][1], func(n int) (goNext bool, err error) { + err = reader.ReadBodyRange(buf, ranges[0].Start(), ranges[0].End(), func(n int) (goNext bool, err error) { _, err = this.writer.Write(buf[:n]) if err != nil { return false, errWritingToClient @@ -446,15 +470,15 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } return } - } else if len(rangeSet) > 1 { - boundary := httpRequestGenBoundary() + } else if len(ranges) > 1 { + var boundary = httpRequestGenBoundary() respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary) respHeader.Del("Content-Length") contentType := respHeader.Get("Content-Type") this.writer.WriteHeader(http.StatusPartialContent) - for index, set := range rangeSet { + for index, r := range ranges { if index == 0 { _, err = this.writer.WriteString("--" + boundary + "\r\n") } else { @@ -465,7 +489,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { return true } - _, err = this.writer.WriteString("Content-Range: " + "bytes " + strconv.FormatInt(set[0], 10) + "-" + strconv.FormatInt(set[1], 10) + "/" + strconv.FormatInt(reader.BodySize(), 10) + "\r\n") + _, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(totalSizeString) + "\r\n") if err != nil { // 不提示写入客户端错误 return true @@ -479,7 +503,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } } - err := reader.ReadBodyRange(buf, set[0], set[1], func(n int) (goNext bool, err error) { + err := reader.ReadBodyRange(buf, r.Start(), r.End(), func(n int) (goNext bool, err error) { _, err = this.writer.Write(buf[:n]) if err != nil { return false, errWritingToClient @@ -503,7 +527,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { } } else { // 没有Range var resp = &http.Response{Body: reader} - this.writer.Prepare(resp, reader.BodySize(), reader.Status(), false) + this.writer.Prepare(resp, fileSize, reader.Status(), false) this.writer.WriteHeader(reader.Status()) _, err = io.CopyBuffer(this.writer, resp.Body, buf) @@ -544,3 +568,47 @@ func (this *HTTPRequest) addExpiresHeader(expiresAt int64) { } } } + +// 尝试读取区间缓存 +func (this *HTTPRequest) tryPartialReader(storage caches.StorageInterface, key string, useStale bool, rangeHeader string) caches.Reader { + // 尝试读取Partial cache + if len(rangeHeader) == 0 { + return nil + } + + ranges, ok := httpRequestParseRangeHeader(rangeHeader) + if !ok { + return nil + } + + pReader, pErr := storage.OpenReader(key+cachePartialSuffix, useStale, true) + if pErr != nil { + return nil + } + + partialReader, ok := pReader.(*caches.PartialFileReader) + if !ok { + _ = pReader.Close() + return nil + } + var isOk = false + defer func() { + if !isOk { + _ = pReader.Close() + } + }() + + // 检查范围 + for _, r := range ranges { + r1, ok := r.Convert(partialReader.MaxLength()) + if !ok { + return nil + } + if !partialReader.ContainsRange(r1) { + return nil + } + } + + isOk = true + return pReader +} diff --git a/internal/nodes/http_request_root.go b/internal/nodes/http_request_root.go index f22493b..bba8d1b 100644 --- a/internal/nodes/http_request_root.go +++ b/internal/nodes/http_request_root.go @@ -2,10 +2,12 @@ package nodes import ( "fmt" + rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" "github.com/TeaOSLab/EdgeNode/internal/zero" "github.com/cespare/xxhash" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/types" "io" "io/fs" "mime" @@ -186,7 +188,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { } // length - fileSize := stat.Size() + var fileSize = stat.Size() // 支持 Last-Modified modifiedTime := stat.ModTime().Format("Mon, 02 Jan 2006 15:04:05 GMT") @@ -231,6 +233,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { for _, v := range ifRangeHeaders { if v == eTag || v == modifiedTime { supportRange = true + break } } if !supportRange { @@ -239,7 +242,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { } // 支持Range - rangeSet := [][]int64{} + var ranges = []rangeutils.Range{} if supportRange { contentRange := this.RawReq.Header.Get("Range") if len(contentRange) > 0 { @@ -249,36 +252,22 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { return true } - set, ok := httpRequestParseContentRange(contentRange) + set, ok := httpRequestParseRangeHeader(contentRange) if !ok { this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } if len(set) > 0 { - rangeSet = set - for _, arr := range rangeSet { - if arr[0] == -1 { - arr[0] = fileSize + arr[1] - arr[1] = fileSize - 1 - - if arr[0] < 0 { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) - this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) - return true - } - } - if arr[1] > 0 { - arr[1] = fileSize - 1 - } - if arr[1] < 0 { - arr[1] = fileSize - 1 - } - if arr[0] > arr[1] { + ranges = set + for k, r := range ranges { + r2, ok := r.Convert(fileSize) + if !ok { this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } + ranges[k] = r2 } } } else { @@ -298,7 +287,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { this.processResponseHeaders(http.StatusOK) // 在Range请求中不能缓存 - if len(rangeSet) > 0 { + if len(ranges) > 0 { this.cacheRef = nil // 不支持缓存 } @@ -311,11 +300,11 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { pool.Put(buf) }() - if len(rangeSet) == 1 { - respHeader.Set("Content-Range", "bytes "+strconv.FormatInt(rangeSet[0][0], 10)+"-"+strconv.FormatInt(rangeSet[0][1], 10)+"/"+strconv.FormatInt(fileSize, 10)) + if len(ranges) == 1 { + respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(types.String(fileSize))) this.writer.WriteHeader(http.StatusPartialContent) - ok, err := httpRequestReadRange(reader, buf, rangeSet[0][0], rangeSet[0][1], func(buf []byte, n int) error { + ok, err := httpRequestReadRange(reader, buf, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error { _, err := this.writer.Write(buf[:n]) return err }) @@ -328,13 +317,13 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } - } else if len(rangeSet) > 1 { + } else if len(ranges) > 1 { boundary := httpRequestGenBoundary() respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary) this.writer.WriteHeader(http.StatusPartialContent) - for index, set := range rangeSet { + for index, r := range ranges { if index == 0 { _, err = this.writer.WriteString("--" + boundary + "\r\n") } else { @@ -345,7 +334,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { return true } - _, err = this.writer.WriteString("Content-Range: " + "bytes " + strconv.FormatInt(set[0], 10) + "-" + strconv.FormatInt(set[1], 10) + "/" + strconv.FormatInt(fileSize, 10) + "\r\n") + _, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(types.String(fileSize)) + "\r\n") if err != nil { logs.Error(err) return true @@ -359,7 +348,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { } } - ok, err := httpRequestReadRange(reader, buf, set[0], set[1], func(buf []byte, n int) error { + ok, err := httpRequestReadRange(reader, buf, r.Start(), r.End(), func(buf []byte, n int) error { _, err := this.writer.Write(buf[:n]) return err }) diff --git a/internal/nodes/http_request_utils.go b/internal/nodes/http_request_utils.go index 909b8b5..8b03ae3 100644 --- a/internal/nodes/http_request_utils.go +++ b/internal/nodes/http_request_utils.go @@ -5,15 +5,20 @@ import ( "fmt" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" + "github.com/iwind/TeaGo/types" "io" "net/http" + "regexp" "strconv" "strings" "sync/atomic" ) +var contentRangeRegexp = regexp.MustCompile(`^bytes (\d+)-(\d+)/`) + // 分解Range -func httpRequestParseContentRange(rangeValue string) (result [][]int64, ok bool) { +func httpRequestParseRangeHeader(rangeValue string) (result []rangeutils.Range, ok bool) { // 参考RFC:https://tools.ietf.org/html/rfc7233 index := strings.Index(rangeValue, "=") if index == -1 { @@ -24,15 +29,15 @@ func httpRequestParseContentRange(rangeValue string) (result [][]int64, ok bool) return } - rangeSetString := rangeValue[index+1:] + var rangeSetString = rangeValue[index+1:] if len(rangeSetString) == 0 { ok = true return } - pieces := strings.Split(rangeSetString, ", ") + var pieces = strings.Split(rangeSetString, ", ") for _, piece := range pieces { - index := strings.Index(piece, "-") + index = strings.Index(piece, "-") if index == -1 { return } @@ -70,7 +75,7 @@ func httpRequestParseContentRange(rangeValue string) (result [][]int64, ok bool) lastInt = -lastInt } - result = append(result, []int64{firstInt, lastInt}) + result = append(result, [2]int64{firstInt, lastInt}) } ok = true @@ -119,6 +124,15 @@ func httpRequestReadRange(reader io.Reader, buf []byte, start int64, end int64, } } +// 分解Content-Range +func httpRequestParseContentRangeHeader(contentRange string) (start int64) { + var matches = contentRangeRegexp.FindStringSubmatch(contentRange) + if len(matches) < 3 { + return -1 + } + return types.Int64(matches[1]) +} + // 生成boundary // 仿照Golang自带的函数(multipart包) func httpRequestGenBoundary() string { @@ -130,6 +144,21 @@ func httpRequestGenBoundary() string { return fmt.Sprintf("%x", buf[:]) } +// 从content-type中读取boundary +func httpRequestParseBoundary(contentType string) string { + var delim = "boundary=" + var boundaryIndex = strings.Index(contentType, delim) + if boundaryIndex < 0 { + return "" + } + var boundary = contentType[boundaryIndex+len(delim):] + semicolonIndex := strings.Index(boundary, ";") + if semicolonIndex >= 0 { + return boundary[:semicolonIndex] + } + return boundary +} + // 判断状态是否为跳转 func httpStatusIsRedirect(statusCode int) bool { return statusCode == http.StatusPermanentRedirect || diff --git a/internal/nodes/http_request_utils_test.go b/internal/nodes/http_request_utils_test.go index a671ae5..ec337ba 100644 --- a/internal/nodes/http_request_utils_test.go +++ b/internal/nodes/http_request_utils_test.go @@ -17,52 +17,84 @@ func TestHTTPRequest_httpRequestGenBoundary(t *testing.T) { } } -func TestHTTPRequest_httpRequestParseContentRange(t *testing.T) { - a := assert.NewAssertion(t) +func TestHTTPRequest_httpRequestParseBoundary(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsTrue(httpRequestParseBoundary("multipart/byteranges") == "") + a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123") == "123") + a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123; 456") == "123") +} + +func TestHTTPRequest_httpRequestParseRangeHeader(t *testing.T) { + var a = assert.NewAssertion(t) { - _, ok := httpRequestParseContentRange("") + _, ok := httpRequestParseRangeHeader("") a.IsFalse(ok) } { - _, ok := httpRequestParseContentRange("byte=") + _, ok := httpRequestParseRangeHeader("byte=") a.IsFalse(ok) } { - _, ok := httpRequestParseContentRange("byte=") + _, ok := httpRequestParseRangeHeader("byte=") a.IsFalse(ok) } { - set, ok := httpRequestParseContentRange("bytes=") + set, ok := httpRequestParseRangeHeader("bytes=") a.IsTrue(ok) a.IsTrue(len(set) == 0) } { - _, ok := httpRequestParseContentRange("bytes=60-50") + _, ok := httpRequestParseRangeHeader("bytes=60-50") a.IsFalse(ok) } { - set, ok := httpRequestParseContentRange("bytes=0-50") + set, ok := httpRequestParseRangeHeader("bytes=0-50") a.IsTrue(ok) a.IsTrue(len(set) > 0) t.Log(set) } { - set, ok := httpRequestParseContentRange("bytes=0-") + set, ok := httpRequestParseRangeHeader("bytes=0-") + a.IsTrue(ok) + a.IsTrue(len(set) > 0) + if len(set) > 0 { + a.IsTrue(set[0][0] == 0) + } + t.Log(set) + } + { + set, ok := httpRequestParseRangeHeader("bytes=-50") a.IsTrue(ok) a.IsTrue(len(set) > 0) t.Log(set) } { - set, ok := httpRequestParseContentRange("bytes=-50") + set, ok := httpRequestParseRangeHeader("bytes=0-50, 60-100") a.IsTrue(ok) a.IsTrue(len(set) > 0) t.Log(set) } +} + +func TestHTTPRequest_httpRequestParseContentRangeHeader(t *testing.T) { { - set, ok := httpRequestParseContentRange("bytes=0-50, 60-100") - a.IsTrue(ok) - a.IsTrue(len(set) > 0) - t.Log(set) + var c1 = "bytes 0-100/*" + t.Log(httpRequestParseContentRangeHeader(c1)) + } + { + var c1 = "bytes 30-100/*" + t.Log(httpRequestParseContentRangeHeader(c1)) + } + { + var c1 = "bytes1 0-100/*" + t.Log(httpRequestParseContentRangeHeader(c1)) + } +} + +func BenchmarkHTTPRequest_httpRequestParseContentRangeHeader(b *testing.B) { + for i := 0; i < b.N; i++ { + var c1 = "bytes 0-100/*" + httpRequestParseContentRangeHeader(c1) } } diff --git a/internal/nodes/http_writer.go b/internal/nodes/http_writer.go index bf6ae15..2757cb4 100644 --- a/internal/nodes/http_writer.go +++ b/internal/nodes/http_writer.go @@ -28,6 +28,7 @@ import ( "io/ioutil" "net" "net/http" + "net/textproto" "os" "path/filepath" "strings" @@ -45,6 +46,7 @@ const compressionCacheSuffix = "@GOEDGE_" // 缓存相关配置 const cacheMethodSuffix = "@GOEDGE_" +const cachePartialSuffix = "@GOEDGE_partial" func init() { var systemMemory = utils.SystemMemoryGB() / 8 @@ -157,12 +159,21 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { var addStatusHeader = this.req.web != nil && this.req.web.Cache != nil && this.req.web.Cache.AddStatusHeader // 不支持Range - if this.StatusCode() == http.StatusPartialContent || len(this.Header().Get("Content-Range")) > 0 { - this.req.varMapping["cache.status"] = "BYPASS" - if addStatusHeader { - this.Header().Set("X-Cache", "BYPASS, not supported Content-Range") + if this.isPartial { + if !cacheRef.AllowPartialContent { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, not supported partial content") + } + return + } + if this.cacheStorage.Policy().Type != serverconfigs.CachePolicyStorageFile { + this.req.varMapping["cache.status"] = "BYPASS" + if addStatusHeader { + this.Header().Set("X-Cache", "BYPASS, not supported partial content in memory storage") + } + return } - return } // 如果允许 ChunkedEncoding,就无需尺寸的判断,因为此时的 size 为 -1 @@ -183,7 +194,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { } // 检查状态 - if len(cacheRef.Status) > 0 && !lists.ContainsInt(cacheRef.Status, this.StatusCode()) { + if !this.isPartial && len(cacheRef.Status) > 0 && !lists.ContainsInt(cacheRef.Status, this.StatusCode()) { this.req.varMapping["cache.status"] = "BYPASS" if addStatusHeader { this.Header().Set("X-Cache", "BYPASS, Status: "+types.String(this.StatusCode())) @@ -193,7 +204,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { // Cache-Control if len(cacheRef.SkipResponseCacheControlValues) > 0 { - var cacheControl = this.Header().Get("Cache-Control") + var cacheControl = this.GetHeader("Cache-Control") if len(cacheControl) > 0 { values := strings.Split(cacheControl, ",") for _, value := range values { @@ -209,7 +220,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { } // Set-Cookie - if cacheRef.SkipResponseSetCookie && len(this.Header().Get("Set-Cookie")) > 0 { + if cacheRef.SkipResponseSetCookie && len(this.GetHeader("Set-Cookie")) > 0 { this.req.varMapping["cache.status"] = "BYPASS" if addStatusHeader { this.Header().Set("X-Cache", "BYPASS, Set-Cookie") @@ -242,7 +253,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { } this.cacheStorage = storage - life := cacheRef.LifeSeconds() + var life = cacheRef.LifeSeconds() if life <= 0 { life = 60 @@ -250,7 +261,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { // 支持源站设置的max-age if this.req.web.Cache != nil && this.req.web.Cache.EnableCacheControlMaxAge { - var cacheControl = this.Header().Get("Cache-Control") + var cacheControl = this.GetHeader("Cache-Control") var pieces = strings.Split(cacheControl, ";") for _, piece := range pieces { var eqIndex = strings.Index(piece, "=") @@ -265,7 +276,10 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { var expiredAt = utils.UnixTime() + life var cacheKey = this.req.cacheKey - cacheWriter, err := storage.OpenWriter(cacheKey, expiredAt, this.StatusCode(), size, false) + if this.isPartial { + cacheKey += cachePartialSuffix + } + cacheWriter, err := storage.OpenWriter(cacheKey, expiredAt, this.StatusCode(), size, this.isPartial) if err != nil { if !caches.CanIgnoreErr(err) { remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error()) @@ -277,6 +291,9 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { // 写入Header for k, v := range this.Header() { for _, v1 := range v { + if this.isPartial && k == "Content-Type" && strings.Contains(v1, "multipart/byteranges") { + continue + } _, err = cacheWriter.WriteHeader([]byte(k + ":" + v1 + "\n")) if err != nil { remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error()) @@ -287,6 +304,101 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) { } } + if this.isPartial { + // content-range + var contentRange = this.GetHeader("Content-Range") + if len(contentRange) > 0 { + var start = httpRequestParseContentRangeHeader(contentRange) + if start < 0 { + return + } + var filterReader = readers.NewFilterReaderCloser(resp.Body) + this.cacheIsFinished = true + var hasError = false + filterReader.Add(func(p []byte, err error) error { + if hasError { + return nil + } + + var l = len(p) + if l == 0 { + return nil + } + defer func() { + start += int64(l) + }() + err = cacheWriter.WriteAt(start, p) + if err != nil { + this.cacheIsFinished = false + hasError = true + } + return nil + }) + resp.Body = filterReader + this.rawReader = filterReader + return + } + + // multipart/byteranges + var contentType = this.GetHeader("Content-Type") + if strings.Contains(contentType, "multipart/byteranges") { + partialWriter, ok := cacheWriter.(*caches.PartialFileWriter) + if !ok { + return + } + + var boundary = httpRequestParseBoundary(contentType) + if len(boundary) == 0 { + return + } + + var reader = readers.NewByteRangesReaderCloser(resp.Body, boundary) + var contentTypeWritten = false + + this.cacheIsFinished = true + var hasError = false + var writtenTotal = false + reader.OnPartRead(func(start int64, end int64, total int64, data []byte, header textproto.MIMEHeader) { + if hasError { + return + } + + // 写入total + if !writtenTotal && total > 0 { + partialWriter.SetBodyLength(total) + writtenTotal = true + } + + // 写入Content-Type + if partialWriter.IsNew() && !contentTypeWritten { + var realContentType = header.Get("Content-Type") + if len(realContentType) > 0 { + var h = []byte("Content-Type:" + realContentType + "\n") + err = partialWriter.AppendHeader(h) + if err != nil { + hasError = true + this.cacheIsFinished = false + return + } + } + + contentTypeWritten = true + } + + err := cacheWriter.WriteAt(start, data) + if err != nil { + hasError = true + this.cacheIsFinished = false + } + }) + + resp.Body = reader + this.rawReader = reader + } + + return + } + var cacheReader = readers.NewTeeReaderCloser(resp.Body, this.cacheWriter) resp.Body = cacheReader this.rawReader = cacheReader @@ -306,7 +418,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) { return } - var contentType = this.Header().Get("Content-Type") + var contentType = this.GetHeader("Content-Type") if this.req.web != nil && this.req.web.WebP != nil && @@ -324,7 +436,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) { return } - var contentEncoding = this.Header().Get("Content-Encoding") + var contentEncoding = this.GetHeader("Content-Encoding") switch contentEncoding { case "gzip", "deflate", "br": reader, err := compressions.NewReader(resp.Body, contentEncoding) @@ -361,7 +473,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) { } var acceptEncodings = this.req.RawReq.Header.Get("Accept-Encoding") - var contentEncoding = this.Header().Get("Content-Encoding") + var contentEncoding = this.GetHeader("Content-Encoding") if this.compressionConfig == nil || !this.compressionConfig.IsOn { if lists.ContainsString([]string{"gzip", "deflate", "br"}, contentEncoding) && !httpAcceptEncoding(acceptEncodings, contentEncoding) { @@ -386,7 +498,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) { } // 尺寸和类型 - var contentType = this.Header().Get("Content-Type") + var contentType = this.GetHeader("Content-Type") if !this.compressionConfig.MatchResponse(contentType, size, filepath.Ext(this.req.Path()), this.req.Format) { return } @@ -504,6 +616,11 @@ func (this *HTTPWriter) Header() http.Header { return this.rawWriter.Header() } +// GetHeader 读取Header值 +func (this *HTTPWriter) GetHeader(name string) string { + return this.Header().Get(name) +} + // DeleteHeader 删除Header func (this *HTTPWriter) DeleteHeader(name string) { this.rawWriter.Header().Del(name) @@ -777,18 +894,19 @@ func (this *HTTPWriter) Close() { if this.isOk && this.cacheIsFinished { // 对比缓存前后的Content-Length var method = this.req.Method() - if method != http.MethodHead && this.StatusCode() != http.StatusNoContent { - var contentLengthString = this.Header().Get("Content-Length") + if method != http.MethodHead && this.StatusCode() != http.StatusNoContent && !this.isPartial { + var contentLengthString = this.GetHeader("Content-Length") if len(contentLengthString) > 0 { var contentLength = types.Int64(contentLengthString) if contentLength != this.cacheWriter.BodySize() { this.isOk = false _ = this.cacheWriter.Discard() + this.cacheWriter = nil } } } - if this.isOk { + if this.isOk && this.cacheWriter != nil { err := this.cacheWriter.Close() if err == nil { var expiredAt = this.cacheWriter.ExpiredAt() @@ -863,7 +981,7 @@ func (this *HTTPWriter) calculateStaleLife() int { // 从Header中读取stale-if-error var isDefinedInHeader = false if staleConfig.SupportStaleIfErrorHeader { - var cacheControl = this.Header().Get("Cache-Control") + var cacheControl = this.GetHeader("Cache-Control") var pieces = strings.Split(cacheControl, ",") for _, piece := range pieces { var eqIndex = strings.Index(piece, "=") diff --git a/internal/utils/ranges/range.go b/internal/utils/ranges/range.go new file mode 100644 index 0000000..c083228 --- /dev/null +++ b/internal/utils/ranges/range.go @@ -0,0 +1,53 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package rangeutils + +import "strconv" + +type Range [2]int64 + +func NewRange(start int64, end int64) Range { + return [2]int64{start, end} +} + +func (this Range) Start() int64 { + return this[0] +} + +func (this Range) End() int64 { + return this[1] +} + +func (this Range) Length() int64 { + return this[1] - this[0] + 1 +} + +func (this Range) Convert(total int64) (newRange Range, ok bool) { + if total <= 0 { + return this, false + } + if this[0] < 0 { + this[0] += total + if this[0] < 0 { + return this, false + } + this[1] = total - 1 + } + if this[1] < 0 { + this[1] = total - 1 + } + if this[1] > total-1 { + this[1] = total - 1 + } + if this[0] > this[1] { + return this, false + } + + return this, true +} + +// ComposeContentRangeHeader 组合Content-Range Header +// totalSize 可能是一个数字,也可能是一个星号(*) +func (this Range) ComposeContentRangeHeader(totalSize string) string { + return "bytes " + strconv.FormatInt(this[0], 10) + "-" + strconv.FormatInt(this[1], 10) + "/" + totalSize +} diff --git a/internal/utils/ranges/range_test.go b/internal/utils/ranges/range_test.go new file mode 100644 index 0000000..29df200 --- /dev/null +++ b/internal/utils/ranges/range_test.go @@ -0,0 +1,69 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package rangeutils_test + +import ( + rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges" + "github.com/iwind/TeaGo/assert" + "testing" +) + +func TestRange(t *testing.T) { + var a = assert.NewAssertion(t) + + var r = rangeutils.NewRange(1, 100) + a.IsTrue(r.Start() == 1) + a.IsTrue(r.End() == 100) + t.Log("start:", r.Start(), "end:", r.End()) +} + +func TestRange_Convert(t *testing.T) { + var a = assert.NewAssertion(t) + + { + var r = rangeutils.NewRange(1, 100) + newR, ok := r.Convert(200) + a.IsTrue(ok) + a.IsTrue(newR.Start() == 1) + a.IsTrue(newR.End() == 100) + } + + { + var r = rangeutils.NewRange(1, 100) + newR, ok := r.Convert(50) + a.IsTrue(ok) + a.IsTrue(newR.Start() == 1) + a.IsTrue(newR.End() == 49) + } + + { + var r = rangeutils.NewRange(1, 100) + _, ok := r.Convert(0) + a.IsFalse(ok) + } + + { + var r = rangeutils.NewRange(-30, -1) + newR, ok := r.Convert(50) + a.IsTrue(ok) + a.IsTrue(newR.Start() == 50-30) + a.IsTrue(newR.End() == 49) + } + + { + var r = rangeutils.NewRange(1000, 100) + _, ok := r.Convert(0) + a.IsFalse(ok) + } + + { + var r = rangeutils.NewRange(50, 100) + _, ok := r.Convert(49) + a.IsFalse(ok) + } +} + +func TestRange_ComposeContentRangeHeader(t *testing.T) { + var r = rangeutils.NewRange(1, 100) + t.Log(r.ComposeContentRangeHeader("1000")) +} diff --git a/internal/utils/readers/filter_reader.go b/internal/utils/readers/filter_reader.go deleted file mode 100644 index 4e462e1..0000000 --- a/internal/utils/readers/filter_reader.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. - -package readers - -import "io" - -type FilterFunc = func(p []byte, err error) error - -type FilterReader struct { - rawReader io.Reader - filters []FilterFunc -} - -func NewFilterReader(rawReader io.Reader) *FilterReader { - return &FilterReader{ - rawReader: rawReader, - } -} - -func (this *FilterReader) Add(filter FilterFunc) { - this.filters = append(this.filters, filter) -} - -func (this *FilterReader) Read(p []byte) (n int, err error) { - n, err = this.rawReader.Read(p) - for _, filter := range this.filters { - filterErr := filter(p[:n], err) - if filterErr != nil { - err = filterErr - return - } - } - return -} diff --git a/internal/utils/readers/handlers.go b/internal/utils/readers/handlers.go new file mode 100644 index 0000000..ac12e5e --- /dev/null +++ b/internal/utils/readers/handlers.go @@ -0,0 +1,3 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers diff --git a/internal/utils/readers/reader_base.go b/internal/utils/readers/reader_base.go new file mode 100644 index 0000000..ee7bb38 --- /dev/null +++ b/internal/utils/readers/reader_base.go @@ -0,0 +1,6 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers + +type BaseReader struct { +} diff --git a/internal/utils/readers/bytes_counter_reader.go b/internal/utils/readers/reader_bytes_counter.go similarity index 100% rename from internal/utils/readers/bytes_counter_reader.go rename to internal/utils/readers/reader_bytes_counter.go diff --git a/internal/utils/readers/reader_closer_byte_ranges.go b/internal/utils/readers/reader_closer_byte_ranges.go new file mode 100644 index 0000000..b28ccce --- /dev/null +++ b/internal/utils/readers/reader_closer_byte_ranges.go @@ -0,0 +1,157 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers + +import ( + "bytes" + "github.com/iwind/TeaGo/types" + "io" + "mime/multipart" + "net/textproto" + "regexp" + "strings" +) + +type OnPartReadHandler func(start int64, end int64, total int64, data []byte, header textproto.MIMEHeader) + +var contentRangeRegexp = regexp.MustCompile(`^(\d+)-(\d+)/(\d+|\*)`) + +type ByteRangesReaderCloser struct { + BaseReader + + rawReader io.ReadCloser + boundary string + + mReader *multipart.Reader + part *multipart.Part + + buf *bytes.Buffer + isEOF bool + + onPartReadHandler OnPartReadHandler + rangeStart int64 + rangeEnd int64 + total int64 + + isStarted bool + nl string +} + +func NewByteRangesReaderCloser(reader io.ReadCloser, boundary string) *ByteRangesReaderCloser { + return &ByteRangesReaderCloser{ + rawReader: reader, + mReader: multipart.NewReader(reader, boundary), + boundary: boundary, + buf: &bytes.Buffer{}, + nl: "\r\n", + } +} + +func (this *ByteRangesReaderCloser) Read(p []byte) (n int, err error) { + n, err = this.read(p) + return +} + +func (this *ByteRangesReaderCloser) Close() error { + return this.rawReader.Close() +} + +func (this *ByteRangesReaderCloser) OnPartRead(handler OnPartReadHandler) { + this.onPartReadHandler = handler +} + +func (this *ByteRangesReaderCloser) read(p []byte) (n int, err error) { + // read from buffer + n, err = this.buf.Read(p) + if !this.isEOF { + err = nil + } + if n > 0 { + return + } + if this.isEOF { + return + } + + if this.part == nil { + part, partErr := this.mReader.NextPart() + if partErr != nil { + if partErr == io.EOF { + this.buf.WriteString(this.nl + "--" + this.boundary + "--" + this.nl) + this.isEOF = true + n, _ = this.buf.Read(p) + return + } + + return 0, partErr + } + + if !this.isStarted { + this.isStarted = true + this.buf.WriteString("--" + this.boundary + this.nl) + } else { + this.buf.WriteString(this.nl + "--" + this.boundary + this.nl) + } + + // Headers + var hasRange = false + for k, v := range part.Header { + for _, v1 := range v { + this.buf.WriteString(k + ": " + v1 + this.nl) + + // parse range + if k == "Content-Range" { + var bytesPrefix = "bytes " + if strings.HasPrefix(v1, bytesPrefix) { + var r = v1[len(bytesPrefix):] + var matches = contentRangeRegexp.FindStringSubmatch(r) + if len(matches) > 2 { + var start = types.Int64(matches[1]) + var end = types.Int64(matches[2]) + var total int64 = 0 + if matches[3] != "*" { + total = types.Int64(matches[3]) + } + if start <= end { + hasRange = true + this.rangeStart = start + this.rangeEnd = end + this.total = total + } + } + } + } + } + } + + if !hasRange { + this.rangeStart = -1 + this.rangeEnd = -1 + } + + this.buf.WriteString(this.nl) + this.part = part + + n, _ = this.buf.Read(p) + return + } + + n, err = this.part.Read(p) + + if this.onPartReadHandler != nil && n > 0 && this.rangeStart >= 0 && this.rangeEnd >= 0 { + this.onPartReadHandler(this.rangeStart, this.rangeEnd, this.total, p[:n], this.part.Header) + this.rangeStart += int64(n) + } + + if err == io.EOF { + this.part = nil + err = nil + + // 如果没有读取到内容,则直接跳到下一个Part + if n == 0 { + return this.read(p) + } + } + + return +} diff --git a/internal/utils/readers/reader_closer_byte_ranges_test.go b/internal/utils/readers/reader_closer_byte_ranges_test.go new file mode 100644 index 0000000..894ac21 --- /dev/null +++ b/internal/utils/readers/reader_closer_byte_ranges_test.go @@ -0,0 +1,52 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers_test + +import ( + "bytes" + "fmt" + "github.com/TeaOSLab/EdgeNode/internal/utils/readers" + "io" + "io/ioutil" + "net/textproto" + "testing" +) + +func TestNewByteRangesReader(t *testing.T) { + var boundary = "7143cd51d2ee12a1" + var dashBoundary = "--" + boundary + var b = bytes.NewReader([]byte(dashBoundary + "\r\nContent-Range: bytes 0-4/36\r\nContent-Type: text/plain\r\n\r\n01234\r\n" + dashBoundary + "\r\nContent-Range: bytes 5-9/36\r\nContent-Type: text/plain\r\n\r\n56789\r\n--" + boundary + "\r\nContent-Range: bytes 10-12/36\r\nContent-Type: text/plain\r\n\r\nabc\r\n" + dashBoundary + "--\r\n")) + + var reader = readers.NewByteRangesReaderCloser(ioutil.NopCloser(b), boundary) + var p = make([]byte, 16) + for { + n, err := reader.Read(p) + if n > 0 { + fmt.Print(string(p[:n])) + } + if err != nil { + if err != io.EOF { + t.Fatal(err) + } + break + } + } +} + +func TestByteRangesReader_OnPartRead(t *testing.T) { + var boundary = "7143cd51d2ee12a1" + var dashBoundary = "--" + boundary + var b = bytes.NewReader([]byte(dashBoundary + "\r\nContent-Range: bytes 0-4/36\r\nContent-Type: text/plain\r\n\r\n01234\r\n" + dashBoundary + "\r\nContent-Range: bytes 5-9/36\r\nContent-Type: text/plain\r\n\r\n56789\r\n--" + boundary + "\r\nContent-Range: bytes 10-12/36\r\nContent-Type: text/plain\r\n\r\nabc\r\n" + dashBoundary + "--\r\n")) + + var reader = readers.NewByteRangesReaderCloser(ioutil.NopCloser(b), boundary) + reader.OnPartRead(func(start int64, end int64, total int64, data []byte, header textproto.MIMEHeader) { + t.Log(start, "-", end, "/", total, string(data)) + }) + var p = make([]byte, 3) + for { + _, err := reader.Read(p) + if err != nil { + break + } + } +} diff --git a/internal/utils/readers/reader_closer_filter.go b/internal/utils/readers/reader_closer_filter.go new file mode 100644 index 0000000..a4a0cd0 --- /dev/null +++ b/internal/utils/readers/reader_closer_filter.go @@ -0,0 +1,42 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package readers + +import "io" + +type FilterFunc = func(p []byte, err error) error + +type FilterReaderCloser struct { + rawReader io.Reader + filters []FilterFunc +} + +func NewFilterReaderCloser(rawReader io.Reader) *FilterReaderCloser { + return &FilterReaderCloser{ + rawReader: rawReader, + } +} + +func (this *FilterReaderCloser) Add(filter FilterFunc) { + this.filters = append(this.filters, filter) +} + +func (this *FilterReaderCloser) Read(p []byte) (n int, err error) { + n, err = this.rawReader.Read(p) + for _, filter := range this.filters { + filterErr := filter(p[:n], err) + if (err == nil || err != io.EOF) && filterErr != nil { + err = filterErr + return + } + } + return +} + +func (this *FilterReaderCloser) Close() error { + closer, ok := this.rawReader.(io.Closer) + if ok { + return closer.Close() + } + return nil +} diff --git a/internal/utils/readers/filter_reader_test.go b/internal/utils/readers/reader_closer_filter_test.go similarity index 90% rename from internal/utils/readers/filter_reader_test.go rename to internal/utils/readers/reader_closer_filter_test.go index 87808b2..eaa41d5 100644 --- a/internal/utils/readers/filter_reader_test.go +++ b/internal/utils/readers/reader_closer_filter_test.go @@ -10,7 +10,7 @@ import ( ) func TestNewFilterReader(t *testing.T) { - var reader = readers.NewFilterReader(bytes.NewBufferString("0123456789")) + var reader = readers.NewFilterReaderCloser(bytes.NewBufferString("0123456789")) reader.Add(func(p []byte, err error) error { t.Log("filter1:", string(p), err) return nil diff --git a/internal/utils/readers/tee_reader_closer.go b/internal/utils/readers/reader_closer_tee.go similarity index 100% rename from internal/utils/readers/tee_reader_closer.go rename to internal/utils/readers/reader_closer_tee.go diff --git a/internal/utils/readers/tee_reader.go b/internal/utils/readers/reader_tee.go similarity index 100% rename from internal/utils/readers/tee_reader.go rename to internal/utils/readers/reader_tee.go diff --git a/internal/utils/writers/bytes_counter_writer.go b/internal/utils/writers/writer_bytes_counter.go similarity index 100% rename from internal/utils/writers/bytes_counter_writer.go rename to internal/utils/writers/writer_bytes_counter.go diff --git a/internal/utils/writers/tee_writer_closer.go b/internal/utils/writers/writer_closer_tee.go similarity index 98% rename from internal/utils/writers/tee_writer_closer.go rename to internal/utils/writers/writer_closer_tee.go index b5d908b..b14092d 100644 --- a/internal/utils/writers/tee_writer_closer.go +++ b/internal/utils/writers/writer_closer_tee.go @@ -2,7 +2,9 @@ package writers -import "io" +import ( + "io" +) type TeeWriterCloser struct { primaryW io.WriteCloser diff --git a/internal/utils/writers/rate_limit_writer.go b/internal/utils/writers/writer_rate_limit.go similarity index 100% rename from internal/utils/writers/rate_limit_writer.go rename to internal/utils/writers/writer_rate_limit.go diff --git a/internal/utils/writers/rate_limit_writer_test.go b/internal/utils/writers/writer_rate_limit_test.go similarity index 100% rename from internal/utils/writers/rate_limit_writer_test.go rename to internal/utils/writers/writer_rate_limit_test.go diff --git a/internal/waf/utils/utils_test.go b/internal/waf/utils/utils_test.go index 4ee83bd..d5a0612 100644 --- a/internal/waf/utils/utils_test.go +++ b/internal/waf/utils/utils_test.go @@ -52,8 +52,9 @@ func TestMatchBytesCache_WithoutCache(t *testing.T) { func BenchmarkMatchStringCache(b *testing.B) { runtime.GOMAXPROCS(1) - data := strings.Repeat("HELLO", 512) - regex := re.MustCompile(`(?iU)\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\b`) + var data = strings.Repeat("HELLO", 512) + var regex = re.MustCompile(`(?iU)\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\b`) + //b.Log(regex.Keywords()) _ = MatchStringCache(regex, data) for i := 0; i < b.N; i++ {