diff --git a/internal/caches/list_file_db.go b/internal/caches/list_file_db.go index 14cb444..ba2e50d 100644 --- a/internal/caches/list_file_db.go +++ b/internal/caches/list_file_db.go @@ -110,7 +110,7 @@ func (this *FileListDB) Open(dbPath string) error { } } - this.writeBatch = dbs.NewBatch(writeDB.RawDB(), 4) + this.writeBatch = dbs.NewBatch(writeDB, 4) this.writeBatch.OnFail(func(err error) { remotelogs.Warn("LIST_FILE_DB", "run batch failed: "+err.Error()+" ("+filepath.Base(this.dbPath)+")") }) @@ -539,10 +539,6 @@ func (this *FileListDB) Close() error { _ = this.listOlderItemsStmt.Close() } - if this.writeBatch != nil { - this.writeBatch.Close() - } - var errStrings []string if this.readDB != nil { diff --git a/internal/utils/dbs/batch.go b/internal/utils/dbs/batch.go index 1a264a3..cdb57a0 100644 --- a/internal/utils/dbs/batch.go +++ b/internal/utils/dbs/batch.go @@ -14,26 +14,28 @@ type batchItem struct { } type Batch struct { - db *sql.DB + db *DB n int enableStat bool onFail func(err error) - queue chan *batchItem - close chan bool + queue chan *batchItem + closeEvent chan bool isClosed bool } -func NewBatch(db *sql.DB, n int) *Batch { - return &Batch{ - db: db, - n: n, - queue: make(chan *batchItem), - close: make(chan bool, 1), +func NewBatch(db *DB, n int) *Batch { + var batch = &Batch{ + db: db, + n: n, + queue: make(chan *batchItem), + closeEvent: make(chan bool, 1), } + db.batches = append(db.batches, batch) + return batch } func (this *Batch) EnableStat(b bool) { @@ -68,7 +70,7 @@ For: // closed if this.isClosed { if lastTx != nil { - _ = lastTx.Commit() + _ = this.commitTx(lastTx) lastTx = nil } @@ -86,6 +88,9 @@ For: err := this.execItem(lastTx, item) if err != nil { + if IsClosedErr(err) { + return + } this.processErr(item.query, err) } @@ -93,9 +98,12 @@ For: if count == n { count = 0 - err = lastTx.Commit() + err = this.commitTx(lastTx) lastTx = nil if err != nil { + if IsClosedErr(err) { + return + } this.processErr("commit", err) } } @@ -104,16 +112,19 @@ For: continue For } count = 0 - err := lastTx.Commit() + err := this.commitTx(lastTx) lastTx = nil if err != nil { + if IsClosedErr(err) { + return + } this.processErr("commit", err) } - case <-this.close: + case <-this.closeEvent: // closed if lastTx != nil { - _ = lastTx.Commit() + _ = this.commitTx(lastTx) lastTx = nil } @@ -122,17 +133,21 @@ For: } } -func (this *Batch) Close() { +func (this *Batch) close() { this.isClosed = true select { - case this.close <- true: + case this.closeEvent <- true: default: } } func (this *Batch) beginTx() *sql.Tx { + if !this.db.BeginUpdating() { + return nil + } + tx, err := this.db.Begin() if err != nil { this.processErr("begin transaction", err) @@ -141,9 +156,18 @@ func (this *Batch) beginTx() *sql.Tx { return tx } +func (this *Batch) commitTx(tx *sql.Tx) error { + // always commit without checking database closing status + this.db.EndUpdating() + return tx.Commit() +} + func (this *Batch) execItem(tx *sql.Tx, item *batchItem) error { - if this.isClosed { - return nil + // check database status + if this.db.BeginUpdating() { + defer this.db.EndUpdating() + } else { + return errDBIsClosed } if this.enableStat { diff --git a/internal/utils/dbs/db.go b/internal/utils/dbs/db.go index 365d847..f137a6a 100644 --- a/internal/utils/dbs/db.go +++ b/internal/utils/dbs/db.go @@ -5,19 +5,31 @@ package dbs import ( "context" "database/sql" + "errors" "fmt" "github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/utils/fileutils" _ "github.com/mattn/go-sqlite3" "strings" + "sync" + "time" ) +var errDBIsClosed = errors.New("the database is closed") + type DB struct { locker *fileutils.Locker rawDB *sql.DB + statusLocker sync.Mutex + countUpdating int32 + + isClosing bool + enableStat bool + + batches []*Batch } func OpenWriter(dsn string) (*DB, error) { @@ -63,10 +75,10 @@ func NewDB(rawDB *sql.DB) *DB { } events.OnKey(events.EventQuit, fmt.Sprintf("db_%p", db), func() { - _ = rawDB.Close() + _ = db.Close() }) events.OnKey(events.EventTerminated, fmt.Sprintf("db_%p", db), func() { - _ = rawDB.Close() + _ = db.Close() }) return db @@ -81,6 +93,13 @@ func (this *DB) EnableStat(b bool) { } func (this *DB) Begin() (*sql.Tx, error) { + // check database status + if this.BeginUpdating() { + defer this.EndUpdating() + } else { + return nil, errDBIsClosed + } + return this.rawDB.Begin() } @@ -90,7 +109,7 @@ func (this *DB) Prepare(query string) (*Stmt, error) { return nil, err } - var s = NewStmt(stmt, query) + var s = NewStmt(this, stmt, query) if this.enableStat { s.EnableStat() } @@ -98,13 +117,28 @@ func (this *DB) Prepare(query string) (*Stmt, error) { } func (this *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + // check database status + if this.BeginUpdating() { + defer this.EndUpdating() + } else { + return nil, errDBIsClosed + } + if this.enableStat { defer SharedQueryStatManager.AddQuery(query).End() } + return this.rawDB.ExecContext(ctx, query, args...) } func (this *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + // check database status + if this.BeginUpdating() { + defer this.EndUpdating() + } else { + return nil, errDBIsClosed + } + if this.enableStat { defer SharedQueryStatManager.AddQuery(query).End() } @@ -125,7 +159,32 @@ func (this *DB) QueryRow(query string, args ...interface{}) *sql.Row { return this.rawDB.QueryRow(query, args...) } +// Close the database func (this *DB) Close() error { + // check database status + this.statusLocker.Lock() + if this.isClosing { + this.statusLocker.Unlock() + return nil + } + this.isClosing = true + this.statusLocker.Unlock() + + // waiting for updating operations to finish + for { + this.statusLocker.Lock() + var countUpdating = this.countUpdating + this.statusLocker.Unlock() + if countUpdating <= 0 { + break + } + time.Sleep(1 * time.Millisecond) + } + + for _, batch := range this.batches { + batch.close() + } + events.Remove(fmt.Sprintf("db_%p", this)) defer func() { @@ -137,6 +196,24 @@ func (this *DB) Close() error { return this.rawDB.Close() } +func (this *DB) BeginUpdating() bool { + this.statusLocker.Lock() + defer this.statusLocker.Unlock() + + if this.isClosing { + return false + } + + this.countUpdating++ + return true +} + +func (this *DB) EndUpdating() { + this.statusLocker.Lock() + this.countUpdating-- + this.statusLocker.Unlock() +} + func (this *DB) RawDB() *sql.DB { return this.rawDB } diff --git a/internal/utils/dbs/stmt.go b/internal/utils/dbs/stmt.go index a5df4c2..2c1be7c 100644 --- a/internal/utils/dbs/stmt.go +++ b/internal/utils/dbs/stmt.go @@ -8,14 +8,16 @@ import ( ) type Stmt struct { + db *DB rawStmt *sql.Stmt query string enableStat bool } -func NewStmt(rawStmt *sql.Stmt, query string) *Stmt { +func NewStmt(db *DB, rawStmt *sql.Stmt, query string) *Stmt { return &Stmt{ + db: db, rawStmt: rawStmt, query: query, } @@ -26,6 +28,13 @@ func (this *Stmt) EnableStat() { } func (this *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + // check database status + if this.db.BeginUpdating() { + defer this.db.EndUpdating() + } else { + return nil, errDBIsClosed + } + if this.enableStat { defer SharedQueryStatManager.AddQuery(this.query).End() } @@ -33,6 +42,13 @@ func (this *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Res } func (this *Stmt) Exec(args ...interface{}) (sql.Result, error) { + // check database status + if this.db.BeginUpdating() { + defer this.db.EndUpdating() + } else { + return nil, errDBIsClosed + } + if this.enableStat { defer SharedQueryStatManager.AddQuery(this.query).End() } diff --git a/internal/utils/dbs/utils.go b/internal/utils/dbs/utils.go new file mode 100644 index 0000000..aefdd82 --- /dev/null +++ b/internal/utils/dbs/utils.go @@ -0,0 +1,7 @@ +// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package dbs + +func IsClosedErr(err error) bool { + return err == errDBIsClosed +}