diff --git a/mayfly_go_web/src/views/ops/db/DbBackupEdit.vue b/mayfly_go_web/src/views/ops/db/DbBackupEdit.vue index 2b023a6c..6fad3187 100644 --- a/mayfly_go_web/src/views/ops/db/DbBackupEdit.vue +++ b/mayfly_go_web/src/views/ops/db/DbBackupEdit.vue @@ -4,18 +4,21 @@ - + + @@ -41,9 +44,10 @@ diff --git a/mayfly_go_web/src/views/ops/db/DbEdit.vue b/mayfly_go_web/src/views/ops/db/DbEdit.vue index e98e0c25..dbe9fdec 100644 --- a/mayfly_go_web/src/views/ops/db/DbEdit.vue +++ b/mayfly_go_web/src/views/ops/db/DbEdit.vue @@ -52,20 +52,23 @@ - + - + + @@ -90,6 +93,7 @@ import { dbApi } from './api'; import { ElMessage } from 'element-plus'; import TagTreeSelect from '../component/TagTreeSelect.vue'; import { TagResourceTypeEnum } from '@/common/commonEnum'; +import type { CheckboxValueType } from 'element-plus'; const props = defineProps({ visible: { @@ -139,13 +143,18 @@ const rules = { ], }; +const checkAllDbNames = ref(false); +const indeterminateDbNames = ref(false); + const dbForm: any = ref(null); const tagSelectRef: any = ref(null); const state = reactive({ dialogVisible: false, allDatabases: [] as any, - databaseList: [] as any, + dbNamesSelected: [] as any, + dbNamesFiltered: [] as any, + filterString: '', form: { id: null, tagId: [], @@ -158,7 +167,7 @@ const state = reactive({ instances: [] as any, }); -const { dialogVisible, allDatabases, form, databaseList } = toRefs(state); +const { dialogVisible, allDatabases, form, dbNamesSelected } = toRefs(state); const { isFetching: saveBtnLoading, execute: saveDbExec } = dbApi.saveDb.useApi(form); @@ -171,25 +180,18 @@ watch(props, async (newValue: any) => { state.form = { ...newValue.db }; // 将数据库名使用空格切割,获取所有数据库列表 - state.databaseList = newValue.db.database.split(' '); + state.dbNamesSelected = newValue.db.database.split(' '); } else { state.form = {} as any; - state.databaseList = []; + state.dbNamesSelected = []; } }); const changeInstance = () => { - state.databaseList = []; + state.dbNamesSelected = []; getAllDatabase(); }; -/** - * 改变表单中的数据库字段,方便表单错误提示。如全部删光,可提示请添加数据库 - */ -const changeDatabase = () => { - state.form.database = state.databaseList.length == 0 ? '' : state.databaseList.join(' '); -}; - const getAllDatabase = async () => { if (state.form.instanceId > 0) { state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId }); @@ -210,7 +212,7 @@ const getInstances = async (instanceName: string = '', id = 0) => { const open = async () => { if (state.form.instanceId) { // 根据id获取,因为需要回显实例名称 - getInstances('', state.form.instanceId); + await getInstances('', state.form.instanceId); } await getAllDatabase(); }; @@ -230,7 +232,7 @@ const btnOk = async () => { }; const resetInputDb = () => { - state.databaseList = []; + state.dbNamesSelected = []; state.allDatabases = []; state.instances = []; }; @@ -242,5 +244,62 @@ const cancel = () => { resetInputDb(); }, 500); }; + +const checkDbSelect = (val: string[]) => { + const selected = val.filter((dbName: string) => { + return dbName.includes(state.filterString); + }); + if (selected.length === 0) { + checkAllDbNames.value = false; + indeterminateDbNames.value = false; + return; + } + if (selected.length === state.dbNamesFiltered.length) { + checkAllDbNames.value = true; + indeterminateDbNames.value = false; + return; + } + indeterminateDbNames.value = true; +}; + +watch(dbNamesSelected, (val: string[]) => { + checkDbSelect(val); + state.form.database = val.join(' '); +}); + +watch(allDatabases, (val: string[]) => { + state.dbNamesFiltered = val.map((dbName: string) => dbName); +}); + +const handleCheckAll = (val: CheckboxValueType) => { + const otherSelected = state.dbNamesSelected.filter((dbName: string) => { + return !state.dbNamesFiltered.includes(dbName); + }); + if (val) { + state.dbNamesSelected = otherSelected.concat(state.dbNamesFiltered); + } else { + state.dbNamesSelected = otherSelected; + } +}; + +const filterDbNames = (filterString: string) => { + const dbNamesCreated = state.dbNamesSelected.filter((dbName: string) => { + return !state.allDatabases.includes(dbName); + }); + if (filterString.length === 0) { + state.dbNamesFiltered = dbNamesCreated.concat(state.allDatabases); + checkDbSelect(state.dbNamesSelected); + return; + } + state.dbNamesFiltered = dbNamesCreated.concat(state.allDatabases).filter((dbName: string) => { + if (dbName == filterString) { + return false; + } + return dbName.includes(filterString); + }); + state.dbNamesFiltered.unshift(filterString); + state.filterString = filterString; + checkDbSelect(state.dbNamesSelected); +}; diff --git a/mayfly_go_web/src/views/ops/db/DbList.vue b/mayfly_go_web/src/views/ops/db/DbList.vue index 97c6f751..157edb3e 100644 --- a/mayfly_go_web/src/views/ops/db/DbList.vue +++ b/mayfly_go_web/src/views/ops/db/DbList.vue @@ -198,11 +198,11 @@ const perms = { const searchItems = [getTagPathSearchItem(TagResourceTypeEnum.Db.value), SearchItem.slot('instanceId', '实例', 'instanceSelect')]; const columns = ref([ - TableColumn.new('instanceName', '实例名'), + TableColumn.new('name', '名称'), TableColumn.new('type', '类型').isSlot().setAddWidth(-15).alignCenter(), + TableColumn.new('instanceName', '实例名'), TableColumn.new('host', 'ip:port').isSlot().setAddWidth(40), TableColumn.new('username', 'username'), - TableColumn.new('name', '名称'), TableColumn.new('tagPath', '关联标签').isSlot().setAddWidth(10).alignCenter(), TableColumn.new('remark', '备注'), ]); diff --git a/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue b/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue index 501bf3ea..705d67be 100644 --- a/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue +++ b/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue @@ -155,7 +155,7 @@ const state = reactive({ pointInTime: null as any, }, btnLoading: false, - selectedDbNames: [] as any, + dbNamesSelected: [] as any, dbNamesWithoutRestore: [] as any, editOrCreate: false, histories: [] as any, @@ -189,12 +189,12 @@ const changeHistory = async () => { }; const init = async (data: any) => { - state.selectedDbNames = []; + state.dbNamesSelected = []; state.form.dbId = props.dbId; if (data) { state.editOrCreate = true; state.dbNamesWithoutRestore = [data.dbName]; - state.selectedDbNames = [data.dbName]; + state.dbNamesSelected = [data.dbName]; state.form.id = data.id; state.form.dbName = data.dbName; state.form.intervalDay = data.intervalDay; diff --git a/server/internal/db/api/db_backup.go b/server/internal/db/api/db_backup.go index ca57fb03..0c4d8cc2 100644 --- a/server/internal/db/api/db_backup.go +++ b/server/internal/db/api/db_backup.go @@ -54,9 +54,14 @@ func (d *DbBackup) Create(rc *req.Ctx) { jobs := make([]*entity.DbBackup, 0, len(dbNames)) for _, dbName := range dbNames { job := &entity.DbBackup{ - DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, dbName, entity.DbJobTypeBackup, true, backupForm.Repeated, backupForm.StartTime, backupForm.Interval), + DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, entity.DbJobTypeBackup), + Enabled: true, + Repeated: backupForm.Repeated, + StartTime: backupForm.StartTime, + Interval: backupForm.Interval, Name: backupForm.Name, } + job.DbName = dbName jobs = append(jobs, job) } biz.ErrIsNilAppendErr(d.DbBackupApp.Create(rc.MetaCtx, jobs), "添加数据库备份任务失败: %v") diff --git a/server/internal/db/api/db_restore.go b/server/internal/db/api/db_restore.go index b202983a..060d6e48 100644 --- a/server/internal/db/api/db_restore.go +++ b/server/internal/db/api/db_restore.go @@ -48,12 +48,17 @@ func (d *DbRestore) Create(rc *req.Ctx) { biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v") job := &entity.DbRestore{ - DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, restoreForm.DbName, entity.DbJobTypeRestore, true, restoreForm.Repeated, restoreForm.StartTime, restoreForm.Interval), + DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, entity.DbJobTypeRestore), + Enabled: true, + Repeated: restoreForm.Repeated, + StartTime: restoreForm.StartTime, + Interval: restoreForm.Interval, PointInTime: restoreForm.PointInTime, DbBackupId: restoreForm.DbBackupId, DbBackupHistoryId: restoreForm.DbBackupHistoryId, DbBackupHistoryName: restoreForm.DbBackupHistoryName, } + job.DbName = restoreForm.DbName biz.ErrIsNilAppendErr(d.DbRestoreApp.Create(rc.MetaCtx, job), "添加数据库恢复任务失败: %v") } diff --git a/server/internal/db/application/application.go b/server/internal/db/application/application.go index 94afbae0..58a15de8 100644 --- a/server/internal/db/application/application.go +++ b/server/internal/db/application/application.go @@ -50,7 +50,7 @@ func Init() { if err != nil { panic(fmt.Sprintf("初始化 dbRestoreApp 失败: %v", err)) } - dbBinlogApp, err = newDbBinlogApp(repositories, dbApp) + dbBinlogApp, err = newDbBinlogApp(repositories, dbApp, scheduler) if err != nil { panic(fmt.Sprintf("初始化 dbBinlogApp 失败: %v", err)) } diff --git a/server/internal/db/application/db_binlog.go b/server/internal/db/application/db_binlog.go index 33cbd644..c326d76b 100644 --- a/server/internal/db/application/db_binlog.go +++ b/server/internal/db/application/db_binlog.go @@ -2,20 +2,14 @@ package application import ( "context" - "fmt" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/logx" - "mayfly-go/pkg/utils/stringx" "mayfly-go/pkg/utils/timex" "sync" "time" ) -const ( - binlogDownloadInterval = time.Minute * 15 -) - type DbBinlogApp struct { binlogRepo repository.DbBinlog binlogHistoryRepo repository.DbBinlogHistory @@ -25,19 +19,10 @@ type DbBinlogApp struct { context context.Context cancel context.CancelFunc waitGroup sync.WaitGroup + scheduler *dbScheduler } -var ( - binlogResult = map[entity.DbJobStatus]string{ - entity.DbJobDelay: "等待备份BINLOG", - entity.DbJobReady: "准备备份BINLOG", - entity.DbJobRunning: "BINLOG备份中", - entity.DbJobSuccess: "BINLOG备份成功", - entity.DbJobFailed: "BINLOG备份失败", - } -) - -func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogApp, error) { +func newDbBinlogApp(repositories *repository.Repositories, dbApp Db, scheduler *dbScheduler) (*DbBinlogApp, error) { ctx, cancel := context.WithCancel(context.Background()) svc := &DbBinlogApp{ binlogRepo: repositories.Binlog, @@ -45,6 +30,7 @@ func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogA backupRepo: repositories.Backup, backupHistoryRepo: repositories.BackupHistory, dbApp: dbApp, + scheduler: scheduler, context: ctx, cancel: cancel, } @@ -53,73 +39,47 @@ func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogA return svc, nil } -func (app *DbBinlogApp) fetchBinlog(ctx context.Context, backup *entity.DbBackup) error { - if err := app.AddJobIfNotExists(ctx, entity.NewDbBinlog(backup.DbInstanceId)); err != nil { - return err - } - latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1) - binlogHistory, ok, err := app.binlogHistoryRepo.GetLatestHistory(backup.DbInstanceId) - if err != nil { - return err - } - if ok { - latestBinlogSequence = binlogHistory.Sequence - } else { - backupHistory, ok, err := app.backupHistoryRepo.GetEarliestHistory(backup.DbInstanceId) - if err != nil { - return err - } - if !ok { - return nil - } - earliestBackupSequence = backupHistory.BinlogSequence - } - conn, err := app.dbApp.GetDbConnByInstanceId(backup.DbInstanceId) - if err != nil { - return err - } - dbProgram := conn.GetDialect().GetDbProgram() - binlogFiles, err := dbProgram.FetchBinlogs(ctx, false, earliestBackupSequence, latestBinlogSequence) - if err == nil { - err = app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, backup.DbInstanceId, binlogFiles) - } - jobStatus := entity.DbJobSuccess - if err != nil { - jobStatus = entity.DbJobFailed - } - job := &entity.DbBinlog{} - job.Id = backup.DbInstanceId - return app.updateCurJob(ctx, jobStatus, err, job) -} - func (app *DbBinlogApp) run() { defer app.waitGroup.Done() // todo: 实现 binlog 并发下载 timex.SleepWithContext(app.context, time.Minute) for !app.closed() { - app.fetchFromAllInstances() - timex.SleepWithContext(app.context, binlogDownloadInterval) - } -} - -func (app *DbBinlogApp) fetchFromAllInstances() { - var backups []*entity.DbBackup - if err := app.backupRepo.ListRepeating(&backups); err != nil { - logx.Errorf("DbBinlogApp: 获取数据库备份任务失败: %s", err.Error()) - return - } - for _, backup := range backups { + jobs, err := app.loadJobs() + if err != nil { + logx.Errorf("DbBinlogApp: 加载 BINLOG 同步任务失败: %s", err.Error()) + timex.SleepWithContext(app.context, time.Minute) + continue + } if app.closed() { break } - if err := app.fetchBinlog(app.context, backup); err != nil { - logx.Errorf("DbBinlogApp: 下载 binlog 文件失败: %s", err.Error()) - return + if err := app.scheduler.AddJob(app.context, false, entity.DbJobTypeBinlog, jobs); err != nil { + logx.Error("DbBinlogApp: 添加 BINLOG 同步任务失败: ", err.Error()) } + timex.SleepWithContext(app.context, entity.BinlogDownloadInterval) } } +func (app *DbBinlogApp) loadJobs() ([]*entity.DbBinlog, error) { + var instanceIds []uint64 + if err := app.backupRepo.ListDbInstances(true, true, &instanceIds); err != nil { + return nil, err + } + jobs := make([]*entity.DbBinlog, 0, len(instanceIds)) + for _, id := range instanceIds { + if app.closed() { + break + } + binlog := entity.NewDbBinlog(id) + if err := app.AddJobIfNotExists(app.context, binlog); err != nil { + return nil, err + } + jobs = append(jobs, binlog) + } + return jobs, nil +} + func (app *DbBinlogApp) Close() { app.cancel() app.waitGroup.Wait() @@ -146,14 +106,3 @@ func (app *DbBinlogApp) Delete(ctx context.Context, jobId uint64) error { } return nil } - -func (app *DbBinlogApp) updateCurJob(ctx context.Context, status entity.DbJobStatus, lastErr error, job *entity.DbBinlog) error { - job.LastStatus = status - var result = binlogResult[status] - if lastErr != nil { - result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr) - } - job.LastResult = stringx.TruncateStr(result, entity.LastResultSize) - job.LastTime = timex.NewNullTime(time.Now()) - return app.binlogRepo.UpdateById(ctx, job, "last_status", "last_result", "last_time") -} diff --git a/server/internal/db/application/db_scheduler.go b/server/internal/db/application/db_scheduler.go index 9a40e16b..8f3625e6 100644 --- a/server/internal/db/application/db_scheduler.go +++ b/server/internal/db/application/db_scheduler.go @@ -26,28 +26,40 @@ type dbScheduler struct { backupHistoryRepo repository.DbBackupHistory restoreRepo repository.DbRestore restoreHistoryRepo repository.DbRestoreHistory + binlogRepo repository.DbBinlog binlogHistoryRepo repository.DbBinlogHistory + binlogTimes map[uint64]time.Time } func newDbScheduler(repositories *repository.Repositories) (*dbScheduler, error) { scheduler := &dbScheduler{ - runner: runner.NewRunner[entity.DbJob](maxRunning), dbApp: dbApp, backupRepo: repositories.Backup, backupHistoryRepo: repositories.BackupHistory, restoreRepo: repositories.Restore, restoreHistoryRepo: repositories.RestoreHistory, + binlogRepo: repositories.Binlog, binlogHistoryRepo: repositories.BinlogHistory, } + scheduler.runner = runner.NewRunner[entity.DbJob](maxRunning, scheduler.runJob, + runner.WithScheduleJob[entity.DbJob](scheduler.scheduleJob), + runner.WithRunnableJob[entity.DbJob](scheduler.runnableJob), + ) return scheduler, nil } +func (s *dbScheduler) scheduleJob(job entity.DbJob) (time.Time, error) { + return job.Schedule() +} + func (s *dbScheduler) repo(typ entity.DbJobType) repository.DbJob { switch typ { case entity.DbJobTypeBackup: return s.backupRepo case entity.DbJobTypeRestore: return s.restoreRepo + case entity.DbJobTypeBinlog: + return s.binlogRepo default: panic(errors.New(fmt.Sprintf("无效的数据库任务类型: %v", typ))) } @@ -60,8 +72,6 @@ func (s *dbScheduler) UpdateJob(ctx context.Context, job entity.DbJob) error { if err := s.repo(job.GetJobType()).UpdateById(ctx, job); err != nil { return err } - job.SetRun(s.run) - job.SetRunnable(s.runnable) _ = s.runner.UpdateOrAdd(ctx, job) return nil } @@ -87,21 +97,11 @@ func (s *dbScheduler) AddJob(ctx context.Context, saving bool, jobType entity.Db for i := 0; i < reflectLen; i++ { job := reflectValue.Index(i).Interface().(entity.DbJob) job.SetJobType(jobType) - if !job.Schedule() { - continue - } - job.SetRun(s.run) - job.SetRunnable(s.runnable) _ = s.runner.Add(ctx, job) } default: job := jobs.(entity.DbJob) job.SetJobType(jobType) - if !job.Schedule() { - return nil - } - job.SetRun(s.run) - job.SetRunnable(s.runnable) _ = s.runner.Add(ctx, job) } return nil @@ -131,12 +131,10 @@ func (s *dbScheduler) EnableJob(ctx context.Context, jobType entity.DbJobType, j if job.IsEnabled() { return nil } - job.GetJobBase().Enabled = true + job.SetEnabled(true) if err := repo.UpdateEnabled(ctx, jobId, true); err != nil { return err } - job.SetRun(s.run) - job.SetRunnable(s.runnable) _ = s.runner.Add(ctx, job) return nil } @@ -171,9 +169,6 @@ func (s *dbScheduler) StartJobNow(ctx context.Context, jobType entity.DbJobType, if !job.IsEnabled() { return errors.New("任务未启用") } - job.GetJobBase().Deadline = time.Now() - job.SetRun(s.run) - job.SetRunnable(s.runnable) _ = s.runner.StartNow(ctx, job) return nil } @@ -267,7 +262,7 @@ func (s *dbScheduler) restoreMysql(ctx context.Context, job entity.DbJob) error return nil } -func (s *dbScheduler) run(ctx context.Context, job entity.DbJob) { +func (s *dbScheduler) runJob(ctx context.Context, job entity.DbJob) { job.SetLastStatus(entity.DbJobRunning, nil) if err := s.repo(job.GetJobType()).UpdateLastStatus(ctx, job); err != nil { logx.Errorf("failed to update job status: %v", err) @@ -280,6 +275,8 @@ func (s *dbScheduler) run(ctx context.Context, job entity.DbJob) { errRun = s.backupMysql(ctx, job) case entity.DbJobTypeRestore: errRun = s.restoreMysql(ctx, job) + case entity.DbJobTypeBinlog: + errRun = s.fetchBinlogMysql(ctx, job) default: errRun = errors.New(fmt.Sprintf("无效的数据库任务类型: %v", typ)) } @@ -294,19 +291,27 @@ func (s *dbScheduler) run(ctx context.Context, job entity.DbJob) { } } -func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool { +func (s *dbScheduler) runnableJob(job entity.DbJob, next runner.NextJobFunc[entity.DbJob]) bool { const maxCountByInstanceId = 4 const maxCountByDbName = 1 var countByInstanceId, countByDbName int jobBase := job.GetJobBase() for item, ok := next(); ok; item, ok = next() { - itemBase := item.(entity.DbJob).GetJobBase() + itemBase := item.GetJobBase() if jobBase.DbInstanceId == itemBase.DbInstanceId { countByInstanceId++ if countByInstanceId >= maxCountByInstanceId { return false } - if jobBase.DbName == itemBase.DbName { + + if relatedToBinlog(job.GetJobType()) { + // todo: 恢复数据库前触发 BINLOG 同步,BINLOG 同步完成后才能恢复数据库 + if relatedToBinlog(item.GetJobType()) { + return false + } + } + + if job.GetDbName() == item.GetDbName() { countByDbName++ if countByDbName >= maxCountByDbName { return false @@ -317,6 +322,10 @@ func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool { return true } +func relatedToBinlog(typ entity.DbJobType) bool { + return typ == entity.DbJobTypeRestore || typ == entity.DbJobTypeBinlog +} + func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error { binlogHistory, err := s.binlogHistoryRepo.GetHistoryByTime(job.DbInstanceId, job.PointInTime.Time) if err != nil { @@ -364,3 +373,34 @@ func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbi.DbPr } return program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid) } + +func (s *dbScheduler) fetchBinlogMysql(ctx context.Context, backup entity.DbJob) error { + instanceId := backup.GetJobBase().DbInstanceId + latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1) + binlogHistory, ok, err := s.binlogHistoryRepo.GetLatestHistory(instanceId) + if err != nil { + return err + } + if ok { + latestBinlogSequence = binlogHistory.Sequence + } else { + backupHistory, ok, err := s.backupHistoryRepo.GetEarliestHistory(instanceId) + if err != nil { + return err + } + if !ok { + return nil + } + earliestBackupSequence = backupHistory.BinlogSequence + } + conn, err := s.dbApp.GetDbConnByInstanceId(instanceId) + if err != nil { + return err + } + dbProgram := conn.GetDialect().GetDbProgram() + binlogFiles, err := dbProgram.FetchBinlogs(ctx, false, earliestBackupSequence, latestBinlogSequence) + if err == nil { + err = s.binlogHistoryRepo.InsertWithBinlogFiles(ctx, instanceId, binlogFiles) + } + return nil +} diff --git a/server/internal/db/dbm/mysql/program.go b/server/internal/db/dbm/mysql/program.go index 3fdf7523..2b33c4e0 100644 --- a/server/internal/db/dbm/mysql/program.go +++ b/server/internal/db/dbm/mysql/program.go @@ -16,8 +16,6 @@ import ( "time" "github.com/pkg/errors" - "golang.org/x/sync/singleflight" - "mayfly-go/internal/db/config" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" @@ -186,6 +184,12 @@ func (svc *DbProgramMysql) downloadBinlogFilesOnServer(ctx context.Context, binl return nil } +// Parse the first binlog eventTs of a local binlog file. +func (svc *DbProgramMysql) parseLocalBinlogLastEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) { + // todo: implement me + return time.Now(), nil +} + // Parse the first binlog eventTs of a local binlog file. func (svc *DbProgramMysql) parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) { args := []string{ @@ -227,36 +231,8 @@ func (svc *DbProgramMysql) parseLocalBinlogFirstEventTime(ctx context.Context, f return time.Time{}, errors.New("解析 binlog 文件失败") } -var singleFlightGroup singleflight.Group - // FetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`. func (svc *DbProgramMysql) FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) { - var downloaded bool - key := strconv.FormatUint(svc.dbInfo().InstanceId, 16) - binlogFiles, err, _ := singleFlightGroup.Do(key, func() (interface{}, error) { - downloaded = true - return svc.fetchBinlogs(ctx, downloadLatestBinlogFile, earliestBackupSequence, latestBinlogSequence) - }) - if err != nil { - return nil, err - } - if downloaded { - return binlogFiles.([]*entity.BinlogFile), nil - } - if !downloadLatestBinlogFile { - return nil, nil - } - binlogFiles, err, _ = singleFlightGroup.Do(key, func() (interface{}, error) { - return svc.fetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence) - }) - if err != nil { - return nil, err - } - return binlogFiles.([]*entity.BinlogFile), err -} - -// fetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`. -func (svc *DbProgramMysql) fetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) { // Read binlog files list on server. binlogFilesOnServerSorted, err := svc.GetSortedBinlogFilesOnServer(ctx) if err != nil { @@ -352,7 +328,13 @@ func (svc *DbProgramMysql) downloadBinlogFile(ctx context.Context, binlogFileToD if err != nil { return err } + lastEventTime, err := svc.parseLocalBinlogLastEventTime(ctx, binlogFilePath) + if err != nil { + return err + } + binlogFileToDownload.FirstEventTime = firstEventTime + binlogFileToDownload.LastEventTime = lastEventTime binlogFileToDownload.Downloaded = true return nil diff --git a/server/internal/db/dbm/mysql/program_e2e_test.go b/server/internal/db/dbm/mysql/program_e2e_test.go index 63d5c62d..8fbb5ea3 100644 --- a/server/internal/db/dbm/mysql/program_e2e_test.go +++ b/server/internal/db/dbm/mysql/program_e2e_test.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/stretchr/testify/suite" "mayfly-go/internal/db/config" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/internal/db/infrastructure/persistence" @@ -32,21 +33,21 @@ type DbInstanceSuite struct { suite.Suite repositories *repository.Repositories instanceSvc *DbProgramMysql - dbConn *DbConn + dbConn *dbi.DbConn } func (s *DbInstanceSuite) SetupSuite() { if err := chdir("mayfly-go", "server"); err != nil { panic(err) } - dbInfo := DbInfo{ - Type: DbTypeMysql, + dbInfo := dbi.DbInfo{ + Type: dbi.DbTypeMysql, Host: "localhost", Port: 3306, Username: "test", Password: "test", } - dbConn, err := dbInfo.Conn() + dbConn, err := dbInfo.Conn(GetMeta()) s.Require().NoError(err) s.dbConn = dbConn s.repositories = &repository.Repositories{ @@ -203,7 +204,7 @@ func (s *DbInstanceSuite) testReplayBinlog(backupHistory *entity.DbBackupHistory binlogHistories = append(binlogHistories, binlogHistoryLast) } - restoreInfo := &RestoreInfo{ + restoreInfo := &dbi.RestoreInfo{ BackupHistory: backupHistory, BinlogHistories: binlogHistories, StartPosition: backupHistory.BinlogPosition, diff --git a/server/internal/db/domain/entity/db_backup.go b/server/internal/db/domain/entity/db_backup.go index ddd57681..c1193760 100644 --- a/server/internal/db/domain/entity/db_backup.go +++ b/server/internal/db/domain/entity/db_backup.go @@ -1,8 +1,8 @@ package entity import ( - "context" "mayfly-go/pkg/runner" + "time" ) var _ DbJob = (*DbBackup)(nil) @@ -11,17 +11,56 @@ var _ DbJob = (*DbBackup)(nil) type DbBackup struct { *DbJobBaseImpl - Name string `json:"Name"` // 数据库备份名称 + Enabled bool // 是否启用 + StartTime time.Time // 开始时间 + Interval time.Duration // 间隔时间 + Repeated bool // 是否重复执行 + DbName string // 数据库名称 + Name string // 数据库备份名称 } -func (d *DbBackup) SetRun(fn func(ctx context.Context, job DbJob)) { - d.run = func(ctx context.Context) { - fn(ctx, d) - } +func (b *DbBackup) GetDbName() string { + return b.DbName } -func (d *DbBackup) SetRunnable(fn func(job DbJob, next runner.NextFunc) bool) { - d.runnable = func(next runner.NextFunc) bool { - return fn(d, next) +func (b *DbBackup) Schedule() (time.Time, error) { + var deadline time.Time + if b.IsFinished() || !b.Enabled { + return deadline, runner.ErrFinished } + switch b.LastStatus { + case DbJobSuccess: + lastTime := b.LastTime.Time + if lastTime.Before(b.StartTime) { + lastTime = b.StartTime.Add(-b.Interval) + } + deadline = lastTime.Add(b.Interval - lastTime.Sub(b.StartTime)%b.Interval) + case DbJobFailed: + deadline = time.Now().Add(time.Minute) + default: + deadline = b.StartTime + } + return deadline, nil +} + +func (b *DbBackup) IsFinished() bool { + return !b.Repeated && b.LastStatus == DbJobSuccess +} + +func (b *DbBackup) IsEnabled() bool { + return b.Enabled +} + +func (b *DbBackup) SetEnabled(enabled bool) { + b.Enabled = enabled +} + +func (b *DbBackup) Update(job runner.Job) { + backup := job.(*DbBackup) + b.StartTime = backup.StartTime + b.Interval = backup.Interval +} + +func (b *DbBackup) GetInterval() time.Duration { + return b.Interval } diff --git a/server/internal/db/domain/entity/db_binlog.go b/server/internal/db/domain/entity/db_binlog.go index a77fe306..77d6586d 100644 --- a/server/internal/db/domain/entity/db_binlog.go +++ b/server/internal/db/domain/entity/db_binlog.go @@ -1,27 +1,13 @@ package entity import ( - "mayfly-go/pkg/model" - "mayfly-go/pkg/utils/timex" + "mayfly-go/pkg/runner" "time" ) -// DbBinlog 数据库备份任务 -type DbBinlog struct { - model.Model - - LastStatus DbJobStatus // 最近一次执行状态 - LastResult string // 最近一次执行结果 - LastTime timex.NullTime // 最近一次执行时间 - DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID -} - -func NewDbBinlog(instanceId uint64) *DbBinlog { - job := &DbBinlog{} - job.Id = instanceId - job.DbInstanceId = instanceId - return job -} +const ( + BinlogDownloadInterval = time.Minute * 15 +) // BinlogFile is the metadata of the MySQL binlog file. type BinlogFile struct { @@ -31,5 +17,49 @@ type BinlogFile struct { // Sequence is parsed from Name and is for the sorting purpose. Sequence int64 FirstEventTime time.Time + LastEventTime time.Time Downloaded bool } + +var _ DbJob = (*DbBinlog)(nil) + +// DbBinlog 数据库备份任务 +type DbBinlog struct { + DbJobBaseImpl +} + +func NewDbBinlog(instanceId uint64) *DbBinlog { + job := &DbBinlog{} + job.Id = instanceId + job.DbInstanceId = instanceId + return job +} + +func (b *DbBinlog) GetDbName() string { + // binlog 是全库级别的 + return "" +} + +func (b *DbBinlog) Schedule() (time.Time, error) { + switch b.GetJobBase().LastStatus { + case DbJobSuccess: + return time.Time{}, runner.ErrFinished + case DbJobFailed: + + return time.Now().Add(BinlogDownloadInterval), nil + default: + return time.Now(), nil + } +} + +func (b *DbBinlog) Update(_ runner.Job) {} + +func (b *DbBinlog) IsEnabled() bool { + return true +} + +func (b *DbBinlog) SetEnabled(_ bool) {} + +func (b *DbBinlog) GetInterval() time.Duration { + return 0 +} diff --git a/server/internal/db/domain/entity/db_job.go b/server/internal/db/domain/entity/db_job.go index 15c354dc..b73f6f14 100644 --- a/server/internal/db/domain/entity/db_job.go +++ b/server/internal/db/domain/entity/db_job.go @@ -1,7 +1,6 @@ package entity import ( - "context" "fmt" "mayfly-go/pkg/model" "mayfly-go/pkg/runner" @@ -14,18 +13,11 @@ const LastResultSize = 256 type DbJobKey = runner.JobKey -type DbJobStatus = runner.JobStatus +type DbJobStatus int const ( - DbJobUnknown = runner.JobUnknown - DbJobDelay = runner.JobDelay - DbJobReady = runner.JobWaiting - DbJobRunning = runner.JobRunning - DbJobRemoved = runner.JobRemoved -) - -const ( - DbJobSuccess DbJobStatus = 0x20 + iota + DbJobRunning DbJobStatus = iota + DbJobSuccess DbJobFailed ) @@ -34,32 +26,37 @@ type DbJobType = string const ( DbJobTypeBackup DbJobType = "db-backup" DbJobTypeRestore DbJobType = "db-restore" + DbJobTypeBinlog DbJobType = "db-binlog" ) const ( DbJobNameBackup = "数据库备份" DbJobNameRestore = "数据库恢复" + DbJobNameBinlog = "BINLOG同步" ) var _ runner.Job = (DbJob)(nil) type DbJobBase interface { model.ModelI - runner.Job - GetId() uint64 + GetKey() string GetJobType() DbJobType SetJobType(typ DbJobType) GetJobBase() *DbJobBaseImpl SetLastStatus(status DbJobStatus, err error) - IsEnabled() bool } type DbJob interface { + runner.Job DbJobBase - SetRun(fn func(ctx context.Context, job DbJob)) - SetRunnable(fn func(job DbJob, next runner.NextFunc) bool) + GetDbName() string + Schedule() (time.Time, error) + IsEnabled() bool + SetEnabled(enabled bool) + Update(job runner.Job) + GetInterval() time.Duration } func NewDbJob(typ DbJobType) DbJob { @@ -85,31 +82,17 @@ type DbJobBaseImpl struct { model.Model DbInstanceId uint64 // 数据库实例ID - DbName string // 数据库名称 - Enabled bool // 是否启用 - StartTime time.Time // 开始时间 - Interval time.Duration // 间隔时间 - Repeated bool // 是否重复执行 LastStatus DbJobStatus // 最近一次执行状态 LastResult string // 最近一次执行结果 LastTime timex.NullTime // 最近一次执行时间 - Deadline time.Time `gorm:"-" json:"-"` // 计划执行时间 - run runner.RunFunc - runnable runner.RunnableFunc jobType DbJobType jobKey runner.JobKey - jobStatus runner.JobStatus } -func NewDbBJobBase(instanceId uint64, dbName string, jobType DbJobType, enabled bool, repeated bool, startTime time.Time, interval time.Duration) *DbJobBaseImpl { +func NewDbBJobBase(instanceId uint64, jobType DbJobType) *DbJobBaseImpl { return &DbJobBaseImpl{ DbInstanceId: instanceId, - DbName: dbName, jobType: jobType, - Enabled: enabled, - Repeated: repeated, - StartTime: startTime, - Interval: interval, } } @@ -138,6 +121,8 @@ func (d *DbJobBaseImpl) SetLastStatus(status DbJobStatus, err error) { jobName = DbJobNameBackup case DbJobTypeRestore: jobName = DbJobNameRestore + case DbJobNameBinlog: + jobName = DbJobNameBinlog default: jobName = d.jobType } @@ -150,71 +135,10 @@ func (d *DbJobBaseImpl) SetLastStatus(status DbJobStatus, err error) { d.LastTime = timex.NewNullTime(time.Now()) } -func (d *DbJobBaseImpl) GetId() uint64 { - if d == nil { - return 0 - } - return d.Id -} - -func (d *DbJobBaseImpl) GetDeadline() time.Time { - return d.Deadline -} - -func (d *DbJobBaseImpl) Schedule() bool { - if d.IsFinished() || !d.Enabled { - return false - } - switch d.LastStatus { - case DbJobSuccess: - if d.Interval == 0 { - return false - } - lastTime := d.LastTime.Time - if lastTime.Sub(d.StartTime) < 0 { - lastTime = d.StartTime.Add(-d.Interval) - } - d.Deadline = lastTime.Add(d.Interval - lastTime.Sub(d.StartTime)%d.Interval) - case DbJobFailed: - d.Deadline = time.Now().Add(time.Minute) - default: - d.Deadline = d.StartTime - } - return true -} - -func (d *DbJobBaseImpl) IsFinished() bool { - return !d.Repeated && d.LastStatus == DbJobSuccess -} - -func (d *DbJobBaseImpl) Update(job runner.Job) { - jobBase := job.(DbJob).GetJobBase() - d.StartTime = jobBase.StartTime - d.Interval = jobBase.Interval -} - func (d *DbJobBaseImpl) GetJobBase() *DbJobBaseImpl { return d } -func (d *DbJobBaseImpl) IsEnabled() bool { - return d.Enabled -} - -func (d *DbJobBaseImpl) Run(ctx context.Context) { - if d.run == nil { - return - } - d.run(ctx) -} - -func (d *DbJobBaseImpl) Runnable(next runner.NextFunc) bool { - if d.runnable == nil { - return true - } - return d.runnable(next) -} - func FormatJobKey(typ DbJobType, jobId uint64) DbJobKey { return fmt.Sprintf("%v-%d", typ, jobId) } @@ -225,11 +149,3 @@ func (d *DbJobBaseImpl) GetKey() DbJobKey { } return d.jobKey } - -func (d *DbJobBaseImpl) GetStatus() DbJobStatus { - return d.jobStatus -} - -func (d *DbJobBaseImpl) SetStatus(status DbJobStatus) { - d.jobStatus = status -} diff --git a/server/internal/db/domain/entity/db_restore.go b/server/internal/db/domain/entity/db_restore.go index d8d0a081..92354e8a 100644 --- a/server/internal/db/domain/entity/db_restore.go +++ b/server/internal/db/domain/entity/db_restore.go @@ -1,9 +1,9 @@ package entity import ( - "context" "mayfly-go/pkg/runner" "mayfly-go/pkg/utils/timex" + "time" ) var _ DbJob = (*DbRestore)(nil) @@ -12,20 +12,59 @@ var _ DbJob = (*DbRestore)(nil) type DbRestore struct { *DbJobBaseImpl + DbName string // 数据库名称 + Enabled bool // 是否启用 + StartTime time.Time // 开始时间 + Interval time.Duration // 间隔时间 + Repeated bool // 是否重复执行 PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间点 DbBackupId uint64 `json:"dbBackupId"` // 用于恢复的数据库恢复任务ID DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 用于恢复的数据库恢复历史ID DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库恢复历史名称 } -func (d *DbRestore) SetRun(fn func(ctx context.Context, job DbJob)) { - d.run = func(ctx context.Context) { - fn(ctx, d) - } +func (r *DbRestore) GetDbName() string { + return r.DbName } -func (d *DbRestore) SetRunnable(fn func(job DbJob, next runner.NextFunc) bool) { - d.runnable = func(next runner.NextFunc) bool { - return fn(d, next) +func (r *DbRestore) Schedule() (time.Time, error) { + var deadline time.Time + if r.IsFinished() || !r.Enabled { + return deadline, runner.ErrFinished } + switch r.LastStatus { + case DbJobSuccess: + lastTime := r.LastTime.Time + if lastTime.Before(r.StartTime) { + lastTime = r.StartTime.Add(-r.Interval) + } + deadline = lastTime.Add(r.Interval - lastTime.Sub(r.StartTime)%r.Interval) + case DbJobFailed: + deadline = time.Now().Add(time.Minute) + default: + deadline = r.StartTime + } + return deadline, nil +} + +func (r *DbRestore) IsEnabled() bool { + return r.Enabled +} + +func (r *DbRestore) SetEnabled(enabled bool) { + r.Enabled = enabled +} + +func (r *DbRestore) IsFinished() bool { + return !r.Repeated && r.LastStatus == DbJobSuccess +} + +func (r *DbRestore) Update(job runner.Job) { + restore := job.(*DbRestore) + r.StartTime = restore.StartTime + r.Interval = restore.Interval +} + +func (r *DbRestore) GetInterval() time.Duration { + return r.Interval } diff --git a/server/internal/db/domain/repository/db_backup.go b/server/internal/db/domain/repository/db_backup.go index f0e26e4a..77b688b8 100644 --- a/server/internal/db/domain/repository/db_backup.go +++ b/server/internal/db/domain/repository/db_backup.go @@ -1,7 +1,17 @@ package repository +import ( + "mayfly-go/internal/db/domain/entity" + "mayfly-go/pkg/model" +) + type DbBackup interface { DbJob + ListToDo(jobs any) error + ListDbInstances(enabled bool, repeated bool, instanceIds *[]uint64) error GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) + + // GetPageList 分页获取数据库任务列表 + GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) } diff --git a/server/internal/db/domain/repository/db_binlog.go b/server/internal/db/domain/repository/db_binlog.go index 30a274b9..76c1122a 100644 --- a/server/internal/db/domain/repository/db_binlog.go +++ b/server/internal/db/domain/repository/db_binlog.go @@ -3,11 +3,10 @@ package repository import ( "context" "mayfly-go/internal/db/domain/entity" - "mayfly-go/pkg/base" ) type DbBinlog interface { - base.Repo[*entity.DbBinlog] + DbJob AddJobIfNotExists(ctx context.Context, job *entity.DbBinlog) error } diff --git a/server/internal/db/domain/repository/db_job.go b/server/internal/db/domain/repository/db_job.go index 7be1da98..76ced093 100644 --- a/server/internal/db/domain/repository/db_job.go +++ b/server/internal/db/domain/repository/db_job.go @@ -2,27 +2,27 @@ package repository import ( "context" - "gorm.io/gorm" "mayfly-go/internal/db/domain/entity" - "mayfly-go/pkg/model" ) -type DbJob interface { - // AddJob 添加数据库任务 - AddJob(ctx context.Context, jobs any) error +type DbJobBase interface { // GetById 根据实体id查询 GetById(e entity.DbJob, id uint64, cols ...string) error - // GetPageList 分页获取数据库任务列表 - GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) + // UpdateById 根据实体id更新实体信息 UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error - // BatchInsertWithDb 使用指定gorm db执行,主要用于事务执行 - BatchInsertWithDb(ctx context.Context, db *gorm.DB, es any) error + // DeleteById 根据实体主键删除实体 DeleteById(ctx context.Context, id uint64) error + // UpdateLastStatus 更新任务执行状态 UpdateLastStatus(ctx context.Context, job entity.DbJob) error - UpdateEnabled(ctx context.Context, jobId uint64, enabled bool) error - ListToDo(jobs any) error - ListRepeating(jobs any) error +} + +type DbJob interface { + DbJobBase + + // AddJob 添加数据库任务 + AddJob(ctx context.Context, jobs any) error + UpdateEnabled(ctx context.Context, jobId uint64, enabled bool) error } diff --git a/server/internal/db/domain/repository/db_restore.go b/server/internal/db/domain/repository/db_restore.go index 61271fe6..9b934d16 100644 --- a/server/internal/db/domain/repository/db_restore.go +++ b/server/internal/db/domain/repository/db_restore.go @@ -1,7 +1,16 @@ package repository +import ( + "mayfly-go/internal/db/domain/entity" + "mayfly-go/pkg/model" +) + type DbRestore interface { DbJob + ListToDo(jobs any) error GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) + + // GetPageList 分页获取数据库任务列表 + GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) } diff --git a/server/internal/db/infrastructure/persistence/db_backup.go b/server/internal/db/infrastructure/persistence/db_backup.go index 8fb4b09c..a42ebd3e 100644 --- a/server/internal/db/infrastructure/persistence/db_backup.go +++ b/server/internal/db/infrastructure/persistence/db_backup.go @@ -1,16 +1,19 @@ package persistence import ( + "context" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" + "mayfly-go/pkg/global" "mayfly-go/pkg/gormx" + "mayfly-go/pkg/model" "slices" ) var _ repository.DbBackup = (*dbBackupRepoImpl)(nil) type dbBackupRepoImpl struct { - dbJobBase[*entity.DbBackup] + dbJobBaseImpl[*entity.DbBackup] } func NewDbBackupRepo() repository.DbBackup { @@ -21,7 +24,8 @@ func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames [] var dbNamesWithBackup []string query := gormx.NewQuery(d.GetModel()). Eq("db_instance_id", instanceId). - Eq("repeated", true) + Eq("repeated", true). + Undeleted() if err := query.GenGdb().Pluck("db_name", &dbNamesWithBackup).Error; err != nil { return nil, err } @@ -33,3 +37,49 @@ func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames [] } return result, nil } + +func (d *dbBackupRepoImpl) ListDbInstances(enabled bool, repeated bool, instanceIds *[]uint64) error { + query := gormx.NewQuery(d.GetModel()). + Eq0("enabled", enabled). + Eq0("repeated", repeated). + Undeleted() + return query.GenGdb().Distinct().Pluck("db_instance_id", &instanceIds).Error +} + +func (d *dbBackupRepoImpl) ListToDo(jobs any) error { + db := global.Db.Model(d.GetModel()) + err := db.Where("enabled = ?", true). + Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.DbJobSuccess)). + Scopes(gormx.UndeleteScope). + Find(jobs).Error + if err != nil { + return err + } + return nil +} + +// GetPageList 分页获取数据库备份任务列表 +func (d *dbBackupRepoImpl) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) { + d.GetModel() + qd := gormx.NewQuery(d.GetModel()). + Eq("id", condition.Id). + Eq0("db_instance_id", condition.DbInstanceId). + Eq0("repeated", condition.Repeated). + In0("db_name", condition.InDbNames). + Like("db_name", condition.DbName) + return gormx.PageQuery(qd, pageParam, toEntity) +} + +// AddJob 添加数据库任务 +func (d *dbBackupRepoImpl) AddJob(ctx context.Context, jobs any) error { + return addJob[*entity.DbBackup](ctx, d.dbJobBaseImpl, jobs) +} + +func (d *dbBackupRepoImpl) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error { + cond := map[string]any{ + "id": jobId, + } + return d.Updates(cond, map[string]any{ + "enabled": enabled, + }) +} diff --git a/server/internal/db/infrastructure/persistence/db_binlog.go b/server/internal/db/infrastructure/persistence/db_binlog.go index ea787354..2cf58c37 100644 --- a/server/internal/db/infrastructure/persistence/db_binlog.go +++ b/server/internal/db/infrastructure/persistence/db_binlog.go @@ -6,14 +6,13 @@ import ( "gorm.io/gorm/clause" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" - "mayfly-go/pkg/base" "mayfly-go/pkg/global" ) var _ repository.DbBinlog = (*dbBinlogRepoImpl)(nil) type dbBinlogRepoImpl struct { - base.RepoImpl[*entity.DbBinlog] + dbJobBaseImpl[*entity.DbBinlog] } func NewDbBinlogRepo() repository.DbBinlog { @@ -21,8 +20,18 @@ func NewDbBinlogRepo() repository.DbBinlog { } func (d *dbBinlogRepoImpl) AddJobIfNotExists(_ context.Context, job *entity.DbBinlog) error { + // todo: 如果存在已删除记录,如何处理? if err := global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(job).Error; err != nil { return fmt.Errorf("启动 binlog 下载失败: %w", err) } return nil } + +// AddJob 添加数据库任务 +func (d *dbBinlogRepoImpl) AddJob(ctx context.Context, jobs any) error { + panic("not implement, use AddJobIfNotExists") +} + +func (d *dbBinlogRepoImpl) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error { + panic("not implement") +} diff --git a/server/internal/db/infrastructure/persistence/db_binlog_history.go b/server/internal/db/infrastructure/persistence/db_binlog_history.go index 09a11444..42345257 100644 --- a/server/internal/db/infrastructure/persistence/db_binlog_history.go +++ b/server/internal/db/infrastructure/persistence/db_binlog_history.go @@ -78,6 +78,7 @@ func (repo *dbBinlogHistoryRepoImpl) Upsert(_ context.Context, history *entity.D old := &entity.DbBinlogHistory{} err := db.Where("db_instance_id = ?", history.DbInstanceId). Where("sequence = ?", history.Sequence). + Scopes(gormx.UndeleteScope). First(old).Error switch { case err == nil: diff --git a/server/internal/db/infrastructure/persistence/db_job_base.go b/server/internal/db/infrastructure/persistence/db_job_base.go index 32340867..f12cf516 100644 --- a/server/internal/db/infrastructure/persistence/db_job_base.go +++ b/server/internal/db/infrastructure/persistence/db_job_base.go @@ -6,74 +6,32 @@ import ( "fmt" "gorm.io/gorm" "mayfly-go/internal/db/domain/entity" + "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/base" - "mayfly-go/pkg/global" "mayfly-go/pkg/gormx" - "mayfly-go/pkg/model" "reflect" ) -type dbJobBase[T entity.DbJob] struct { +var _ repository.DbJobBase = (*dbJobBaseImpl[entity.DbJob])(nil) + +type dbJobBaseImpl[T entity.DbJob] struct { base.RepoImpl[T] } -func (d *dbJobBase[T]) GetById(e entity.DbJob, id uint64, cols ...string) error { +func (d *dbJobBaseImpl[T]) GetById(e entity.DbJob, id uint64, cols ...string) error { return d.RepoImpl.GetById(e.(T), id, cols...) } -func (d *dbJobBase[T]) UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error { +func (d *dbJobBaseImpl[T]) UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error { return d.RepoImpl.UpdateById(ctx, e.(T), columns...) } -func (d *dbJobBase[T]) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error { - cond := map[string]any{ - "id": jobId, - } - return d.Updates(cond, map[string]any{ - "enabled": enabled, - }) -} - -func (d *dbJobBase[T]) UpdateLastStatus(ctx context.Context, job entity.DbJob) error { +func (d *dbJobBaseImpl[T]) UpdateLastStatus(ctx context.Context, job entity.DbJob) error { return d.UpdateById(ctx, job.(T), "last_status", "last_result", "last_time") } -func (d *dbJobBase[T]) ListToDo(jobs any) error { - db := global.Db.Model(d.GetModel()) - err := db.Where("enabled = ?", true). - Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.DbJobSuccess)). - Scopes(gormx.UndeleteScope). - Find(jobs).Error - if err != nil { - return err - } - return nil -} - -func (d *dbJobBase[T]) ListRepeating(jobs any) error { - cond := map[string]any{ - "enabled": true, - "repeated": true, - } - if err := d.ListByCond(cond, jobs); err != nil { - return err - } - return nil -} - -// GetPageList 分页获取数据库备份任务列表 -func (d *dbJobBase[T]) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) { - d.GetModel() - qd := gormx.NewQuery(d.GetModel()). - Eq("id", condition.Id). - Eq0("db_instance_id", condition.DbInstanceId). - Eq0("repeated", condition.Repeated). - In0("db_name", condition.InDbNames). - Like("db_name", condition.DbName) - return gormx.PageQuery(qd, pageParam, toEntity) -} - -func (d *dbJobBase[T]) AddJob(ctx context.Context, jobs any) error { +func addJob[T entity.DbJob](ctx context.Context, repo dbJobBaseImpl[T], jobs any) error { + // refactor and jobs from any to []T return gormx.Tx(func(db *gorm.DB) error { var instanceId uint64 var dbNames []string @@ -93,26 +51,28 @@ func (d *dbJobBase[T]) AddJob(ctx context.Context, jobs any) error { if jobBase.DbInstanceId != instanceId { return errors.New("不支持同时为多个数据库实例添加数据库任务") } - if jobBase.Interval == 0 { + if job.GetInterval() == 0 { // 单次执行的数据库任务可重复创建 continue } - dbNames = append(dbNames, jobBase.DbName) + dbNames = append(dbNames, job.GetDbName()) } default: - jobBase := jobs.(entity.DbJob).GetJobBase() + job := jobs.(entity.DbJob) + jobBase := job.GetJobBase() instanceId = jobBase.DbInstanceId - if jobBase.Interval > 0 { - dbNames = append(dbNames, jobBase.DbName) + if job.GetInterval() > 0 { + dbNames = append(dbNames, job.GetDbName()) } } var res []string - err := db.Model(d.GetModel()).Select("db_name"). + err := db.Model(repo.GetModel()).Select("db_name"). Where("db_instance_id = ?", instanceId). Where("db_name in ?", dbNames). Where("repeated = true"). - Scopes(gormx.UndeleteScope).Find(&res).Error + Scopes(gormx.UndeleteScope). + Find(&res).Error if err != nil { return err } @@ -120,8 +80,8 @@ func (d *dbJobBase[T]) AddJob(ctx context.Context, jobs any) error { return errors.New(fmt.Sprintf("数据库任务已存在: %v", res)) } if plural { - return d.BatchInsertWithDb(ctx, db, jobs) + return repo.BatchInsertWithDb(ctx, db, jobs.([]T)) } - return d.InsertWithDb(ctx, db, jobs.(T)) + return repo.InsertWithDb(ctx, db, jobs.(T)) }) } diff --git a/server/internal/db/infrastructure/persistence/db_restore.go b/server/internal/db/infrastructure/persistence/db_restore.go index 5f93f86e..b0884386 100644 --- a/server/internal/db/infrastructure/persistence/db_restore.go +++ b/server/internal/db/infrastructure/persistence/db_restore.go @@ -1,16 +1,19 @@ package persistence import ( + "context" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" + "mayfly-go/pkg/global" "mayfly-go/pkg/gormx" + "mayfly-go/pkg/model" "slices" ) var _ repository.DbRestore = (*dbRestoreRepoImpl)(nil) type dbRestoreRepoImpl struct { - dbJobBase[*entity.DbRestore] + dbJobBaseImpl[*entity.DbRestore] } func NewDbRestoreRepo() repository.DbRestore { @@ -21,7 +24,8 @@ func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames var dbNamesWithRestore []string query := gormx.NewQuery(d.GetModel()). Eq("db_instance_id", instanceId). - Eq("repeated", true) + Eq("repeated", true). + Undeleted() if err := query.GenGdb().Pluck("db_name", &dbNamesWithRestore).Error; err != nil { return nil, err } @@ -33,3 +37,41 @@ func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames } return result, nil } + +func (d *dbRestoreRepoImpl) ListToDo(jobs any) error { + db := global.Db.Model(d.GetModel()) + err := db.Where("enabled = ?", true). + Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.DbJobSuccess)). + Scopes(gormx.UndeleteScope). + Find(jobs).Error + if err != nil { + return err + } + return nil +} + +// GetPageList 分页获取数据库备份任务列表 +func (d *dbRestoreRepoImpl) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) { + d.GetModel() + qd := gormx.NewQuery(d.GetModel()). + Eq("id", condition.Id). + Eq0("db_instance_id", condition.DbInstanceId). + Eq0("repeated", condition.Repeated). + In0("db_name", condition.InDbNames). + Like("db_name", condition.DbName) + return gormx.PageQuery(qd, pageParam, toEntity) +} + +// AddJob 添加数据库任务 +func (d *dbRestoreRepoImpl) AddJob(ctx context.Context, jobs any) error { + return addJob[*entity.DbRestore](ctx, d.dbJobBaseImpl, jobs) +} + +func (d *dbRestoreRepoImpl) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error { + cond := map[string]any{ + "id": jobId, + } + return d.Updates(cond, map[string]any{ + "enabled": enabled, + }) +} diff --git a/server/internal/tag/api/team.go b/server/internal/tag/api/team.go index bb8e28de..6be11fed 100644 --- a/server/internal/tag/api/team.go +++ b/server/internal/tag/api/team.go @@ -23,8 +23,9 @@ type Team struct { } func (p *Team) GetTeams(rc *req.Ctx) { + queryCond, page := ginx.BindQueryAndPage(rc.GinCtx, new(entity.TeamQuery)) teams := &[]entity.Team{} - res, err := p.TeamApp.GetPageList(&entity.Team{}, ginx.GetPageParam(rc.GinCtx), teams) + res, err := p.TeamApp.GetPageList(queryCond, page, teams) biz.ErrIsNil(err) rc.ResData = res } diff --git a/server/internal/tag/application/team.go b/server/internal/tag/application/team.go index 053b06b6..dc962263 100644 --- a/server/internal/tag/application/team.go +++ b/server/internal/tag/application/team.go @@ -13,7 +13,7 @@ import ( type Team interface { // 分页获取项目团队信息列表 - GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) + GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) Save(ctx context.Context, team *entity.Team) error @@ -55,7 +55,7 @@ type teamAppImpl struct { tagTreeTeamRepo repository.TagTreeTeam } -func (p *teamAppImpl) GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { +func (p *teamAppImpl) GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { return p.teamRepo.GetPageList(condition, pageParam, toEntity, orderBy...) } diff --git a/server/internal/tag/domain/entity/query.go b/server/internal/tag/domain/entity/query.go index d392b2b7..f59b1a05 100644 --- a/server/internal/tag/domain/entity/query.go +++ b/server/internal/tag/domain/entity/query.go @@ -26,3 +26,9 @@ type TagResourceQuery struct { TagPathLike string // 标签路径模糊查询 TagPathLikes []string } + +type TeamQuery struct { + model.Model + + Name string `json:"name" form:"name"` // 团队名称 +} diff --git a/server/internal/tag/domain/repository/team.go b/server/internal/tag/domain/repository/team.go index 3deca5f1..589c0968 100644 --- a/server/internal/tag/domain/repository/team.go +++ b/server/internal/tag/domain/repository/team.go @@ -9,5 +9,5 @@ import ( type Team interface { base.Repo[*entity.Team] - GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) + GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) } diff --git a/server/internal/tag/infrastructure/persistence/team.go b/server/internal/tag/infrastructure/persistence/team.go index 5ca2775d..16e3c3b4 100644 --- a/server/internal/tag/infrastructure/persistence/team.go +++ b/server/internal/tag/infrastructure/persistence/team.go @@ -13,10 +13,12 @@ type teamRepoImpl struct { } func newTeamRepo() repository.Team { - return &teamRepoImpl{base.RepoImpl[*entity.Team]{M: new(entity.Team)}} + return &teamRepoImpl{} } -func (p *teamRepoImpl) GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { - qd := gormx.NewQuery(condition).WithCondModel(condition).WithOrderBy(orderBy...) +func (p *teamRepoImpl) GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { + qd := gormx.NewQuery(p.GetModel()). + Like("name", condition.Name). + WithOrderBy() return gormx.PageQuery(qd, pageParam, toEntity) } diff --git a/server/pkg/base/repo.go b/server/pkg/base/repo.go index 48915e71..e0cc273b 100644 --- a/server/pkg/base/repo.go +++ b/server/pkg/base/repo.go @@ -22,10 +22,10 @@ type Repo[T model.ModelI] interface { InsertWithDb(ctx context.Context, db *gorm.DB, e T) error // 批量新增实体 - BatchInsert(ctx context.Context, models any) error + BatchInsert(ctx context.Context, models []T) error // 使用指定gorm db执行,主要用于事务执行 - BatchInsertWithDb(ctx context.Context, db *gorm.DB, models any) error + BatchInsertWithDb(ctx context.Context, db *gorm.DB, models []T) error // 根据实体id更新实体信息 UpdateById(ctx context.Context, e T, columns ...string) error @@ -93,19 +93,19 @@ func (br *RepoImpl[T]) InsertWithDb(ctx context.Context, db *gorm.DB, e T) error return gormx.InsertWithDb(db, br.fillBaseInfo(ctx, e)) } -func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es any) error { +func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es []T) error { if db := contextx.GetDb(ctx); db != nil { return br.BatchInsertWithDb(ctx, db, es) } - for _, e := range es.([]T) { + for _, e := range es { br.fillBaseInfo(ctx, e) } return gormx.BatchInsert[T](es) } // 使用指定gorm db执行,主要用于事务执行 -func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es any) error { - for _, e := range es.([]T) { +func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error { + for _, e := range es { br.fillBaseInfo(ctx, e) } return gormx.BatchInsertWithDb[T](db, es) diff --git a/server/pkg/gormx/gormx.go b/server/pkg/gormx/gormx.go index e59f3a83..61a8eed5 100644 --- a/server/pkg/gormx/gormx.go +++ b/server/pkg/gormx/gormx.go @@ -135,13 +135,13 @@ func InsertWithDb(db *gorm.DB, model any) error { } // 批量插入 -func BatchInsert[T any](models any) error { +func BatchInsert[T any](models []T) error { return BatchInsertWithDb[T](global.Db, models) } // 批量插入 -func BatchInsertWithDb[T any](db *gorm.DB, models any) error { - return db.CreateInBatches(models, len(models.([]T))).Error +func BatchInsertWithDb[T any](db *gorm.DB, models []T) error { + return db.CreateInBatches(models, len(models)).Error } // 根据id更新model,更新字段为model中不为空的值,即int类型不为0,ptr类型不为nil这类字段值 diff --git a/server/pkg/runner/delay_queue.go b/server/pkg/runner/delay_queue.go index 9ae4b661..89484b70 100644 --- a/server/pkg/runner/delay_queue.go +++ b/server/pkg/runner/delay_queue.go @@ -25,6 +25,35 @@ type Delayable interface { GetKey() string } +var _ Delayable = (*wrapper[Job])(nil) + +type wrapper[T Job] struct { + key string + deadline time.Time + removed bool + status JobStatus + job T +} + +func newWrapper[T Job](job T) *wrapper[T] { + return &wrapper[T]{ + key: job.GetKey(), + job: job, + } +} + +func (d *wrapper[T]) GetDeadline() time.Time { + return d.deadline +} + +func (d *wrapper[T]) GetKey() string { + return d.key +} + +func (d *wrapper[T]) Payload() T { + return d.job +} + func NewDelayQueue[T Delayable](cap int) *DelayQueue[T] { singleDequeue := make(chan struct{}, 1) singleDequeue <- struct{}{} diff --git a/server/pkg/runner/priority_queue.go b/server/pkg/runner/priority_queue.go index cdb14074..68a3a1e7 100644 --- a/server/pkg/runner/priority_queue.go +++ b/server/pkg/runner/priority_queue.go @@ -1,11 +1,5 @@ package runner -//var ( -// false = errors.New("queue: 队列已满") -// false = errors.New("queue: 队列为空") -// false = errors.New("queue: 元素未找到") -//) - // PriorityQueue 是一个基于小顶堆的优先队列 // 当capacity <= 0时,为无界队列,切片容量会动态扩缩容 // 当capacity > 0 时,为有界队列,初始化后就固定容量,不会扩缩容 diff --git a/server/pkg/runner/runner.go b/server/pkg/runner/runner.go index 0e5d153d..44187b97 100644 --- a/server/pkg/runner/runner.go +++ b/server/pkg/runner/runner.go @@ -2,43 +2,43 @@ package runner import ( "context" + "errors" "fmt" "github.com/emirpasic/gods/maps/linkedhashmap" "mayfly-go/pkg/logx" - "mayfly-go/pkg/utils/timex" "sync" "time" ) +var ( + ErrNotFound = errors.New("job not found") + ErrExist = errors.New("job already exists") + ErrFinished = errors.New("job already finished") +) + type JobKey = string -type RunFunc func(ctx context.Context) -type NextFunc func() (Job, bool) -type RunnableFunc func(next NextFunc) bool +type RunJobFunc[T Job] func(ctx context.Context, job T) +type NextJobFunc[T Job] func() (T, bool) +type RunnableJobFunc[T Job] func(job T, next NextJobFunc[T]) bool +type ScheduleJobFunc[T Job] func(job T) (deadline time.Time, err error) type JobStatus int const ( JobUnknown JobStatus = iota - JobDelay + JobDelaying JobWaiting JobRunning - JobRemoved ) type Job interface { GetKey() JobKey - GetStatus() JobStatus - SetStatus(status JobStatus) - Run(ctx context.Context) - Runnable(next NextFunc) bool - GetDeadline() time.Time - Schedule() bool Update(job Job) } type iterator[T Job] struct { index int - data []T + data []*wrapper[T] zero T } @@ -51,19 +51,18 @@ func (iter *iterator[T]) Next() (T, bool) { return iter.zero, false } iter.index++ - return iter.data[iter.index], true + return iter.data[iter.index].job, true } type array[T Job] struct { size int - data []T - zero T + data []*wrapper[T] } func newArray[T Job](size int) *array[T] { return &array[T]{ size: size, - data: make([]T, 0, size), + data: make([]*wrapper[T], 0, size), } } @@ -78,7 +77,7 @@ func (a *array[T]) Full() bool { return len(a.data) >= a.size } -func (a *array[T]) Append(job T) bool { +func (a *array[T]) Append(job *wrapper[T]) bool { if len(a.data) >= a.size { return false } @@ -86,20 +85,20 @@ func (a *array[T]) Append(job T) bool { return true } -func (a *array[T]) Get(key JobKey) (T, bool) { +func (a *array[T]) Get(key JobKey) (*wrapper[T], bool) { for _, job := range a.data { if key == job.GetKey() { return job, true } } - return a.zero, false + return nil, false } func (a *array[T]) Remove(key JobKey) { length := len(a.data) for i, elm := range a.data { if key == elm.GetKey() { - a.data[i], a.data[length-1] = a.data[length-1], a.zero + a.data[i], a.data[length-1] = a.data[length-1], nil a.data = a.data[:length-1] return } @@ -107,47 +106,76 @@ func (a *array[T]) Remove(key JobKey) { } type Runner[T Job] struct { - maxRunning int - waiting *linkedhashmap.Map - running *array[T] - runnable func(job T, iterateRunning func() (T, bool)) bool - mutex sync.Mutex - wg sync.WaitGroup - context context.Context - cancel context.CancelFunc - zero T - signal chan struct{} - all map[string]T - delayQueue *DelayQueue[T] + maxRunning int + waiting *linkedhashmap.Map + running *array[T] + runJob RunJobFunc[T] + runnableJob RunnableJobFunc[T] + scheduleJob ScheduleJobFunc[T] + mutex sync.Mutex + wg sync.WaitGroup + context context.Context + cancel context.CancelFunc + zero T + signal chan struct{} + all map[JobKey]*wrapper[T] + delayQueue *DelayQueue[*wrapper[T]] } -func NewRunner[T Job](maxRunning int) *Runner[T] { +type Opt[T Job] func(runner *Runner[T]) + +func WithRunnableJob[T Job](runnableJob RunnableJobFunc[T]) Opt[T] { + return func(runner *Runner[T]) { + runner.runnableJob = runnableJob + } +} + +func WithScheduleJob[T Job](scheduleJob ScheduleJobFunc[T]) Opt[T] { + return func(runner *Runner[T]) { + runner.scheduleJob = scheduleJob + } +} + +func NewRunner[T Job](maxRunning int, runJob RunJobFunc[T], opts ...Opt[T]) *Runner[T] { ctx, cancel := context.WithCancel(context.Background()) runner := &Runner[T]{ maxRunning: maxRunning, - all: make(map[string]T, maxRunning), + all: make(map[string]*wrapper[T], maxRunning), waiting: linkedhashmap.New(), running: newArray[T](maxRunning), context: ctx, cancel: cancel, signal: make(chan struct{}, 1), - delayQueue: NewDelayQueue[T](0), + delayQueue: NewDelayQueue[*wrapper[T]](0), } + runner.runJob = runJob + for _, opt := range opts { + opt(runner) + } + if runner.runnableJob == nil { + runner.runnableJob = func(job T, _ NextJobFunc[T]) bool { + return true + } + } + runner.wg.Add(maxRunning + 1) for i := 0; i < maxRunning; i++ { go runner.run() } go func() { defer runner.wg.Done() - timex.SleepWithContext(runner.context, time.Second*10) for runner.context.Err() == nil { - job, ok := runner.delayQueue.Dequeue(ctx) + wrap, ok := runner.delayQueue.Dequeue(ctx) if !ok { continue } runner.mutex.Lock() - runner.waiting.Put(job.GetKey(), job) - job.SetStatus(JobWaiting) + if old, ok := runner.all[wrap.key]; !ok || wrap != old { + runner.mutex.Unlock() + continue + } + runner.waiting.Put(wrap.key, wrap) + wrap.status = JobWaiting runner.trigger() runner.mutex.Unlock() } @@ -166,144 +194,155 @@ func (r *Runner[T]) run() { for r.context.Err() == nil { select { case <-r.signal: - job, ok := r.pickRunnable() + wrap, ok := r.pickRunnableJob() if !ok { continue } - r.doRun(job) - r.afterRun(job) + r.doRun(wrap) + r.afterRun(wrap) case <-r.context.Done(): } } } -func (r *Runner[T]) doRun(job T) { +func (r *Runner[T]) doRun(wrap *wrapper[T]) { defer func() { if err := recover(); err != nil { logx.Error(fmt.Sprintf("failed to run job: %v", err)) } }() - job.Run(r.context) + r.runJob(r.context, wrap.job) } -func (r *Runner[T]) afterRun(job T) { +func (r *Runner[T]) afterRun(wrap *wrapper[T]) { r.mutex.Lock() defer r.mutex.Unlock() - key := job.GetKey() - r.running.Remove(key) + r.running.Remove(wrap.key) + delete(r.all, wrap.key) + wrap.status = JobUnknown r.trigger() - switch job.GetStatus() { - case JobRunning: - r.schedule(r.context, job) - case JobRemoved: - delete(r.all, key) - default: - panic(fmt.Sprintf("invalid job status %v occurred after run", job.GetStatus())) + if wrap.removed { + return } + deadline, err := r.doScheduleJob(wrap.job, true) + if err != nil { + return + } + _ = r.schedule(r.context, deadline, wrap.job) } -func (r *Runner[T]) pickRunnable() (T, bool) { +func (r *Runner[T]) doScheduleJob(job T, finished bool) (time.Time, error) { + if r.scheduleJob == nil { + if finished { + return time.Time{}, ErrFinished + } + return time.Now(), nil + } + return r.scheduleJob(job) +} + +func (r *Runner[T]) pickRunnableJob() (*wrapper[T], bool) { r.mutex.Lock() defer r.mutex.Unlock() iter := r.running.Iterator() - var runnable T + var runnable *wrapper[T] ok := r.waiting.Any(func(key interface{}, value interface{}) bool { - job := value.(T) + wrap := value.(*wrapper[T]) iter.Begin() - if job.Runnable(func() (Job, bool) { return iter.Next() }) { + if r.runnableJob(wrap.job, iter.Next) { if r.running.Full() { return false } r.waiting.Remove(key) - r.running.Append(job) - job.SetStatus(JobRunning) + r.running.Append(wrap) + wrap.status = JobRunning if !r.running.Full() && !r.waiting.Empty() { r.trigger() } - runnable = job + runnable = wrap return true } return false }) if !ok { - return r.zero, false + return nil, false } return runnable, true } -func (r *Runner[T]) schedule(ctx context.Context, job T) { - if !job.Schedule() { - delete(r.all, job.GetKey()) - job.SetStatus(JobRemoved) - return +func (r *Runner[T]) schedule(ctx context.Context, deadline time.Time, job T) error { + wrap := newWrapper(job) + wrap.deadline = deadline + if wrap.deadline.After(time.Now()) { + r.delayQueue.Enqueue(ctx, wrap) + wrap.status = JobDelaying + } else { + r.waiting.Put(wrap.key, wrap) + wrap.status = JobWaiting + r.trigger() } - r.delayQueue.Enqueue(ctx, job) - job.SetStatus(JobDelay) + r.all[wrap.key] = wrap + return nil } -//func (r *Runner[T]) Schedule(ctx context.Context, job T) { -// r.mutex.Lock() -// defer r.mutex.Unlock() -// -// switch job.GetStatus() { -// case JobUnknown: -// case JobDelay: -// r.delayQueue.Remove(ctx, job.GetKey()) -// case JobWaiting: -// r.waiting.Remove(job) -// case JobRunning: -// // 标记为 removed, 任务执行完成后再删除 -// return -// case JobRemoved: -// return -// } -// r.schedule(ctx, job) -//} - func (r *Runner[T]) Add(ctx context.Context, job T) error { r.mutex.Lock() defer r.mutex.Unlock() if _, ok := r.all[job.GetKey()]; ok { - return nil + return ErrExist } - r.schedule(ctx, job) - return nil + deadline, err := r.doScheduleJob(job, false) + if err != nil { + return err + } + return r.schedule(ctx, deadline, job) } func (r *Runner[T]) UpdateOrAdd(ctx context.Context, job T) error { r.mutex.Lock() defer r.mutex.Unlock() - if old, ok := r.all[job.GetKey()]; ok { - old.Update(job) - job = old + wrap, ok := r.all[job.GetKey()] + if ok { + wrap.job.Update(job) + switch wrap.status { + case JobDelaying: + r.delayQueue.Remove(ctx, wrap.key) + delete(r.all, wrap.key) + case JobWaiting: + r.waiting.Remove(wrap.key) + delete(r.all, wrap.key) + case JobRunning: + return nil + default: + } } - r.schedule(ctx, job) - return nil + deadline, err := r.doScheduleJob(job, false) + if err != nil { + return err + } + return r.schedule(ctx, deadline, wrap.job) } func (r *Runner[T]) StartNow(ctx context.Context, job T) error { r.mutex.Lock() defer r.mutex.Unlock() - key := job.GetKey() - if old, ok := r.all[key]; ok { - job = old - if job.GetStatus() == JobDelay { - r.delayQueue.Remove(ctx, key) - r.waiting.Put(key, job) - r.trigger() + if wrap, ok := r.all[job.GetKey()]; ok { + switch wrap.status { + case JobDelaying: + r.delayQueue.Remove(ctx, wrap.key) + delete(r.all, wrap.key) + case JobWaiting, JobRunning: + return nil + default: } - return nil } - r.all[key] = job - r.waiting.Put(key, job) - r.trigger() - return nil + return r.schedule(ctx, time.Now(), job) } func (r *Runner[T]) trigger() { @@ -317,23 +356,21 @@ func (r *Runner[T]) Remove(ctx context.Context, key JobKey) error { r.mutex.Lock() defer r.mutex.Unlock() - job, ok := r.all[key] + wrap, ok := r.all[key] if !ok { - return nil + return ErrNotFound } - switch job.GetStatus() { - case JobUnknown: - panic(fmt.Sprintf("invalid job status %v occurred after added", job.GetStatus())) - case JobDelay: + switch wrap.status { + case JobDelaying: r.delayQueue.Remove(ctx, key) + delete(r.all, key) case JobWaiting: r.waiting.Remove(key) + delete(r.all, key) case JobRunning: - // 标记为 removed, 任务执行完成后再删除 - case JobRemoved: - return nil + // 统一标记为 removed, 待任务执行完成后再删除 + wrap.removed = true + default: } - delete(r.all, key) - job.SetStatus(JobRemoved) return nil } diff --git a/server/pkg/runner/runner_test.go b/server/pkg/runner/runner_test.go index c5443037..2eedc8b0 100644 --- a/server/pkg/runner/runner_test.go +++ b/server/pkg/runner/runner_test.go @@ -2,6 +2,7 @@ package runner import ( "context" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "mayfly-go/pkg/utils/timex" "sync" @@ -11,72 +12,35 @@ import ( var _ Job = &testJob{} -func newTestJob(key string, runTime time.Duration) *testJob { +func newTestJob(key string) *testJob { return &testJob{ - deadline: time.Now(), - Key: key, - run: func(ctx context.Context) { - timex.SleepWithContext(ctx, runTime) - }, + Key: key, } } type testJob struct { - run RunFunc - Key JobKey - status JobStatus - ran bool - deadline time.Time + Key JobKey + status int } -func (t *testJob) Update(_ Job) { -} - -func (t *testJob) GetDeadline() time.Time { - return t.deadline -} - -func (t *testJob) Schedule() bool { - return !t.ran -} - -func (t *testJob) Run(ctx context.Context) { - if t.run == nil { - return - } - t.run(ctx) - t.ran = true -} - -func (t *testJob) Runnable(_ NextFunc) bool { - return true -} +func (t *testJob) Update(_ Job) {} func (t *testJob) GetKey() JobKey { return t.Key } -func (t *testJob) GetStatus() JobStatus { - return t.status -} - -func (t *testJob) SetStatus(status JobStatus) { - t.status = status -} - func TestRunner_Close(t *testing.T) { - runner := NewRunner[*testJob](1) signal := make(chan struct{}, 1) waiting := sync.WaitGroup{} waiting.Add(1) + runner := NewRunner[*testJob](1, func(ctx context.Context, job *testJob) { + waiting.Done() + timex.SleepWithContext(ctx, time.Hour) + signal <- struct{}{} + }) go func() { job := &testJob{ Key: "close", - run: func(ctx context.Context) { - waiting.Done() - timex.SleepWithContext(ctx, time.Hour) - signal <- struct{}{} - }, } _ = runner.Add(context.Background(), job) }() @@ -95,55 +59,65 @@ func TestRunner_AddJob(t *testing.T) { type testCase struct { name string job *testJob - want bool + want error } testCases := []testCase{ { name: "first job", - job: newTestJob("single", time.Hour), - want: true, + job: newTestJob("single"), + want: nil, }, { name: "second job", - job: newTestJob("dual", time.Hour), - want: true, - }, - { - name: "non repetitive job", - job: newTestJob("single", time.Hour), - want: true, + job: newTestJob("dual"), + want: nil, }, { name: "repetitive job", - job: newTestJob("dual", time.Hour), - want: true, + job: newTestJob("dual"), + want: ErrExist, }, } - runner := NewRunner[*testJob](1) + runner := NewRunner[*testJob](1, func(ctx context.Context, job *testJob) { + timex.SleepWithContext(ctx, time.Hour) + }) defer runner.Close() for _, tc := range testCases { err := runner.Add(context.Background(), tc.job) + if tc.want != nil { + require.ErrorIs(t, err, tc.want) + continue + } require.NoError(t, err) } } func TestJob_UpdateStatus(t *testing.T) { const d = time.Millisecond * 20 - runner := NewRunner[*testJob](1) - running := newTestJob("running", d*2) - waiting := newTestJob("waiting", d*2) - _ = runner.Add(context.Background(), running) - _ = runner.Add(context.Background(), waiting) + const ( + unknown = iota + running + finished + ) + runner := NewRunner[*testJob](1, func(ctx context.Context, job *testJob) { + job.status = running + timex.SleepWithContext(ctx, d*2) + job.status = finished + }) + first := newTestJob("first") + second := newTestJob("second") + _ = runner.Add(context.Background(), first) + _ = runner.Add(context.Background(), second) time.Sleep(d) - require.Equal(t, JobRunning, running.status) - require.Equal(t, JobWaiting, waiting.status) + assert.Equal(t, running, first.status) + assert.Equal(t, unknown, second.status) time.Sleep(d * 2) - require.Equal(t, JobRemoved, running.status) - require.Equal(t, JobRunning, waiting.status) + assert.Equal(t, finished, first.status) + assert.Equal(t, running, second.status) time.Sleep(d * 2) - require.Equal(t, JobRemoved, running.status) - require.Equal(t, JobRemoved, waiting.status) + assert.Equal(t, finished, first.status) + assert.Equal(t, finished, second.status) }