diff --git a/internal/caches/storage_file.go b/internal/caches/storage_file.go index cbd390f..c434c83 100644 --- a/internal/caches/storage_file.go +++ b/internal/caches/storage_file.go @@ -47,6 +47,9 @@ const ( HotItemSize = 1024 ) +var sharedWritingKeyMap = map[string]zero.Zero{} // key => bool +var sharedWritingKeyLocker = sync.Mutex{} + // FileStorage 文件缓存 // 文件结构: // [expires time] | [ status ] | [url length] | [header length] | [body length] | [url] [header data] [body data] @@ -56,10 +59,9 @@ type FileStorage struct { memoryStorage *MemoryStorage // 一级缓存 totalSize int64 - list ListInterface - writingKeyMap map[string]zero.Zero // key => bool - locker sync.RWMutex - purgeTicker *utils.Ticker + list ListInterface + locker sync.RWMutex + purgeTicker *utils.Ticker hotMap map[string]*HotItem // key => count hotMapLocker sync.Mutex @@ -69,10 +71,9 @@ type FileStorage struct { func NewFileStorage(policy *serverconfigs.HTTPCachePolicy) *FileStorage { return &FileStorage{ - policy: policy, - writingKeyMap: map[string]zero.Zero{}, - hotMap: map[string]*HotItem{}, - lastHotSize: -1, + policy: policy, + hotMap: map[string]*HotItem{}, + lastHotSize: -1, } } @@ -314,21 +315,20 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int) (Wr } // 是否正在写入 - var isWriting = false - this.locker.Lock() - _, ok := this.writingKeyMap[key] - this.locker.Unlock() + var isOk = false + sharedWritingKeyLocker.Lock() + _, ok := sharedWritingKeyMap[key] if ok { + sharedWritingKeyLocker.Unlock() return nil, ErrFileIsWriting } - this.locker.Lock() - this.writingKeyMap[key] = zero.New() - this.locker.Unlock() + sharedWritingKeyMap[key] = zero.New() + sharedWritingKeyLocker.Unlock() defer func() { - if !isWriting { - this.locker.Lock() - delete(this.writingKeyMap, key) - this.locker.Unlock() + if !isOk { + sharedWritingKeyLocker.Lock() + delete(sharedWritingKeyMap, key) + sharedWritingKeyLocker.Unlock() } }() @@ -358,21 +358,27 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int) (Wr } } + // 检查缓存是否已经生成 + var cachePath = dir + "/" + hash + ".cache" + stat, err := os.Stat(cachePath) + if err == nil && time.Now().Sub(stat.ModTime()) <= 1*time.Second { + // 防止并发连续写入 + return nil, ErrFileIsWriting + } + var tmpPath = cachePath + ".tmp" + // 先删除 err = this.list.Remove(hash) if err != nil { return nil, err } - path := dir + "/" + hash + ".cache.tmp" - writer, err := os.OpenFile(path, os.O_CREATE|os.O_SYNC|os.O_WRONLY, 0666) + writer, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_SYNC|os.O_WRONLY, 0666) if err != nil { return nil, err } - isWriting = true - isOk := false - removeOnFailure := true + var removeOnFailure = true defer func() { if err != nil { isOk = false @@ -382,7 +388,7 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int) (Wr if !isOk { _ = writer.Close() if removeOnFailure { - _ = os.Remove(path) + _ = os.Remove(tmpPath) } } }() @@ -453,11 +459,10 @@ func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int) (Wr } isOk = true - return NewFileWriter(writer, key, expiredAt, func() { - this.locker.Lock() - delete(this.writingKeyMap, key) - this.locker.Unlock() + sharedWritingKeyLocker.Lock() + delete(sharedWritingKeyMap, key) + sharedWritingKeyLocker.Unlock() }), nil } diff --git a/internal/caches/storage_file_test.go b/internal/caches/storage_file_test.go index 2418a46..15505cb 100644 --- a/internal/caches/storage_file_test.go +++ b/internal/caches/storage_file_test.go @@ -270,7 +270,7 @@ func TestFileStorage_Read(t *testing.T) { t.Fatal(err) } now := time.Now() - reader, err := storage.OpenReader("my-key") + reader, err := storage.OpenReader("my-key", false) if err != nil { t.Fatal(err) } @@ -306,7 +306,7 @@ func TestFileStorage_Read_HTTP_Response(t *testing.T) { t.Fatal(err) } now := time.Now() - reader, err := storage.OpenReader("my-http-response") + reader, err := storage.OpenReader("my-http-response", false) if err != nil { t.Fatal(err) } @@ -360,7 +360,7 @@ func TestFileStorage_Read_NotFound(t *testing.T) { } now := time.Now() buf := make([]byte, 6) - reader, err := storage.OpenReader("my-key-10000") + reader, err := storage.OpenReader("my-key-10000", false) if err != nil { if err == ErrNotFound { t.Log("cache not fund") @@ -506,7 +506,7 @@ func BenchmarkFileStorage_Read(b *testing.B) { b.Fatal(err) } for i := 0; i < b.N; i++ { - reader, err := storage.OpenReader("my-key") + reader, err := storage.OpenReader("my-key", false) if err != nil { b.Fatal(err) } diff --git a/internal/caches/storage_memory_test.go b/internal/caches/storage_memory_test.go index 11441cc..03185a2 100644 --- a/internal/caches/storage_memory_test.go +++ b/internal/caches/storage_memory_test.go @@ -25,7 +25,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) { t.Log(storage.valuesMap) { - reader, err := storage.OpenReader("abc") + reader, err := storage.OpenReader("abc", false) if err != nil { if err == ErrNotFound { t.Log("not found: abc") @@ -52,7 +52,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) { } { - _, err := storage.OpenReader("abc 2") + _, err := storage.OpenReader("abc 2", false) if err != nil { if err == ErrNotFound { t.Log("not found: abc2") @@ -68,7 +68,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) { } _, _ = writer.Write([]byte("Hello123")) { - reader, err := storage.OpenReader("abc") + reader, err := storage.OpenReader("abc", false) if err != nil { if err == ErrNotFound { t.Log("not found: abc") @@ -97,7 +97,7 @@ func TestMemoryStorage_OpenReaderLock(t *testing.T) { IsDone: true, }, } - _, _ = storage.OpenReader("test") + _, _ = storage.OpenReader("test", false) } func TestMemoryStorage_Delete(t *testing.T) { diff --git a/internal/caches/writer_file.go b/internal/caches/writer_file.go index 115ab35..06ff7ed 100644 --- a/internal/caches/writer_file.go +++ b/internal/caches/writer_file.go @@ -6,6 +6,7 @@ import ( "io" "os" "strings" + "sync" ) type FileWriter struct { @@ -15,6 +16,7 @@ type FileWriter struct { bodySize int64 expiredAt int64 endFunc func() + once sync.Once } func NewFileWriter(rawWriter *os.File, key string, expiredAt int64, endFunc func()) *FileWriter { @@ -82,18 +84,25 @@ func (this *FileWriter) WriteBodyLength(bodyLength int64) error { // Close 关闭 func (this *FileWriter) Close() error { - defer this.endFunc() + defer this.once.Do(func() { + this.endFunc() + }) + + path := this.rawWriter.Name() err := this.WriteHeaderLength(types.Int(this.headerSize)) if err != nil { + _ = this.rawWriter.Close() + _ = os.Remove(path) return err } err = this.WriteBodyLength(this.bodySize) if err != nil { + _ = this.rawWriter.Close() + _ = os.Remove(path) return err } - path := this.rawWriter.Name() err = this.rawWriter.Close() if err != nil { _ = os.Remove(path) @@ -109,7 +118,9 @@ func (this *FileWriter) Close() error { // Discard 丢弃 func (this *FileWriter) Discard() error { - defer this.endFunc() + defer this.once.Do(func() { + this.endFunc() + }) _ = this.rawWriter.Close()