mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 07:40:56 +08:00 
			
		
		
		
	实现基础的206 partial content缓存
This commit is contained in:
		@@ -4,28 +4,37 @@ package caches
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PartialRanges 内容分区范围定义
 | 
			
		||||
type PartialRanges struct {
 | 
			
		||||
	ranges [][2]int64
 | 
			
		||||
	Ranges [][2]int64 `json:"ranges"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewPartialRanges 获取新对象
 | 
			
		||||
func NewPartialRanges() *PartialRanges {
 | 
			
		||||
	return &PartialRanges{ranges: [][2]int64{}}
 | 
			
		||||
	return &PartialRanges{Ranges: [][2]int64{}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewPartialRangesFromJSON 从JSON中解析范围
 | 
			
		||||
func NewPartialRangesFromJSON(data []byte) (*PartialRanges, error) {
 | 
			
		||||
	var rs = [][2]int64{}
 | 
			
		||||
	var rs = NewPartialRanges()
 | 
			
		||||
	err := json.Unmarshal(data, &rs)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	var r = NewPartialRanges()
 | 
			
		||||
	r.ranges = rs
 | 
			
		||||
	return r, nil
 | 
			
		||||
 | 
			
		||||
	return rs, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPartialRangesFromFile(path string) (*PartialRanges, error) {
 | 
			
		||||
	data, err := ioutil.ReadFile(path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return NewPartialRangesFromJSON(data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Add 添加新范围
 | 
			
		||||
@@ -36,46 +45,41 @@ func (this *PartialRanges) Add(begin int64, end int64) {
 | 
			
		||||
 | 
			
		||||
	var nr = [2]int64{begin, end}
 | 
			
		||||
 | 
			
		||||
	var count = len(this.ranges)
 | 
			
		||||
	var count = len(this.Ranges)
 | 
			
		||||
	if count == 0 {
 | 
			
		||||
		this.ranges = [][2]int64{nr}
 | 
			
		||||
		this.Ranges = [][2]int64{nr}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// insert
 | 
			
		||||
	// TODO 将来使用二分法改进
 | 
			
		||||
	var index = -1
 | 
			
		||||
	for i, r := range this.ranges {
 | 
			
		||||
	for i, r := range this.Ranges {
 | 
			
		||||
		if r[0] > begin || (r[0] == begin && r[1] >= end) {
 | 
			
		||||
			index = i
 | 
			
		||||
			this.ranges = append(this.ranges, [2]int64{})
 | 
			
		||||
			copy(this.ranges[index+1:], this.ranges[index:])
 | 
			
		||||
			this.ranges[index] = nr
 | 
			
		||||
			this.Ranges = append(this.Ranges, [2]int64{})
 | 
			
		||||
			copy(this.Ranges[index+1:], this.Ranges[index:])
 | 
			
		||||
			this.Ranges[index] = nr
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if index == -1 {
 | 
			
		||||
		index = count
 | 
			
		||||
		this.ranges = append(this.ranges, nr)
 | 
			
		||||
		this.Ranges = append(this.Ranges, nr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.merge(index)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Ranges 获取所有范围
 | 
			
		||||
func (this *PartialRanges) Ranges() [][2]int64 {
 | 
			
		||||
	return this.ranges
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains 检查是否包含某个范围
 | 
			
		||||
func (this *PartialRanges) Contains(begin int64, end int64) bool {
 | 
			
		||||
	if len(this.ranges) == 0 {
 | 
			
		||||
		return true
 | 
			
		||||
	if len(this.Ranges) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO 使用二分法查找改进性能
 | 
			
		||||
	for _, r2 := range this.ranges {
 | 
			
		||||
	for _, r2 := range this.Ranges {
 | 
			
		||||
		if r2[0] <= begin && r2[1] >= end {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
@@ -86,21 +90,46 @@ func (this *PartialRanges) Contains(begin int64, end int64) bool {
 | 
			
		||||
 | 
			
		||||
// AsJSON 转换为JSON
 | 
			
		||||
func (this *PartialRanges) AsJSON() ([]byte, error) {
 | 
			
		||||
	return json.Marshal(this.ranges)
 | 
			
		||||
	return json.Marshal(this)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WriteToFile 写入到文件中
 | 
			
		||||
func (this *PartialRanges) WriteToFile(path string) error {
 | 
			
		||||
	data, err := this.AsJSON()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("convert to json failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	return ioutil.WriteFile(path, data, 0666)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadFromFile 从文件中读取
 | 
			
		||||
func (this *PartialRanges) ReadFromFile(path string) (*PartialRanges, error) {
 | 
			
		||||
	data, err := ioutil.ReadFile(path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return NewPartialRangesFromJSON(data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialRanges) Max() int64 {
 | 
			
		||||
	if len(this.Ranges) > 0 {
 | 
			
		||||
		return this.Ranges[len(this.Ranges)-1][1]
 | 
			
		||||
	}
 | 
			
		||||
	return 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialRanges) merge(index int) {
 | 
			
		||||
	// forward
 | 
			
		||||
	var lastIndex = index
 | 
			
		||||
	for i := index; i >= 1; i-- {
 | 
			
		||||
		var curr = this.ranges[i]
 | 
			
		||||
		var prev = this.ranges[i-1]
 | 
			
		||||
		var curr = this.Ranges[i]
 | 
			
		||||
		var prev = this.Ranges[i-1]
 | 
			
		||||
		var w1 = this.w(curr)
 | 
			
		||||
		var w2 = this.w(prev)
 | 
			
		||||
		if w1+w2 >= this.max(curr[1], prev[1])-this.min(curr[0], prev[0])-1 {
 | 
			
		||||
			prev = [2]int64{this.min(curr[0], prev[0]), this.max(curr[1], prev[1])}
 | 
			
		||||
			this.ranges[i-1] = prev
 | 
			
		||||
			this.ranges = append(this.ranges[:i], this.ranges[i+1:]...)
 | 
			
		||||
			this.Ranges[i-1] = prev
 | 
			
		||||
			this.Ranges = append(this.Ranges[:i], this.Ranges[i+1:]...)
 | 
			
		||||
			lastIndex = i - 1
 | 
			
		||||
		} else {
 | 
			
		||||
			break
 | 
			
		||||
@@ -109,15 +138,15 @@ func (this *PartialRanges) merge(index int) {
 | 
			
		||||
 | 
			
		||||
	// backward
 | 
			
		||||
	index = lastIndex
 | 
			
		||||
	for index < len(this.ranges)-1 {
 | 
			
		||||
		var curr = this.ranges[index]
 | 
			
		||||
		var next = this.ranges[index+1]
 | 
			
		||||
	for index < len(this.Ranges)-1 {
 | 
			
		||||
		var curr = this.Ranges[index]
 | 
			
		||||
		var next = this.Ranges[index+1]
 | 
			
		||||
		var w1 = this.w(curr)
 | 
			
		||||
		var w2 = this.w(next)
 | 
			
		||||
		if w1+w2 >= this.max(curr[1], next[1])-this.min(curr[0], next[0])-1 {
 | 
			
		||||
			curr = [2]int64{this.min(curr[0], next[0]), this.max(curr[1], next[1])}
 | 
			
		||||
			this.ranges = append(this.ranges[:index], this.ranges[index+1:]...)
 | 
			
		||||
			this.ranges[index] = curr
 | 
			
		||||
			this.Ranges = append(this.Ranges[:index], this.Ranges[index+1:]...)
 | 
			
		||||
			this.Ranges[index] = curr
 | 
			
		||||
		} else {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,7 +21,8 @@ func TestNewPartialRanges(t *testing.T) {
 | 
			
		||||
	r.Add(200, 1000)
 | 
			
		||||
	r.Add(200, 10040)
 | 
			
		||||
 | 
			
		||||
	logs.PrintAsJSON(r.Ranges())
 | 
			
		||||
	logs.PrintAsJSON(r.Ranges, t)
 | 
			
		||||
	t.Log("max:", r.Max())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewPartialRanges1(t *testing.T) {
 | 
			
		||||
@@ -35,7 +36,7 @@ func TestNewPartialRanges1(t *testing.T) {
 | 
			
		||||
	r.Add(200, 300)
 | 
			
		||||
	r.Add(1, 1000)
 | 
			
		||||
 | 
			
		||||
	var rs = r.Ranges()
 | 
			
		||||
	var rs = r.Ranges
 | 
			
		||||
	logs.PrintAsJSON(rs, t)
 | 
			
		||||
	a.IsTrue(len(rs) == 1)
 | 
			
		||||
	if len(rs) == 1 {
 | 
			
		||||
@@ -56,7 +57,7 @@ func TestNewPartialRanges2(t *testing.T) {
 | 
			
		||||
	r.Add(303, 304)
 | 
			
		||||
	r.Add(250, 400)
 | 
			
		||||
 | 
			
		||||
	var rs = r.Ranges()
 | 
			
		||||
	var rs = r.Ranges
 | 
			
		||||
	logs.PrintAsJSON(rs, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -68,7 +69,7 @@ func TestNewPartialRanges3(t *testing.T) {
 | 
			
		||||
	r.Add(200, 300)
 | 
			
		||||
	r.Add(250, 400)
 | 
			
		||||
 | 
			
		||||
	var rs = r.Ranges()
 | 
			
		||||
	var rs = r.Ranges
 | 
			
		||||
	logs.PrintAsJSON(rs, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -83,7 +84,7 @@ func TestNewPartialRanges4(t *testing.T) {
 | 
			
		||||
	r.Add(410, 415)
 | 
			
		||||
	r.Add(400, 409)
 | 
			
		||||
 | 
			
		||||
	var rs = r.Ranges()
 | 
			
		||||
	var rs = r.Ranges
 | 
			
		||||
	logs.PrintAsJSON(rs, t)
 | 
			
		||||
	t.Log(r.Contains(400, 416))
 | 
			
		||||
}
 | 
			
		||||
@@ -93,7 +94,7 @@ func TestNewPartialRanges5(t *testing.T) {
 | 
			
		||||
	for j := 0; j < 1000; j++ {
 | 
			
		||||
		r.Add(int64(j), int64(j+100))
 | 
			
		||||
	}
 | 
			
		||||
	logs.PrintAsJSON(r.Ranges(), t)
 | 
			
		||||
	logs.PrintAsJSON(r.Ranges, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewPartialRanges_AsJSON(t *testing.T) {
 | 
			
		||||
@@ -111,7 +112,7 @@ func TestNewPartialRanges_AsJSON(t *testing.T) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log(r2.Ranges())
 | 
			
		||||
	t.Log(r2.Ranges)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkNewPartialRanges(b *testing.B) {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,7 @@
 | 
			
		||||
package caches
 | 
			
		||||
 | 
			
		||||
import "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
 | 
			
		||||
type ReaderFunc func(n int) (goNext bool, err error)
 | 
			
		||||
 | 
			
		||||
type Reader interface {
 | 
			
		||||
@@ -36,6 +38,9 @@ type Reader interface {
 | 
			
		||||
	// BodySize Body Size
 | 
			
		||||
	BodySize() int64
 | 
			
		||||
 | 
			
		||||
	// ContainsRange 是否包含某个区间内容
 | 
			
		||||
	ContainsRange(r rangeutils.Range) bool
 | 
			
		||||
 | 
			
		||||
	// Close 关闭
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package caches
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -332,6 +333,11 @@ func (this *FileReader) ReadBodyRange(buf []byte, start int64, end int64, callba
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContainsRange 是否包含某些区间内容
 | 
			
		||||
func (this *FileReader) ContainsRange(r rangeutils.Range) bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FileReader) Close() error {
 | 
			
		||||
	if this.openFileCache != nil {
 | 
			
		||||
		if this.isClosed {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package caches
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
	"io"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -197,6 +198,11 @@ func (this *MemoryReader) ReadBodyRange(buf []byte, start int64, end int64, call
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContainsRange 是否包含某些区间内容
 | 
			
		||||
func (this *MemoryReader) ContainsRange(r rangeutils.Range) bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *MemoryReader) Close() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										142
									
								
								internal/caches/reader_partial_file.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								internal/caches/reader_partial_file.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,142 @@
 | 
			
		||||
package caches
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PartialFileReader struct {
 | 
			
		||||
	*FileReader
 | 
			
		||||
 | 
			
		||||
	ranges    *PartialRanges
 | 
			
		||||
	rangePath string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPartialFileReader(fp *os.File) *PartialFileReader {
 | 
			
		||||
	// range path
 | 
			
		||||
	var path = fp.Name()
 | 
			
		||||
	var dotIndex = strings.LastIndex(path, ".")
 | 
			
		||||
	var rangePath = ""
 | 
			
		||||
	if dotIndex < 0 {
 | 
			
		||||
		rangePath = path + "@ranges.cache"
 | 
			
		||||
	} else {
 | 
			
		||||
		rangePath = path[:dotIndex] + "@ranges" + path[dotIndex:]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &PartialFileReader{
 | 
			
		||||
		FileReader: NewFileReader(fp),
 | 
			
		||||
		rangePath:  rangePath,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialFileReader) Init() error {
 | 
			
		||||
	return this.InitAutoDiscard(true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialFileReader) InitAutoDiscard(autoDiscard bool) error {
 | 
			
		||||
	if this.openFile != nil {
 | 
			
		||||
		this.meta = this.openFile.meta
 | 
			
		||||
		this.header = this.openFile.header
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isOk := false
 | 
			
		||||
 | 
			
		||||
	if autoDiscard {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if !isOk {
 | 
			
		||||
				_ = this.discard()
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 读取Range
 | 
			
		||||
	ranges, err := NewPartialRangesFromFile(this.rangePath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("read ranges failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	this.ranges = ranges
 | 
			
		||||
 | 
			
		||||
	var buf = this.meta
 | 
			
		||||
	if len(buf) == 0 {
 | 
			
		||||
		buf = make([]byte, SizeMeta)
 | 
			
		||||
		ok, err := this.readToBuff(this.fp, buf)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if !ok {
 | 
			
		||||
			return ErrNotFound
 | 
			
		||||
		}
 | 
			
		||||
		this.meta = buf
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.expiresAt = int64(binary.BigEndian.Uint32(buf[:SizeExpiresAt]))
 | 
			
		||||
 | 
			
		||||
	status := types.Int(string(buf[SizeExpiresAt : SizeExpiresAt+SizeStatus]))
 | 
			
		||||
	if status < 100 || status > 999 {
 | 
			
		||||
		return errors.New("invalid status")
 | 
			
		||||
	}
 | 
			
		||||
	this.status = status
 | 
			
		||||
 | 
			
		||||
	// URL
 | 
			
		||||
	urlLength := binary.BigEndian.Uint32(buf[SizeExpiresAt+SizeStatus : SizeExpiresAt+SizeStatus+SizeURLLength])
 | 
			
		||||
 | 
			
		||||
	// header
 | 
			
		||||
	headerSize := int(binary.BigEndian.Uint32(buf[SizeExpiresAt+SizeStatus+SizeURLLength : SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength]))
 | 
			
		||||
	if headerSize == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	this.headerSize = headerSize
 | 
			
		||||
	this.headerOffset = int64(SizeMeta) + int64(urlLength)
 | 
			
		||||
 | 
			
		||||
	// body
 | 
			
		||||
	this.bodyOffset = this.headerOffset + int64(headerSize)
 | 
			
		||||
	bodySize := int(binary.BigEndian.Uint64(buf[SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength : SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength+SizeBodyLength]))
 | 
			
		||||
	if bodySize == 0 {
 | 
			
		||||
		isOk = true
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	this.bodySize = int64(bodySize)
 | 
			
		||||
 | 
			
		||||
	// read header
 | 
			
		||||
	if this.openFileCache != nil && len(this.header) == 0 {
 | 
			
		||||
		if headerSize > 0 && headerSize <= 512 {
 | 
			
		||||
			this.header = make([]byte, headerSize)
 | 
			
		||||
			_, err := this.fp.Seek(this.headerOffset, io.SeekStart)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			_, err = this.readToBuff(this.fp, this.header)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isOk = true
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContainsRange 是否包含某些区间内容
 | 
			
		||||
// 这里的 r 是已经经过格式化的
 | 
			
		||||
func (this *PartialFileReader) ContainsRange(r rangeutils.Range) bool {
 | 
			
		||||
	return this.ranges.Contains(r.Start(), r.End())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MaxLength 获取区间最大长度
 | 
			
		||||
func (this *PartialFileReader) MaxLength() int64 {
 | 
			
		||||
	if this.bodySize > 0 {
 | 
			
		||||
		return this.bodySize
 | 
			
		||||
	}
 | 
			
		||||
	return this.ranges.Max() + 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialFileReader) discard() error {
 | 
			
		||||
	_ = os.Remove(this.rangePath)
 | 
			
		||||
	return this.FileReader.discard()
 | 
			
		||||
}
 | 
			
		||||
@@ -216,19 +216,24 @@ func (this *FileStorage) Init() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FileStorage) OpenReader(key string, useStale bool) (Reader, error) {
 | 
			
		||||
	return this.openReader(key, true, useStale)
 | 
			
		||||
func (this *FileStorage) OpenReader(key string, useStale bool, isPartial bool) (Reader, error) {
 | 
			
		||||
	return this.openReader(key, true, useStale, isPartial)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool) (Reader, error) {
 | 
			
		||||
func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool, isPartial bool) (Reader, error) {
 | 
			
		||||
	// 使用陈旧缓存的时候,我们认为是短暂的,只需要从文件里检查即可
 | 
			
		||||
	if useStale {
 | 
			
		||||
		allowMemory = false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 区间缓存只存在文件中
 | 
			
		||||
	if isPartial {
 | 
			
		||||
		allowMemory = false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 先尝试内存缓存
 | 
			
		||||
	if allowMemory && this.memoryStorage != nil {
 | 
			
		||||
		reader, err := this.memoryStorage.OpenReader(key, useStale)
 | 
			
		||||
		reader, err := this.memoryStorage.OpenReader(key, useStale, isPartial)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return reader, err
 | 
			
		||||
		}
 | 
			
		||||
@@ -273,9 +278,18 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var reader = NewFileReader(fp)
 | 
			
		||||
	reader.openFile = openFile
 | 
			
		||||
	reader.openFileCache = this.openFileCache
 | 
			
		||||
	var reader Reader
 | 
			
		||||
	if isPartial {
 | 
			
		||||
		var partialFileReader = NewPartialFileReader(fp)
 | 
			
		||||
		partialFileReader.openFile = openFile
 | 
			
		||||
		partialFileReader.openFileCache = this.openFileCache
 | 
			
		||||
		reader = partialFileReader
 | 
			
		||||
	} else {
 | 
			
		||||
		var fileReader = NewFileReader(fp)
 | 
			
		||||
		fileReader.openFile = openFile
 | 
			
		||||
		fileReader.openFileCache = this.openFileCache
 | 
			
		||||
		reader = fileReader
 | 
			
		||||
	}
 | 
			
		||||
	err = reader.Init()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -284,40 +298,7 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool)
 | 
			
		||||
	// 增加点击量
 | 
			
		||||
	// 1/1000采样
 | 
			
		||||
	if allowMemory {
 | 
			
		||||
		var rate = this.policy.PersistenceHitSampleRate
 | 
			
		||||
		if rate <= 0 {
 | 
			
		||||
			rate = 1000
 | 
			
		||||
		}
 | 
			
		||||
		if this.lastHotSize == 0 {
 | 
			
		||||
			// 自动降低采样率来增加热点数据的缓存几率
 | 
			
		||||
			rate = rate / 10
 | 
			
		||||
		}
 | 
			
		||||
		if rands.Int(0, rate) == 0 {
 | 
			
		||||
			var hitErr = this.list.IncreaseHit(hash)
 | 
			
		||||
			if hitErr != nil {
 | 
			
		||||
				// 此错误可以忽略
 | 
			
		||||
				remotelogs.Error("CACHE", "increase hit failed: "+hitErr.Error())
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 增加到热点
 | 
			
		||||
			// 这里不收录缓存尺寸过大的文件
 | 
			
		||||
			if this.memoryStorage != nil && reader.BodySize() > 0 && reader.BodySize() < 128*1024*1024 {
 | 
			
		||||
				this.hotMapLocker.Lock()
 | 
			
		||||
				hotItem, ok := this.hotMap[key]
 | 
			
		||||
				if ok {
 | 
			
		||||
					hotItem.Hits++
 | 
			
		||||
					hotItem.ExpiresAt = reader.expiresAt
 | 
			
		||||
				} else if len(this.hotMap) < HotItemSize { // 控制数量
 | 
			
		||||
					this.hotMap[key] = &HotItem{
 | 
			
		||||
						Key:       key,
 | 
			
		||||
						ExpiresAt: reader.ExpiresAt(),
 | 
			
		||||
						Status:    reader.Status(),
 | 
			
		||||
						Hits:      1,
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				this.hotMapLocker.Unlock()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		this.increaseHit(key, hash, reader)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isOk = true
 | 
			
		||||
@@ -380,7 +361,8 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int, siz
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查缓存是否已经生成
 | 
			
		||||
	var cachePath = dir + "/" + hash + ".cache"
 | 
			
		||||
	var cachePathName = dir + "/" + hash
 | 
			
		||||
	var cachePath = cachePathName + ".cache"
 | 
			
		||||
	stat, err := os.Stat(cachePath)
 | 
			
		||||
	if err == nil && time.Now().Sub(stat.ModTime()) <= 1*time.Second {
 | 
			
		||||
		// 防止并发连续写入
 | 
			
		||||
@@ -388,7 +370,7 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int, siz
 | 
			
		||||
	}
 | 
			
		||||
	var tmpPath = cachePath + ".tmp"
 | 
			
		||||
	if isPartial {
 | 
			
		||||
		tmpPath = cachePath
 | 
			
		||||
		tmpPath = cachePathName + ".cache"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 先删除
 | 
			
		||||
@@ -502,7 +484,12 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int, siz
 | 
			
		||||
 | 
			
		||||
	isOk = true
 | 
			
		||||
	if isPartial {
 | 
			
		||||
		return NewPartialFileWriter(writer, key, expiredAt, isNewCreated, isPartial, partialBodyOffset, func() {
 | 
			
		||||
		ranges, err := NewPartialRangesFromFile(cachePathName + "@ranges.cache")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			ranges = NewPartialRanges()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return NewPartialFileWriter(writer, key, expiredAt, isNewCreated, isPartial, partialBodyOffset, ranges, func() {
 | 
			
		||||
			sharedWritingFileKeyLocker.Lock()
 | 
			
		||||
			delete(sharedWritingFileKeyMap, key)
 | 
			
		||||
			sharedWritingFileKeyLocker.Unlock()
 | 
			
		||||
@@ -923,7 +910,7 @@ func (this *FileStorage) hotLoop() {
 | 
			
		||||
		var buf = utils.BytePool16k.Get()
 | 
			
		||||
		defer utils.BytePool16k.Put(buf)
 | 
			
		||||
		for _, item := range result[:size] {
 | 
			
		||||
			reader, err := this.openReader(item.Key, false, false)
 | 
			
		||||
			reader, err := this.openReader(item.Key, false, false, false)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
@@ -1025,3 +1012,41 @@ func (this *FileStorage) cleanDeletedDirs(dir string) error {
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 增加某个Key的点击量
 | 
			
		||||
func (this *FileStorage) increaseHit(key string, hash string, reader Reader) {
 | 
			
		||||
	var rate = this.policy.PersistenceHitSampleRate
 | 
			
		||||
	if rate <= 0 {
 | 
			
		||||
		rate = 1000
 | 
			
		||||
	}
 | 
			
		||||
	if this.lastHotSize == 0 {
 | 
			
		||||
		// 自动降低采样率来增加热点数据的缓存几率
 | 
			
		||||
		rate = rate / 10
 | 
			
		||||
	}
 | 
			
		||||
	if rands.Int(0, rate) == 0 {
 | 
			
		||||
		var hitErr = this.list.IncreaseHit(hash)
 | 
			
		||||
		if hitErr != nil {
 | 
			
		||||
			// 此错误可以忽略
 | 
			
		||||
			remotelogs.Error("CACHE", "increase hit failed: "+hitErr.Error())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 增加到热点
 | 
			
		||||
		// 这里不收录缓存尺寸过大的文件
 | 
			
		||||
		if this.memoryStorage != nil && reader.BodySize() > 0 && reader.BodySize() < 128*1024*1024 {
 | 
			
		||||
			this.hotMapLocker.Lock()
 | 
			
		||||
			hotItem, ok := this.hotMap[key]
 | 
			
		||||
			if ok {
 | 
			
		||||
				hotItem.Hits++
 | 
			
		||||
				hotItem.ExpiresAt = reader.ExpiresAt()
 | 
			
		||||
			} else if len(this.hotMap) < HotItemSize { // 控制数量
 | 
			
		||||
				this.hotMap[key] = &HotItem{
 | 
			
		||||
					Key:       key,
 | 
			
		||||
					ExpiresAt: reader.ExpiresAt(),
 | 
			
		||||
					Status:    reader.Status(),
 | 
			
		||||
					Hits:      1,
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			this.hotMapLocker.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -110,7 +110,7 @@ func TestFileStorage_OpenWriter_Partial(t *testing.T) {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = writer.WriteAt([]byte("Hello, World"), 0)
 | 
			
		||||
	err = writer.WriteAt(0, []byte("Hello, World"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -311,7 +311,7 @@ func TestFileStorage_Read(t *testing.T) {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	reader, err := storage.OpenReader("my-key", false)
 | 
			
		||||
	reader, err := storage.OpenReader("my-key", false, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -347,7 +347,7 @@ func TestFileStorage_Read_HTTP_Response(t *testing.T) {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	reader, err := storage.OpenReader("my-http-response", false)
 | 
			
		||||
	reader, err := storage.OpenReader("my-http-response", false, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -401,7 +401,7 @@ func TestFileStorage_Read_NotFound(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	buf := make([]byte, 6)
 | 
			
		||||
	reader, err := storage.OpenReader("my-key-10000", false)
 | 
			
		||||
	reader, err := storage.OpenReader("my-key-10000", false, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == ErrNotFound {
 | 
			
		||||
			t.Log("cache not fund")
 | 
			
		||||
@@ -543,7 +543,7 @@ func BenchmarkFileStorage_Read(b *testing.B) {
 | 
			
		||||
		b.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		reader, err := storage.OpenReader("my-key", false)
 | 
			
		||||
		reader, err := storage.OpenReader("my-key", false, false)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			b.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,7 +10,7 @@ type StorageInterface interface {
 | 
			
		||||
	Init() error
 | 
			
		||||
 | 
			
		||||
	// OpenReader 读取缓存
 | 
			
		||||
	OpenReader(key string, useStale bool) (reader Reader, err error)
 | 
			
		||||
	OpenReader(key string, useStale bool, isPartial bool) (reader Reader, err error)
 | 
			
		||||
 | 
			
		||||
	// OpenWriter 打开缓存写入器等待写入
 | 
			
		||||
	OpenWriter(key string, expiredAt int64, status int, size int64, isPartial bool) (Writer, error)
 | 
			
		||||
 
 | 
			
		||||
@@ -105,7 +105,7 @@ func (this *MemoryStorage) Init() error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenReader 读取缓存
 | 
			
		||||
func (this *MemoryStorage) OpenReader(key string, useStale bool) (Reader, error) {
 | 
			
		||||
func (this *MemoryStorage) OpenReader(key string, useStale bool, isPartial bool) (Reader, error) {
 | 
			
		||||
	hash := this.hash(key)
 | 
			
		||||
 | 
			
		||||
	this.locker.RLock()
 | 
			
		||||
 
 | 
			
		||||
@@ -25,7 +25,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
 | 
			
		||||
	t.Log(storage.valuesMap)
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		reader, err := storage.OpenReader("abc", false)
 | 
			
		||||
		reader, err := storage.OpenReader("abc", false, false)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err == ErrNotFound {
 | 
			
		||||
				t.Log("not found: abc")
 | 
			
		||||
@@ -52,7 +52,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		_, err := storage.OpenReader("abc 2", false)
 | 
			
		||||
		_, err := storage.OpenReader("abc 2", false, false)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err == ErrNotFound {
 | 
			
		||||
				t.Log("not found: abc2")
 | 
			
		||||
@@ -68,7 +68,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	_, _ = writer.Write([]byte("Hello123"))
 | 
			
		||||
	{
 | 
			
		||||
		reader, err := storage.OpenReader("abc", false)
 | 
			
		||||
		reader, err := storage.OpenReader("abc", false, false)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err == ErrNotFound {
 | 
			
		||||
				t.Log("not found: abc")
 | 
			
		||||
@@ -97,7 +97,7 @@ func TestMemoryStorage_OpenReaderLock(t *testing.T) {
 | 
			
		||||
			IsDone: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	_, _ = storage.OpenReader("test", false)
 | 
			
		||||
	_, _ = storage.OpenReader("test", false, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMemoryStorage_Delete(t *testing.T) {
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ type Writer interface {
 | 
			
		||||
	Write(data []byte) (n int, err error)
 | 
			
		||||
 | 
			
		||||
	// WriteAt 在指定位置写入数据
 | 
			
		||||
	WriteAt(data []byte, offset int64) error
 | 
			
		||||
	WriteAt(offset int64, data []byte) error
 | 
			
		||||
 | 
			
		||||
	// HeaderSize 写入的Header数据大小
 | 
			
		||||
	HeaderSize() int64
 | 
			
		||||
 
 | 
			
		||||
@@ -67,7 +67,7 @@ func (this *FileWriter) Write(data []byte) (n int, err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WriteAt 在指定位置写入数据
 | 
			
		||||
func (this *FileWriter) WriteAt(data []byte, offset int64) error {
 | 
			
		||||
func (this *FileWriter) WriteAt(offset int64, data []byte) error {
 | 
			
		||||
	_ = data
 | 
			
		||||
	_ = offset
 | 
			
		||||
	return errors.New("not supported")
 | 
			
		||||
 
 | 
			
		||||
@@ -57,7 +57,7 @@ func (this *MemoryWriter) Write(data []byte) (n int, err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WriteAt 在指定位置写入数据
 | 
			
		||||
func (this *MemoryWriter) WriteAt(b []byte, offset int64) error {
 | 
			
		||||
func (this *MemoryWriter) WriteAt(offset int64, b []byte) error {
 | 
			
		||||
	_ = b
 | 
			
		||||
	_ = offset
 | 
			
		||||
	return errors.New("not supported")
 | 
			
		||||
 
 | 
			
		||||
@@ -23,9 +23,22 @@ type PartialFileWriter struct {
 | 
			
		||||
	isNew      bool
 | 
			
		||||
	isPartial  bool
 | 
			
		||||
	bodyOffset int64
 | 
			
		||||
 | 
			
		||||
	ranges    *PartialRanges
 | 
			
		||||
	rangePath string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, isNew bool, isPartial bool, bodyOffset int64, endFunc func()) *PartialFileWriter {
 | 
			
		||||
func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, isNew bool, isPartial bool, bodyOffset int64, ranges *PartialRanges, endFunc func()) *PartialFileWriter {
 | 
			
		||||
	var path = rawWriter.Name()
 | 
			
		||||
	// ranges路径
 | 
			
		||||
	var dotIndex = strings.LastIndex(path, ".")
 | 
			
		||||
	var rangePath = ""
 | 
			
		||||
	if dotIndex < 0 {
 | 
			
		||||
		rangePath = path + "@ranges.cache"
 | 
			
		||||
	} else {
 | 
			
		||||
		rangePath = path[:dotIndex] + "@ranges" + path[dotIndex:]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &PartialFileWriter{
 | 
			
		||||
		key:        key,
 | 
			
		||||
		rawWriter:  rawWriter,
 | 
			
		||||
@@ -34,6 +47,8 @@ func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, isNew
 | 
			
		||||
		isNew:      isNew,
 | 
			
		||||
		isPartial:  isPartial,
 | 
			
		||||
		bodyOffset: bodyOffset,
 | 
			
		||||
		ranges:     ranges,
 | 
			
		||||
		rangePath:  rangePath,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -50,6 +65,21 @@ func (this *PartialFileWriter) WriteHeader(data []byte) (n int, err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialFileWriter) AppendHeader(data []byte) error {
 | 
			
		||||
	_, err := this.rawWriter.Write(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		_ = this.Discard()
 | 
			
		||||
	} else {
 | 
			
		||||
		var c = len(data)
 | 
			
		||||
		this.headerSize += int64(c)
 | 
			
		||||
		err = this.WriteHeaderLength(int(this.headerSize))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			_ = this.Discard()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WriteHeaderLength 写入Header长度数据
 | 
			
		||||
func (this *PartialFileWriter) WriteHeaderLength(headerLength int) error {
 | 
			
		||||
	bytes4 := make([]byte, 4)
 | 
			
		||||
@@ -78,12 +108,34 @@ func (this *PartialFileWriter) Write(data []byte) (n int, err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WriteAt 在指定位置写入数据
 | 
			
		||||
func (this *PartialFileWriter) WriteAt(data []byte, offset int64) error {
 | 
			
		||||
func (this *PartialFileWriter) WriteAt(offset int64, data []byte) error {
 | 
			
		||||
	var c = int64(len(data))
 | 
			
		||||
	if c == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	var end = offset + c - 1
 | 
			
		||||
 | 
			
		||||
	// 是否已包含在内
 | 
			
		||||
	if this.ranges.Contains(offset, end) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if this.bodyOffset == 0 {
 | 
			
		||||
		this.bodyOffset = SizeMeta + int64(len(this.key)) + this.headerSize
 | 
			
		||||
	}
 | 
			
		||||
	_, err := this.rawWriter.WriteAt(data, this.bodyOffset+offset)
 | 
			
		||||
	return err
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.ranges.Add(offset, end)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetBodyLength 设置内容总长度
 | 
			
		||||
func (this *PartialFileWriter) SetBodyLength(bodyLength int64) {
 | 
			
		||||
	this.bodySize = bodyLength
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WriteBodyLength 写入Body长度数据
 | 
			
		||||
@@ -109,31 +161,30 @@ func (this *PartialFileWriter) Close() error {
 | 
			
		||||
		this.endFunc()
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var path = this.rawWriter.Name()
 | 
			
		||||
	err := this.ranges.WriteToFile(this.rangePath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 关闭当前writer
 | 
			
		||||
	if this.isNew {
 | 
			
		||||
		err := this.WriteHeaderLength(types.Int(this.headerSize))
 | 
			
		||||
		err = this.WriteHeaderLength(types.Int(this.headerSize))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			_ = this.rawWriter.Close()
 | 
			
		||||
			_ = os.Remove(path)
 | 
			
		||||
			this.remove()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		err = this.WriteBodyLength(this.bodySize)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			_ = this.rawWriter.Close()
 | 
			
		||||
			_ = os.Remove(path)
 | 
			
		||||
			this.remove()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := this.rawWriter.Close()
 | 
			
		||||
	err = this.rawWriter.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		_ = os.Remove(path)
 | 
			
		||||
	} else if !this.isPartial {
 | 
			
		||||
		err = os.Rename(path, strings.Replace(path, ".tmp", "", 1))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			_ = os.Remove(path)
 | 
			
		||||
		}
 | 
			
		||||
		this.remove()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
@@ -147,6 +198,8 @@ func (this *PartialFileWriter) Discard() error {
 | 
			
		||||
 | 
			
		||||
	_ = this.rawWriter.Close()
 | 
			
		||||
 | 
			
		||||
	_ = os.Remove(this.rangePath)
 | 
			
		||||
 | 
			
		||||
	err := os.Remove(this.rawWriter.Name())
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
@@ -171,3 +224,12 @@ func (this *PartialFileWriter) Key() string {
 | 
			
		||||
func (this *PartialFileWriter) ItemType() ItemType {
 | 
			
		||||
	return ItemTypeFile
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialFileWriter) IsNew() bool {
 | 
			
		||||
	return this.isNew && len(this.ranges.Ranges) == 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *PartialFileWriter) remove() {
 | 
			
		||||
	_ = os.Remove(this.rawWriter.Name())
 | 
			
		||||
	_ = os.Remove(this.rangePath)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -11,8 +11,8 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestPartialFileWriter_SeekOffset(t *testing.T) {
 | 
			
		||||
	var path = "/tmp/test@partial.cache"
 | 
			
		||||
func TestPartialFileWriter_Write(t *testing.T) {
 | 
			
		||||
	var path = "/tmp/test_partial.cache"
 | 
			
		||||
	_ = os.Remove(path)
 | 
			
		||||
 | 
			
		||||
	var reader = func() {
 | 
			
		||||
@@ -27,7 +27,8 @@ func TestPartialFileWriter_SeekOffset(t *testing.T) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	var writer = caches.NewPartialFileWriter(fp, "test", time.Now().Unix()+86500, true, true, 0, func() {
 | 
			
		||||
	var ranges = caches.NewPartialRanges()
 | 
			
		||||
	var writer = caches.NewPartialFileWriter(fp, "test", time.Now().Unix()+86500, true, true, 0, ranges, func() {
 | 
			
		||||
		t.Log("end")
 | 
			
		||||
	})
 | 
			
		||||
	_, err = writer.WriteHeader([]byte("header"))
 | 
			
		||||
@@ -36,7 +37,7 @@ func TestPartialFileWriter_SeekOffset(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 移动位置
 | 
			
		||||
	err = writer.WriteAt([]byte("HELLO"), 100)
 | 
			
		||||
	err = writer.WriteAt(100, []byte("HELLO"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -240,7 +240,7 @@ func (this *APIStream) handleReadCache(message *pb.NodeStreamMessage) error {
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reader, err := storage.OpenReader(msg.Key, false)
 | 
			
		||||
	reader, err := storage.OpenReader(msg.Key, false, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == caches.ErrNotFound {
 | 
			
		||||
			this.replyFail(message.RequestId, "key not found")
 | 
			
		||||
@@ -351,7 +351,11 @@ func (this *APIStream) handlePurgeCache(message *pb.NodeStreamMessage) error {
 | 
			
		||||
	if msg.Type == "file" {
 | 
			
		||||
		var keys = msg.Keys
 | 
			
		||||
		for _, key := range keys {
 | 
			
		||||
			keys = append(keys, key+webpCacheSuffix, key+cacheMethodSuffix+"HEAD")
 | 
			
		||||
			keys = append(keys,
 | 
			
		||||
				key+cacheMethodSuffix+"HEAD",
 | 
			
		||||
				key+webpCacheSuffix,
 | 
			
		||||
				key+cachePartialSuffix,
 | 
			
		||||
			)
 | 
			
		||||
			// TODO 根据实际缓存的内容进行组合
 | 
			
		||||
			for _, encoding := range compressions.AllEncodings() {
 | 
			
		||||
				keys = append(keys, key+compressionCacheSuffix+encoding)
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,8 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
@@ -36,6 +38,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 添加 X-Cache Header
 | 
			
		||||
	var addStatusHeader = this.web.Cache.AddStatusHeader
 | 
			
		||||
	if addStatusHeader {
 | 
			
		||||
		defer func() {
 | 
			
		||||
@@ -137,7 +140,12 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	if this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey {
 | 
			
		||||
		this.varMapping["cache.status"] = "PURGE"
 | 
			
		||||
 | 
			
		||||
		var subKeys = []string{key, key + cacheMethodSuffix + "HEAD"}
 | 
			
		||||
		var subKeys = []string{
 | 
			
		||||
			key,
 | 
			
		||||
			key + cacheMethodSuffix + "HEAD",
 | 
			
		||||
			key + webpCacheSuffix,
 | 
			
		||||
			key + cachePartialSuffix,
 | 
			
		||||
		}
 | 
			
		||||
		// TODO 根据实际缓存的内容进行组合
 | 
			
		||||
		for _, encoding := range compressions.AllEncodings() {
 | 
			
		||||
			subKeys = append(subKeys, key+compressionCacheSuffix+encoding)
 | 
			
		||||
@@ -180,10 +188,14 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	var reader caches.Reader
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	var rangeHeader = this.RawReq.Header.Get("Range")
 | 
			
		||||
	var isPartialRequest = len(rangeHeader) > 0
 | 
			
		||||
 | 
			
		||||
	// 检查是否支持WebP
 | 
			
		||||
	var webPIsEnabled = false
 | 
			
		||||
	var isHeadMethod = method == http.MethodHead
 | 
			
		||||
	if !isHeadMethod &&
 | 
			
		||||
	if !isPartialRequest &&
 | 
			
		||||
		!isHeadMethod &&
 | 
			
		||||
		this.web.WebP != nil &&
 | 
			
		||||
		this.web.WebP.IsOn &&
 | 
			
		||||
		this.web.WebP.MatchRequest(filepath.Ext(this.Path()), this.Format) &&
 | 
			
		||||
@@ -192,13 +204,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查压缩缓存
 | 
			
		||||
	if !isHeadMethod && reader == nil {
 | 
			
		||||
	if !isPartialRequest && !isHeadMethod && reader == nil {
 | 
			
		||||
		if this.web.Compression != nil && this.web.Compression.IsOn {
 | 
			
		||||
			_, encoding, ok := this.web.Compression.MatchAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"))
 | 
			
		||||
			if ok {
 | 
			
		||||
				// 检查支持WebP的压缩缓存
 | 
			
		||||
				if webPIsEnabled {
 | 
			
		||||
					reader, _ = storage.OpenReader(key+webpCacheSuffix+compressionCacheSuffix+encoding, useStale)
 | 
			
		||||
					reader, _ = storage.OpenReader(key+webpCacheSuffix+compressionCacheSuffix+encoding, useStale, false)
 | 
			
		||||
					if reader != nil {
 | 
			
		||||
						tags = append(tags, "webp", encoding)
 | 
			
		||||
					}
 | 
			
		||||
@@ -206,7 +218,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
 | 
			
		||||
				// 检查普通缓存
 | 
			
		||||
				if reader == nil {
 | 
			
		||||
					reader, _ = storage.OpenReader(key+compressionCacheSuffix+encoding, useStale)
 | 
			
		||||
					reader, _ = storage.OpenReader(key+compressionCacheSuffix+encoding, useStale, false)
 | 
			
		||||
					if reader != nil {
 | 
			
		||||
						tags = append(tags, encoding)
 | 
			
		||||
					}
 | 
			
		||||
@@ -216,8 +228,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查WebP
 | 
			
		||||
	if !isHeadMethod && reader == nil && webPIsEnabled {
 | 
			
		||||
		reader, _ = storage.OpenReader(key+webpCacheSuffix, useStale)
 | 
			
		||||
	if !isPartialRequest &&
 | 
			
		||||
		!isHeadMethod &&
 | 
			
		||||
		reader == nil &&
 | 
			
		||||
		webPIsEnabled {
 | 
			
		||||
		reader, _ = storage.OpenReader(key+webpCacheSuffix, useStale, false)
 | 
			
		||||
		if reader != nil {
 | 
			
		||||
			this.writer.cacheReaderSuffix = webpCacheSuffix
 | 
			
		||||
			tags = append(tags, "webp")
 | 
			
		||||
@@ -225,8 +240,18 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查正常的文件
 | 
			
		||||
	var isPartialCache = false
 | 
			
		||||
	if reader == nil {
 | 
			
		||||
		reader, err = storage.OpenReader(key, useStale)
 | 
			
		||||
		reader, err = storage.OpenReader(key, useStale, false)
 | 
			
		||||
		if err != nil && this.cacheRef.AllowPartialContent {
 | 
			
		||||
			pReader := this.tryPartialReader(storage, key, useStale, rangeHeader)
 | 
			
		||||
			if pReader != nil {
 | 
			
		||||
				isPartialCache = true
 | 
			
		||||
				reader = pReader
 | 
			
		||||
				err = nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err == caches.ErrNotFound {
 | 
			
		||||
				// cache相关变量
 | 
			
		||||
@@ -260,7 +285,16 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 准备Buffer
 | 
			
		||||
	var pool = this.bytePool(reader.BodySize())
 | 
			
		||||
	var fileSize = reader.BodySize()
 | 
			
		||||
	var totalSizeString = types.String(fileSize)
 | 
			
		||||
	if isPartialCache {
 | 
			
		||||
		fileSize = reader.(*caches.PartialFileReader).MaxLength()
 | 
			
		||||
		if totalSizeString == "0" {
 | 
			
		||||
			totalSizeString = "*"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var pool = this.bytePool(fileSize)
 | 
			
		||||
	var buf = pool.Get()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		pool.Put(buf)
 | 
			
		||||
@@ -323,7 +357,9 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
			eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\""
 | 
			
		||||
		}
 | 
			
		||||
		respHeader.Del("Etag")
 | 
			
		||||
		respHeader["ETag"] = []string{eTag}
 | 
			
		||||
		if !isPartialCache {
 | 
			
		||||
			respHeader["ETag"] = []string{eTag}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 支持 Last-Modified
 | 
			
		||||
@@ -331,11 +367,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	var modifiedTime = ""
 | 
			
		||||
	if lastModifiedAt > 0 {
 | 
			
		||||
		modifiedTime = time.Unix(utils.GMTUnixTime(lastModifiedAt), 0).Format("Mon, 02 Jan 2006 15:04:05") + " GMT"
 | 
			
		||||
		respHeader.Set("Last-Modified", modifiedTime)
 | 
			
		||||
		if !isPartialCache {
 | 
			
		||||
			respHeader.Set("Last-Modified", modifiedTime)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 支持 If-None-Match
 | 
			
		||||
	if len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
 | 
			
		||||
	if !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
 | 
			
		||||
		// 自定义Header
 | 
			
		||||
		this.processResponseHeaders(http.StatusNotModified)
 | 
			
		||||
		this.writer.WriteHeader(http.StatusNotModified)
 | 
			
		||||
@@ -346,7 +384,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 支持 If-Modified-Since
 | 
			
		||||
	if len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
 | 
			
		||||
	if !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
 | 
			
		||||
		// 自定义Header
 | 
			
		||||
		this.processResponseHeaders(http.StatusNotModified)
 | 
			
		||||
		this.writer.WriteHeader(http.StatusNotModified)
 | 
			
		||||
@@ -364,69 +402,55 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
		this.writer.WriteHeader(reader.Status())
 | 
			
		||||
	} else {
 | 
			
		||||
		ifRangeHeaders, ok := this.RawReq.Header["If-Range"]
 | 
			
		||||
		supportRange := true
 | 
			
		||||
		var supportRange = true
 | 
			
		||||
		if ok {
 | 
			
		||||
			supportRange = false
 | 
			
		||||
			for _, v := range ifRangeHeaders {
 | 
			
		||||
				if v == this.writer.Header().Get("ETag") || v == this.writer.Header().Get("Last-Modified") {
 | 
			
		||||
					supportRange = true
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 支持Range
 | 
			
		||||
		rangeSet := [][]int64{}
 | 
			
		||||
		var ranges = []rangeutils.Range{}
 | 
			
		||||
		if supportRange {
 | 
			
		||||
			fileSize := reader.BodySize()
 | 
			
		||||
			contentRange := this.RawReq.Header.Get("Range")
 | 
			
		||||
			if len(contentRange) > 0 {
 | 
			
		||||
			if len(rangeHeader) > 0 {
 | 
			
		||||
				if fileSize == 0 {
 | 
			
		||||
					this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
					this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				set, ok := httpRequestParseContentRange(contentRange)
 | 
			
		||||
				set, ok := httpRequestParseRangeHeader(rangeHeader)
 | 
			
		||||
				if !ok {
 | 
			
		||||
					this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
					this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				if len(set) > 0 {
 | 
			
		||||
					rangeSet = set
 | 
			
		||||
					for _, arr := range rangeSet {
 | 
			
		||||
						if arr[0] == -1 {
 | 
			
		||||
							arr[0] = fileSize + arr[1]
 | 
			
		||||
							arr[1] = fileSize - 1
 | 
			
		||||
 | 
			
		||||
							if arr[0] < 0 {
 | 
			
		||||
								this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
								this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
								return true
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						if arr[1] < 0 {
 | 
			
		||||
							arr[1] = fileSize - 1
 | 
			
		||||
						}
 | 
			
		||||
						if arr[1] >= fileSize {
 | 
			
		||||
							arr[1] = fileSize - 1
 | 
			
		||||
						}
 | 
			
		||||
						if arr[0] > arr[1] {
 | 
			
		||||
					ranges = set
 | 
			
		||||
					for k, r := range ranges {
 | 
			
		||||
						r2, ok := r.Convert(fileSize)
 | 
			
		||||
						if !ok {
 | 
			
		||||
							this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
							this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
							return true
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						ranges[k] = r2
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(rangeSet) == 1 {
 | 
			
		||||
			respHeader.Set("Content-Range", "bytes "+strconv.FormatInt(rangeSet[0][0], 10)+"-"+strconv.FormatInt(rangeSet[0][1], 10)+"/"+strconv.FormatInt(reader.BodySize(), 10))
 | 
			
		||||
			respHeader.Set("Content-Length", strconv.FormatInt(rangeSet[0][1]-rangeSet[0][0]+1, 10))
 | 
			
		||||
		if len(ranges) == 1 {
 | 
			
		||||
			respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(totalSizeString))
 | 
			
		||||
			respHeader.Set("Content-Length", strconv.FormatInt(ranges[0].Length(), 10))
 | 
			
		||||
			this.writer.WriteHeader(http.StatusPartialContent)
 | 
			
		||||
 | 
			
		||||
			err = reader.ReadBodyRange(buf, rangeSet[0][0], rangeSet[0][1], func(n int) (goNext bool, err error) {
 | 
			
		||||
			err = reader.ReadBodyRange(buf, ranges[0].Start(), ranges[0].End(), func(n int) (goNext bool, err error) {
 | 
			
		||||
				_, err = this.writer.Write(buf[:n])
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return false, errWritingToClient
 | 
			
		||||
@@ -446,15 +470,15 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else if len(rangeSet) > 1 {
 | 
			
		||||
			boundary := httpRequestGenBoundary()
 | 
			
		||||
		} else if len(ranges) > 1 {
 | 
			
		||||
			var boundary = httpRequestGenBoundary()
 | 
			
		||||
			respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
 | 
			
		||||
			respHeader.Del("Content-Length")
 | 
			
		||||
			contentType := respHeader.Get("Content-Type")
 | 
			
		||||
 | 
			
		||||
			this.writer.WriteHeader(http.StatusPartialContent)
 | 
			
		||||
 | 
			
		||||
			for index, set := range rangeSet {
 | 
			
		||||
			for index, r := range ranges {
 | 
			
		||||
				if index == 0 {
 | 
			
		||||
					_, err = this.writer.WriteString("--" + boundary + "\r\n")
 | 
			
		||||
				} else {
 | 
			
		||||
@@ -465,7 +489,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				_, err = this.writer.WriteString("Content-Range: " + "bytes " + strconv.FormatInt(set[0], 10) + "-" + strconv.FormatInt(set[1], 10) + "/" + strconv.FormatInt(reader.BodySize(), 10) + "\r\n")
 | 
			
		||||
				_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(totalSizeString) + "\r\n")
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					// 不提示写入客户端错误
 | 
			
		||||
					return true
 | 
			
		||||
@@ -479,7 +503,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				err := reader.ReadBodyRange(buf, set[0], set[1], func(n int) (goNext bool, err error) {
 | 
			
		||||
				err := reader.ReadBodyRange(buf, r.Start(), r.End(), func(n int) (goNext bool, err error) {
 | 
			
		||||
					_, err = this.writer.Write(buf[:n])
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return false, errWritingToClient
 | 
			
		||||
@@ -503,7 +527,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
 | 
			
		||||
			}
 | 
			
		||||
		} else { // 没有Range
 | 
			
		||||
			var resp = &http.Response{Body: reader}
 | 
			
		||||
			this.writer.Prepare(resp, reader.BodySize(), reader.Status(), false)
 | 
			
		||||
			this.writer.Prepare(resp, fileSize, reader.Status(), false)
 | 
			
		||||
			this.writer.WriteHeader(reader.Status())
 | 
			
		||||
 | 
			
		||||
			_, err = io.CopyBuffer(this.writer, resp.Body, buf)
 | 
			
		||||
@@ -544,3 +568,47 @@ func (this *HTTPRequest) addExpiresHeader(expiresAt int64) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 尝试读取区间缓存
 | 
			
		||||
func (this *HTTPRequest) tryPartialReader(storage caches.StorageInterface, key string, useStale bool, rangeHeader string) caches.Reader {
 | 
			
		||||
	// 尝试读取Partial cache
 | 
			
		||||
	if len(rangeHeader) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ranges, ok := httpRequestParseRangeHeader(rangeHeader)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pReader, pErr := storage.OpenReader(key+cachePartialSuffix, useStale, true)
 | 
			
		||||
	if pErr != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	partialReader, ok := pReader.(*caches.PartialFileReader)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		_ = pReader.Close()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	var isOk = false
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if !isOk {
 | 
			
		||||
			_ = pReader.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// 检查范围
 | 
			
		||||
	for _, r := range ranges {
 | 
			
		||||
		r1, ok := r.Convert(partialReader.MaxLength())
 | 
			
		||||
		if !ok {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		if !partialReader.ContainsRange(r1) {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isOk = true
 | 
			
		||||
	return pReader
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,10 +2,12 @@ package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/zero"
 | 
			
		||||
	"github.com/cespare/xxhash"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"mime"
 | 
			
		||||
@@ -186,7 +188,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// length
 | 
			
		||||
	fileSize := stat.Size()
 | 
			
		||||
	var fileSize = stat.Size()
 | 
			
		||||
 | 
			
		||||
	// 支持 Last-Modified
 | 
			
		||||
	modifiedTime := stat.ModTime().Format("Mon, 02 Jan 2006 15:04:05 GMT")
 | 
			
		||||
@@ -231,6 +233,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
		for _, v := range ifRangeHeaders {
 | 
			
		||||
			if v == eTag || v == modifiedTime {
 | 
			
		||||
				supportRange = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if !supportRange {
 | 
			
		||||
@@ -239,7 +242,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 支持Range
 | 
			
		||||
	rangeSet := [][]int64{}
 | 
			
		||||
	var ranges = []rangeutils.Range{}
 | 
			
		||||
	if supportRange {
 | 
			
		||||
		contentRange := this.RawReq.Header.Get("Range")
 | 
			
		||||
		if len(contentRange) > 0 {
 | 
			
		||||
@@ -249,36 +252,22 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			set, ok := httpRequestParseContentRange(contentRange)
 | 
			
		||||
			set, ok := httpRequestParseRangeHeader(contentRange)
 | 
			
		||||
			if !ok {
 | 
			
		||||
				this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
				this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if len(set) > 0 {
 | 
			
		||||
				rangeSet = set
 | 
			
		||||
				for _, arr := range rangeSet {
 | 
			
		||||
					if arr[0] == -1 {
 | 
			
		||||
						arr[0] = fileSize + arr[1]
 | 
			
		||||
						arr[1] = fileSize - 1
 | 
			
		||||
 | 
			
		||||
						if arr[0] < 0 {
 | 
			
		||||
							this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
							this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
							return true
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					if arr[1] > 0 {
 | 
			
		||||
						arr[1] = fileSize - 1
 | 
			
		||||
					}
 | 
			
		||||
					if arr[1] < 0 {
 | 
			
		||||
						arr[1] = fileSize - 1
 | 
			
		||||
					}
 | 
			
		||||
					if arr[0] > arr[1] {
 | 
			
		||||
				ranges = set
 | 
			
		||||
				for k, r := range ranges {
 | 
			
		||||
					r2, ok := r.Convert(fileSize)
 | 
			
		||||
					if !ok {
 | 
			
		||||
						this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
						this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
						return true
 | 
			
		||||
					}
 | 
			
		||||
					ranges[k] = r2
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -298,7 +287,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
	this.processResponseHeaders(http.StatusOK)
 | 
			
		||||
 | 
			
		||||
	// 在Range请求中不能缓存
 | 
			
		||||
	if len(rangeSet) > 0 {
 | 
			
		||||
	if len(ranges) > 0 {
 | 
			
		||||
		this.cacheRef = nil // 不支持缓存
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -311,11 +300,11 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
		pool.Put(buf)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if len(rangeSet) == 1 {
 | 
			
		||||
		respHeader.Set("Content-Range", "bytes "+strconv.FormatInt(rangeSet[0][0], 10)+"-"+strconv.FormatInt(rangeSet[0][1], 10)+"/"+strconv.FormatInt(fileSize, 10))
 | 
			
		||||
	if len(ranges) == 1 {
 | 
			
		||||
		respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(types.String(fileSize)))
 | 
			
		||||
		this.writer.WriteHeader(http.StatusPartialContent)
 | 
			
		||||
 | 
			
		||||
		ok, err := httpRequestReadRange(reader, buf, rangeSet[0][0], rangeSet[0][1], func(buf []byte, n int) error {
 | 
			
		||||
		ok, err := httpRequestReadRange(reader, buf, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error {
 | 
			
		||||
			_, err := this.writer.Write(buf[:n])
 | 
			
		||||
			return err
 | 
			
		||||
		})
 | 
			
		||||
@@ -328,13 +317,13 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
			this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	} else if len(rangeSet) > 1 {
 | 
			
		||||
	} else if len(ranges) > 1 {
 | 
			
		||||
		boundary := httpRequestGenBoundary()
 | 
			
		||||
		respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
 | 
			
		||||
 | 
			
		||||
		this.writer.WriteHeader(http.StatusPartialContent)
 | 
			
		||||
 | 
			
		||||
		for index, set := range rangeSet {
 | 
			
		||||
		for index, r := range ranges {
 | 
			
		||||
			if index == 0 {
 | 
			
		||||
				_, err = this.writer.WriteString("--" + boundary + "\r\n")
 | 
			
		||||
			} else {
 | 
			
		||||
@@ -345,7 +334,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			_, err = this.writer.WriteString("Content-Range: " + "bytes " + strconv.FormatInt(set[0], 10) + "-" + strconv.FormatInt(set[1], 10) + "/" + strconv.FormatInt(fileSize, 10) + "\r\n")
 | 
			
		||||
			_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(types.String(fileSize)) + "\r\n")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logs.Error(err)
 | 
			
		||||
				return true
 | 
			
		||||
@@ -359,7 +348,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ok, err := httpRequestReadRange(reader, buf, set[0], set[1], func(buf []byte, n int) error {
 | 
			
		||||
			ok, err := httpRequestReadRange(reader, buf, r.Start(), r.End(), func(buf []byte, n int) error {
 | 
			
		||||
				_, err := this.writer.Write(buf[:n])
 | 
			
		||||
				return err
 | 
			
		||||
			})
 | 
			
		||||
 
 | 
			
		||||
@@ -5,15 +5,20 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var contentRangeRegexp = regexp.MustCompile(`^bytes (\d+)-(\d+)/`)
 | 
			
		||||
 | 
			
		||||
// 分解Range
 | 
			
		||||
func httpRequestParseContentRange(rangeValue string) (result [][]int64, ok bool) {
 | 
			
		||||
func httpRequestParseRangeHeader(rangeValue string) (result []rangeutils.Range, ok bool) {
 | 
			
		||||
	// 参考RFC:https://tools.ietf.org/html/rfc7233
 | 
			
		||||
	index := strings.Index(rangeValue, "=")
 | 
			
		||||
	if index == -1 {
 | 
			
		||||
@@ -24,15 +29,15 @@ func httpRequestParseContentRange(rangeValue string) (result [][]int64, ok bool)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rangeSetString := rangeValue[index+1:]
 | 
			
		||||
	var rangeSetString = rangeValue[index+1:]
 | 
			
		||||
	if len(rangeSetString) == 0 {
 | 
			
		||||
		ok = true
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pieces := strings.Split(rangeSetString, ", ")
 | 
			
		||||
	var pieces = strings.Split(rangeSetString, ", ")
 | 
			
		||||
	for _, piece := range pieces {
 | 
			
		||||
		index := strings.Index(piece, "-")
 | 
			
		||||
		index = strings.Index(piece, "-")
 | 
			
		||||
		if index == -1 {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
@@ -70,7 +75,7 @@ func httpRequestParseContentRange(rangeValue string) (result [][]int64, ok bool)
 | 
			
		||||
			lastInt = -lastInt
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		result = append(result, []int64{firstInt, lastInt})
 | 
			
		||||
		result = append(result, [2]int64{firstInt, lastInt})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ok = true
 | 
			
		||||
@@ -119,6 +124,15 @@ func httpRequestReadRange(reader io.Reader, buf []byte, start int64, end int64,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 分解Content-Range
 | 
			
		||||
func httpRequestParseContentRangeHeader(contentRange string) (start int64) {
 | 
			
		||||
	var matches = contentRangeRegexp.FindStringSubmatch(contentRange)
 | 
			
		||||
	if len(matches) < 3 {
 | 
			
		||||
		return -1
 | 
			
		||||
	}
 | 
			
		||||
	return types.Int64(matches[1])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 生成boundary
 | 
			
		||||
// 仿照Golang自带的函数(multipart包)
 | 
			
		||||
func httpRequestGenBoundary() string {
 | 
			
		||||
@@ -130,6 +144,21 @@ func httpRequestGenBoundary() string {
 | 
			
		||||
	return fmt.Sprintf("%x", buf[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从content-type中读取boundary
 | 
			
		||||
func httpRequestParseBoundary(contentType string) string {
 | 
			
		||||
	var delim = "boundary="
 | 
			
		||||
	var boundaryIndex = strings.Index(contentType, delim)
 | 
			
		||||
	if boundaryIndex < 0 {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	var boundary = contentType[boundaryIndex+len(delim):]
 | 
			
		||||
	semicolonIndex := strings.Index(boundary, ";")
 | 
			
		||||
	if semicolonIndex >= 0 {
 | 
			
		||||
		return boundary[:semicolonIndex]
 | 
			
		||||
	}
 | 
			
		||||
	return boundary
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断状态是否为跳转
 | 
			
		||||
func httpStatusIsRedirect(statusCode int) bool {
 | 
			
		||||
	return statusCode == http.StatusPermanentRedirect ||
 | 
			
		||||
 
 | 
			
		||||
@@ -17,52 +17,84 @@ func TestHTTPRequest_httpRequestGenBoundary(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHTTPRequest_httpRequestParseContentRange(t *testing.T) {
 | 
			
		||||
	a := assert.NewAssertion(t)
 | 
			
		||||
func TestHTTPRequest_httpRequestParseBoundary(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
	a.IsTrue(httpRequestParseBoundary("multipart/byteranges") == "")
 | 
			
		||||
	a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123") == "123")
 | 
			
		||||
	a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123; 456") == "123")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHTTPRequest_httpRequestParseRangeHeader(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
	{
 | 
			
		||||
		_, ok := httpRequestParseContentRange("")
 | 
			
		||||
		_, ok := httpRequestParseRangeHeader("")
 | 
			
		||||
		a.IsFalse(ok)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		_, ok := httpRequestParseContentRange("byte=")
 | 
			
		||||
		_, ok := httpRequestParseRangeHeader("byte=")
 | 
			
		||||
		a.IsFalse(ok)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		_, ok := httpRequestParseContentRange("byte=")
 | 
			
		||||
		_, ok := httpRequestParseRangeHeader("byte=")
 | 
			
		||||
		a.IsFalse(ok)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		set, ok := httpRequestParseContentRange("bytes=")
 | 
			
		||||
		set, ok := httpRequestParseRangeHeader("bytes=")
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(len(set) == 0)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		_, ok := httpRequestParseContentRange("bytes=60-50")
 | 
			
		||||
		_, ok := httpRequestParseRangeHeader("bytes=60-50")
 | 
			
		||||
		a.IsFalse(ok)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		set, ok := httpRequestParseContentRange("bytes=0-50")
 | 
			
		||||
		set, ok := httpRequestParseRangeHeader("bytes=0-50")
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(len(set) > 0)
 | 
			
		||||
		t.Log(set)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		set, ok := httpRequestParseContentRange("bytes=0-")
 | 
			
		||||
		set, ok := httpRequestParseRangeHeader("bytes=0-")
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(len(set) > 0)
 | 
			
		||||
		if len(set) > 0 {
 | 
			
		||||
			a.IsTrue(set[0][0] == 0)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log(set)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		set, ok := httpRequestParseRangeHeader("bytes=-50")
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(len(set) > 0)
 | 
			
		||||
		t.Log(set)
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		set, ok := httpRequestParseContentRange("bytes=-50")
 | 
			
		||||
		set, ok := httpRequestParseRangeHeader("bytes=0-50, 60-100")
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(len(set) > 0)
 | 
			
		||||
		t.Log(set)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHTTPRequest_httpRequestParseContentRangeHeader(t *testing.T) {
 | 
			
		||||
	{
 | 
			
		||||
		set, ok := httpRequestParseContentRange("bytes=0-50, 60-100")
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(len(set) > 0)
 | 
			
		||||
		t.Log(set)
 | 
			
		||||
		var c1 = "bytes 0-100/*"
 | 
			
		||||
		t.Log(httpRequestParseContentRangeHeader(c1))
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		var c1 = "bytes 30-100/*"
 | 
			
		||||
		t.Log(httpRequestParseContentRangeHeader(c1))
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		var c1 = "bytes1 0-100/*"
 | 
			
		||||
		t.Log(httpRequestParseContentRangeHeader(c1))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkHTTPRequest_httpRequestParseContentRangeHeader(b *testing.B) {
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		var c1 = "bytes 0-100/*"
 | 
			
		||||
		httpRequestParseContentRangeHeader(c1)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -28,6 +28,7 @@ import (
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/textproto"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -45,6 +46,7 @@ const compressionCacheSuffix = "@GOEDGE_"
 | 
			
		||||
 | 
			
		||||
// 缓存相关配置
 | 
			
		||||
const cacheMethodSuffix = "@GOEDGE_"
 | 
			
		||||
const cachePartialSuffix = "@GOEDGE_partial"
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	var systemMemory = utils.SystemMemoryGB() / 8
 | 
			
		||||
@@ -157,12 +159,21 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
	var addStatusHeader = this.req.web != nil && this.req.web.Cache != nil && this.req.web.Cache.AddStatusHeader
 | 
			
		||||
 | 
			
		||||
	// 不支持Range
 | 
			
		||||
	if this.StatusCode() == http.StatusPartialContent || len(this.Header().Get("Content-Range")) > 0 {
 | 
			
		||||
		this.req.varMapping["cache.status"] = "BYPASS"
 | 
			
		||||
		if addStatusHeader {
 | 
			
		||||
			this.Header().Set("X-Cache", "BYPASS, not supported Content-Range")
 | 
			
		||||
	if this.isPartial {
 | 
			
		||||
		if !cacheRef.AllowPartialContent {
 | 
			
		||||
			this.req.varMapping["cache.status"] = "BYPASS"
 | 
			
		||||
			if addStatusHeader {
 | 
			
		||||
				this.Header().Set("X-Cache", "BYPASS, not supported partial content")
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if this.cacheStorage.Policy().Type != serverconfigs.CachePolicyStorageFile {
 | 
			
		||||
			this.req.varMapping["cache.status"] = "BYPASS"
 | 
			
		||||
			if addStatusHeader {
 | 
			
		||||
				this.Header().Set("X-Cache", "BYPASS, not supported partial content in memory storage")
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果允许 ChunkedEncoding,就无需尺寸的判断,因为此时的 size 为 -1
 | 
			
		||||
@@ -183,7 +194,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查状态
 | 
			
		||||
	if len(cacheRef.Status) > 0 && !lists.ContainsInt(cacheRef.Status, this.StatusCode()) {
 | 
			
		||||
	if !this.isPartial && len(cacheRef.Status) > 0 && !lists.ContainsInt(cacheRef.Status, this.StatusCode()) {
 | 
			
		||||
		this.req.varMapping["cache.status"] = "BYPASS"
 | 
			
		||||
		if addStatusHeader {
 | 
			
		||||
			this.Header().Set("X-Cache", "BYPASS, Status: "+types.String(this.StatusCode()))
 | 
			
		||||
@@ -193,7 +204,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
 | 
			
		||||
	// Cache-Control
 | 
			
		||||
	if len(cacheRef.SkipResponseCacheControlValues) > 0 {
 | 
			
		||||
		var cacheControl = this.Header().Get("Cache-Control")
 | 
			
		||||
		var cacheControl = this.GetHeader("Cache-Control")
 | 
			
		||||
		if len(cacheControl) > 0 {
 | 
			
		||||
			values := strings.Split(cacheControl, ",")
 | 
			
		||||
			for _, value := range values {
 | 
			
		||||
@@ -209,7 +220,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set-Cookie
 | 
			
		||||
	if cacheRef.SkipResponseSetCookie && len(this.Header().Get("Set-Cookie")) > 0 {
 | 
			
		||||
	if cacheRef.SkipResponseSetCookie && len(this.GetHeader("Set-Cookie")) > 0 {
 | 
			
		||||
		this.req.varMapping["cache.status"] = "BYPASS"
 | 
			
		||||
		if addStatusHeader {
 | 
			
		||||
			this.Header().Set("X-Cache", "BYPASS, Set-Cookie")
 | 
			
		||||
@@ -242,7 +253,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.cacheStorage = storage
 | 
			
		||||
	life := cacheRef.LifeSeconds()
 | 
			
		||||
	var life = cacheRef.LifeSeconds()
 | 
			
		||||
 | 
			
		||||
	if life <= 0 {
 | 
			
		||||
		life = 60
 | 
			
		||||
@@ -250,7 +261,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
 | 
			
		||||
	// 支持源站设置的max-age
 | 
			
		||||
	if this.req.web.Cache != nil && this.req.web.Cache.EnableCacheControlMaxAge {
 | 
			
		||||
		var cacheControl = this.Header().Get("Cache-Control")
 | 
			
		||||
		var cacheControl = this.GetHeader("Cache-Control")
 | 
			
		||||
		var pieces = strings.Split(cacheControl, ";")
 | 
			
		||||
		for _, piece := range pieces {
 | 
			
		||||
			var eqIndex = strings.Index(piece, "=")
 | 
			
		||||
@@ -265,7 +276,10 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
 | 
			
		||||
	var expiredAt = utils.UnixTime() + life
 | 
			
		||||
	var cacheKey = this.req.cacheKey
 | 
			
		||||
	cacheWriter, err := storage.OpenWriter(cacheKey, expiredAt, this.StatusCode(), size, false)
 | 
			
		||||
	if this.isPartial {
 | 
			
		||||
		cacheKey += cachePartialSuffix
 | 
			
		||||
	}
 | 
			
		||||
	cacheWriter, err := storage.OpenWriter(cacheKey, expiredAt, this.StatusCode(), size, this.isPartial)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if !caches.CanIgnoreErr(err) {
 | 
			
		||||
			remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error())
 | 
			
		||||
@@ -277,6 +291,9 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
	// 写入Header
 | 
			
		||||
	for k, v := range this.Header() {
 | 
			
		||||
		for _, v1 := range v {
 | 
			
		||||
			if this.isPartial && k == "Content-Type" && strings.Contains(v1, "multipart/byteranges") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			_, err = cacheWriter.WriteHeader([]byte(k + ":" + v1 + "\n"))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error())
 | 
			
		||||
@@ -287,6 +304,101 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if this.isPartial {
 | 
			
		||||
		// content-range
 | 
			
		||||
		var contentRange = this.GetHeader("Content-Range")
 | 
			
		||||
		if len(contentRange) > 0 {
 | 
			
		||||
			var start = httpRequestParseContentRangeHeader(contentRange)
 | 
			
		||||
			if start < 0 {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			var filterReader = readers.NewFilterReaderCloser(resp.Body)
 | 
			
		||||
			this.cacheIsFinished = true
 | 
			
		||||
			var hasError = false
 | 
			
		||||
			filterReader.Add(func(p []byte, err error) error {
 | 
			
		||||
				if hasError {
 | 
			
		||||
					return nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				var l = len(p)
 | 
			
		||||
				if l == 0 {
 | 
			
		||||
					return nil
 | 
			
		||||
				}
 | 
			
		||||
				defer func() {
 | 
			
		||||
					start += int64(l)
 | 
			
		||||
				}()
 | 
			
		||||
				err = cacheWriter.WriteAt(start, p)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					this.cacheIsFinished = false
 | 
			
		||||
					hasError = true
 | 
			
		||||
				}
 | 
			
		||||
				return nil
 | 
			
		||||
			})
 | 
			
		||||
			resp.Body = filterReader
 | 
			
		||||
			this.rawReader = filterReader
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// multipart/byteranges
 | 
			
		||||
		var contentType = this.GetHeader("Content-Type")
 | 
			
		||||
		if strings.Contains(contentType, "multipart/byteranges") {
 | 
			
		||||
			partialWriter, ok := cacheWriter.(*caches.PartialFileWriter)
 | 
			
		||||
			if !ok {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var boundary = httpRequestParseBoundary(contentType)
 | 
			
		||||
			if len(boundary) == 0 {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var reader = readers.NewByteRangesReaderCloser(resp.Body, boundary)
 | 
			
		||||
			var contentTypeWritten = false
 | 
			
		||||
 | 
			
		||||
			this.cacheIsFinished = true
 | 
			
		||||
			var hasError = false
 | 
			
		||||
			var writtenTotal = false
 | 
			
		||||
			reader.OnPartRead(func(start int64, end int64, total int64, data []byte, header textproto.MIMEHeader) {
 | 
			
		||||
				if hasError {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 写入total
 | 
			
		||||
				if !writtenTotal && total > 0 {
 | 
			
		||||
					partialWriter.SetBodyLength(total)
 | 
			
		||||
					writtenTotal = true
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 写入Content-Type
 | 
			
		||||
				if partialWriter.IsNew() && !contentTypeWritten {
 | 
			
		||||
					var realContentType = header.Get("Content-Type")
 | 
			
		||||
					if len(realContentType) > 0 {
 | 
			
		||||
						var h = []byte("Content-Type:" + realContentType + "\n")
 | 
			
		||||
						err = partialWriter.AppendHeader(h)
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							hasError = true
 | 
			
		||||
							this.cacheIsFinished = false
 | 
			
		||||
							return
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					contentTypeWritten = true
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				err := cacheWriter.WriteAt(start, data)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					hasError = true
 | 
			
		||||
					this.cacheIsFinished = false
 | 
			
		||||
				}
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			resp.Body = reader
 | 
			
		||||
			this.rawReader = reader
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var cacheReader = readers.NewTeeReaderCloser(resp.Body, this.cacheWriter)
 | 
			
		||||
	resp.Body = cacheReader
 | 
			
		||||
	this.rawReader = cacheReader
 | 
			
		||||
@@ -306,7 +418,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var contentType = this.Header().Get("Content-Type")
 | 
			
		||||
	var contentType = this.GetHeader("Content-Type")
 | 
			
		||||
 | 
			
		||||
	if this.req.web != nil &&
 | 
			
		||||
		this.req.web.WebP != nil &&
 | 
			
		||||
@@ -324,7 +436,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var contentEncoding = this.Header().Get("Content-Encoding")
 | 
			
		||||
		var contentEncoding = this.GetHeader("Content-Encoding")
 | 
			
		||||
		switch contentEncoding {
 | 
			
		||||
		case "gzip", "deflate", "br":
 | 
			
		||||
			reader, err := compressions.NewReader(resp.Body, contentEncoding)
 | 
			
		||||
@@ -361,7 +473,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var acceptEncodings = this.req.RawReq.Header.Get("Accept-Encoding")
 | 
			
		||||
	var contentEncoding = this.Header().Get("Content-Encoding")
 | 
			
		||||
	var contentEncoding = this.GetHeader("Content-Encoding")
 | 
			
		||||
 | 
			
		||||
	if this.compressionConfig == nil || !this.compressionConfig.IsOn {
 | 
			
		||||
		if lists.ContainsString([]string{"gzip", "deflate", "br"}, contentEncoding) && !httpAcceptEncoding(acceptEncodings, contentEncoding) {
 | 
			
		||||
@@ -386,7 +498,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 尺寸和类型
 | 
			
		||||
	var contentType = this.Header().Get("Content-Type")
 | 
			
		||||
	var contentType = this.GetHeader("Content-Type")
 | 
			
		||||
	if !this.compressionConfig.MatchResponse(contentType, size, filepath.Ext(this.req.Path()), this.req.Format) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -504,6 +616,11 @@ func (this *HTTPWriter) Header() http.Header {
 | 
			
		||||
	return this.rawWriter.Header()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetHeader 读取Header值
 | 
			
		||||
func (this *HTTPWriter) GetHeader(name string) string {
 | 
			
		||||
	return this.Header().Get(name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteHeader 删除Header
 | 
			
		||||
func (this *HTTPWriter) DeleteHeader(name string) {
 | 
			
		||||
	this.rawWriter.Header().Del(name)
 | 
			
		||||
@@ -777,18 +894,19 @@ func (this *HTTPWriter) Close() {
 | 
			
		||||
		if this.isOk && this.cacheIsFinished {
 | 
			
		||||
			// 对比缓存前后的Content-Length
 | 
			
		||||
			var method = this.req.Method()
 | 
			
		||||
			if method != http.MethodHead && this.StatusCode() != http.StatusNoContent {
 | 
			
		||||
				var contentLengthString = this.Header().Get("Content-Length")
 | 
			
		||||
			if method != http.MethodHead && this.StatusCode() != http.StatusNoContent && !this.isPartial {
 | 
			
		||||
				var contentLengthString = this.GetHeader("Content-Length")
 | 
			
		||||
				if len(contentLengthString) > 0 {
 | 
			
		||||
					var contentLength = types.Int64(contentLengthString)
 | 
			
		||||
					if contentLength != this.cacheWriter.BodySize() {
 | 
			
		||||
						this.isOk = false
 | 
			
		||||
						_ = this.cacheWriter.Discard()
 | 
			
		||||
						this.cacheWriter = nil
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if this.isOk {
 | 
			
		||||
			if this.isOk && this.cacheWriter != nil {
 | 
			
		||||
				err := this.cacheWriter.Close()
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					var expiredAt = this.cacheWriter.ExpiredAt()
 | 
			
		||||
@@ -863,7 +981,7 @@ func (this *HTTPWriter) calculateStaleLife() int {
 | 
			
		||||
		// 从Header中读取stale-if-error
 | 
			
		||||
		var isDefinedInHeader = false
 | 
			
		||||
		if staleConfig.SupportStaleIfErrorHeader {
 | 
			
		||||
			var cacheControl = this.Header().Get("Cache-Control")
 | 
			
		||||
			var cacheControl = this.GetHeader("Cache-Control")
 | 
			
		||||
			var pieces = strings.Split(cacheControl, ",")
 | 
			
		||||
			for _, piece := range pieces {
 | 
			
		||||
				var eqIndex = strings.Index(piece, "=")
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										53
									
								
								internal/utils/ranges/range.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								internal/utils/ranges/range.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package rangeutils
 | 
			
		||||
 | 
			
		||||
import "strconv"
 | 
			
		||||
 | 
			
		||||
type Range [2]int64
 | 
			
		||||
 | 
			
		||||
func NewRange(start int64, end int64) Range {
 | 
			
		||||
	return [2]int64{start, end}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this Range) Start() int64 {
 | 
			
		||||
	return this[0]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this Range) End() int64 {
 | 
			
		||||
	return this[1]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this Range) Length() int64 {
 | 
			
		||||
	return this[1] - this[0] + 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this Range) Convert(total int64) (newRange Range, ok bool) {
 | 
			
		||||
	if total <= 0 {
 | 
			
		||||
		return this, false
 | 
			
		||||
	}
 | 
			
		||||
	if this[0] < 0 {
 | 
			
		||||
		this[0] += total
 | 
			
		||||
		if this[0] < 0 {
 | 
			
		||||
			return this, false
 | 
			
		||||
		}
 | 
			
		||||
		this[1] = total - 1
 | 
			
		||||
	}
 | 
			
		||||
	if this[1] < 0 {
 | 
			
		||||
		this[1] = total - 1
 | 
			
		||||
	}
 | 
			
		||||
	if this[1] > total-1 {
 | 
			
		||||
		this[1] = total - 1
 | 
			
		||||
	}
 | 
			
		||||
	if this[0] > this[1] {
 | 
			
		||||
		return this, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ComposeContentRangeHeader 组合Content-Range Header
 | 
			
		||||
// totalSize 可能是一个数字,也可能是一个星号(*)
 | 
			
		||||
func (this Range) ComposeContentRangeHeader(totalSize string) string {
 | 
			
		||||
	return "bytes " + strconv.FormatInt(this[0], 10) + "-" + strconv.FormatInt(this[1], 10) + "/" + totalSize
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										69
									
								
								internal/utils/ranges/range_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								internal/utils/ranges/range_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package rangeutils_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
 | 
			
		||||
	"github.com/iwind/TeaGo/assert"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestRange(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	var r = rangeutils.NewRange(1, 100)
 | 
			
		||||
	a.IsTrue(r.Start() == 1)
 | 
			
		||||
	a.IsTrue(r.End() == 100)
 | 
			
		||||
	t.Log("start:", r.Start(), "end:", r.End())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRange_Convert(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		var r = rangeutils.NewRange(1, 100)
 | 
			
		||||
		newR, ok := r.Convert(200)
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(newR.Start() == 1)
 | 
			
		||||
		a.IsTrue(newR.End() == 100)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		var r = rangeutils.NewRange(1, 100)
 | 
			
		||||
		newR, ok := r.Convert(50)
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(newR.Start() == 1)
 | 
			
		||||
		a.IsTrue(newR.End() == 49)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		var r = rangeutils.NewRange(1, 100)
 | 
			
		||||
		_, ok := r.Convert(0)
 | 
			
		||||
		a.IsFalse(ok)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		var r = rangeutils.NewRange(-30, -1)
 | 
			
		||||
		newR, ok := r.Convert(50)
 | 
			
		||||
		a.IsTrue(ok)
 | 
			
		||||
		a.IsTrue(newR.Start() == 50-30)
 | 
			
		||||
		a.IsTrue(newR.End() == 49)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		var r = rangeutils.NewRange(1000, 100)
 | 
			
		||||
		_, ok := r.Convert(0)
 | 
			
		||||
		a.IsFalse(ok)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		var r = rangeutils.NewRange(50, 100)
 | 
			
		||||
		_, ok := r.Convert(49)
 | 
			
		||||
		a.IsFalse(ok)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRange_ComposeContentRangeHeader(t *testing.T) {
 | 
			
		||||
	var r = rangeutils.NewRange(1, 100)
 | 
			
		||||
	t.Log(r.ComposeContentRangeHeader("1000"))
 | 
			
		||||
}
 | 
			
		||||
@@ -1,34 +0,0 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package readers
 | 
			
		||||
 | 
			
		||||
import "io"
 | 
			
		||||
 | 
			
		||||
type FilterFunc = func(p []byte, err error) error
 | 
			
		||||
 | 
			
		||||
type FilterReader struct {
 | 
			
		||||
	rawReader io.Reader
 | 
			
		||||
	filters   []FilterFunc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFilterReader(rawReader io.Reader) *FilterReader {
 | 
			
		||||
	return &FilterReader{
 | 
			
		||||
		rawReader: rawReader,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FilterReader) Add(filter FilterFunc) {
 | 
			
		||||
	this.filters = append(this.filters, filter)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FilterReader) Read(p []byte) (n int, err error) {
 | 
			
		||||
	n, err = this.rawReader.Read(p)
 | 
			
		||||
	for _, filter := range this.filters {
 | 
			
		||||
		filterErr := filter(p[:n], err)
 | 
			
		||||
		if filterErr != nil {
 | 
			
		||||
			err = filterErr
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										3
									
								
								internal/utils/readers/handlers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								internal/utils/readers/handlers.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package readers
 | 
			
		||||
							
								
								
									
										6
									
								
								internal/utils/readers/reader_base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								internal/utils/readers/reader_base.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package readers
 | 
			
		||||
 | 
			
		||||
type BaseReader struct {
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										157
									
								
								internal/utils/readers/reader_closer_byte_ranges.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								internal/utils/readers/reader_closer_byte_ranges.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,157 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package readers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mime/multipart"
 | 
			
		||||
	"net/textproto"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type OnPartReadHandler func(start int64, end int64, total int64, data []byte, header textproto.MIMEHeader)
 | 
			
		||||
 | 
			
		||||
var contentRangeRegexp = regexp.MustCompile(`^(\d+)-(\d+)/(\d+|\*)`)
 | 
			
		||||
 | 
			
		||||
type ByteRangesReaderCloser struct {
 | 
			
		||||
	BaseReader
 | 
			
		||||
 | 
			
		||||
	rawReader io.ReadCloser
 | 
			
		||||
	boundary  string
 | 
			
		||||
 | 
			
		||||
	mReader *multipart.Reader
 | 
			
		||||
	part    *multipart.Part
 | 
			
		||||
 | 
			
		||||
	buf   *bytes.Buffer
 | 
			
		||||
	isEOF bool
 | 
			
		||||
 | 
			
		||||
	onPartReadHandler OnPartReadHandler
 | 
			
		||||
	rangeStart        int64
 | 
			
		||||
	rangeEnd          int64
 | 
			
		||||
	total             int64
 | 
			
		||||
 | 
			
		||||
	isStarted bool
 | 
			
		||||
	nl        string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewByteRangesReaderCloser(reader io.ReadCloser, boundary string) *ByteRangesReaderCloser {
 | 
			
		||||
	return &ByteRangesReaderCloser{
 | 
			
		||||
		rawReader: reader,
 | 
			
		||||
		mReader:   multipart.NewReader(reader, boundary),
 | 
			
		||||
		boundary:  boundary,
 | 
			
		||||
		buf:       &bytes.Buffer{},
 | 
			
		||||
		nl:        "\r\n",
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ByteRangesReaderCloser) Read(p []byte) (n int, err error) {
 | 
			
		||||
	n, err = this.read(p)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ByteRangesReaderCloser) Close() error {
 | 
			
		||||
	return this.rawReader.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ByteRangesReaderCloser) OnPartRead(handler OnPartReadHandler) {
 | 
			
		||||
	this.onPartReadHandler = handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ByteRangesReaderCloser) read(p []byte) (n int, err error) {
 | 
			
		||||
	// read from buffer
 | 
			
		||||
	n, err = this.buf.Read(p)
 | 
			
		||||
	if !this.isEOF {
 | 
			
		||||
		err = nil
 | 
			
		||||
	}
 | 
			
		||||
	if n > 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if this.isEOF {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if this.part == nil {
 | 
			
		||||
		part, partErr := this.mReader.NextPart()
 | 
			
		||||
		if partErr != nil {
 | 
			
		||||
			if partErr == io.EOF {
 | 
			
		||||
				this.buf.WriteString(this.nl + "--" + this.boundary + "--" + this.nl)
 | 
			
		||||
				this.isEOF = true
 | 
			
		||||
				n, _ = this.buf.Read(p)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return 0, partErr
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !this.isStarted {
 | 
			
		||||
			this.isStarted = true
 | 
			
		||||
			this.buf.WriteString("--" + this.boundary + this.nl)
 | 
			
		||||
		} else {
 | 
			
		||||
			this.buf.WriteString(this.nl + "--" + this.boundary + this.nl)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Headers
 | 
			
		||||
		var hasRange = false
 | 
			
		||||
		for k, v := range part.Header {
 | 
			
		||||
			for _, v1 := range v {
 | 
			
		||||
				this.buf.WriteString(k + ": " + v1 + this.nl)
 | 
			
		||||
 | 
			
		||||
				// parse range
 | 
			
		||||
				if k == "Content-Range" {
 | 
			
		||||
					var bytesPrefix = "bytes "
 | 
			
		||||
					if strings.HasPrefix(v1, bytesPrefix) {
 | 
			
		||||
						var r = v1[len(bytesPrefix):]
 | 
			
		||||
						var matches = contentRangeRegexp.FindStringSubmatch(r)
 | 
			
		||||
						if len(matches) > 2 {
 | 
			
		||||
							var start = types.Int64(matches[1])
 | 
			
		||||
							var end = types.Int64(matches[2])
 | 
			
		||||
							var total int64 = 0
 | 
			
		||||
							if matches[3] != "*" {
 | 
			
		||||
								total = types.Int64(matches[3])
 | 
			
		||||
							}
 | 
			
		||||
							if start <= end {
 | 
			
		||||
								hasRange = true
 | 
			
		||||
								this.rangeStart = start
 | 
			
		||||
								this.rangeEnd = end
 | 
			
		||||
								this.total = total
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !hasRange {
 | 
			
		||||
			this.rangeStart = -1
 | 
			
		||||
			this.rangeEnd = -1
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		this.buf.WriteString(this.nl)
 | 
			
		||||
		this.part = part
 | 
			
		||||
 | 
			
		||||
		n, _ = this.buf.Read(p)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	n, err = this.part.Read(p)
 | 
			
		||||
 | 
			
		||||
	if this.onPartReadHandler != nil && n > 0 && this.rangeStart >= 0 && this.rangeEnd >= 0 {
 | 
			
		||||
		this.onPartReadHandler(this.rangeStart, this.rangeEnd, this.total, p[:n], this.part.Header)
 | 
			
		||||
		this.rangeStart += int64(n)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err == io.EOF {
 | 
			
		||||
		this.part = nil
 | 
			
		||||
		err = nil
 | 
			
		||||
 | 
			
		||||
		// 如果没有读取到内容,则直接跳到下一个Part
 | 
			
		||||
		if n == 0 {
 | 
			
		||||
			return this.read(p)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										52
									
								
								internal/utils/readers/reader_closer_byte_ranges_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								internal/utils/readers/reader_closer_byte_ranges_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package readers_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/readers"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/textproto"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNewByteRangesReader(t *testing.T) {
 | 
			
		||||
	var boundary = "7143cd51d2ee12a1"
 | 
			
		||||
	var dashBoundary = "--" + boundary
 | 
			
		||||
	var b = bytes.NewReader([]byte(dashBoundary + "\r\nContent-Range: bytes 0-4/36\r\nContent-Type: text/plain\r\n\r\n01234\r\n" + dashBoundary + "\r\nContent-Range: bytes 5-9/36\r\nContent-Type: text/plain\r\n\r\n56789\r\n--" + boundary + "\r\nContent-Range: bytes 10-12/36\r\nContent-Type: text/plain\r\n\r\nabc\r\n" + dashBoundary + "--\r\n"))
 | 
			
		||||
 | 
			
		||||
	var reader = readers.NewByteRangesReaderCloser(ioutil.NopCloser(b), boundary)
 | 
			
		||||
	var p = make([]byte, 16)
 | 
			
		||||
	for {
 | 
			
		||||
		n, err := reader.Read(p)
 | 
			
		||||
		if n > 0 {
 | 
			
		||||
			fmt.Print(string(p[:n]))
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err != io.EOF {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestByteRangesReader_OnPartRead(t *testing.T) {
 | 
			
		||||
	var boundary = "7143cd51d2ee12a1"
 | 
			
		||||
	var dashBoundary = "--" + boundary
 | 
			
		||||
	var b = bytes.NewReader([]byte(dashBoundary + "\r\nContent-Range: bytes 0-4/36\r\nContent-Type: text/plain\r\n\r\n01234\r\n" + dashBoundary + "\r\nContent-Range: bytes 5-9/36\r\nContent-Type: text/plain\r\n\r\n56789\r\n--" + boundary + "\r\nContent-Range: bytes 10-12/36\r\nContent-Type: text/plain\r\n\r\nabc\r\n" + dashBoundary + "--\r\n"))
 | 
			
		||||
 | 
			
		||||
	var reader = readers.NewByteRangesReaderCloser(ioutil.NopCloser(b), boundary)
 | 
			
		||||
	reader.OnPartRead(func(start int64, end int64, total int64, data []byte, header textproto.MIMEHeader) {
 | 
			
		||||
		t.Log(start, "-", end, "/", total, string(data))
 | 
			
		||||
	})
 | 
			
		||||
	var p = make([]byte, 3)
 | 
			
		||||
	for {
 | 
			
		||||
		_, err := reader.Read(p)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										42
									
								
								internal/utils/readers/reader_closer_filter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								internal/utils/readers/reader_closer_filter.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package readers
 | 
			
		||||
 | 
			
		||||
import "io"
 | 
			
		||||
 | 
			
		||||
type FilterFunc = func(p []byte, err error) error
 | 
			
		||||
 | 
			
		||||
type FilterReaderCloser struct {
 | 
			
		||||
	rawReader io.Reader
 | 
			
		||||
	filters   []FilterFunc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFilterReaderCloser(rawReader io.Reader) *FilterReaderCloser {
 | 
			
		||||
	return &FilterReaderCloser{
 | 
			
		||||
		rawReader: rawReader,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FilterReaderCloser) Add(filter FilterFunc) {
 | 
			
		||||
	this.filters = append(this.filters, filter)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FilterReaderCloser) Read(p []byte) (n int, err error) {
 | 
			
		||||
	n, err = this.rawReader.Read(p)
 | 
			
		||||
	for _, filter := range this.filters {
 | 
			
		||||
		filterErr := filter(p[:n], err)
 | 
			
		||||
		if (err == nil || err != io.EOF) && filterErr != nil {
 | 
			
		||||
			err = filterErr
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *FilterReaderCloser) Close() error {
 | 
			
		||||
	closer, ok := this.rawReader.(io.Closer)
 | 
			
		||||
	if ok {
 | 
			
		||||
		return closer.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -10,7 +10,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNewFilterReader(t *testing.T) {
 | 
			
		||||
	var reader = readers.NewFilterReader(bytes.NewBufferString("0123456789"))
 | 
			
		||||
	var reader = readers.NewFilterReaderCloser(bytes.NewBufferString("0123456789"))
 | 
			
		||||
	reader.Add(func(p []byte, err error) error {
 | 
			
		||||
		t.Log("filter1:", string(p), err)
 | 
			
		||||
		return nil
 | 
			
		||||
@@ -2,7 +2,9 @@
 | 
			
		||||
 | 
			
		||||
package writers
 | 
			
		||||
 | 
			
		||||
import "io"
 | 
			
		||||
import (
 | 
			
		||||
	"io"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TeeWriterCloser struct {
 | 
			
		||||
	primaryW   io.WriteCloser
 | 
			
		||||
@@ -52,8 +52,9 @@ func TestMatchBytesCache_WithoutCache(t *testing.T) {
 | 
			
		||||
func BenchmarkMatchStringCache(b *testing.B) {
 | 
			
		||||
	runtime.GOMAXPROCS(1)
 | 
			
		||||
 | 
			
		||||
	data := strings.Repeat("HELLO", 512)
 | 
			
		||||
	regex := re.MustCompile(`(?iU)\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\b`)
 | 
			
		||||
	var data = strings.Repeat("HELLO", 512)
 | 
			
		||||
	var regex = re.MustCompile(`(?iU)\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\b`)
 | 
			
		||||
	//b.Log(regex.Keywords())
 | 
			
		||||
	_ = MatchStringCache(regex, data)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user