feat: 实现数据库备份和恢复并发调度 (#84)

This commit is contained in:
kanzihuang
2024-01-11 11:35:51 +08:00
committed by GitHub
parent 3857d674ba
commit bbec3eca0d
40 changed files with 1373 additions and 843 deletions

View File

@@ -93,3 +93,5 @@ require (
modernc.org/sqlite v1.23.1 // indirect modernc.org/sqlite v1.23.1 // indirect
vitess.io/vitess v0.17.3 // indirect vitess.io/vitess v0.17.3 // indirect
) )
require github.com/emirpasic/gods v1.18.1

View File

@@ -28,7 +28,7 @@ func (d *DbBackup) GetPageList(rc *req.Ctx) {
db, err := d.DbApp.GetById(new(entity.Db), dbId, "db_instance_id", "database") db, err := d.DbApp.GetById(new(entity.Db), dbId, "db_instance_id", "database")
biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v") biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v")
queryCond, page := ginx.BindQueryAndPage[*entity.DbBackupQuery](rc.GinCtx, new(entity.DbBackupQuery)) queryCond, page := ginx.BindQueryAndPage[*entity.DbJobQuery](rc.GinCtx, new(entity.DbJobQuery))
queryCond.DbInstanceId = db.InstanceId queryCond.DbInstanceId = db.InstanceId
queryCond.InDbNames = strings.Fields(db.Database) queryCond.InDbNames = strings.Fields(db.Database)
res, err := d.DbBackupApp.GetPageList(queryCond, page, new([]vo.DbBackup)) res, err := d.DbBackupApp.GetPageList(queryCond, page, new([]vo.DbBackup))
@@ -51,32 +51,30 @@ func (d *DbBackup) Create(rc *req.Ctx) {
db, err := d.DbApp.GetById(new(entity.Db), dbId, "instanceId") db, err := d.DbApp.GetById(new(entity.Db), dbId, "instanceId")
biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v") biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v")
tasks := make([]*entity.DbBackup, 0, len(dbNames)) jobs := make([]*entity.DbBackup, 0, len(dbNames))
for _, dbName := range dbNames { for _, dbName := range dbNames {
task := &entity.DbBackup{ job := &entity.DbBackup{
DbTaskBase: entity.NewDbBTaskBase(true, backupForm.Repeated, backupForm.StartTime, backupForm.Interval), DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, dbName, entity.DbJobTypeBackup, true, backupForm.Repeated, backupForm.StartTime, backupForm.Interval),
DbName: dbName,
Name: backupForm.Name, Name: backupForm.Name,
DbInstanceId: db.InstanceId,
} }
tasks = append(tasks, task) jobs = append(jobs, job)
} }
biz.ErrIsNilAppendErr(d.DbBackupApp.Create(rc.MetaCtx, tasks...), "添加数据库备份任务失败: %v") biz.ErrIsNilAppendErr(d.DbBackupApp.Create(rc.MetaCtx, jobs), "添加数据库备份任务失败: %v")
} }
// Save 保存数据库备份任务 // Update 保存数据库备份任务
// @router /api/dbs/:dbId/backups/:backupId [PUT] // @router /api/dbs/:dbId/backups/:backupId [PUT]
func (d *DbBackup) Save(rc *req.Ctx) { func (d *DbBackup) Update(rc *req.Ctx) {
backupForm := &form.DbBackupForm{} backupForm := &form.DbBackupForm{}
ginx.BindJsonAndValid(rc.GinCtx, backupForm) ginx.BindJsonAndValid(rc.GinCtx, backupForm)
rc.ReqParam = backupForm rc.ReqParam = backupForm
task := &entity.DbBackup{} job := entity.NewDbJob(entity.DbJobTypeBackup).(*entity.DbBackup)
task.Id = backupForm.Id job.Id = backupForm.Id
task.Name = backupForm.Name job.Name = backupForm.Name
task.StartTime = backupForm.StartTime job.StartTime = backupForm.StartTime
task.Interval = backupForm.Interval job.Interval = backupForm.Interval
biz.ErrIsNilAppendErr(d.DbBackupApp.Save(rc.MetaCtx, task), "保存数据库备份任务失败: %v") biz.ErrIsNilAppendErr(d.DbBackupApp.Update(rc.MetaCtx, job), "保存数据库备份任务失败: %v")
} }
func (d *DbBackup) walk(rc *req.Ctx, fn func(ctx context.Context, backupId uint64) error) error { func (d *DbBackup) walk(rc *req.Ctx, fn func(ctx context.Context, backupId uint64) error) error {
@@ -89,8 +87,8 @@ func (d *DbBackup) walk(rc *req.Ctx, fn func(ctx context.Context, backupId uint6
if err != nil { if err != nil {
return err return err
} }
taskId := uint64(value) backupId := uint64(value)
err = fn(rc.MetaCtx, taskId) err = fn(rc.MetaCtx, backupId)
if err != nil { if err != nil {
return err return err
} }
@@ -99,28 +97,28 @@ func (d *DbBackup) walk(rc *req.Ctx, fn func(ctx context.Context, backupId uint6
} }
// Delete 删除数据库备份任务 // Delete 删除数据库备份任务
// @router /api/dbs/:dbId/backups/:taskId [DELETE] // @router /api/dbs/:dbId/backups/:backupId [DELETE]
func (d *DbBackup) Delete(rc *req.Ctx) { func (d *DbBackup) Delete(rc *req.Ctx) {
err := d.walk(rc, d.DbBackupApp.Delete) err := d.walk(rc, d.DbBackupApp.Delete)
biz.ErrIsNilAppendErr(err, "删除数据库备份任务失败: %v") biz.ErrIsNilAppendErr(err, "删除数据库备份任务失败: %v")
} }
// Enable 启用数据库备份任务 // Enable 启用数据库备份任务
// @router /api/dbs/:dbId/backups/:taskId/enable [PUT] // @router /api/dbs/:dbId/backups/:backupId/enable [PUT]
func (d *DbBackup) Enable(rc *req.Ctx) { func (d *DbBackup) Enable(rc *req.Ctx) {
err := d.walk(rc, d.DbBackupApp.Enable) err := d.walk(rc, d.DbBackupApp.Enable)
biz.ErrIsNilAppendErr(err, "启用数据库备份任务失败: %v") biz.ErrIsNilAppendErr(err, "启用数据库备份任务失败: %v")
} }
// Disable 禁用数据库备份任务 // Disable 禁用数据库备份任务
// @router /api/dbs/:dbId/backups/:taskId/disable [PUT] // @router /api/dbs/:dbId/backups/:backupId/disable [PUT]
func (d *DbBackup) Disable(rc *req.Ctx) { func (d *DbBackup) Disable(rc *req.Ctx) {
err := d.walk(rc, d.DbBackupApp.Disable) err := d.walk(rc, d.DbBackupApp.Disable)
biz.ErrIsNilAppendErr(err, "禁用数据库备份任务失败: %v") biz.ErrIsNilAppendErr(err, "禁用数据库备份任务失败: %v")
} }
// Start 禁用数据库备份任务 // Start 禁用数据库备份任务
// @router /api/dbs/:dbId/backups/:taskId/start [PUT] // @router /api/dbs/:dbId/backups/:backupId/start [PUT]
func (d *DbBackup) Start(rc *req.Ctx) { func (d *DbBackup) Start(rc *req.Ctx) {
err := d.walk(rc, d.DbBackupApp.Start) err := d.walk(rc, d.DbBackupApp.Start)
biz.ErrIsNilAppendErr(err, "运行数据库备份任务失败: %v") biz.ErrIsNilAppendErr(err, "运行数据库备份任务失败: %v")
@@ -138,7 +136,7 @@ func (d *DbBackup) GetDbNamesWithoutBackup(rc *req.Ctx) {
rc.ResData = dbNamesWithoutBackup rc.ResData = dbNamesWithoutBackup
} }
// GetPageList 获取数据库备份历史 // GetHistoryPageList 获取数据库备份历史
// @router /api/dbs/:dbId/backups/:backupId/histories [GET] // @router /api/dbs/:dbId/backups/:backupId/histories [GET]
func (d *DbBackup) GetHistoryPageList(rc *req.Ctx) { func (d *DbBackup) GetHistoryPageList(rc *req.Ctx) {
dbId := uint64(ginx.PathParamInt(rc.GinCtx, "dbId")) dbId := uint64(ginx.PathParamInt(rc.GinCtx, "dbId"))

View File

@@ -27,7 +27,7 @@ func (d *DbRestore) GetPageList(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v") biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v")
var restores []vo.DbRestore var restores []vo.DbRestore
queryCond, page := ginx.BindQueryAndPage[*entity.DbRestoreQuery](rc.GinCtx, new(entity.DbRestoreQuery)) queryCond, page := ginx.BindQueryAndPage[*entity.DbJobQuery](rc.GinCtx, new(entity.DbJobQuery))
queryCond.DbInstanceId = db.InstanceId queryCond.DbInstanceId = db.InstanceId
queryCond.InDbNames = strings.Fields(db.Database) queryCond.InDbNames = strings.Fields(db.Database)
res, err := d.DbRestoreApp.GetPageList(queryCond, page, &restores) res, err := d.DbRestoreApp.GetPageList(queryCond, page, &restores)
@@ -47,33 +47,31 @@ func (d *DbRestore) Create(rc *req.Ctx) {
db, err := d.DbApp.GetById(new(entity.Db), dbId, "instanceId") db, err := d.DbApp.GetById(new(entity.Db), dbId, "instanceId")
biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v") biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v")
task := &entity.DbRestore{ job := &entity.DbRestore{
DbTaskBase: entity.NewDbBTaskBase(true, restoreForm.Repeated, restoreForm.StartTime, restoreForm.Interval), DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, restoreForm.DbName, entity.DbJobTypeRestore, true, restoreForm.Repeated, restoreForm.StartTime, restoreForm.Interval),
DbName: restoreForm.DbName,
DbInstanceId: db.InstanceId,
PointInTime: restoreForm.PointInTime, PointInTime: restoreForm.PointInTime,
DbBackupId: restoreForm.DbBackupId, DbBackupId: restoreForm.DbBackupId,
DbBackupHistoryId: restoreForm.DbBackupHistoryId, DbBackupHistoryId: restoreForm.DbBackupHistoryId,
DbBackupHistoryName: restoreForm.DbBackupHistoryName, DbBackupHistoryName: restoreForm.DbBackupHistoryName,
} }
biz.ErrIsNilAppendErr(d.DbRestoreApp.Create(rc.MetaCtx, task), "添加数据库恢复任务失败: %v") biz.ErrIsNilAppendErr(d.DbRestoreApp.Create(rc.MetaCtx, job), "添加数据库恢复任务失败: %v")
} }
// Save 保存数据库恢复任务 // Update 保存数据库恢复任务
// @router /api/dbs/:dbId/restores/:restoreId [PUT] // @router /api/dbs/:dbId/restores/:restoreId [PUT]
func (d *DbRestore) Save(rc *req.Ctx) { func (d *DbRestore) Update(rc *req.Ctx) {
restoreForm := &form.DbRestoreForm{} restoreForm := &form.DbRestoreForm{}
ginx.BindJsonAndValid(rc.GinCtx, restoreForm) ginx.BindJsonAndValid(rc.GinCtx, restoreForm)
rc.ReqParam = restoreForm rc.ReqParam = restoreForm
task := &entity.DbRestore{} job := &entity.DbRestore{}
task.Id = restoreForm.Id job.Id = restoreForm.Id
task.StartTime = restoreForm.StartTime job.StartTime = restoreForm.StartTime
task.Interval = restoreForm.Interval job.Interval = restoreForm.Interval
biz.ErrIsNilAppendErr(d.DbRestoreApp.Save(rc.MetaCtx, task), "保存数据库恢复任务失败: %v") biz.ErrIsNilAppendErr(d.DbRestoreApp.Update(rc.MetaCtx, job), "保存数据库恢复任务失败: %v")
} }
func (d *DbRestore) walk(rc *req.Ctx, fn func(ctx context.Context, taskId uint64) error) error { func (d *DbRestore) walk(rc *req.Ctx, fn func(ctx context.Context, restoreId uint64) error) error {
idsStr := ginx.PathParam(rc.GinCtx, "restoreId") idsStr := ginx.PathParam(rc.GinCtx, "restoreId")
biz.NotEmpty(idsStr, "restoreId 为空") biz.NotEmpty(idsStr, "restoreId 为空")
rc.ReqParam = idsStr rc.ReqParam = idsStr
@@ -83,8 +81,8 @@ func (d *DbRestore) walk(rc *req.Ctx, fn func(ctx context.Context, taskId uint64
if err != nil { if err != nil {
return err return err
} }
taskId := uint64(value) restoreId := uint64(value)
err = fn(rc.MetaCtx, taskId) err = fn(rc.MetaCtx, restoreId)
if err != nil { if err != nil {
return err return err
} }
@@ -92,19 +90,22 @@ func (d *DbRestore) walk(rc *req.Ctx, fn func(ctx context.Context, taskId uint64
return nil return nil
} }
// @router /api/dbs/:dbId/restores/:taskId [DELETE] // Delete 删除数据库恢复任务
// @router /api/dbs/:dbId/restores/:restoreId [DELETE]
func (d *DbRestore) Delete(rc *req.Ctx) { func (d *DbRestore) Delete(rc *req.Ctx) {
err := d.walk(rc, d.DbRestoreApp.Delete) err := d.walk(rc, d.DbRestoreApp.Delete)
biz.ErrIsNilAppendErr(err, "删除数据库恢复任务失败: %v") biz.ErrIsNilAppendErr(err, "删除数据库恢复任务失败: %v")
} }
// @router /api/dbs/:dbId/restores/:taskId/enable [PUT] // Enable 启用数据库恢复任务
// @router /api/dbs/:dbId/restores/:restoreId/enable [PUT]
func (d *DbRestore) Enable(rc *req.Ctx) { func (d *DbRestore) Enable(rc *req.Ctx) {
err := d.walk(rc, d.DbRestoreApp.Enable) err := d.walk(rc, d.DbRestoreApp.Enable)
biz.ErrIsNilAppendErr(err, "启用数据库恢复任务失败: %v") biz.ErrIsNilAppendErr(err, "启用数据库恢复任务失败: %v")
} }
// @router /api/dbs/:dbId/restores/:taskId/disable [PUT] // Disable 禁用数据库恢复任务
// @router /api/dbs/:dbId/restores/:restoreId/disable [PUT]
func (d *DbRestore) Disable(rc *req.Ctx) { func (d *DbRestore) Disable(rc *req.Ctx) {
err := d.walk(rc, d.DbRestoreApp.Disable) err := d.walk(rc, d.DbRestoreApp.Disable)
biz.ErrIsNilAppendErr(err, "禁用数据库恢复任务失败: %v") biz.ErrIsNilAppendErr(err, "禁用数据库恢复任务失败: %v")

View File

@@ -19,11 +19,13 @@ var (
dataSyncApp DataSyncTask dataSyncApp DataSyncTask
) )
var repositories *repository.Repositories //var repositories *repository.Repositories
//var scheduler *dbScheduler[*entity.DbBackup]
//var scheduler1 *dbScheduler[*entity.DbRestore]
func Init() { func Init() {
sync.OnceFunc(func() { sync.OnceFunc(func() {
repositories = &repository.Repositories{ repositories := &repository.Repositories{
Instance: persistence.GetInstanceRepo(), Instance: persistence.GetInstanceRepo(),
Backup: persistence.NewDbBackupRepo(), Backup: persistence.NewDbBackupRepo(),
BackupHistory: persistence.NewDbBackupHistoryRepo(), BackupHistory: persistence.NewDbBackupHistoryRepo(),
@@ -40,15 +42,18 @@ func Init() {
dbSqlApp = newDbSqlApp(persistence.GetDbSqlRepo()) dbSqlApp = newDbSqlApp(persistence.GetDbSqlRepo())
dataSyncApp = newDataSyncApp(persistence.GetDataSyncTaskRepo(), persistence.GetDataSyncLogRepo()) dataSyncApp = newDataSyncApp(persistence.GetDataSyncTaskRepo(), persistence.GetDataSyncLogRepo())
dbBackupApp, err = newDbBackupApp(repositories, dbApp) scheduler, err := newDbScheduler(repositories)
if err != nil {
panic(fmt.Sprintf("初始化 dbScheduler 失败: %v", err))
}
dbBackupApp, err = newDbBackupApp(repositories, dbApp, scheduler)
if err != nil { if err != nil {
panic(fmt.Sprintf("初始化 dbBackupApp 失败: %v", err)) panic(fmt.Sprintf("初始化 dbBackupApp 失败: %v", err))
} }
dbRestoreApp, err = newDbRestoreApp(repositories, dbApp) dbRestoreApp, err = newDbRestoreApp(repositories, dbApp, scheduler)
if err != nil { if err != nil {
panic(fmt.Sprintf("初始化 dbRestoreApp 失败: %v", err)) panic(fmt.Sprintf("初始化 dbRestoreApp 失败: %v", err))
} }
dbBinlogApp, err = newDbBinlogApp(repositories, dbApp) dbBinlogApp, err = newDbBinlogApp(repositories, dbApp)
if err != nil { if err != nil {
panic(fmt.Sprintf("初始化 dbBinlogApp 失败: %v", err)) panic(fmt.Sprintf("初始化 dbBinlogApp 失败: %v", err))

View File

@@ -3,29 +3,27 @@ package application
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"fmt" "github.com/google/uuid"
"mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository" "mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/model" "mayfly-go/pkg/model"
"time"
"github.com/google/uuid"
) )
func newDbBackupApp(repositories *repository.Repositories, dbApp Db) (*DbBackupApp, error) { func newDbBackupApp(repositories *repository.Repositories, dbApp Db, scheduler *dbScheduler) (*DbBackupApp, error) {
var jobs []*entity.DbBackup
if err := repositories.Backup.ListToDo(&jobs); err != nil {
return nil, err
}
if err := scheduler.AddJob(context.Background(), false, entity.DbJobTypeBackup, jobs); err != nil {
return nil, err
}
app := &DbBackupApp{ app := &DbBackupApp{
backupRepo: repositories.Backup, backupRepo: repositories.Backup,
instanceRepo: repositories.Instance, instanceRepo: repositories.Instance,
backupHistoryRepo: repositories.BackupHistory, backupHistoryRepo: repositories.BackupHistory,
dbApp: dbApp, dbApp: dbApp,
scheduler: scheduler,
} }
scheduler, err := newDbScheduler[*entity.DbBackup](
repositories.Backup,
withRunBackupTask(app))
if err != nil {
return nil, err
}
app.scheduler = scheduler
return app, nil return app, nil
} }
@@ -34,41 +32,41 @@ type DbBackupApp struct {
instanceRepo repository.Instance instanceRepo repository.Instance
backupHistoryRepo repository.DbBackupHistory backupHistoryRepo repository.DbBackupHistory
dbApp Db dbApp Db
scheduler *dbScheduler[*entity.DbBackup] scheduler *dbScheduler
} }
func (app *DbBackupApp) Close() { func (app *DbBackupApp) Close() {
app.scheduler.Close() app.scheduler.Close()
} }
func (app *DbBackupApp) Create(ctx context.Context, tasks ...*entity.DbBackup) error { func (app *DbBackupApp) Create(ctx context.Context, jobs []*entity.DbBackup) error {
return app.scheduler.AddTask(ctx, tasks...) return app.scheduler.AddJob(ctx, true /* 保存到数据库 */, entity.DbJobTypeBackup, jobs)
} }
func (app *DbBackupApp) Save(ctx context.Context, task *entity.DbBackup) error { func (app *DbBackupApp) Update(ctx context.Context, job *entity.DbBackup) error {
return app.scheduler.UpdateTask(ctx, task) return app.scheduler.UpdateJob(ctx, job)
} }
func (app *DbBackupApp) Delete(ctx context.Context, taskId uint64) error { func (app *DbBackupApp) Delete(ctx context.Context, jobId uint64) error {
// todo: 删除数据库备份历史文件 // todo: 删除数据库备份历史文件
return app.scheduler.DeleteTask(ctx, taskId) return app.scheduler.RemoveJob(ctx, entity.DbJobTypeBackup, jobId)
} }
func (app *DbBackupApp) Enable(ctx context.Context, taskId uint64) error { func (app *DbBackupApp) Enable(ctx context.Context, jobId uint64) error {
return app.scheduler.EnableTask(ctx, taskId) return app.scheduler.EnableJob(ctx, entity.DbJobTypeBackup, jobId)
} }
func (app *DbBackupApp) Disable(ctx context.Context, taskId uint64) error { func (app *DbBackupApp) Disable(ctx context.Context, jobId uint64) error {
return app.scheduler.DisableTask(ctx, taskId) return app.scheduler.DisableJob(ctx, entity.DbJobTypeBackup, jobId)
} }
func (app *DbBackupApp) Start(ctx context.Context, taskId uint64) error { func (app *DbBackupApp) Start(ctx context.Context, jobId uint64) error {
return app.scheduler.StartTask(ctx, taskId) return app.scheduler.StartJobNow(ctx, entity.DbJobTypeBackup, jobId)
} }
// GetPageList 分页获取数据库备份任务 // GetPageList 分页获取数据库备份任务
func (app *DbBackupApp) GetPageList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { func (app *DbBackupApp) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.backupRepo.GetDbBackupList(condition, pageParam, toEntity, orderBy...) return app.backupRepo.GetPageList(condition, pageParam, toEntity, orderBy...)
} }
// GetDbNamesWithoutBackup 获取未配置定时备份的数据库名称 // GetDbNamesWithoutBackup 获取未配置定时备份的数据库名称
@@ -76,54 +74,11 @@ func (app *DbBackupApp) GetDbNamesWithoutBackup(instanceId uint64, dbNames []str
return app.backupRepo.GetDbNamesWithoutBackup(instanceId, dbNames) return app.backupRepo.GetDbNamesWithoutBackup(instanceId, dbNames)
} }
// GetPageList 分页获取数据库备份历史 // GetHistoryPageList 分页获取数据库备份历史
func (app *DbBackupApp) GetHistoryPageList(condition *entity.DbBackupHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { func (app *DbBackupApp) GetHistoryPageList(condition *entity.DbBackupHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.backupHistoryRepo.GetHistories(condition, pageParam, toEntity, orderBy...) return app.backupHistoryRepo.GetHistories(condition, pageParam, toEntity, orderBy...)
} }
func withRunBackupTask(app *DbBackupApp) dbSchedulerOption[*entity.DbBackup] {
return func(scheduler *dbScheduler[*entity.DbBackup]) {
scheduler.RunTask = app.runTask
}
}
func (app *DbBackupApp) runTask(ctx context.Context, task *entity.DbBackup) error {
id, err := NewIncUUID()
if err != nil {
return err
}
history := &entity.DbBackupHistory{
Uuid: id.String(),
DbBackupId: task.Id,
DbInstanceId: task.DbInstanceId,
DbName: task.DbName,
}
conn, err := app.dbApp.GetDbConnByInstanceId(task.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
binlogInfo, err := dbProgram.Backup(ctx, history)
if err != nil {
return err
}
now := time.Now()
name := task.Name
if len(name) == 0 {
name = task.DbName
}
history.Name = fmt.Sprintf("%s[%s]", name, now.Format(time.DateTime))
history.CreateTime = now
history.BinlogFileName = binlogInfo.FileName
history.BinlogSequence = binlogInfo.Sequence
history.BinlogPosition = binlogInfo.Position
if err := app.backupHistoryRepo.Insert(ctx, history); err != nil {
return err
}
return nil
}
func NewIncUUID() (uuid.UUID, error) { func NewIncUUID() (uuid.UUID, error) {
var uid uuid.UUID var uid uuid.UUID
now, seq, err := uuid.GetTime() now, seq, err := uuid.GetTime()
@@ -144,19 +99,3 @@ func NewIncUUID() (uuid.UUID, error) {
return uid, nil return uid, nil
} }
// func newDbBackupHistoryApp(repositories *repository.Repositories) (*DbBackupHistoryApp, error) {
// app := &DbBackupHistoryApp{
// repo: repositories.BackupHistory,
// }
// return app, nil
// }
// type DbBackupHistoryApp struct {
// repo repository.DbBackupHistory
// }
// // GetPageList 分页获取数据库备份历史
// func (app *DbBackupHistoryApp) GetPageList(condition *entity.DbBackupHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
// return app.repo.GetHistories(condition, pageParam, toEntity, orderBy...)
// }

View File

@@ -28,12 +28,12 @@ type DbBinlogApp struct {
} }
var ( var (
binlogResult = map[entity.TaskStatus]string{ binlogResult = map[entity.DbJobStatus]string{
entity.TaskDelay: "等待备份BINLOG", entity.DbJobDelay: "等待备份BINLOG",
entity.TaskReady: "准备备份BINLOG", entity.DbJobReady: "准备备份BINLOG",
entity.TaskReserved: "BINLOG备份中", entity.DbJobRunning: "BINLOG备份中",
entity.TaskSuccess: "BINLOG备份成功", entity.DbJobSuccess: "BINLOG备份成功",
entity.TaskFailed: "BINLOG备份失败", entity.DbJobFailed: "BINLOG备份失败",
} }
) )
@@ -53,8 +53,8 @@ func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogA
return svc, nil return svc, nil
} }
func (app *DbBinlogApp) runTask(ctx context.Context, backup *entity.DbBackup) error { func (app *DbBinlogApp) fetchBinlog(ctx context.Context, backup *entity.DbBackup) error {
if err := app.AddTaskIfNotExists(ctx, entity.NewDbBinlog(backup.DbInstanceId)); err != nil { if err := app.AddJobIfNotExists(ctx, entity.NewDbBinlog(backup.DbInstanceId)); err != nil {
return err return err
} }
latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1) latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
@@ -80,13 +80,13 @@ func (app *DbBinlogApp) runTask(ctx context.Context, backup *entity.DbBackup) er
if err == nil { if err == nil {
err = app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, backup.DbInstanceId, binlogFiles) err = app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, backup.DbInstanceId, binlogFiles)
} }
taskStatus := entity.TaskSuccess jobStatus := entity.DbJobSuccess
if err != nil { if err != nil {
taskStatus = entity.TaskFailed jobStatus = entity.DbJobFailed
} }
task := &entity.DbBinlog{} job := &entity.DbBinlog{}
task.Id = backup.DbInstanceId job.Id = backup.DbInstanceId
return app.updateCurTask(ctx, taskStatus, err, task) return app.updateCurJob(ctx, jobStatus, err, job)
} }
func (app *DbBinlogApp) run() { func (app *DbBinlogApp) run() {
@@ -99,16 +99,16 @@ func (app *DbBinlogApp) run() {
} }
func (app *DbBinlogApp) fetchFromAllInstances() { func (app *DbBinlogApp) fetchFromAllInstances() {
tasks, err := app.backupRepo.ListRepeating() var backups []*entity.DbBackup
if err != nil { if err := app.backupRepo.ListRepeating(&backups); err != nil {
logx.Errorf("DbBinlogApp: 获取数据库备份任务失败: %s", err.Error()) logx.Errorf("DbBinlogApp: 获取数据库备份任务失败: %s", err.Error())
return return
} }
for _, task := range tasks { for _, backup := range backups {
if app.closed() { if app.closed() {
break break
} }
if err := app.runTask(app.context, task); err != nil { if err := app.fetchBinlog(app.context, backup); err != nil {
logx.Errorf("DbBinlogApp: 下载 binlog 文件失败: %s", err.Error()) logx.Errorf("DbBinlogApp: 下载 binlog 文件失败: %s", err.Error())
return return
} }
@@ -124,31 +124,31 @@ func (app *DbBinlogApp) closed() bool {
return app.context.Err() != nil return app.context.Err() != nil
} }
func (app *DbBinlogApp) AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error { func (app *DbBinlogApp) AddJobIfNotExists(ctx context.Context, job *entity.DbBinlog) error {
if err := app.binlogRepo.AddTaskIfNotExists(ctx, task); err != nil { if err := app.binlogRepo.AddJobIfNotExists(ctx, job); err != nil {
return err return err
} }
if task.Id == 0 { if job.Id == 0 {
return nil return nil
} }
return nil return nil
} }
func (app *DbBinlogApp) DeleteTask(ctx context.Context, taskId uint64) error { func (app *DbBinlogApp) Delete(ctx context.Context, jobId uint64) error {
// todo: 删除 Binlog 历史文件 // todo: 删除 Binlog 历史文件
if err := app.binlogRepo.DeleteById(ctx, taskId); err != nil { if err := app.binlogRepo.DeleteById(ctx, jobId); err != nil {
return err return err
} }
return nil return nil
} }
func (app *DbBinlogApp) updateCurTask(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbBinlog) error { func (app *DbBinlogApp) updateCurJob(ctx context.Context, status entity.DbJobStatus, lastErr error, job *entity.DbBinlog) error {
task.LastStatus = status job.LastStatus = status
var result = binlogResult[status] var result = binlogResult[status]
if lastErr != nil { if lastErr != nil {
result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr) result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr)
} }
task.LastResult = stringx.TruncateStr(result, entity.LastResultSize) job.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
task.LastTime = timex.NewNullTime(time.Now()) job.LastTime = timex.NewNullTime(time.Now())
return app.binlogRepo.UpdateById(ctx, task, "last_status", "last_result", "last_time") return app.binlogRepo.UpdateById(ctx, job, "last_status", "last_result", "last_time")
} }

View File

@@ -2,14 +2,19 @@ package application
import ( import (
"context" "context"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository" "mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/model" "mayfly-go/pkg/model"
"time"
) )
func newDbRestoreApp(repositories *repository.Repositories, dbApp Db) (*DbRestoreApp, error) { func newDbRestoreApp(repositories *repository.Repositories, dbApp Db, scheduler *dbScheduler) (*DbRestoreApp, error) {
var jobs []*entity.DbRestore
if err := repositories.Restore.ListToDo(&jobs); err != nil {
return nil, err
}
if err := scheduler.AddJob(context.Background(), false, entity.DbJobTypeRestore, jobs); err != nil {
return nil, err
}
app := &DbRestoreApp{ app := &DbRestoreApp{
restoreRepo: repositories.Restore, restoreRepo: repositories.Restore,
instanceRepo: repositories.Instance, instanceRepo: repositories.Instance,
@@ -17,14 +22,8 @@ func newDbRestoreApp(repositories *repository.Repositories, dbApp Db) (*DbRestor
restoreHistoryRepo: repositories.RestoreHistory, restoreHistoryRepo: repositories.RestoreHistory,
binlogHistoryRepo: repositories.BinlogHistory, binlogHistoryRepo: repositories.BinlogHistory,
dbApp: dbApp, dbApp: dbApp,
scheduler: scheduler,
} }
scheduler, err := newDbScheduler[*entity.DbRestore](
repositories.Restore,
withRunRestoreTask(app))
if err != nil {
return nil, err
}
app.scheduler = scheduler
return app, nil return app, nil
} }
@@ -35,37 +34,37 @@ type DbRestoreApp struct {
restoreHistoryRepo repository.DbRestoreHistory restoreHistoryRepo repository.DbRestoreHistory
binlogHistoryRepo repository.DbBinlogHistory binlogHistoryRepo repository.DbBinlogHistory
dbApp Db dbApp Db
scheduler *dbScheduler[*entity.DbRestore] scheduler *dbScheduler
} }
func (app *DbRestoreApp) Close() { func (app *DbRestoreApp) Close() {
app.scheduler.Close() app.scheduler.Close()
} }
func (app *DbRestoreApp) Create(ctx context.Context, tasks ...*entity.DbRestore) error { func (app *DbRestoreApp) Create(ctx context.Context, job *entity.DbRestore) error {
return app.scheduler.AddTask(ctx, tasks...) return app.scheduler.AddJob(ctx, true /* 保存到数据库 */, entity.DbJobTypeRestore, job)
} }
func (app *DbRestoreApp) Save(ctx context.Context, task *entity.DbRestore) error { func (app *DbRestoreApp) Update(ctx context.Context, job *entity.DbRestore) error {
return app.scheduler.UpdateTask(ctx, task) return app.scheduler.UpdateJob(ctx, job)
} }
func (app *DbRestoreApp) Delete(ctx context.Context, taskId uint64) error { func (app *DbRestoreApp) Delete(ctx context.Context, jobId uint64) error {
// todo: 删除数据库恢复历史文件 // todo: 删除数据库恢复历史文件
return app.scheduler.DeleteTask(ctx, taskId) return app.scheduler.RemoveJob(ctx, entity.DbJobTypeRestore, jobId)
} }
func (app *DbRestoreApp) Enable(ctx context.Context, taskId uint64) error { func (app *DbRestoreApp) Enable(ctx context.Context, jobId uint64) error {
return app.scheduler.EnableTask(ctx, taskId) return app.scheduler.EnableJob(ctx, entity.DbJobTypeRestore, jobId)
} }
func (app *DbRestoreApp) Disable(ctx context.Context, taskId uint64) error { func (app *DbRestoreApp) Disable(ctx context.Context, jobId uint64) error {
return app.scheduler.DisableTask(ctx, taskId) return app.scheduler.DisableJob(ctx, entity.DbJobTypeRestore, jobId)
} }
// GetPageList 分页获取数据库恢复任务 // GetPageList 分页获取数据库恢复任务
func (app *DbRestoreApp) GetPageList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { func (app *DbRestoreApp) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.restoreRepo.GetDbRestoreList(condition, pageParam, toEntity, orderBy...) return app.restoreRepo.GetPageList(condition, pageParam, toEntity, orderBy...)
} }
// GetDbNamesWithoutRestore 获取未配置定时恢复的数据库名称 // GetDbNamesWithoutRestore 获取未配置定时恢复的数据库名称
@@ -73,108 +72,7 @@ func (app *DbRestoreApp) GetDbNamesWithoutRestore(instanceId uint64, dbNames []s
return app.restoreRepo.GetDbNamesWithoutRestore(instanceId, dbNames) return app.restoreRepo.GetDbNamesWithoutRestore(instanceId, dbNames)
} }
// 分页获取数据库备份历史 // GetHistoryPageList 分页获取数据库备份历史
func (app *DbRestoreApp) GetHistoryPageList(condition *entity.DbRestoreHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { func (app *DbRestoreApp) GetHistoryPageList(condition *entity.DbRestoreHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.restoreHistoryRepo.GetDbRestoreHistories(condition, pageParam, toEntity, orderBy...) return app.restoreHistoryRepo.GetDbRestoreHistories(condition, pageParam, toEntity, orderBy...)
} }
func (app *DbRestoreApp) runTask(ctx context.Context, task *entity.DbRestore) error {
conn, err := app.dbApp.GetDbConnByInstanceId(task.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
if task.PointInTime.Valid {
latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
binlogHistory, ok, err := app.binlogHistoryRepo.GetLatestHistory(task.DbInstanceId)
if err != nil {
return err
}
if ok {
latestBinlogSequence = binlogHistory.Sequence
} else {
backupHistory, err := app.backupHistoryRepo.GetEarliestHistory(task.DbInstanceId)
if err != nil {
return err
}
earliestBackupSequence = backupHistory.BinlogSequence
}
binlogFiles, err := dbProgram.FetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence)
if err != nil {
return err
}
if err := app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, task.DbInstanceId, binlogFiles); err != nil {
return err
}
if err := app.restorePointInTime(ctx, dbProgram, task); err != nil {
return err
}
} else {
if err := app.restoreBackupHistory(ctx, dbProgram, task); err != nil {
return err
}
}
history := &entity.DbRestoreHistory{
CreateTime: time.Now(),
DbRestoreId: task.Id,
}
if err := app.restoreHistoryRepo.Insert(ctx, history); err != nil {
return err
}
return nil
}
func (app *DbRestoreApp) restorePointInTime(ctx context.Context, program dbm.DbProgram, task *entity.DbRestore) error {
binlogHistory, err := app.binlogHistoryRepo.GetHistoryByTime(task.DbInstanceId, task.PointInTime.Time)
if err != nil {
return err
}
position, err := program.GetBinlogEventPositionAtOrAfterTime(ctx, binlogHistory.FileName, task.PointInTime.Time)
if err != nil {
return err
}
target := &entity.BinlogInfo{
FileName: binlogHistory.FileName,
Sequence: binlogHistory.Sequence,
Position: position,
}
backupHistory, err := app.backupHistoryRepo.GetLatestHistory(task.DbInstanceId, task.DbName, target)
if err != nil {
return err
}
start := &entity.BinlogInfo{
FileName: backupHistory.BinlogFileName,
Sequence: backupHistory.BinlogSequence,
Position: backupHistory.BinlogPosition,
}
binlogHistories, err := app.binlogHistoryRepo.GetHistories(task.DbInstanceId, start, target)
if err != nil {
return err
}
restoreInfo := &dbm.RestoreInfo{
BackupHistory: backupHistory,
BinlogHistories: binlogHistories,
StartPosition: backupHistory.BinlogPosition,
TargetPosition: target.Position,
TargetTime: task.PointInTime.Time,
}
if err := program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid); err != nil {
return err
}
return program.ReplayBinlog(ctx, task.DbName, task.DbName, restoreInfo)
}
func (app *DbRestoreApp) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, task *entity.DbRestore) error {
backupHistory := &entity.DbBackupHistory{}
if err := app.backupHistoryRepo.GetById(backupHistory, task.DbBackupHistoryId); err != nil {
return err
}
return program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid)
}
func withRunRestoreTask(app *DbRestoreApp) dbSchedulerOption[*entity.DbRestore] {
return func(scheduler *dbScheduler[*entity.DbRestore]) {
scheduler.RunTask = app.runTask
}
}

View File

@@ -4,232 +4,360 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository" "mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/queue" "mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/runner"
"mayfly-go/pkg/utils/stringx" "reflect"
"mayfly-go/pkg/utils/timex"
"sync" "sync"
"time" "time"
) )
const sleepAfterError = time.Minute const (
maxRunning = 8
)
type dbScheduler[T entity.DbTask] struct { type dbScheduler struct {
mutex sync.Mutex mutex sync.Mutex
waitGroup sync.WaitGroup runner *runner.Runner[entity.DbJob]
queue *queue.DelayQueue[T] dbApp Db
context context.Context backupRepo repository.DbBackup
cancel context.CancelFunc backupHistoryRepo repository.DbBackupHistory
RunTask func(ctx context.Context, task T) error restoreRepo repository.DbRestore
taskRepo repository.DbTask[T] restoreHistoryRepo repository.DbRestoreHistory
binlogHistoryRepo repository.DbBinlogHistory
} }
type dbSchedulerOption[T entity.DbTask] func(*dbScheduler[T]) func newDbScheduler(repositories *repository.Repositories) (*dbScheduler, error) {
scheduler := &dbScheduler{
func newDbScheduler[T entity.DbTask](taskRepo repository.DbTask[T], opts ...dbSchedulerOption[T]) (*dbScheduler[T], error) { runner: runner.NewRunner[entity.DbJob](maxRunning),
ctx, cancel := context.WithCancel(context.Background()) dbApp: dbApp,
scheduler := &dbScheduler[T]{ backupRepo: repositories.Backup,
taskRepo: taskRepo, backupHistoryRepo: repositories.BackupHistory,
queue: queue.NewDelayQueue[T](0), restoreRepo: repositories.Restore,
context: ctx, restoreHistoryRepo: repositories.RestoreHistory,
cancel: cancel, binlogHistoryRepo: repositories.BinlogHistory,
} }
for _, opt := range opts {
opt(scheduler)
}
if scheduler.RunTask == nil {
return nil, errors.New("数据库任务调度器没有设置 RunTask")
}
if err := scheduler.loadTask(context.Background()); err != nil {
return nil, err
}
scheduler.waitGroup.Add(1)
go scheduler.run()
return scheduler, nil return scheduler, nil
} }
func (s *dbScheduler[T]) updateTaskStatus(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error { func (s *dbScheduler) repo(typ entity.DbJobType) repository.DbJob {
base := task.GetTaskBase() switch typ {
base.LastStatus = status case entity.DbJobTypeBackup:
var result = task.MessageWithStatus(status) return s.backupRepo
if lastErr != nil { case entity.DbJobTypeRestore:
result = fmt.Sprintf("%v: %v", result, lastErr) return s.restoreRepo
default:
panic(errors.New(fmt.Sprintf("无效的数据库任务类型: %v", typ)))
} }
base.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
base.LastTime = timex.NewNullTime(time.Now())
return s.taskRepo.UpdateTaskStatus(ctx, task)
} }
func (s *dbScheduler[T]) UpdateTask(ctx context.Context, task T) error { func (s *dbScheduler) UpdateJob(ctx context.Context, job entity.DbJob) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if err := s.taskRepo.UpdateById(ctx, task); err != nil { if err := s.repo(job.GetJobType()).UpdateById(ctx, job); err != nil {
return err return err
} }
job.SetRun(s.run)
oldTask, ok := s.queue.Remove(ctx, task.GetId()) job.SetRunnable(s.runnable)
if !ok { _ = s.runner.UpdateOrAdd(ctx, job)
return errors.New("任务不存在")
}
oldTask.Update(task)
if !oldTask.Schedule() {
return nil
}
if !s.queue.Enqueue(ctx, oldTask) {
return errors.New("任务入队失败")
}
return nil return nil
} }
func (s *dbScheduler[T]) run() { func (s *dbScheduler) Close() {
defer s.waitGroup.Done() s.runner.Close()
for !s.closed() {
time.Sleep(time.Second)
s.mutex.Lock()
task, ok := s.queue.TryDequeue()
if !ok {
s.mutex.Unlock()
continue
}
if err := s.updateTaskStatus(s.context, entity.TaskReserved, nil, task); err != nil {
s.mutex.Unlock()
timex.SleepWithContext(s.context, sleepAfterError)
continue
}
s.mutex.Unlock()
errRun := s.RunTask(s.context, task)
taskStatus := entity.TaskSuccess
if errRun != nil {
taskStatus = entity.TaskFailed
}
s.mutex.Lock()
if err := s.updateTaskStatus(s.context, taskStatus, errRun, task); err != nil {
s.mutex.Unlock()
timex.SleepWithContext(s.context, sleepAfterError)
continue
}
task.Schedule()
if !task.IsFinished() {
s.queue.Enqueue(s.context, task)
}
s.mutex.Unlock()
}
} }
func (s *dbScheduler[T]) Close() { func (s *dbScheduler) AddJob(ctx context.Context, saving bool, jobType entity.DbJobType, jobs any) error {
s.cancel()
s.waitGroup.Wait()
}
func (s *dbScheduler[T]) closed() bool {
return s.context.Err() != nil
}
func (s *dbScheduler[T]) loadTask(ctx context.Context) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
tasks, err := s.taskRepo.ListToDo() if saving {
if err != nil { if err := s.repo(jobType).AddJob(ctx, jobs); err != nil {
return err return err
} }
for _, task := range tasks { }
if !task.Schedule() {
reflectValue := reflect.ValueOf(jobs)
switch reflectValue.Kind() {
case reflect.Array, reflect.Slice:
reflectLen := reflectValue.Len()
for i := 0; i < reflectLen; i++ {
job := reflectValue.Index(i).Interface().(entity.DbJob)
job.SetJobType(jobType)
if !job.Schedule() {
continue continue
} }
s.queue.Enqueue(ctx, task) job.SetRun(s.run)
job.SetRunnable(s.runnable)
_ = s.runner.Add(ctx, job)
}
default:
job := jobs.(entity.DbJob)
job.SetJobType(jobType)
if !job.Schedule() {
return nil
}
job.SetRun(s.run)
job.SetRunnable(s.runnable)
_ = s.runner.Add(ctx, job)
} }
return nil return nil
} }
func (s *dbScheduler[T]) AddTask(ctx context.Context, tasks ...T) error { func (s *dbScheduler) RemoveJob(ctx context.Context, jobType entity.DbJobType, jobId uint64) error {
s.mutex.Lock()
defer s.mutex.Unlock()
for _, task := range tasks {
if err := s.taskRepo.AddTask(ctx, task); err != nil {
return err
}
if !task.Schedule() {
continue
}
s.queue.Enqueue(ctx, task)
}
return nil
}
func (s *dbScheduler[T]) DeleteTask(ctx context.Context, taskId uint64) error {
// todo: 删除数据库备份历史文件 // todo: 删除数据库备份历史文件
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if err := s.taskRepo.DeleteById(ctx, taskId); err != nil { if err := s.repo(jobType).DeleteById(ctx, jobId); err != nil {
return err return err
} }
s.queue.Remove(ctx, taskId) _ = s.runner.Remove(ctx, entity.FormatJobKey(jobType, jobId))
return nil return nil
} }
func (s *dbScheduler[T]) EnableTask(ctx context.Context, taskId uint64) error { func (s *dbScheduler) EnableJob(ctx context.Context, jobType entity.DbJobType, jobId uint64) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
task := anyx.DeepZero[T]() repo := s.repo(jobType)
if err := s.taskRepo.GetById(task, taskId); err != nil { job := entity.NewDbJob(jobType)
if err := repo.GetById(job, jobId); err != nil {
return err return err
} }
if task.IsEnabled() { if job.IsEnabled() {
return nil return nil
} }
task.GetTaskBase().Enabled = true job.GetJobBase().Enabled = true
if err := s.taskRepo.UpdateEnabled(ctx, taskId, true); err != nil { if err := repo.UpdateEnabled(ctx, jobId, true); err != nil {
return err return err
} }
s.queue.Remove(ctx, taskId) job.SetRun(s.run)
if !task.Schedule() { job.SetRunnable(s.runnable)
return nil _ = s.runner.Add(ctx, job)
}
s.queue.Enqueue(ctx, task)
return nil return nil
} }
func (s *dbScheduler[T]) DisableTask(ctx context.Context, taskId uint64) error { func (s *dbScheduler) DisableJob(ctx context.Context, jobType entity.DbJobType, jobId uint64) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
task := anyx.DeepZero[T]() repo := s.repo(jobType)
if err := s.taskRepo.GetById(task, taskId); err != nil { job := entity.NewDbJob(jobType)
if err := repo.GetById(job, jobId); err != nil {
return err return err
} }
if !task.IsEnabled() { if !job.IsEnabled() {
return nil return nil
} }
if err := s.taskRepo.UpdateEnabled(ctx, taskId, false); err != nil { if err := repo.UpdateEnabled(ctx, jobId, false); err != nil {
return err return err
} }
s.queue.Remove(ctx, taskId) _ = s.runner.Remove(ctx, job.GetKey())
return nil return nil
} }
func (s *dbScheduler[T]) StartTask(ctx context.Context, taskId uint64) error { func (s *dbScheduler) StartJobNow(ctx context.Context, jobType entity.DbJobType, jobId uint64) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
task := anyx.DeepZero[T]() job := entity.NewDbJob(jobType)
if err := s.taskRepo.GetById(task, taskId); err != nil { if err := s.repo(jobType).GetById(job, jobId); err != nil {
return err return err
} }
if !task.IsEnabled() { if !job.IsEnabled() {
return errors.New("任务未启用") return errors.New("任务未启用")
} }
s.queue.Remove(ctx, taskId) job.GetJobBase().Deadline = time.Now()
task.GetTaskBase().Deadline = time.Now() job.SetRun(s.run)
s.queue.Enqueue(ctx, task) job.SetRunnable(s.runnable)
_ = s.runner.StartNow(ctx, job)
return nil return nil
} }
func (s *dbScheduler) backupMysql(ctx context.Context, job entity.DbJob) error {
id, err := NewIncUUID()
if err != nil {
return err
}
backup := job.(*entity.DbBackup)
history := &entity.DbBackupHistory{
Uuid: id.String(),
DbBackupId: backup.Id,
DbInstanceId: backup.DbInstanceId,
DbName: backup.DbName,
}
conn, err := s.dbApp.GetDbConnByInstanceId(backup.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
binlogInfo, err := dbProgram.Backup(ctx, history)
if err != nil {
return err
}
now := time.Now()
name := backup.Name
if len(name) == 0 {
name = backup.DbName
}
history.Name = fmt.Sprintf("%s[%s]", name, now.Format(time.DateTime))
history.CreateTime = now
history.BinlogFileName = binlogInfo.FileName
history.BinlogSequence = binlogInfo.Sequence
history.BinlogPosition = binlogInfo.Position
if err := s.backupHistoryRepo.Insert(ctx, history); err != nil {
return err
}
return nil
}
func (s *dbScheduler) restoreMysql(ctx context.Context, job entity.DbJob) error {
restore := job.(*entity.DbRestore)
conn, err := s.dbApp.GetDbConnByInstanceId(restore.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
if restore.PointInTime.Valid {
latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
binlogHistory, ok, err := s.binlogHistoryRepo.GetLatestHistory(restore.DbInstanceId)
if err != nil {
return err
}
if ok {
latestBinlogSequence = binlogHistory.Sequence
} else {
backupHistory, err := s.backupHistoryRepo.GetEarliestHistory(restore.DbInstanceId)
if err != nil {
return err
}
earliestBackupSequence = backupHistory.BinlogSequence
}
binlogFiles, err := dbProgram.FetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence)
if err != nil {
return err
}
if err := s.binlogHistoryRepo.InsertWithBinlogFiles(ctx, restore.DbInstanceId, binlogFiles); err != nil {
return err
}
if err := s.restorePointInTime(ctx, dbProgram, restore); err != nil {
return err
}
} else {
if err := s.restoreBackupHistory(ctx, dbProgram, restore); err != nil {
return err
}
}
history := &entity.DbRestoreHistory{
CreateTime: time.Now(),
DbRestoreId: restore.Id,
}
if err := s.restoreHistoryRepo.Insert(ctx, history); err != nil {
return err
}
return nil
}
func (s *dbScheduler) run(ctx context.Context, job entity.DbJob) {
job.SetLastStatus(entity.DbJobRunning, nil)
if err := s.repo(job.GetJobType()).UpdateLastStatus(ctx, job); err != nil {
logx.Errorf("failed to update job status: %v", err)
return
}
var errRun error
switch typ := job.GetJobType(); typ {
case entity.DbJobTypeBackup:
errRun = s.backupMysql(ctx, job)
case entity.DbJobTypeRestore:
errRun = s.restoreMysql(ctx, job)
default:
errRun = errors.New(fmt.Sprintf("无效的数据库任务类型: %v", typ))
}
status := entity.DbJobSuccess
if errRun != nil {
status = entity.DbJobFailed
}
job.SetLastStatus(status, errRun)
if err := s.repo(job.GetJobType()).UpdateLastStatus(ctx, job); err != nil {
logx.Errorf("failed to update job status: %v", err)
return
}
}
func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool {
const maxCountByInstanceId = 4
const maxCountByDbName = 1
var countByInstanceId, countByDbName int
jobBase := job.GetJobBase()
for item, ok := next(); ok; item, ok = next() {
itemBase := item.(entity.DbJob).GetJobBase()
if jobBase.DbInstanceId == itemBase.DbInstanceId {
countByInstanceId++
if countByInstanceId > maxCountByInstanceId {
return false
}
if jobBase.DbName == itemBase.DbName {
countByDbName++
if countByDbName > maxCountByDbName {
return false
}
}
}
}
return true
}
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
binlogHistory, err := s.binlogHistoryRepo.GetHistoryByTime(job.DbInstanceId, job.PointInTime.Time)
if err != nil {
return err
}
position, err := program.GetBinlogEventPositionAtOrAfterTime(ctx, binlogHistory.FileName, job.PointInTime.Time)
if err != nil {
return err
}
target := &entity.BinlogInfo{
FileName: binlogHistory.FileName,
Sequence: binlogHistory.Sequence,
Position: position,
}
backupHistory, err := s.backupHistoryRepo.GetLatestHistory(job.DbInstanceId, job.DbName, target)
if err != nil {
return err
}
start := &entity.BinlogInfo{
FileName: backupHistory.BinlogFileName,
Sequence: backupHistory.BinlogSequence,
Position: backupHistory.BinlogPosition,
}
binlogHistories, err := s.binlogHistoryRepo.GetHistories(job.DbInstanceId, start, target)
if err != nil {
return err
}
restoreInfo := &dbm.RestoreInfo{
BackupHistory: backupHistory,
BinlogHistories: binlogHistories,
StartPosition: backupHistory.BinlogPosition,
TargetPosition: target.Position,
TargetTime: job.PointInTime.Time,
}
if err := program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid); err != nil {
return err
}
return program.ReplayBinlog(ctx, job.DbName, job.DbName, restoreInfo)
}
func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
backupHistory := &entity.DbBackupHistory{}
if err := s.backupHistoryRepo.GetById(backupHistory, job.DbBackupHistoryId); err != nil {
return err
}
return program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid)
}

View File

@@ -97,12 +97,12 @@ func (s *DbInstanceSuite) TearDownTest() {
} }
func (s *DbInstanceSuite) TestBackup() { func (s *DbInstanceSuite) TestBackup() {
task := &entity.DbBackupHistory{ history := &entity.DbBackupHistory{
DbName: dbNameBackupTest, DbName: dbNameBackupTest,
Uuid: dbNameBackupTest, Uuid: dbNameBackupTest,
} }
task.Id = backupIdTest history.Id = backupIdTest
s.testBackup(task) s.testBackup(history)
} }
func (s *DbInstanceSuite) testBackup(backupHistory *entity.DbBackupHistory) { func (s *DbInstanceSuite) testBackup(backupHistory *entity.DbBackupHistory) {

View File

@@ -1,29 +1,27 @@
package entity package entity
var _ DbTask = (*DbBackup)(nil) import (
"context"
"mayfly-go/pkg/runner"
)
var _ DbJob = (*DbBackup)(nil)
// DbBackup 数据库备份任务 // DbBackup 数据库备份任务
type DbBackup struct { type DbBackup struct {
*DbTaskBase *DbJobBaseImpl
Name string `json:"name"` // 备份任务名称 Name string `json:"Name"` // 数据库备份名称
DbName string `json:"dbName"` // 数据库名
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
} }
func (*DbBackup) MessageWithStatus(status TaskStatus) string { func (d *DbBackup) SetRun(fn func(ctx context.Context, job DbJob)) {
var result string d.run = func(ctx context.Context) {
switch status { fn(ctx, d)
case TaskDelay: }
result = "等待备份数据库" }
case TaskReady:
result = "准备备份数据库" func (d *DbBackup) SetRunnable(fn func(job DbJob, next runner.NextFunc) bool) {
case TaskReserved: d.runnable = func(next runner.NextFunc) bool {
result = "数据库备份中" return fn(d, next)
case TaskSuccess:
result = "数据库备份成功"
case TaskFailed:
result = "数据库备份失败"
} }
return result
} }

View File

@@ -10,17 +10,17 @@ import (
type DbBinlog struct { type DbBinlog struct {
model.Model model.Model
LastStatus TaskStatus // 最近一次执行状态 LastStatus DbJobStatus // 最近一次执行状态
LastResult string // 最近一次执行结果 LastResult string // 最近一次执行结果
LastTime timex.NullTime // 最近一次执行时间 LastTime timex.NullTime // 最近一次执行时间
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
} }
func NewDbBinlog(instanceId uint64) *DbBinlog { func NewDbBinlog(instanceId uint64) *DbBinlog {
binlogTask := &DbBinlog{} job := &DbBinlog{}
binlogTask.Id = instanceId job.Id = instanceId
binlogTask.DbInstanceId = instanceId job.DbInstanceId = instanceId
return binlogTask return job
} }
// BinlogFile is the metadata of the MySQL binlog file. // BinlogFile is the metadata of the MySQL binlog file.

View File

@@ -0,0 +1,235 @@
package entity
import (
"context"
"fmt"
"mayfly-go/pkg/model"
"mayfly-go/pkg/runner"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/utils/timex"
"time"
)
const LastResultSize = 256
type DbJobKey = runner.JobKey
type DbJobStatus = runner.JobStatus
const (
DbJobUnknown = runner.JobUnknown
DbJobDelay = runner.JobDelay
DbJobReady = runner.JobWaiting
DbJobRunning = runner.JobRunning
DbJobRemoved = runner.JobRemoved
)
const (
DbJobSuccess DbJobStatus = 0x20 + iota
DbJobFailed
)
type DbJobType = string
const (
DbJobTypeBackup DbJobType = "db-backup"
DbJobTypeRestore DbJobType = "db-restore"
)
const (
DbJobNameBackup = "数据库备份"
DbJobNameRestore = "数据库恢复"
)
var _ runner.Job = (DbJob)(nil)
type DbJobBase interface {
model.ModelI
runner.Job
GetId() uint64
GetJobType() DbJobType
SetJobType(typ DbJobType)
GetJobBase() *DbJobBaseImpl
SetLastStatus(status DbJobStatus, err error)
IsEnabled() bool
}
type DbJob interface {
DbJobBase
SetRun(fn func(ctx context.Context, job DbJob))
SetRunnable(fn func(job DbJob, next runner.NextFunc) bool)
}
func NewDbJob(typ DbJobType) DbJob {
switch typ {
case DbJobTypeBackup:
return &DbBackup{
DbJobBaseImpl: &DbJobBaseImpl{
jobType: DbJobTypeBackup},
}
case DbJobTypeRestore:
return &DbRestore{
DbJobBaseImpl: &DbJobBaseImpl{
jobType: DbJobTypeRestore},
}
default:
panic(fmt.Sprintf("invalid DbJobType: %v", typ))
}
}
var _ DbJobBase = (*DbJobBaseImpl)(nil)
type DbJobBaseImpl struct {
model.Model
DbInstanceId uint64 // 数据库实例ID
DbName string // 数据库名称
Enabled bool // 是否启用
StartTime time.Time // 开始时间
Interval time.Duration // 间隔时间
Repeated bool // 是否重复执行
LastStatus DbJobStatus // 最近一次执行状态
LastResult string // 最近一次执行结果
LastTime timex.NullTime // 最近一次执行时间
Deadline time.Time `gorm:"-" json:"-"` // 计划执行时间
run runner.RunFunc
runnable runner.RunnableFunc
jobType DbJobType
jobKey runner.JobKey
jobStatus runner.JobStatus
}
func NewDbBJobBase(instanceId uint64, dbName string, jobType DbJobType, enabled bool, repeated bool, startTime time.Time, interval time.Duration) *DbJobBaseImpl {
return &DbJobBaseImpl{
DbInstanceId: instanceId,
DbName: dbName,
jobType: jobType,
Enabled: enabled,
Repeated: repeated,
StartTime: startTime,
Interval: interval,
}
}
func (d *DbJobBaseImpl) GetJobType() DbJobType {
return d.jobType
}
func (d *DbJobBaseImpl) SetJobType(typ DbJobType) {
d.jobType = typ
}
func (d *DbJobBaseImpl) SetLastStatus(status DbJobStatus, err error) {
var statusName, jobName string
switch status {
case DbJobRunning:
statusName = "运行中"
case DbJobSuccess:
statusName = "成功"
case DbJobFailed:
statusName = "失败"
default:
return
}
switch d.jobType {
case DbJobTypeBackup:
jobName = DbJobNameBackup
case DbJobTypeRestore:
jobName = DbJobNameRestore
default:
jobName = d.jobType
}
d.LastStatus = status
var result = jobName + statusName
if err != nil {
result = fmt.Sprintf("%s: %v", result, err)
}
d.LastResult = stringx.TruncateStr(result, LastResultSize)
d.LastTime = timex.NewNullTime(time.Now())
}
func (d *DbJobBaseImpl) GetId() uint64 {
if d == nil {
return 0
}
return d.Id
}
func (d *DbJobBaseImpl) GetDeadline() time.Time {
return d.Deadline
}
func (d *DbJobBaseImpl) Schedule() bool {
if d.IsFinished() || !d.Enabled {
return false
}
switch d.LastStatus {
case DbJobSuccess:
if d.Interval == 0 {
return false
}
lastTime := d.LastTime.Time
if lastTime.Sub(d.StartTime) < 0 {
lastTime = d.StartTime.Add(-d.Interval)
}
d.Deadline = lastTime.Add(d.Interval - lastTime.Sub(d.StartTime)%d.Interval)
case DbJobFailed:
d.Deadline = time.Now().Add(time.Minute)
default:
d.Deadline = d.StartTime
}
return true
}
func (d *DbJobBaseImpl) IsFinished() bool {
return !d.Repeated && d.LastStatus == DbJobSuccess
}
func (d *DbJobBaseImpl) Renew(job runner.Job) {
jobBase := job.(DbJob).GetJobBase()
d.StartTime = jobBase.StartTime
d.Interval = jobBase.Interval
}
func (d *DbJobBaseImpl) GetJobBase() *DbJobBaseImpl {
return d
}
func (d *DbJobBaseImpl) IsEnabled() bool {
return d.Enabled
}
func (d *DbJobBaseImpl) Run(ctx context.Context) {
if d.run == nil {
return
}
d.run(ctx)
}
func (d *DbJobBaseImpl) Runnable(next runner.NextFunc) bool {
if d.runnable == nil {
return true
}
return d.runnable(next)
}
func FormatJobKey(typ DbJobType, jobId uint64) DbJobKey {
return fmt.Sprintf("%v-%d", typ, jobId)
}
func (d *DbJobBaseImpl) GetKey() DbJobKey {
if len(d.jobKey) == 0 {
d.jobKey = FormatJobKey(d.jobType, d.Id)
}
return d.jobKey
}
func (d *DbJobBaseImpl) GetStatus() DbJobStatus {
return d.jobStatus
}
func (d *DbJobBaseImpl) SetStatus(status DbJobStatus) {
d.jobStatus = status
}

View File

@@ -1,36 +1,31 @@
package entity package entity
import ( import (
"context"
"mayfly-go/pkg/runner"
"mayfly-go/pkg/utils/timex" "mayfly-go/pkg/utils/timex"
) )
var _ DbTask = (*DbRestore)(nil) var _ DbJob = (*DbRestore)(nil)
// DbRestore 数据库恢复任务 // DbRestore 数据库恢复任务
type DbRestore struct { type DbRestore struct {
*DbTaskBase *DbJobBaseImpl
DbName string `json:"dbName"` // 数据库名
PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间点 PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间点
DbBackupId uint64 `json:"dbBackupId"` // 用于恢复的数据库恢复任务ID DbBackupId uint64 `json:"dbBackupId"` // 用于恢复的数据库恢复任务ID
DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 用于恢复的数据库恢复历史ID DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 用于恢复的数据库恢复历史ID
DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库恢复历史名称 DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库恢复历史名称
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
} }
func (*DbRestore) MessageWithStatus(status TaskStatus) string { func (d *DbRestore) SetRun(fn func(ctx context.Context, job DbJob)) {
var result string d.run = func(ctx context.Context) {
switch status { fn(ctx, d)
case TaskDelay: }
result = "等待恢复数据库" }
case TaskReady:
result = "准备恢复数据库" func (d *DbRestore) SetRunnable(fn func(job DbJob, next runner.NextFunc) bool) {
case TaskReserved: d.runnable = func(next runner.NextFunc) bool {
result = "数据库恢复中" return fn(d, next)
case TaskSuccess:
result = "数据库恢复成功"
case TaskFailed:
result = "数据库恢复失败"
} }
return result
} }

View File

@@ -1,109 +0,0 @@
package entity
import (
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/timex"
"time"
)
type TaskStatus int
const (
TaskDelay TaskStatus = iota
TaskReady
TaskReserved
TaskSuccess
TaskFailed
)
const LastResultSize = 256
type DbTask interface {
model.ModelI
GetId() uint64
GetDeadline() time.Time
IsFinished() bool
Schedule() bool
Update(task DbTask)
GetTaskBase() *DbTaskBase
MessageWithStatus(status TaskStatus) string
IsEnabled() bool
}
func NewDbBTaskBase(enabled bool, repeated bool, startTime time.Time, interval time.Duration) *DbTaskBase {
return &DbTaskBase{
Enabled: enabled,
Repeated: repeated,
StartTime: startTime,
Interval: interval,
}
}
type DbTaskBase struct {
model.Model
Enabled bool // 是否启用
StartTime time.Time // 开始时间
Interval time.Duration // 间隔时间
Repeated bool // 是否重复执行
LastStatus TaskStatus // 最近一次执行状态
LastResult string // 最近一次执行结果
LastTime timex.NullTime // 最近一次执行时间
Deadline time.Time `gorm:"-" json:"-"` // 计划执行时间
}
func (d *DbTaskBase) GetId() uint64 {
if d == nil {
return 0
}
return d.Id
}
func (d *DbTaskBase) GetDeadline() time.Time {
return d.Deadline
}
func (d *DbTaskBase) Schedule() bool {
if d.IsFinished() || !d.Enabled {
return false
}
switch d.LastStatus {
case TaskSuccess:
if d.Interval == 0 {
return false
}
lastTime := d.LastTime.Time
if lastTime.Sub(d.StartTime) < 0 {
lastTime = d.StartTime.Add(-d.Interval)
}
d.Deadline = lastTime.Add(d.Interval - lastTime.Sub(d.StartTime)%d.Interval)
case TaskFailed:
d.Deadline = time.Now().Add(time.Minute)
default:
d.Deadline = d.StartTime
}
return true
}
func (d *DbTaskBase) IsFinished() bool {
return !d.Repeated && d.LastStatus == TaskSuccess
}
func (d *DbTaskBase) Update(task DbTask) {
t := task.GetTaskBase()
d.StartTime = t.StartTime
d.Interval = t.Interval
}
func (d *DbTaskBase) GetTaskBase() *DbTaskBase {
return d
}
func (*DbTaskBase) MessageWithStatus(_ TaskStatus) string {
return ""
}
func (d *DbTaskBase) IsEnabled() bool {
return d.Enabled
}

View File

@@ -40,8 +40,8 @@ type DbSqlExecQuery struct {
CreatorId uint64 CreatorId uint64
} }
// DbBackupQuery 数据库备份任务查询 // DbJobQuery 数据库备份任务查询
type DbBackupQuery struct { type DbJobQuery struct {
Id uint64 `json:"id" form:"id"` Id uint64 `json:"id" form:"id"`
DbName string `json:"dbName" form:"dbName"` DbName string `json:"dbName" form:"dbName"`
IntervalDay int `json:"intervalDay" form:"intervalDay"` IntervalDay int `json:"intervalDay" form:"intervalDay"`
@@ -61,13 +61,13 @@ type DbBackupHistoryQuery struct {
} }
// DbRestoreQuery 数据库备份任务查询 // DbRestoreQuery 数据库备份任务查询
type DbRestoreQuery struct { //type DbRestoreQuery struct {
Id uint64 `json:"id" form:"id"` // Id uint64 `json:"id" form:"id"`
DbName string `json:"dbName" form:"dbName"` // DbName string `json:"dbName" form:"dbName"`
InDbNames []string `json:"-" form:"-"` // InDbNames []string `json:"-" form:"-"`
DbInstanceId uint64 `json:"-" form:"-"` // DbInstanceId uint64 `json:"-" form:"-"`
Repeated bool `json:"repeated" form:"repeated"` // 是否重复执行 // Repeated bool `json:"repeated" form:"repeated"` // 是否重复执行
} //}
// DbRestoreHistoryQuery 数据库备份任务查询 // DbRestoreHistoryQuery 数据库备份任务查询
type DbRestoreHistoryQuery struct { type DbRestoreHistoryQuery struct {

View File

@@ -1,16 +1,7 @@
package repository package repository
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/model"
)
type DbBackup interface { type DbBackup interface {
DbTask[*entity.DbBackup] DbJob
// GetDbBackupList 分页获取数据信息列表
GetDbBackupList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
AddTask(ctx context.Context, tasks ...*entity.DbBackup) error
GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error)
} }

View File

@@ -9,5 +9,5 @@ import (
type DbBinlog interface { type DbBinlog interface {
base.Repo[*entity.DbBinlog] base.Repo[*entity.DbBinlog]
AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error AddJobIfNotExists(ctx context.Context, job *entity.DbBinlog) error
} }

View File

@@ -0,0 +1,28 @@
package repository
import (
"context"
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/model"
)
type DbJob interface {
// AddJob 添加数据库任务
AddJob(ctx context.Context, jobs any) error
// GetById 根据实体id查询
GetById(e entity.DbJob, id uint64, cols ...string) error
// GetPageList 分页获取数据库任务列表
GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
// UpdateById 根据实体id更新实体信息
UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error
// BatchInsertWithDb 使用指定gorm db执行主要用于事务执行
BatchInsertWithDb(ctx context.Context, db *gorm.DB, es any) error
// DeleteById 根据实体主键删除实体
DeleteById(ctx context.Context, id uint64) error
UpdateLastStatus(ctx context.Context, job entity.DbJob) error
UpdateEnabled(ctx context.Context, jobId uint64, enabled bool) error
ListToDo(jobs any) error
ListRepeating(jobs any) error
}

View File

@@ -1,16 +1,7 @@
package repository package repository
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/model"
)
type DbRestore interface { type DbRestore interface {
DbTask[*entity.DbRestore] DbJob
// GetDbRestoreList 分页获取数据信息列表
GetDbRestoreList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
AddTask(ctx context.Context, tasks ...*entity.DbRestore) error
GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error)
} }

View File

@@ -1,17 +0,0 @@
package repository
import (
"context"
"mayfly-go/pkg/base"
"mayfly-go/pkg/model"
)
type DbTask[T model.ModelI] interface {
base.Repo[T]
UpdateTaskStatus(ctx context.Context, task T) error
UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
ListToDo() ([]T, error)
ListRepeating() ([]T, error)
AddTask(ctx context.Context, tasks ...T) error
}

View File

@@ -1,74 +1,22 @@
package persistence package persistence
import ( import (
"context"
"errors"
"fmt"
"mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository" "mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/gormx" "mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices" "slices"
"gorm.io/gorm"
) )
var _ repository.DbBackup = (*dbBackupRepoImpl)(nil) var _ repository.DbBackup = (*dbBackupRepoImpl)(nil)
type dbBackupRepoImpl struct { type dbBackupRepoImpl struct {
//base.RepoImpl[*entity.DbBackup] dbJobBase[*entity.DbBackup]
dbTaskBase[*entity.DbBackup]
} }
func NewDbBackupRepo() repository.DbBackup { func NewDbBackupRepo() repository.DbBackup {
return &dbBackupRepoImpl{} return &dbBackupRepoImpl{}
} }
// GetDbBackupList 分页获取数据库备份任务列表
func (d *dbBackupRepoImpl) GetDbBackupList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) {
qd := gormx.NewQuery(d.GetModel()).
Eq("id", condition.Id).
Eq0("db_instance_id", condition.DbInstanceId).
Eq0("repeated", condition.Repeated).
In0("db_name", condition.InDbNames).
Like("db_name", condition.DbName)
return gormx.PageQuery(qd, pageParam, toEntity)
}
func (d *dbBackupRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbBackup) error {
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
dbNames := make([]string, 0, len(tasks))
for _, task := range tasks {
if instanceId == 0 {
instanceId = task.DbInstanceId
}
if task.DbInstanceId != instanceId {
return errors.New("不支持同时为多个数据库实例添加备份任务")
}
if task.Interval == 0 {
// 单次执行的备份任务可重复创建
continue
}
dbNames = append(dbNames, task.DbName)
}
var res []string
err := db.Model(d.GetModel()).Select("db_name").
Where("db_instance_id = ?", instanceId).
Where("db_name in ?", dbNames).
Where("repeated = true").
Scopes(gormx.UndeleteScope).Find(&res).Error
if err != nil {
return err
}
if len(res) > 0 {
return fmt.Errorf("数据库备份任务已存在: %v", res)
}
return d.BatchInsertWithDb(ctx, db, tasks)
})
}
func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) { func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) {
var dbNamesWithBackup []string var dbNamesWithBackup []string
query := gormx.NewQuery(d.GetModel()). query := gormx.NewQuery(d.GetModel()).

View File

@@ -20,8 +20,8 @@ func NewDbBinlogRepo() repository.DbBinlog {
return &dbBinlogRepoImpl{} return &dbBinlogRepoImpl{}
} }
func (d *dbBinlogRepoImpl) AddTaskIfNotExists(_ context.Context, task *entity.DbBinlog) error { func (d *dbBinlogRepoImpl) AddJobIfNotExists(_ context.Context, job *entity.DbBinlog) error {
if err := global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(task).Error; err != nil { if err := global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(job).Error; err != nil {
return fmt.Errorf("启动 binlog 下载失败: %w", err) return fmt.Errorf("启动 binlog 下载失败: %w", err)
} }
return nil return nil

View File

@@ -94,7 +94,7 @@ func (repo *dbBinlogHistoryRepoImpl) InsertWithBinlogFiles(ctx context.Context,
if len(binlogFiles) == 0 { if len(binlogFiles) == 0 {
return nil return nil
} }
histories := make([]*entity.DbBinlogHistory, 0, len(binlogFiles)) histories := make([]any, 0, len(binlogFiles))
for _, fileOnServer := range binlogFiles { for _, fileOnServer := range binlogFiles {
if !fileOnServer.Downloaded { if !fileOnServer.Downloaded {
break break
@@ -115,7 +115,7 @@ func (repo *dbBinlogHistoryRepoImpl) InsertWithBinlogFiles(ctx context.Context,
} }
} }
if len(histories) > 0 { if len(histories) > 0 {
if err := repo.Upsert(ctx, histories[len(histories)-1]); err != nil { if err := repo.Upsert(ctx, histories[len(histories)-1].(*entity.DbBinlogHistory)); err != nil {
return err return err
} }
} }

View File

@@ -0,0 +1,127 @@
package persistence
import (
"context"
"errors"
"fmt"
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/base"
"mayfly-go/pkg/global"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"reflect"
)
type dbJobBase[T entity.DbJob] struct {
base.RepoImpl[T]
}
func (d *dbJobBase[T]) GetById(e entity.DbJob, id uint64, cols ...string) error {
return d.RepoImpl.GetById(e.(T), id, cols...)
}
func (d *dbJobBase[T]) UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error {
return d.RepoImpl.UpdateById(ctx, e.(T), columns...)
}
func (d *dbJobBase[T]) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error {
cond := map[string]any{
"id": jobId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}
func (d *dbJobBase[T]) UpdateLastStatus(ctx context.Context, job entity.DbJob) error {
return d.UpdateById(ctx, job.(T), "last_status", "last_result", "last_time")
}
func (d *dbJobBase[T]) ListToDo(jobs any) error {
db := global.Db.Model(d.GetModel())
err := db.Where("enabled = ?", true).
Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.DbJobSuccess)).
Scopes(gormx.UndeleteScope).
Find(jobs).Error
if err != nil {
return err
}
return nil
}
func (d *dbJobBase[T]) ListRepeating(jobs any) error {
cond := map[string]any{
"enabled": true,
"repeated": true,
}
if err := d.ListByCond(cond, jobs); err != nil {
return err
}
return nil
}
// GetPageList 分页获取数据库备份任务列表
func (d *dbJobBase[T]) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) {
d.GetModel()
qd := gormx.NewQuery(d.GetModel()).
Eq("id", condition.Id).
Eq0("db_instance_id", condition.DbInstanceId).
Eq0("repeated", condition.Repeated).
In0("db_name", condition.InDbNames).
Like("db_name", condition.DbName)
return gormx.PageQuery(qd, pageParam, toEntity)
}
func (d *dbJobBase[T]) AddJob(ctx context.Context, jobs any) error {
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
var dbNames []string
reflectValue := reflect.ValueOf(jobs)
var plural bool
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
plural = true
reflectLen := reflectValue.Len()
dbNames = make([]string, 0, reflectLen)
for i := 0; i < reflectLen; i++ {
job := reflectValue.Index(i).Interface().(entity.DbJob)
jobBase := job.GetJobBase()
if instanceId == 0 {
instanceId = jobBase.DbInstanceId
}
if jobBase.DbInstanceId != instanceId {
return errors.New("不支持同时为多个数据库实例添加数据库任务")
}
if jobBase.Interval == 0 {
// 单次执行的数据库任务可重复创建
continue
}
dbNames = append(dbNames, jobBase.DbName)
}
default:
jobBase := jobs.(entity.DbJob).GetJobBase()
instanceId = jobBase.DbInstanceId
if jobBase.Interval > 0 {
dbNames = append(dbNames, jobBase.DbName)
}
}
var res []string
err := db.Model(d.GetModel()).Select("db_name").
Where("db_instance_id = ?", instanceId).
Where("db_name in ?", dbNames).
Where("repeated = true").
Scopes(gormx.UndeleteScope).Find(&res).Error
if err != nil {
return err
}
if len(res) > 0 {
return errors.New(fmt.Sprintf("数据库任务已存在: %v", res))
}
if plural {
return d.BatchInsertWithDb(ctx, db, jobs)
}
return d.InsertWithDb(ctx, db, jobs.(T))
})
}

View File

@@ -1,73 +1,22 @@
package persistence package persistence
import ( import (
"context"
"errors"
"fmt"
"mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository" "mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/gormx" "mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices" "slices"
"gorm.io/gorm"
) )
var _ repository.DbRestore = (*dbRestoreRepoImpl)(nil) var _ repository.DbRestore = (*dbRestoreRepoImpl)(nil)
type dbRestoreRepoImpl struct { type dbRestoreRepoImpl struct {
dbTaskBase[*entity.DbRestore] dbJobBase[*entity.DbRestore]
} }
func NewDbRestoreRepo() repository.DbRestore { func NewDbRestoreRepo() repository.DbRestore {
return &dbRestoreRepoImpl{} return &dbRestoreRepoImpl{}
} }
// GetDbRestoreList 分页获取数据库备份任务列表
func (d *dbRestoreRepoImpl) GetDbRestoreList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) {
qd := gormx.NewQuery(d.GetModel()).
Eq("id", condition.Id).
Eq0("db_instance_id", condition.DbInstanceId).
Eq0("repeated", condition.Repeated).
In0("db_name", condition.InDbNames).
Like("db_name", condition.DbName)
return gormx.PageQuery(qd, pageParam, toEntity)
}
func (d *dbRestoreRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbRestore) error {
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
dbNames := make([]string, 0, len(tasks))
for _, task := range tasks {
if instanceId == 0 {
instanceId = task.DbInstanceId
}
if task.DbInstanceId != instanceId {
return errors.New("不支持同时为多个数据库实例添加备份任务")
}
if task.Interval == 0 {
// 单次执行的恢复任务可重复创建
continue
}
dbNames = append(dbNames, task.DbName)
}
var res []string
err := db.Model(new(entity.DbRestore)).Select("db_name").
Where("db_instance_id = ?", instanceId).
Where("db_name in ?", dbNames).
Where("repeated = true").
Scopes(gormx.UndeleteScope).Find(&res).Error
if err != nil {
return err
}
if len(res) > 0 {
return fmt.Errorf("数据库备份任务已存在: %v", res)
}
return d.BatchInsertWithDb(ctx, db, tasks)
})
}
func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) { func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) {
var dbNamesWithRestore []string var dbNamesWithRestore []string
query := gormx.NewQuery(d.GetModel()). query := gormx.NewQuery(d.GetModel()).

View File

@@ -1,52 +0,0 @@
package persistence
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/base"
"mayfly-go/pkg/global"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
)
type dbTaskBase[T model.ModelI] struct {
base.RepoImpl[T]
}
func (d *dbTaskBase[T]) UpdateEnabled(_ context.Context, taskId uint64, enabled bool) error {
cond := map[string]any{
"id": taskId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}
func (d *dbTaskBase[T]) UpdateTaskStatus(ctx context.Context, task T) error {
return d.UpdateById(ctx, task, "last_status", "last_result", "last_time")
}
func (d *dbTaskBase[T]) ListToDo() ([]T, error) {
var tasks []T
db := global.Db.Model(d.GetModel())
err := db.Where("enabled = ?", true).
Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.TaskSuccess)).
Scopes(gormx.UndeleteScope).
Find(&tasks).Error
if err != nil {
return nil, err
}
return tasks, nil
}
func (d *dbTaskBase[T]) ListRepeating() ([]T, error) {
cond := map[string]any{
"enabled": true,
"repeated": true,
}
var tasks []T
if err := d.ListByCond(cond, &tasks); err != nil {
return nil, err
}
return tasks, nil
}

View File

@@ -22,7 +22,7 @@ func InitDbBackupRouter(router *gin.RouterGroup) {
// 创建数据库备份任务 // 创建数据库备份任务
req.NewPost(":dbId/backups", d.Create).Log(req.NewLogSave("db-创建数据库备份任务")), req.NewPost(":dbId/backups", d.Create).Log(req.NewLogSave("db-创建数据库备份任务")),
// 保存数据库备份任务 // 保存数据库备份任务
req.NewPut(":dbId/backups/:backupId", d.Save).Log(req.NewLogSave("db-保存数据库备份任务")), req.NewPut(":dbId/backups/:backupId", d.Update).Log(req.NewLogSave("db-保存数据库备份任务")),
// 启用数据库备份任务 // 启用数据库备份任务
req.NewPut(":dbId/backups/:backupId/enable", d.Enable).Log(req.NewLogSave("db-启用数据库备份任务")), req.NewPut(":dbId/backups/:backupId/enable", d.Enable).Log(req.NewLogSave("db-启用数据库备份任务")),
// 禁用数据库备份任务 // 禁用数据库备份任务

View File

@@ -22,7 +22,7 @@ func InitDbRestoreRouter(router *gin.RouterGroup) {
// 创建数据库备份任务 // 创建数据库备份任务
req.NewPost(":dbId/restores", d.Create).Log(req.NewLogSave("db-创建数据库恢复任务")), req.NewPost(":dbId/restores", d.Create).Log(req.NewLogSave("db-创建数据库恢复任务")),
// 保存数据库备份任务 // 保存数据库备份任务
req.NewPut(":dbId/restores/:restoreId", d.Save).Log(req.NewLogSave("db-保存数据库恢复任务")), req.NewPut(":dbId/restores/:restoreId", d.Update).Log(req.NewLogSave("db-保存数据库恢复任务")),
// 启用数据库备份任务 // 启用数据库备份任务
req.NewPut(":dbId/restores/:restoreId/enable", d.Enable).Log(req.NewLogSave("db-启用数据库恢复任务")), req.NewPut(":dbId/restores/:restoreId/enable", d.Enable).Log(req.NewLogSave("db-启用数据库恢复任务")),
// 禁用数据库备份任务 // 禁用数据库备份任务

View File

@@ -62,7 +62,7 @@ func (m *roleRepoImpl) GetRoleResources(roleId uint64, toEntity any) {
} }
func (m *roleRepoImpl) SaveRoleResource(rr []*entity.RoleResource) { func (m *roleRepoImpl) SaveRoleResource(rr []*entity.RoleResource) {
gormx.BatchInsert(rr) gormx.BatchInsert[*entity.RoleResource](rr)
} }
func (m *roleRepoImpl) DeleteRoleResource(roleId uint64, resourceId uint64) { func (m *roleRepoImpl) DeleteRoleResource(roleId uint64, resourceId uint64) {

View File

@@ -103,8 +103,8 @@ func (ai *AppImpl[T, R]) BatchInsert(ctx context.Context, es []T) error {
} }
// 使用指定gorm db执行主要用于事务执行 // 使用指定gorm db执行主要用于事务执行
func (ai *AppImpl[T, R]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error { func (ai *AppImpl[T, R]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, models []T) error {
return ai.GetRepo().BatchInsertWithDb(ctx, db, es) return ai.GetRepo().BatchInsertWithDb(ctx, db, models)
} }
// 根据实体id更新实体信息 (单纯更新,不做其他业务逻辑处理) // 根据实体id更新实体信息 (单纯更新,不做其他业务逻辑处理)

View File

@@ -22,10 +22,10 @@ type Repo[T model.ModelI] interface {
InsertWithDb(ctx context.Context, db *gorm.DB, e T) error InsertWithDb(ctx context.Context, db *gorm.DB, e T) error
// 批量新增实体 // 批量新增实体
BatchInsert(ctx context.Context, models []T) error BatchInsert(ctx context.Context, models any) error
// 使用指定gorm db执行主要用于事务执行 // 使用指定gorm db执行主要用于事务执行
BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error BatchInsertWithDb(ctx context.Context, db *gorm.DB, models any) error
// 根据实体id更新实体信息 // 根据实体id更新实体信息
UpdateById(ctx context.Context, e T, columns ...string) error UpdateById(ctx context.Context, e T, columns ...string) error
@@ -93,23 +93,22 @@ func (br *RepoImpl[T]) InsertWithDb(ctx context.Context, db *gorm.DB, e T) error
return gormx.InsertWithDb(db, br.fillBaseInfo(ctx, e)) return gormx.InsertWithDb(db, br.fillBaseInfo(ctx, e))
} }
func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es []T) error { func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es any) error {
if db := contextx.GetDb(ctx); db != nil { if db := contextx.GetDb(ctx); db != nil {
return br.BatchInsertWithDb(ctx, db, es) return br.BatchInsertWithDb(ctx, db, es)
} }
for _, e := range es.([]T) {
for _, e := range es {
br.fillBaseInfo(ctx, e) br.fillBaseInfo(ctx, e)
} }
return gormx.BatchInsert(es) return gormx.BatchInsert[T](es)
} }
// 使用指定gorm db执行主要用于事务执行 // 使用指定gorm db执行主要用于事务执行
func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error { func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es any) error {
for _, e := range es { for _, e := range es.([]T) {
br.fillBaseInfo(ctx, e) br.fillBaseInfo(ctx, e)
} }
return gormx.BatchInsertWithDb(db, es) return gormx.BatchInsertWithDb[T](db, es)
} }
func (br *RepoImpl[T]) UpdateById(ctx context.Context, e T, columns ...string) error { func (br *RepoImpl[T]) UpdateById(ctx context.Context, e T, columns ...string) error {

View File

@@ -135,13 +135,13 @@ func InsertWithDb(db *gorm.DB, model any) error {
} }
// 批量插入 // 批量插入
func BatchInsert[T any](models []T) error { func BatchInsert[T any](models any) error {
return BatchInsertWithDb[T](global.Db, models) return BatchInsertWithDb[T](global.Db, models)
} }
// 批量插入 // 批量插入
func BatchInsertWithDb[T any](db *gorm.DB, models []T) error { func BatchInsertWithDb[T any](db *gorm.DB, models any) error {
return db.CreateInBatches(models, len(models)).Error return db.CreateInBatches(models, len(models.([]T))).Error
} }
// 根据id更新model更新字段为model中不为空的值即int类型不为0ptr类型不为nil这类字段值 // 根据id更新model更新字段为model中不为空的值即int类型不为0ptr类型不为nil这类字段值

View File

@@ -1,4 +1,4 @@
package queue package runner
import ( import (
"context" "context"
@@ -7,7 +7,7 @@ import (
"time" "time"
) )
const minTimerDelay = time.Millisecond const minTimerDelay = time.Millisecond * 1
const maxTimerDelay = time.Nanosecond * math.MaxInt64 const maxTimerDelay = time.Nanosecond * math.MaxInt64
type DelayQueue[T Delayable] struct { type DelayQueue[T Delayable] struct {
@@ -17,14 +17,12 @@ type DelayQueue[T Delayable] struct {
singleDequeue chan struct{} singleDequeue chan struct{}
mutex sync.Mutex mutex sync.Mutex
priorityQueue *PriorityQueue[T] priorityQueue *PriorityQueue[T]
elmMap map[uint64]T
zero T zero T
} }
type Delayable interface { type Delayable interface {
GetDeadline() time.Time GetDeadline() time.Time
GetId() uint64 GetKey() string
} }
func NewDelayQueue[T Delayable](cap int) *DelayQueue[T] { func NewDelayQueue[T Delayable](cap int) *DelayQueue[T] {
@@ -35,7 +33,6 @@ func NewDelayQueue[T Delayable](cap int) *DelayQueue[T] {
dequeuedSignal: make(chan struct{}), dequeuedSignal: make(chan struct{}),
transferChan: make(chan T), transferChan: make(chan T),
singleDequeue: singleDequeue, singleDequeue: singleDequeue,
elmMap: make(map[uint64]T, 64),
priorityQueue: NewPriorityQueue[T](cap, func(src T, dst T) bool { priorityQueue: NewPriorityQueue[T](cap, func(src T, dst T) bool {
return src.GetDeadline().Before(dst.GetDeadline()) return src.GetDeadline().Before(dst.GetDeadline())
}), }),
@@ -135,7 +132,6 @@ func (s *DelayQueue[T]) dequeue() (T, bool) {
if !ok { if !ok {
return s.zero, false return s.zero, false
} }
delete(s.elmMap, elm.GetId())
select { select {
case s.dequeuedSignal <- struct{}{}: case s.dequeuedSignal <- struct{}{}:
default: default:
@@ -147,7 +143,6 @@ func (s *DelayQueue[T]) enqueue(val T) bool {
if ok := s.priorityQueue.Enqueue(val); !ok { if ok := s.priorityQueue.Enqueue(val); !ok {
return false return false
} }
s.elmMap[val.GetId()] = val
select { select {
case s.enqueuedSignal <- struct{}{}: case s.enqueuedSignal <- struct{}{}:
default: default:
@@ -169,10 +164,6 @@ func (s *DelayQueue[T]) Enqueue(ctx context.Context, val T) bool {
for { for {
// 全局锁:避免入队和出队信号的重置与激活出现并发问题 // 全局锁:避免入队和出队信号的重置与激活出现并发问题
s.mutex.Lock() s.mutex.Lock()
if _, ok := s.elmMap[val.GetId()]; ok {
s.mutex.Unlock()
return false
}
if ctx.Err() != nil { if ctx.Err() != nil {
s.mutex.Unlock() s.mutex.Unlock()
@@ -220,24 +211,20 @@ func (s *DelayQueue[T]) Enqueue(ctx context.Context, val T) bool {
} }
} }
func (s *DelayQueue[T]) Remove(_ context.Context, elmId uint64) (T, bool) { func (s *DelayQueue[T]) Remove(_ context.Context, key string) (T, bool) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if _, ok := s.elmMap[elmId]; ok { return s.priorityQueue.Remove(s.index(key))
delete(s.elmMap, elmId)
return s.priorityQueue.Remove(s.index(elmId))
}
return s.zero, false
} }
func (s *DelayQueue[T]) index(elmId uint64) int { func (s *DelayQueue[T]) index(key string) int {
for i := 0; i < s.priorityQueue.Len(); i++ { for i := 0; i < s.priorityQueue.Len(); i++ {
elm, ok := s.priorityQueue.Peek(i) elm, ok := s.priorityQueue.Peek(i)
if !ok { if !ok {
continue continue
} }
if elmId == elm.GetId() { if key == elm.GetKey() {
return i return i
} }
} }

View File

@@ -1,4 +1,4 @@
package queue package runner
import ( import (
"context" "context"
@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"math/rand" "math/rand"
"runtime" "runtime"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -28,6 +29,10 @@ func (elm *delayElement) GetId() uint64 {
return elm.id return elm.id
} }
func (elm *delayElement) GetKey() string {
return strconv.FormatUint(elm.id, 16)
}
type testDelayQueue = DelayQueue[*delayElement] type testDelayQueue = DelayQueue[*delayElement]
func newTestDelayQueue(cap int) *testDelayQueue { func newTestDelayQueue(cap int) *testDelayQueue {
@@ -42,7 +47,6 @@ func mustEnqueue(val int, delay int64) func(t *testing.T, queue *testDelayQueue)
} }
func newTestElm(value int, delay int64) *delayElement { func newTestElm(value int, delay int64) *delayElement {
return &delayElement{ return &delayElement{
id: elmId.Add(1), id: elmId.Add(1),
value: value, value: value,

View File

@@ -1,4 +1,4 @@
package queue package runner
//var ( //var (
// false = errors.New("queue: 队列已满") // false = errors.New("queue: 队列已满")

View File

@@ -1,4 +1,4 @@
package queue package runner
import ( import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"

337
server/pkg/runner/runner.go Normal file
View File

@@ -0,0 +1,337 @@
package runner
import (
"context"
"fmt"
"github.com/emirpasic/gods/maps/linkedhashmap"
"mayfly-go/pkg/logx"
"sync"
"time"
)
type JobKey = string
type RunFunc func(ctx context.Context)
type NextFunc func() (Job, bool)
type RunnableFunc func(next NextFunc) bool
type JobStatus int
const (
JobUnknown JobStatus = iota
JobDelay
JobWaiting
JobRunning
JobRemoved
)
type Job interface {
GetKey() JobKey
GetStatus() JobStatus
SetStatus(status JobStatus)
Run(ctx context.Context)
Runnable(next NextFunc) bool
GetDeadline() time.Time
Schedule() bool
Renew(job Job)
}
type iterator[T Job] struct {
index int
data []T
zero T
}
func (iter *iterator[T]) Begin() {
iter.index = -1
}
func (iter *iterator[T]) Next() (T, bool) {
if iter.index >= len(iter.data)-1 {
return iter.zero, false
}
iter.index++
return iter.data[iter.index], true
}
type array[T Job] struct {
size int
data []T
zero T
}
func newArray[T Job](size int) *array[T] {
return &array[T]{
size: size,
data: make([]T, 0, size),
}
}
func (a *array[T]) Iterator() *iterator[T] {
return &iterator[T]{
index: -1,
data: a.data,
}
}
func (a *array[T]) Full() bool {
return len(a.data) >= a.size
}
func (a *array[T]) Append(job T) bool {
if len(a.data) >= a.size {
return false
}
a.data = append(a.data, job)
return true
}
func (a *array[T]) Get(key JobKey) (T, bool) {
for _, job := range a.data {
if key == job.GetKey() {
return job, true
}
}
return a.zero, false
}
func (a *array[T]) Remove(key JobKey) {
length := len(a.data)
for i, elm := range a.data {
if key == elm.GetKey() {
a.data[i], a.data[length-1] = a.data[length-1], a.zero
a.data = a.data[:length-1]
return
}
}
}
type Runner[T Job] struct {
maxRunning int
waiting *linkedhashmap.Map
running *array[T]
runnable func(job T, iterateRunning func() (T, bool)) bool
mutex sync.Mutex
wg sync.WaitGroup
context context.Context
cancel context.CancelFunc
zero T
signal chan struct{}
all map[string]T
delayQueue *DelayQueue[T]
}
func NewRunner[T Job](maxRunning int) *Runner[T] {
ctx, cancel := context.WithCancel(context.Background())
runner := &Runner[T]{
maxRunning: maxRunning,
all: make(map[string]T, maxRunning),
waiting: linkedhashmap.New(),
running: newArray[T](maxRunning),
context: ctx,
cancel: cancel,
signal: make(chan struct{}, 1),
delayQueue: NewDelayQueue[T](0),
}
runner.wg.Add(maxRunning + 1)
for i := 0; i < maxRunning; i++ {
go runner.run()
}
go func() {
defer runner.wg.Done()
for runner.context.Err() == nil {
job, ok := runner.delayQueue.Dequeue(ctx)
if !ok {
continue
}
runner.mutex.Lock()
runner.waiting.Put(job.GetKey(), job)
job.SetStatus(JobWaiting)
runner.trigger()
runner.mutex.Unlock()
}
}()
return runner
}
func (r *Runner[T]) Close() {
r.cancel()
r.wg.Wait()
}
func (r *Runner[T]) run() {
defer r.wg.Done()
for r.context.Err() == nil {
select {
case <-r.signal:
job, ok := r.pickRunnable()
if !ok {
continue
}
r.doRun(job)
r.afterRun(job)
case <-r.context.Done():
}
}
}
func (r *Runner[T]) doRun(job T) {
defer func() {
if err := recover(); err != nil {
logx.Error(fmt.Sprintf("failed to run job: %v", err))
}
}()
job.Run(r.context)
}
func (r *Runner[T]) afterRun(job T) {
r.mutex.Lock()
defer r.mutex.Unlock()
key := job.GetKey()
r.running.Remove(key)
r.trigger()
switch job.GetStatus() {
case JobRunning:
r.schedule(r.context, job)
case JobRemoved:
delete(r.all, key)
default:
panic(fmt.Sprintf("invalid job status %v occurred after run", job.GetStatus()))
}
}
func (r *Runner[T]) pickRunnable() (T, bool) {
r.mutex.Lock()
defer r.mutex.Unlock()
iter := r.running.Iterator()
var runnable T
ok := r.waiting.Any(func(key interface{}, value interface{}) bool {
job := value.(T)
iter.Begin()
if job.Runnable(func() (Job, bool) { return iter.Next() }) {
if r.running.Full() {
return false
}
r.waiting.Remove(key)
r.running.Append(job)
job.SetStatus(JobRunning)
if !r.running.Full() && !r.waiting.Empty() {
r.trigger()
}
runnable = job
return true
}
return false
})
if !ok {
return r.zero, false
}
return runnable, true
}
func (r *Runner[T]) schedule(ctx context.Context, job T) {
if !job.Schedule() {
delete(r.all, job.GetKey())
job.SetStatus(JobRemoved)
return
}
r.delayQueue.Enqueue(ctx, job)
job.SetStatus(JobDelay)
}
//func (r *Runner[T]) Schedule(ctx context.Context, job T) {
// r.mutex.Lock()
// defer r.mutex.Unlock()
//
// switch job.GetStatus() {
// case JobUnknown:
// case JobDelay:
// r.delayQueue.Remove(ctx, job.GetKey())
// case JobWaiting:
// r.waiting.Remove(job)
// case JobRunning:
// // 标记为 removed, 任务执行完成后再删除
// return
// case JobRemoved:
// return
// }
// r.schedule(ctx, job)
//}
func (r *Runner[T]) Add(ctx context.Context, job T) error {
r.mutex.Lock()
defer r.mutex.Unlock()
if _, ok := r.all[job.GetKey()]; ok {
return nil
}
r.schedule(ctx, job)
return nil
}
func (r *Runner[T]) UpdateOrAdd(ctx context.Context, job T) error {
r.mutex.Lock()
defer r.mutex.Unlock()
if old, ok := r.all[job.GetKey()]; ok {
old.Renew(job)
job = old
}
r.schedule(ctx, job)
return nil
}
func (r *Runner[T]) StartNow(ctx context.Context, job T) error {
r.mutex.Lock()
defer r.mutex.Unlock()
key := job.GetKey()
if old, ok := r.all[key]; ok {
job = old
if job.GetStatus() == JobDelay {
r.delayQueue.Remove(ctx, key)
r.waiting.Put(key, job)
r.trigger()
}
return nil
}
r.all[key] = job
r.waiting.Put(key, job)
r.trigger()
return nil
}
func (r *Runner[T]) trigger() {
select {
case r.signal <- struct{}{}:
default:
}
}
func (r *Runner[T]) Remove(ctx context.Context, key JobKey) error {
r.mutex.Lock()
defer r.mutex.Unlock()
job, ok := r.all[key]
if !ok {
return nil
}
switch job.GetStatus() {
case JobUnknown:
panic(fmt.Sprintf("invalid job status %v occurred after added", job.GetStatus()))
case JobDelay:
r.delayQueue.Remove(ctx, key)
case JobWaiting:
r.waiting.Remove(key)
case JobRunning:
// 标记为 removed, 任务执行完成后再删除
case JobRemoved:
return nil
}
delete(r.all, key)
job.SetStatus(JobRemoved)
return nil
}

View File

@@ -0,0 +1,148 @@
package runner
import (
"context"
"github.com/stretchr/testify/require"
"mayfly-go/pkg/utils/timex"
"sync"
"testing"
"time"
)
var _ Job = &testJob{}
func newTestJob(key string, runTime time.Duration) *testJob {
return &testJob{
deadline: time.Now(),
Key: key,
run: func(ctx context.Context) {
timex.SleepWithContext(ctx, runTime)
},
}
}
type testJob struct {
run RunFunc
Key JobKey
status JobStatus
ran bool
deadline time.Time
}
func (t *testJob) Renew(job Job) {
}
func (t *testJob) GetDeadline() time.Time {
return t.deadline
}
func (t *testJob) Schedule() bool {
return !t.ran
}
func (t *testJob) Run(ctx context.Context) {
if t.run == nil {
return
}
t.run(ctx)
t.ran = true
}
func (t *testJob) Runnable(_ NextFunc) bool {
return true
}
func (t *testJob) GetKey() JobKey {
return t.Key
}
func (t *testJob) GetStatus() JobStatus {
return t.status
}
func (t *testJob) SetStatus(status JobStatus) {
t.status = status
}
func TestRunner_Close(t *testing.T) {
runner := NewRunner[*testJob](1)
signal := make(chan struct{}, 1)
waiting := sync.WaitGroup{}
waiting.Add(1)
go func() {
job := &testJob{
Key: "close",
run: func(ctx context.Context) {
waiting.Done()
timex.SleepWithContext(ctx, time.Hour)
signal <- struct{}{}
},
}
_ = runner.Add(context.Background(), job)
}()
waiting.Wait()
timer := time.NewTimer(time.Microsecond * 10)
runner.Close()
select {
case <-timer.C:
require.FailNow(t, "runner 未能及时退出")
case <-signal:
}
}
func TestRunner_AddJob(t *testing.T) {
type testCase struct {
name string
job *testJob
want bool
}
testCases := []testCase{
{
name: "first job",
job: newTestJob("single", time.Hour),
want: true,
},
{
name: "second job",
job: newTestJob("dual", time.Hour),
want: true,
},
{
name: "non repetitive job",
job: newTestJob("single", time.Hour),
want: true,
},
{
name: "repetitive job",
job: newTestJob("dual", time.Hour),
want: true,
},
}
runner := NewRunner[*testJob](1)
defer runner.Close()
for _, tc := range testCases {
err := runner.Add(context.Background(), tc.job)
require.NoError(t, err)
}
}
func TestJob_UpdateStatus(t *testing.T) {
const d = time.Millisecond * 20
runner := NewRunner[*testJob](1)
running := newTestJob("running", d*2)
waiting := newTestJob("waiting", d*2)
_ = runner.Add(context.Background(), running)
_ = runner.Add(context.Background(), waiting)
time.Sleep(d)
require.Equal(t, JobRunning, running.status)
require.Equal(t, JobWaiting, waiting.status)
time.Sleep(d * 2)
require.Equal(t, JobRemoved, running.status)
require.Equal(t, JobRunning, waiting.status)
time.Sleep(d * 2)
require.Equal(t, JobRemoved, running.status)
require.Equal(t, JobRemoved, waiting.status)
}

View File

@@ -41,6 +41,7 @@ func runWebServer(ctx context.Context) {
if err != nil { if err != nil {
logx.Errorf("Failed to Shutdown HTTP Server: %v", err) logx.Errorf("Failed to Shutdown HTTP Server: %v", err)
} }
initialize.Terminate() initialize.Terminate()
}() }()

View File

@@ -272,7 +272,6 @@ func pkcs7UnPadding(data []byte) ([]byte, error) {
} }
//获取填充的个数 //获取填充的个数
unPadding := int(data[length-1]) unPadding := int(data[length-1])
// todo fix: slice bounds out of range
if unPadding > length { if unPadding > length {
return nil, errors.New("解密字符串时去除填充个数超出字符串长度") return nil, errors.New("解密字符串时去除填充个数超出字符串长度")
} }