diff --git a/internal/db/models/sys_locker_dao.go b/internal/db/models/sys_locker_dao.go index 76448fa7..3b119276 100644 --- a/internal/db/models/sys_locker_dao.go +++ b/internal/db/models/sys_locker_dao.go @@ -1,12 +1,14 @@ package models import ( + "errors" "github.com/TeaOSLab/EdgeAPI/internal/zero" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" + "strings" "time" ) @@ -119,6 +121,11 @@ func (this *SysLockerDAO) Unlock(tx *dbs.Tx, key string) error { // Increase 增加版本号 func (this *SysLockerDAO) Increase(tx *dbs.Tx, key string, defaultValue int64) (int64, error) { + // validate key + if strings.Contains(key, "'") { + return 0, errors.New("invalid key '" + key + "'") + } + if tx == nil { var result int64 var err error @@ -137,7 +144,21 @@ func (this *SysLockerDAO) Increase(tx *dbs.Tx, key string, defaultValue int64) ( }) return result, err } - err := this.Query(tx). + + // combine statements to make increasing faster + colValue, err := tx.FindCol(0, "INSERT INTO `edgeSysLockers` (`key`, `version`) VALUES ('"+key+"', "+types.String(defaultValue)+") ON DUPLICATE KEY UPDATE `version`=`version`+1; SELECT `version` FROM edgeSysLockers WHERE `key`='"+key+"'") + if err != nil { + if CheckSQLErrCode(err, 1064 /** syntax error **/) { + // continue to use seperated query + err = nil + } else { + return 0, err + } + } else { + return types.Int64(colValue), nil + } + + err = this.Query(tx). Reuse(false). // no need to prepare statement in every transaction InsertOrUpdateQuickly(maps.Map{ "key": key, diff --git a/internal/db/models/sys_locker_dao_test.go b/internal/db/models/sys_locker_dao_test.go index 2df3c5d2..07120f97 100644 --- a/internal/db/models/sys_locker_dao_test.go +++ b/internal/db/models/sys_locker_dao_test.go @@ -12,14 +12,16 @@ import ( func TestSysLockerDAO_Lock(t *testing.T) { var tx *dbs.Tx - isOk, err := SharedSysLockerDAO.Lock(tx, "test", 600) + var dao = NewSysLockerDAO() + + isOk, err := dao.Lock(tx, "test", 600) if err != nil { t.Fatal(err) } t.Log(isOk) if isOk { - err = SharedSysLockerDAO.Unlock(tx, "test") + err = dao.Unlock(tx, "test") if err != nil { t.Fatal(err) } @@ -128,3 +130,19 @@ func TestSysLocker_Increase_Performance(t *testing.T) { t.Log("cost:", time.Since(before).Seconds()*1000, "ms") } + +func BenchmarkSysLockerDAO_Increase(b *testing.B) { + var dao = NewSysLockerDAO() + _, _ = dao.Increase(nil, "hello", 0) + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := dao.Increase(nil, "hello", 0) + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/internal/db/models/utils.go b/internal/db/models/utils.go index e31bc9f3..a33ac3b7 100644 --- a/internal/db/models/utils.go +++ b/internal/db/models/utils.go @@ -80,3 +80,12 @@ func CheckSQLDuplicateErr(err error) bool { } return CheckSQLErrCode(err, 1062) } + +// IsMySQLError Check error is MySQLError +func IsMySQLError(err error) bool { + if err == nil { + return false + } + _, ok := err.(*mysql.MySQLError) + return ok +} diff --git a/internal/setup/setup.go b/internal/setup/setup.go index 7da0a53e..a740d784 100644 --- a/internal/setup/setup.go +++ b/internal/setup/setup.go @@ -97,7 +97,13 @@ func (this *Setup) Run() error { } for _, db := range config.DBs { // 可以同时运行多条语句 - db.Dsn += "&multiStatements=true" + if !strings.Contains(db.Dsn, "multiStatements=") { + if strings.Contains(db.Dsn, "?") { + db.Dsn += "&multiStatements=true" + } else { + db.Dsn += "?multiStatements=true" + } + } } dbConfig, ok := config.DBs[Tea.Env] if !ok {