!84 fix: 修复数据库备份与恢复问题

* refactor dbScheduler
* fix: 按团队名称检索团队
* feat: 创建数据库资源时支持全选数据库
* refactor dbScheduler
* fix: 修复数据库备份与恢复问题
This commit is contained in:
kanzihuang
2024-01-17 08:37:22 +00:00
committed by Coder慌
parent cc3981d99c
commit 94da6df33e
35 changed files with 846 additions and 609 deletions

View File

@@ -4,18 +4,21 @@
<el-form :model="state.form" ref="backupForm" label-width="auto" :rules="rules">
<el-form-item prop="dbNames" label="数据库名称">
<el-select
@change="changeDatabase"
v-model="state.selectedDbNames"
v-model="state.dbNamesSelected"
multiple
clearable
collapse-tags
collapse-tags-tooltip
filterable
:disabled="state.editOrCreate"
:filter-method="filterDbNames"
placeholder="数据库名称"
style="width: 100%"
>
<el-option v-for="db in state.dbNamesWithoutBackup" :key="db" :label="db" :value="db" />
<template #header>
<el-checkbox v-model="checkAllDbNames" :indeterminate="indeterminateDbNames" @change="handleCheckAll"> 全选 </el-checkbox>
</template>
<el-option v-for="db in state.dbNamesFiltered" :key="db" :label="db" :value="db" />
</el-select>
</el-form-item>
@@ -41,9 +44,10 @@
</template>
<script lang="ts" setup>
import { reactive, ref, watch } from 'vue';
import { reactive, ref, toRefs, watch } from 'vue';
import { dbApi } from './api';
import { ElMessage } from 'element-plus';
import type { CheckboxValueType } from 'element-plus';
const props = defineProps({
data: {
@@ -96,38 +100,38 @@ const state = reactive({
form: {
id: 0,
dbId: 0,
dbNames: String,
dbNames: '',
name: null as any,
intervalDay: null,
startTime: null as any,
repeated: null as any,
},
btnLoading: false,
selectedDbNames: [] as any,
dbNamesSelected: [] as any,
dbNamesWithoutBackup: [] as any,
dbNamesFiltered: [] as any,
filterString: '',
editOrCreate: false,
});
const { dbNamesSelected, dbNamesWithoutBackup } = toRefs(state);
const checkAllDbNames = ref(false);
const indeterminateDbNames = ref(false);
watch(visible, (newValue: any) => {
if (newValue) {
init(props.data);
}
});
/**
* 改变表单中的数据库字段,方便表单错误提示。如全部删光,可提示请添加数据库
*/
const changeDatabase = () => {
state.form.dbNames = state.selectedDbNames.length == 0 ? '' : state.selectedDbNames.join(' ');
};
const init = (data: any) => {
state.selectedDbNames = [];
state.dbNamesSelected = [];
state.form.dbId = props.dbId;
if (data) {
state.editOrCreate = true;
state.dbNamesWithoutBackup = [data.dbName];
state.selectedDbNames = [data.dbName];
state.dbNamesSelected = [data.dbName];
state.form.id = data.id;
state.form.dbNames = data.dbName;
state.form.name = data.name;
@@ -179,5 +183,50 @@ const cancel = () => {
visible.value = false;
emit('cancel');
};
const checkDbSelect = (val: string[]) => {
const selected = val.filter((dbName: string) => {
return dbName.includes(state.filterString);
});
if (selected.length === 0) {
checkAllDbNames.value = false;
indeterminateDbNames.value = false;
return;
}
if (selected.length === state.dbNamesFiltered.length) {
checkAllDbNames.value = true;
indeterminateDbNames.value = false;
return;
}
indeterminateDbNames.value = true;
};
watch(dbNamesSelected, (val: string[]) => {
checkDbSelect(val);
state.form.dbNames = val.join(' ');
});
watch(dbNamesWithoutBackup, (val: string[]) => {
state.dbNamesFiltered = val.map((dbName: string) => dbName);
});
const handleCheckAll = (val: CheckboxValueType) => {
const selected = state.dbNamesSelected.filter((dbName: string) => {
return !state.dbNamesFiltered.includes(dbName);
});
if (val) {
state.dbNamesSelected = selected.concat(state.dbNamesFiltered);
} else {
state.dbNamesSelected = selected;
}
};
const filterDbNames = (filterString: string) => {
state.dbNamesFiltered = state.dbNamesWithoutBackup.filter((dbName: string) => {
return dbName.includes(filterString);
});
state.filterString = filterString;
checkDbSelect(state.dbNamesSelected);
};
</script>
<style lang="scss"></style>

View File

@@ -52,20 +52,23 @@
<el-input v-model.trim="form.name" placeholder="请输入数据库别名" auto-complete="off"></el-input>
</el-form-item>
<el-form-item prop="database" label="数据库名" required>
<el-form-item prop="database" label="数据库名">
<el-select
@change="changeDatabase"
v-model="databaseList"
v-model="dbNamesSelected"
multiple
clearable
collapse-tags
collapse-tags-tooltip
filterable
:filter-method="filterDbNames"
allow-create
placeholder="请确保数据库实例信息填写完整后获取库名"
style="width: 100%"
>
<el-option v-for="db in allDatabases" :key="db" :label="db" :value="db" />
<template #header>
<el-checkbox v-model="checkAllDbNames" :indeterminate="indeterminateDbNames" @change="handleCheckAll"> 全选 </el-checkbox>
</template>
<el-option v-for="db in state.dbNamesFiltered" :key="db" :label="db" :value="db" />
</el-select>
</el-form-item>
@@ -90,6 +93,7 @@ import { dbApi } from './api';
import { ElMessage } from 'element-plus';
import TagTreeSelect from '../component/TagTreeSelect.vue';
import { TagResourceTypeEnum } from '@/common/commonEnum';
import type { CheckboxValueType } from 'element-plus';
const props = defineProps({
visible: {
@@ -139,13 +143,18 @@ const rules = {
],
};
const checkAllDbNames = ref(false);
const indeterminateDbNames = ref(false);
const dbForm: any = ref(null);
const tagSelectRef: any = ref(null);
const state = reactive({
dialogVisible: false,
allDatabases: [] as any,
databaseList: [] as any,
dbNamesSelected: [] as any,
dbNamesFiltered: [] as any,
filterString: '',
form: {
id: null,
tagId: [],
@@ -158,7 +167,7 @@ const state = reactive({
instances: [] as any,
});
const { dialogVisible, allDatabases, form, databaseList } = toRefs(state);
const { dialogVisible, allDatabases, form, dbNamesSelected } = toRefs(state);
const { isFetching: saveBtnLoading, execute: saveDbExec } = dbApi.saveDb.useApi(form);
@@ -171,25 +180,18 @@ watch(props, async (newValue: any) => {
state.form = { ...newValue.db };
// 将数据库名使用空格切割,获取所有数据库列表
state.databaseList = newValue.db.database.split(' ');
state.dbNamesSelected = newValue.db.database.split(' ');
} else {
state.form = {} as any;
state.databaseList = [];
state.dbNamesSelected = [];
}
});
const changeInstance = () => {
state.databaseList = [];
state.dbNamesSelected = [];
getAllDatabase();
};
/**
* 改变表单中的数据库字段,方便表单错误提示。如全部删光,可提示请添加数据库
*/
const changeDatabase = () => {
state.form.database = state.databaseList.length == 0 ? '' : state.databaseList.join(' ');
};
const getAllDatabase = async () => {
if (state.form.instanceId > 0) {
state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId });
@@ -210,7 +212,7 @@ const getInstances = async (instanceName: string = '', id = 0) => {
const open = async () => {
if (state.form.instanceId) {
// 根据id获取因为需要回显实例名称
getInstances('', state.form.instanceId);
await getInstances('', state.form.instanceId);
}
await getAllDatabase();
};
@@ -230,7 +232,7 @@ const btnOk = async () => {
};
const resetInputDb = () => {
state.databaseList = [];
state.dbNamesSelected = [];
state.allDatabases = [];
state.instances = [];
};
@@ -242,5 +244,62 @@ const cancel = () => {
resetInputDb();
}, 500);
};
const checkDbSelect = (val: string[]) => {
const selected = val.filter((dbName: string) => {
return dbName.includes(state.filterString);
});
if (selected.length === 0) {
checkAllDbNames.value = false;
indeterminateDbNames.value = false;
return;
}
if (selected.length === state.dbNamesFiltered.length) {
checkAllDbNames.value = true;
indeterminateDbNames.value = false;
return;
}
indeterminateDbNames.value = true;
};
watch(dbNamesSelected, (val: string[]) => {
checkDbSelect(val);
state.form.database = val.join(' ');
});
watch(allDatabases, (val: string[]) => {
state.dbNamesFiltered = val.map((dbName: string) => dbName);
});
const handleCheckAll = (val: CheckboxValueType) => {
const otherSelected = state.dbNamesSelected.filter((dbName: string) => {
return !state.dbNamesFiltered.includes(dbName);
});
if (val) {
state.dbNamesSelected = otherSelected.concat(state.dbNamesFiltered);
} else {
state.dbNamesSelected = otherSelected;
}
};
const filterDbNames = (filterString: string) => {
const dbNamesCreated = state.dbNamesSelected.filter((dbName: string) => {
return !state.allDatabases.includes(dbName);
});
if (filterString.length === 0) {
state.dbNamesFiltered = dbNamesCreated.concat(state.allDatabases);
checkDbSelect(state.dbNamesSelected);
return;
}
state.dbNamesFiltered = dbNamesCreated.concat(state.allDatabases).filter((dbName: string) => {
if (dbName == filterString) {
return false;
}
return dbName.includes(filterString);
});
state.dbNamesFiltered.unshift(filterString);
state.filterString = filterString;
checkDbSelect(state.dbNamesSelected);
};
</script>
<style lang="scss"></style>

View File

@@ -198,11 +198,11 @@ const perms = {
const searchItems = [getTagPathSearchItem(TagResourceTypeEnum.Db.value), SearchItem.slot('instanceId', '实例', 'instanceSelect')];
const columns = ref([
TableColumn.new('instanceName', '实例名'),
TableColumn.new('name', '名'),
TableColumn.new('type', '类型').isSlot().setAddWidth(-15).alignCenter(),
TableColumn.new('instanceName', '实例名'),
TableColumn.new('host', 'ip:port').isSlot().setAddWidth(40),
TableColumn.new('username', 'username'),
TableColumn.new('name', '名称'),
TableColumn.new('tagPath', '关联标签').isSlot().setAddWidth(10).alignCenter(),
TableColumn.new('remark', '备注'),
]);

View File

@@ -155,7 +155,7 @@ const state = reactive({
pointInTime: null as any,
},
btnLoading: false,
selectedDbNames: [] as any,
dbNamesSelected: [] as any,
dbNamesWithoutRestore: [] as any,
editOrCreate: false,
histories: [] as any,
@@ -189,12 +189,12 @@ const changeHistory = async () => {
};
const init = async (data: any) => {
state.selectedDbNames = [];
state.dbNamesSelected = [];
state.form.dbId = props.dbId;
if (data) {
state.editOrCreate = true;
state.dbNamesWithoutRestore = [data.dbName];
state.selectedDbNames = [data.dbName];
state.dbNamesSelected = [data.dbName];
state.form.id = data.id;
state.form.dbName = data.dbName;
state.form.intervalDay = data.intervalDay;

View File

@@ -54,9 +54,14 @@ func (d *DbBackup) Create(rc *req.Ctx) {
jobs := make([]*entity.DbBackup, 0, len(dbNames))
for _, dbName := range dbNames {
job := &entity.DbBackup{
DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, dbName, entity.DbJobTypeBackup, true, backupForm.Repeated, backupForm.StartTime, backupForm.Interval),
DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, entity.DbJobTypeBackup),
Enabled: true,
Repeated: backupForm.Repeated,
StartTime: backupForm.StartTime,
Interval: backupForm.Interval,
Name: backupForm.Name,
}
job.DbName = dbName
jobs = append(jobs, job)
}
biz.ErrIsNilAppendErr(d.DbBackupApp.Create(rc.MetaCtx, jobs), "添加数据库备份任务失败: %v")

View File

@@ -48,12 +48,17 @@ func (d *DbRestore) Create(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v")
job := &entity.DbRestore{
DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, restoreForm.DbName, entity.DbJobTypeRestore, true, restoreForm.Repeated, restoreForm.StartTime, restoreForm.Interval),
DbJobBaseImpl: entity.NewDbBJobBase(db.InstanceId, entity.DbJobTypeRestore),
Enabled: true,
Repeated: restoreForm.Repeated,
StartTime: restoreForm.StartTime,
Interval: restoreForm.Interval,
PointInTime: restoreForm.PointInTime,
DbBackupId: restoreForm.DbBackupId,
DbBackupHistoryId: restoreForm.DbBackupHistoryId,
DbBackupHistoryName: restoreForm.DbBackupHistoryName,
}
job.DbName = restoreForm.DbName
biz.ErrIsNilAppendErr(d.DbRestoreApp.Create(rc.MetaCtx, job), "添加数据库恢复任务失败: %v")
}

View File

@@ -50,7 +50,7 @@ func Init() {
if err != nil {
panic(fmt.Sprintf("初始化 dbRestoreApp 失败: %v", err))
}
dbBinlogApp, err = newDbBinlogApp(repositories, dbApp)
dbBinlogApp, err = newDbBinlogApp(repositories, dbApp, scheduler)
if err != nil {
panic(fmt.Sprintf("初始化 dbBinlogApp 失败: %v", err))
}

View File

@@ -2,20 +2,14 @@ package application
import (
"context"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/utils/timex"
"sync"
"time"
)
const (
binlogDownloadInterval = time.Minute * 15
)
type DbBinlogApp struct {
binlogRepo repository.DbBinlog
binlogHistoryRepo repository.DbBinlogHistory
@@ -25,19 +19,10 @@ type DbBinlogApp struct {
context context.Context
cancel context.CancelFunc
waitGroup sync.WaitGroup
scheduler *dbScheduler
}
var (
binlogResult = map[entity.DbJobStatus]string{
entity.DbJobDelay: "等待备份BINLOG",
entity.DbJobReady: "准备备份BINLOG",
entity.DbJobRunning: "BINLOG备份中",
entity.DbJobSuccess: "BINLOG备份成功",
entity.DbJobFailed: "BINLOG备份失败",
}
)
func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogApp, error) {
func newDbBinlogApp(repositories *repository.Repositories, dbApp Db, scheduler *dbScheduler) (*DbBinlogApp, error) {
ctx, cancel := context.WithCancel(context.Background())
svc := &DbBinlogApp{
binlogRepo: repositories.Binlog,
@@ -45,6 +30,7 @@ func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogA
backupRepo: repositories.Backup,
backupHistoryRepo: repositories.BackupHistory,
dbApp: dbApp,
scheduler: scheduler,
context: ctx,
cancel: cancel,
}
@@ -53,73 +39,47 @@ func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogA
return svc, nil
}
func (app *DbBinlogApp) fetchBinlog(ctx context.Context, backup *entity.DbBackup) error {
if err := app.AddJobIfNotExists(ctx, entity.NewDbBinlog(backup.DbInstanceId)); err != nil {
return err
}
latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
binlogHistory, ok, err := app.binlogHistoryRepo.GetLatestHistory(backup.DbInstanceId)
if err != nil {
return err
}
if ok {
latestBinlogSequence = binlogHistory.Sequence
} else {
backupHistory, ok, err := app.backupHistoryRepo.GetEarliestHistory(backup.DbInstanceId)
if err != nil {
return err
}
if !ok {
return nil
}
earliestBackupSequence = backupHistory.BinlogSequence
}
conn, err := app.dbApp.GetDbConnByInstanceId(backup.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
binlogFiles, err := dbProgram.FetchBinlogs(ctx, false, earliestBackupSequence, latestBinlogSequence)
if err == nil {
err = app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, backup.DbInstanceId, binlogFiles)
}
jobStatus := entity.DbJobSuccess
if err != nil {
jobStatus = entity.DbJobFailed
}
job := &entity.DbBinlog{}
job.Id = backup.DbInstanceId
return app.updateCurJob(ctx, jobStatus, err, job)
}
func (app *DbBinlogApp) run() {
defer app.waitGroup.Done()
// todo: 实现 binlog 并发下载
timex.SleepWithContext(app.context, time.Minute)
for !app.closed() {
app.fetchFromAllInstances()
timex.SleepWithContext(app.context, binlogDownloadInterval)
}
}
func (app *DbBinlogApp) fetchFromAllInstances() {
var backups []*entity.DbBackup
if err := app.backupRepo.ListRepeating(&backups); err != nil {
logx.Errorf("DbBinlogApp: 获取数据库备份任务失败: %s", err.Error())
return
}
for _, backup := range backups {
jobs, err := app.loadJobs()
if err != nil {
logx.Errorf("DbBinlogApp: 加载 BINLOG 同步任务失败: %s", err.Error())
timex.SleepWithContext(app.context, time.Minute)
continue
}
if app.closed() {
break
}
if err := app.fetchBinlog(app.context, backup); err != nil {
logx.Errorf("DbBinlogApp: 下载 binlog 文件失败: %s", err.Error())
return
if err := app.scheduler.AddJob(app.context, false, entity.DbJobTypeBinlog, jobs); err != nil {
logx.Error("DbBinlogApp: 添加 BINLOG 同步任务失败: ", err.Error())
}
timex.SleepWithContext(app.context, entity.BinlogDownloadInterval)
}
}
func (app *DbBinlogApp) loadJobs() ([]*entity.DbBinlog, error) {
var instanceIds []uint64
if err := app.backupRepo.ListDbInstances(true, true, &instanceIds); err != nil {
return nil, err
}
jobs := make([]*entity.DbBinlog, 0, len(instanceIds))
for _, id := range instanceIds {
if app.closed() {
break
}
binlog := entity.NewDbBinlog(id)
if err := app.AddJobIfNotExists(app.context, binlog); err != nil {
return nil, err
}
jobs = append(jobs, binlog)
}
return jobs, nil
}
func (app *DbBinlogApp) Close() {
app.cancel()
app.waitGroup.Wait()
@@ -146,14 +106,3 @@ func (app *DbBinlogApp) Delete(ctx context.Context, jobId uint64) error {
}
return nil
}
func (app *DbBinlogApp) updateCurJob(ctx context.Context, status entity.DbJobStatus, lastErr error, job *entity.DbBinlog) error {
job.LastStatus = status
var result = binlogResult[status]
if lastErr != nil {
result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr)
}
job.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
job.LastTime = timex.NewNullTime(time.Now())
return app.binlogRepo.UpdateById(ctx, job, "last_status", "last_result", "last_time")
}

View File

@@ -26,28 +26,40 @@ type dbScheduler struct {
backupHistoryRepo repository.DbBackupHistory
restoreRepo repository.DbRestore
restoreHistoryRepo repository.DbRestoreHistory
binlogRepo repository.DbBinlog
binlogHistoryRepo repository.DbBinlogHistory
binlogTimes map[uint64]time.Time
}
func newDbScheduler(repositories *repository.Repositories) (*dbScheduler, error) {
scheduler := &dbScheduler{
runner: runner.NewRunner[entity.DbJob](maxRunning),
dbApp: dbApp,
backupRepo: repositories.Backup,
backupHistoryRepo: repositories.BackupHistory,
restoreRepo: repositories.Restore,
restoreHistoryRepo: repositories.RestoreHistory,
binlogRepo: repositories.Binlog,
binlogHistoryRepo: repositories.BinlogHistory,
}
scheduler.runner = runner.NewRunner[entity.DbJob](maxRunning, scheduler.runJob,
runner.WithScheduleJob[entity.DbJob](scheduler.scheduleJob),
runner.WithRunnableJob[entity.DbJob](scheduler.runnableJob),
)
return scheduler, nil
}
func (s *dbScheduler) scheduleJob(job entity.DbJob) (time.Time, error) {
return job.Schedule()
}
func (s *dbScheduler) repo(typ entity.DbJobType) repository.DbJob {
switch typ {
case entity.DbJobTypeBackup:
return s.backupRepo
case entity.DbJobTypeRestore:
return s.restoreRepo
case entity.DbJobTypeBinlog:
return s.binlogRepo
default:
panic(errors.New(fmt.Sprintf("无效的数据库任务类型: %v", typ)))
}
@@ -60,8 +72,6 @@ func (s *dbScheduler) UpdateJob(ctx context.Context, job entity.DbJob) error {
if err := s.repo(job.GetJobType()).UpdateById(ctx, job); err != nil {
return err
}
job.SetRun(s.run)
job.SetRunnable(s.runnable)
_ = s.runner.UpdateOrAdd(ctx, job)
return nil
}
@@ -87,21 +97,11 @@ func (s *dbScheduler) AddJob(ctx context.Context, saving bool, jobType entity.Db
for i := 0; i < reflectLen; i++ {
job := reflectValue.Index(i).Interface().(entity.DbJob)
job.SetJobType(jobType)
if !job.Schedule() {
continue
}
job.SetRun(s.run)
job.SetRunnable(s.runnable)
_ = s.runner.Add(ctx, job)
}
default:
job := jobs.(entity.DbJob)
job.SetJobType(jobType)
if !job.Schedule() {
return nil
}
job.SetRun(s.run)
job.SetRunnable(s.runnable)
_ = s.runner.Add(ctx, job)
}
return nil
@@ -131,12 +131,10 @@ func (s *dbScheduler) EnableJob(ctx context.Context, jobType entity.DbJobType, j
if job.IsEnabled() {
return nil
}
job.GetJobBase().Enabled = true
job.SetEnabled(true)
if err := repo.UpdateEnabled(ctx, jobId, true); err != nil {
return err
}
job.SetRun(s.run)
job.SetRunnable(s.runnable)
_ = s.runner.Add(ctx, job)
return nil
}
@@ -171,9 +169,6 @@ func (s *dbScheduler) StartJobNow(ctx context.Context, jobType entity.DbJobType,
if !job.IsEnabled() {
return errors.New("任务未启用")
}
job.GetJobBase().Deadline = time.Now()
job.SetRun(s.run)
job.SetRunnable(s.runnable)
_ = s.runner.StartNow(ctx, job)
return nil
}
@@ -267,7 +262,7 @@ func (s *dbScheduler) restoreMysql(ctx context.Context, job entity.DbJob) error
return nil
}
func (s *dbScheduler) run(ctx context.Context, job entity.DbJob) {
func (s *dbScheduler) runJob(ctx context.Context, job entity.DbJob) {
job.SetLastStatus(entity.DbJobRunning, nil)
if err := s.repo(job.GetJobType()).UpdateLastStatus(ctx, job); err != nil {
logx.Errorf("failed to update job status: %v", err)
@@ -280,6 +275,8 @@ func (s *dbScheduler) run(ctx context.Context, job entity.DbJob) {
errRun = s.backupMysql(ctx, job)
case entity.DbJobTypeRestore:
errRun = s.restoreMysql(ctx, job)
case entity.DbJobTypeBinlog:
errRun = s.fetchBinlogMysql(ctx, job)
default:
errRun = errors.New(fmt.Sprintf("无效的数据库任务类型: %v", typ))
}
@@ -294,19 +291,27 @@ func (s *dbScheduler) run(ctx context.Context, job entity.DbJob) {
}
}
func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool {
func (s *dbScheduler) runnableJob(job entity.DbJob, next runner.NextJobFunc[entity.DbJob]) bool {
const maxCountByInstanceId = 4
const maxCountByDbName = 1
var countByInstanceId, countByDbName int
jobBase := job.GetJobBase()
for item, ok := next(); ok; item, ok = next() {
itemBase := item.(entity.DbJob).GetJobBase()
itemBase := item.GetJobBase()
if jobBase.DbInstanceId == itemBase.DbInstanceId {
countByInstanceId++
if countByInstanceId >= maxCountByInstanceId {
return false
}
if jobBase.DbName == itemBase.DbName {
if relatedToBinlog(job.GetJobType()) {
// todo: 恢复数据库前触发 BINLOG 同步BINLOG 同步完成后才能恢复数据库
if relatedToBinlog(item.GetJobType()) {
return false
}
}
if job.GetDbName() == item.GetDbName() {
countByDbName++
if countByDbName >= maxCountByDbName {
return false
@@ -317,6 +322,10 @@ func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool {
return true
}
func relatedToBinlog(typ entity.DbJobType) bool {
return typ == entity.DbJobTypeRestore || typ == entity.DbJobTypeBinlog
}
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error {
binlogHistory, err := s.binlogHistoryRepo.GetHistoryByTime(job.DbInstanceId, job.PointInTime.Time)
if err != nil {
@@ -364,3 +373,34 @@ func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbi.DbPr
}
return program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid)
}
func (s *dbScheduler) fetchBinlogMysql(ctx context.Context, backup entity.DbJob) error {
instanceId := backup.GetJobBase().DbInstanceId
latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
binlogHistory, ok, err := s.binlogHistoryRepo.GetLatestHistory(instanceId)
if err != nil {
return err
}
if ok {
latestBinlogSequence = binlogHistory.Sequence
} else {
backupHistory, ok, err := s.backupHistoryRepo.GetEarliestHistory(instanceId)
if err != nil {
return err
}
if !ok {
return nil
}
earliestBackupSequence = backupHistory.BinlogSequence
}
conn, err := s.dbApp.GetDbConnByInstanceId(instanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
binlogFiles, err := dbProgram.FetchBinlogs(ctx, false, earliestBackupSequence, latestBinlogSequence)
if err == nil {
err = s.binlogHistoryRepo.InsertWithBinlogFiles(ctx, instanceId, binlogFiles)
}
return nil
}

View File

@@ -16,8 +16,6 @@ import (
"time"
"github.com/pkg/errors"
"golang.org/x/sync/singleflight"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
@@ -186,6 +184,12 @@ func (svc *DbProgramMysql) downloadBinlogFilesOnServer(ctx context.Context, binl
return nil
}
// Parse the first binlog eventTs of a local binlog file.
func (svc *DbProgramMysql) parseLocalBinlogLastEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) {
// todo: implement me
return time.Now(), nil
}
// Parse the first binlog eventTs of a local binlog file.
func (svc *DbProgramMysql) parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) {
args := []string{
@@ -227,36 +231,8 @@ func (svc *DbProgramMysql) parseLocalBinlogFirstEventTime(ctx context.Context, f
return time.Time{}, errors.New("解析 binlog 文件失败")
}
var singleFlightGroup singleflight.Group
// FetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`.
func (svc *DbProgramMysql) FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) {
var downloaded bool
key := strconv.FormatUint(svc.dbInfo().InstanceId, 16)
binlogFiles, err, _ := singleFlightGroup.Do(key, func() (interface{}, error) {
downloaded = true
return svc.fetchBinlogs(ctx, downloadLatestBinlogFile, earliestBackupSequence, latestBinlogSequence)
})
if err != nil {
return nil, err
}
if downloaded {
return binlogFiles.([]*entity.BinlogFile), nil
}
if !downloadLatestBinlogFile {
return nil, nil
}
binlogFiles, err, _ = singleFlightGroup.Do(key, func() (interface{}, error) {
return svc.fetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence)
})
if err != nil {
return nil, err
}
return binlogFiles.([]*entity.BinlogFile), err
}
// fetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`.
func (svc *DbProgramMysql) fetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) {
// Read binlog files list on server.
binlogFilesOnServerSorted, err := svc.GetSortedBinlogFilesOnServer(ctx)
if err != nil {
@@ -352,7 +328,13 @@ func (svc *DbProgramMysql) downloadBinlogFile(ctx context.Context, binlogFileToD
if err != nil {
return err
}
lastEventTime, err := svc.parseLocalBinlogLastEventTime(ctx, binlogFilePath)
if err != nil {
return err
}
binlogFileToDownload.FirstEventTime = firstEventTime
binlogFileToDownload.LastEventTime = lastEventTime
binlogFileToDownload.Downloaded = true
return nil

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"github.com/stretchr/testify/suite"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/infrastructure/persistence"
@@ -32,21 +33,21 @@ type DbInstanceSuite struct {
suite.Suite
repositories *repository.Repositories
instanceSvc *DbProgramMysql
dbConn *DbConn
dbConn *dbi.DbConn
}
func (s *DbInstanceSuite) SetupSuite() {
if err := chdir("mayfly-go", "server"); err != nil {
panic(err)
}
dbInfo := DbInfo{
Type: DbTypeMysql,
dbInfo := dbi.DbInfo{
Type: dbi.DbTypeMysql,
Host: "localhost",
Port: 3306,
Username: "test",
Password: "test",
}
dbConn, err := dbInfo.Conn()
dbConn, err := dbInfo.Conn(GetMeta())
s.Require().NoError(err)
s.dbConn = dbConn
s.repositories = &repository.Repositories{
@@ -203,7 +204,7 @@ func (s *DbInstanceSuite) testReplayBinlog(backupHistory *entity.DbBackupHistory
binlogHistories = append(binlogHistories, binlogHistoryLast)
}
restoreInfo := &RestoreInfo{
restoreInfo := &dbi.RestoreInfo{
BackupHistory: backupHistory,
BinlogHistories: binlogHistories,
StartPosition: backupHistory.BinlogPosition,

View File

@@ -1,8 +1,8 @@
package entity
import (
"context"
"mayfly-go/pkg/runner"
"time"
)
var _ DbJob = (*DbBackup)(nil)
@@ -11,17 +11,56 @@ var _ DbJob = (*DbBackup)(nil)
type DbBackup struct {
*DbJobBaseImpl
Name string `json:"Name"` // 数据库备份名称
Enabled bool // 是否启用
StartTime time.Time // 开始时间
Interval time.Duration // 间隔时间
Repeated bool // 是否重复执行
DbName string // 数据库名称
Name string // 数据库备份名称
}
func (d *DbBackup) SetRun(fn func(ctx context.Context, job DbJob)) {
d.run = func(ctx context.Context) {
fn(ctx, d)
}
func (b *DbBackup) GetDbName() string {
return b.DbName
}
func (d *DbBackup) SetRunnable(fn func(job DbJob, next runner.NextFunc) bool) {
d.runnable = func(next runner.NextFunc) bool {
return fn(d, next)
func (b *DbBackup) Schedule() (time.Time, error) {
var deadline time.Time
if b.IsFinished() || !b.Enabled {
return deadline, runner.ErrFinished
}
switch b.LastStatus {
case DbJobSuccess:
lastTime := b.LastTime.Time
if lastTime.Before(b.StartTime) {
lastTime = b.StartTime.Add(-b.Interval)
}
deadline = lastTime.Add(b.Interval - lastTime.Sub(b.StartTime)%b.Interval)
case DbJobFailed:
deadline = time.Now().Add(time.Minute)
default:
deadline = b.StartTime
}
return deadline, nil
}
func (b *DbBackup) IsFinished() bool {
return !b.Repeated && b.LastStatus == DbJobSuccess
}
func (b *DbBackup) IsEnabled() bool {
return b.Enabled
}
func (b *DbBackup) SetEnabled(enabled bool) {
b.Enabled = enabled
}
func (b *DbBackup) Update(job runner.Job) {
backup := job.(*DbBackup)
b.StartTime = backup.StartTime
b.Interval = backup.Interval
}
func (b *DbBackup) GetInterval() time.Duration {
return b.Interval
}

View File

@@ -1,27 +1,13 @@
package entity
import (
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/timex"
"mayfly-go/pkg/runner"
"time"
)
// DbBinlog 数据库备份任务
type DbBinlog struct {
model.Model
LastStatus DbJobStatus // 最近一次执行状态
LastResult string // 最近一次执行结果
LastTime timex.NullTime // 最近一次执行时间
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
func NewDbBinlog(instanceId uint64) *DbBinlog {
job := &DbBinlog{}
job.Id = instanceId
job.DbInstanceId = instanceId
return job
}
const (
BinlogDownloadInterval = time.Minute * 15
)
// BinlogFile is the metadata of the MySQL binlog file.
type BinlogFile struct {
@@ -31,5 +17,49 @@ type BinlogFile struct {
// Sequence is parsed from Name and is for the sorting purpose.
Sequence int64
FirstEventTime time.Time
LastEventTime time.Time
Downloaded bool
}
var _ DbJob = (*DbBinlog)(nil)
// DbBinlog 数据库备份任务
type DbBinlog struct {
DbJobBaseImpl
}
func NewDbBinlog(instanceId uint64) *DbBinlog {
job := &DbBinlog{}
job.Id = instanceId
job.DbInstanceId = instanceId
return job
}
func (b *DbBinlog) GetDbName() string {
// binlog 是全库级别的
return ""
}
func (b *DbBinlog) Schedule() (time.Time, error) {
switch b.GetJobBase().LastStatus {
case DbJobSuccess:
return time.Time{}, runner.ErrFinished
case DbJobFailed:
return time.Now().Add(BinlogDownloadInterval), nil
default:
return time.Now(), nil
}
}
func (b *DbBinlog) Update(_ runner.Job) {}
func (b *DbBinlog) IsEnabled() bool {
return true
}
func (b *DbBinlog) SetEnabled(_ bool) {}
func (b *DbBinlog) GetInterval() time.Duration {
return 0
}

View File

@@ -1,7 +1,6 @@
package entity
import (
"context"
"fmt"
"mayfly-go/pkg/model"
"mayfly-go/pkg/runner"
@@ -14,18 +13,11 @@ const LastResultSize = 256
type DbJobKey = runner.JobKey
type DbJobStatus = runner.JobStatus
type DbJobStatus int
const (
DbJobUnknown = runner.JobUnknown
DbJobDelay = runner.JobDelay
DbJobReady = runner.JobWaiting
DbJobRunning = runner.JobRunning
DbJobRemoved = runner.JobRemoved
)
const (
DbJobSuccess DbJobStatus = 0x20 + iota
DbJobRunning DbJobStatus = iota
DbJobSuccess
DbJobFailed
)
@@ -34,32 +26,37 @@ type DbJobType = string
const (
DbJobTypeBackup DbJobType = "db-backup"
DbJobTypeRestore DbJobType = "db-restore"
DbJobTypeBinlog DbJobType = "db-binlog"
)
const (
DbJobNameBackup = "数据库备份"
DbJobNameRestore = "数据库恢复"
DbJobNameBinlog = "BINLOG同步"
)
var _ runner.Job = (DbJob)(nil)
type DbJobBase interface {
model.ModelI
runner.Job
GetId() uint64
GetKey() string
GetJobType() DbJobType
SetJobType(typ DbJobType)
GetJobBase() *DbJobBaseImpl
SetLastStatus(status DbJobStatus, err error)
IsEnabled() bool
}
type DbJob interface {
runner.Job
DbJobBase
SetRun(fn func(ctx context.Context, job DbJob))
SetRunnable(fn func(job DbJob, next runner.NextFunc) bool)
GetDbName() string
Schedule() (time.Time, error)
IsEnabled() bool
SetEnabled(enabled bool)
Update(job runner.Job)
GetInterval() time.Duration
}
func NewDbJob(typ DbJobType) DbJob {
@@ -85,31 +82,17 @@ type DbJobBaseImpl struct {
model.Model
DbInstanceId uint64 // 数据库实例ID
DbName string // 数据库名称
Enabled bool // 是否启用
StartTime time.Time // 开始时间
Interval time.Duration // 间隔时间
Repeated bool // 是否重复执行
LastStatus DbJobStatus // 最近一次执行状态
LastResult string // 最近一次执行结果
LastTime timex.NullTime // 最近一次执行时间
Deadline time.Time `gorm:"-" json:"-"` // 计划执行时间
run runner.RunFunc
runnable runner.RunnableFunc
jobType DbJobType
jobKey runner.JobKey
jobStatus runner.JobStatus
}
func NewDbBJobBase(instanceId uint64, dbName string, jobType DbJobType, enabled bool, repeated bool, startTime time.Time, interval time.Duration) *DbJobBaseImpl {
func NewDbBJobBase(instanceId uint64, jobType DbJobType) *DbJobBaseImpl {
return &DbJobBaseImpl{
DbInstanceId: instanceId,
DbName: dbName,
jobType: jobType,
Enabled: enabled,
Repeated: repeated,
StartTime: startTime,
Interval: interval,
}
}
@@ -138,6 +121,8 @@ func (d *DbJobBaseImpl) SetLastStatus(status DbJobStatus, err error) {
jobName = DbJobNameBackup
case DbJobTypeRestore:
jobName = DbJobNameRestore
case DbJobNameBinlog:
jobName = DbJobNameBinlog
default:
jobName = d.jobType
}
@@ -150,71 +135,10 @@ func (d *DbJobBaseImpl) SetLastStatus(status DbJobStatus, err error) {
d.LastTime = timex.NewNullTime(time.Now())
}
func (d *DbJobBaseImpl) GetId() uint64 {
if d == nil {
return 0
}
return d.Id
}
func (d *DbJobBaseImpl) GetDeadline() time.Time {
return d.Deadline
}
func (d *DbJobBaseImpl) Schedule() bool {
if d.IsFinished() || !d.Enabled {
return false
}
switch d.LastStatus {
case DbJobSuccess:
if d.Interval == 0 {
return false
}
lastTime := d.LastTime.Time
if lastTime.Sub(d.StartTime) < 0 {
lastTime = d.StartTime.Add(-d.Interval)
}
d.Deadline = lastTime.Add(d.Interval - lastTime.Sub(d.StartTime)%d.Interval)
case DbJobFailed:
d.Deadline = time.Now().Add(time.Minute)
default:
d.Deadline = d.StartTime
}
return true
}
func (d *DbJobBaseImpl) IsFinished() bool {
return !d.Repeated && d.LastStatus == DbJobSuccess
}
func (d *DbJobBaseImpl) Update(job runner.Job) {
jobBase := job.(DbJob).GetJobBase()
d.StartTime = jobBase.StartTime
d.Interval = jobBase.Interval
}
func (d *DbJobBaseImpl) GetJobBase() *DbJobBaseImpl {
return d
}
func (d *DbJobBaseImpl) IsEnabled() bool {
return d.Enabled
}
func (d *DbJobBaseImpl) Run(ctx context.Context) {
if d.run == nil {
return
}
d.run(ctx)
}
func (d *DbJobBaseImpl) Runnable(next runner.NextFunc) bool {
if d.runnable == nil {
return true
}
return d.runnable(next)
}
func FormatJobKey(typ DbJobType, jobId uint64) DbJobKey {
return fmt.Sprintf("%v-%d", typ, jobId)
}
@@ -225,11 +149,3 @@ func (d *DbJobBaseImpl) GetKey() DbJobKey {
}
return d.jobKey
}
func (d *DbJobBaseImpl) GetStatus() DbJobStatus {
return d.jobStatus
}
func (d *DbJobBaseImpl) SetStatus(status DbJobStatus) {
d.jobStatus = status
}

View File

@@ -1,9 +1,9 @@
package entity
import (
"context"
"mayfly-go/pkg/runner"
"mayfly-go/pkg/utils/timex"
"time"
)
var _ DbJob = (*DbRestore)(nil)
@@ -12,20 +12,59 @@ var _ DbJob = (*DbRestore)(nil)
type DbRestore struct {
*DbJobBaseImpl
DbName string // 数据库名称
Enabled bool // 是否启用
StartTime time.Time // 开始时间
Interval time.Duration // 间隔时间
Repeated bool // 是否重复执行
PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间点
DbBackupId uint64 `json:"dbBackupId"` // 用于恢复的数据库恢复任务ID
DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 用于恢复的数据库恢复历史ID
DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库恢复历史名称
}
func (d *DbRestore) SetRun(fn func(ctx context.Context, job DbJob)) {
d.run = func(ctx context.Context) {
fn(ctx, d)
}
func (r *DbRestore) GetDbName() string {
return r.DbName
}
func (d *DbRestore) SetRunnable(fn func(job DbJob, next runner.NextFunc) bool) {
d.runnable = func(next runner.NextFunc) bool {
return fn(d, next)
func (r *DbRestore) Schedule() (time.Time, error) {
var deadline time.Time
if r.IsFinished() || !r.Enabled {
return deadline, runner.ErrFinished
}
switch r.LastStatus {
case DbJobSuccess:
lastTime := r.LastTime.Time
if lastTime.Before(r.StartTime) {
lastTime = r.StartTime.Add(-r.Interval)
}
deadline = lastTime.Add(r.Interval - lastTime.Sub(r.StartTime)%r.Interval)
case DbJobFailed:
deadline = time.Now().Add(time.Minute)
default:
deadline = r.StartTime
}
return deadline, nil
}
func (r *DbRestore) IsEnabled() bool {
return r.Enabled
}
func (r *DbRestore) SetEnabled(enabled bool) {
r.Enabled = enabled
}
func (r *DbRestore) IsFinished() bool {
return !r.Repeated && r.LastStatus == DbJobSuccess
}
func (r *DbRestore) Update(job runner.Job) {
restore := job.(*DbRestore)
r.StartTime = restore.StartTime
r.Interval = restore.Interval
}
func (r *DbRestore) GetInterval() time.Duration {
return r.Interval
}

View File

@@ -1,7 +1,17 @@
package repository
import (
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/model"
)
type DbBackup interface {
DbJob
ListToDo(jobs any) error
ListDbInstances(enabled bool, repeated bool, instanceIds *[]uint64) error
GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error)
// GetPageList 分页获取数据库任务列表
GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
}

View File

@@ -3,11 +3,10 @@ package repository
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/base"
)
type DbBinlog interface {
base.Repo[*entity.DbBinlog]
DbJob
AddJobIfNotExists(ctx context.Context, job *entity.DbBinlog) error
}

View File

@@ -2,27 +2,27 @@ package repository
import (
"context"
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/model"
)
type DbJob interface {
// AddJob 添加数据库任务
AddJob(ctx context.Context, jobs any) error
type DbJobBase interface {
// GetById 根据实体id查询
GetById(e entity.DbJob, id uint64, cols ...string) error
// GetPageList 分页获取数据库任务列表
GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
// UpdateById 根据实体id更新实体信息
UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error
// BatchInsertWithDb 使用指定gorm db执行主要用于事务执行
BatchInsertWithDb(ctx context.Context, db *gorm.DB, es any) error
// DeleteById 根据实体主键删除实体
DeleteById(ctx context.Context, id uint64) error
// UpdateLastStatus 更新任务执行状态
UpdateLastStatus(ctx context.Context, job entity.DbJob) error
UpdateEnabled(ctx context.Context, jobId uint64, enabled bool) error
ListToDo(jobs any) error
ListRepeating(jobs any) error
}
type DbJob interface {
DbJobBase
// AddJob 添加数据库任务
AddJob(ctx context.Context, jobs any) error
UpdateEnabled(ctx context.Context, jobId uint64, enabled bool) error
}

View File

@@ -1,7 +1,16 @@
package repository
import (
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/model"
)
type DbRestore interface {
DbJob
ListToDo(jobs any) error
GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error)
// GetPageList 分页获取数据库任务列表
GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
}

View File

@@ -1,16 +1,19 @@
package persistence
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/global"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices"
)
var _ repository.DbBackup = (*dbBackupRepoImpl)(nil)
type dbBackupRepoImpl struct {
dbJobBase[*entity.DbBackup]
dbJobBaseImpl[*entity.DbBackup]
}
func NewDbBackupRepo() repository.DbBackup {
@@ -21,7 +24,8 @@ func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames []
var dbNamesWithBackup []string
query := gormx.NewQuery(d.GetModel()).
Eq("db_instance_id", instanceId).
Eq("repeated", true)
Eq("repeated", true).
Undeleted()
if err := query.GenGdb().Pluck("db_name", &dbNamesWithBackup).Error; err != nil {
return nil, err
}
@@ -33,3 +37,49 @@ func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames []
}
return result, nil
}
func (d *dbBackupRepoImpl) ListDbInstances(enabled bool, repeated bool, instanceIds *[]uint64) error {
query := gormx.NewQuery(d.GetModel()).
Eq0("enabled", enabled).
Eq0("repeated", repeated).
Undeleted()
return query.GenGdb().Distinct().Pluck("db_instance_id", &instanceIds).Error
}
func (d *dbBackupRepoImpl) ListToDo(jobs any) error {
db := global.Db.Model(d.GetModel())
err := db.Where("enabled = ?", true).
Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.DbJobSuccess)).
Scopes(gormx.UndeleteScope).
Find(jobs).Error
if err != nil {
return err
}
return nil
}
// GetPageList 分页获取数据库备份任务列表
func (d *dbBackupRepoImpl) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) {
d.GetModel()
qd := gormx.NewQuery(d.GetModel()).
Eq("id", condition.Id).
Eq0("db_instance_id", condition.DbInstanceId).
Eq0("repeated", condition.Repeated).
In0("db_name", condition.InDbNames).
Like("db_name", condition.DbName)
return gormx.PageQuery(qd, pageParam, toEntity)
}
// AddJob 添加数据库任务
func (d *dbBackupRepoImpl) AddJob(ctx context.Context, jobs any) error {
return addJob[*entity.DbBackup](ctx, d.dbJobBaseImpl, jobs)
}
func (d *dbBackupRepoImpl) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error {
cond := map[string]any{
"id": jobId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}

View File

@@ -6,14 +6,13 @@ import (
"gorm.io/gorm/clause"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
"mayfly-go/pkg/global"
)
var _ repository.DbBinlog = (*dbBinlogRepoImpl)(nil)
type dbBinlogRepoImpl struct {
base.RepoImpl[*entity.DbBinlog]
dbJobBaseImpl[*entity.DbBinlog]
}
func NewDbBinlogRepo() repository.DbBinlog {
@@ -21,8 +20,18 @@ func NewDbBinlogRepo() repository.DbBinlog {
}
func (d *dbBinlogRepoImpl) AddJobIfNotExists(_ context.Context, job *entity.DbBinlog) error {
// todo: 如果存在已删除记录,如何处理?
if err := global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(job).Error; err != nil {
return fmt.Errorf("启动 binlog 下载失败: %w", err)
}
return nil
}
// AddJob 添加数据库任务
func (d *dbBinlogRepoImpl) AddJob(ctx context.Context, jobs any) error {
panic("not implement, use AddJobIfNotExists")
}
func (d *dbBinlogRepoImpl) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error {
panic("not implement")
}

View File

@@ -78,6 +78,7 @@ func (repo *dbBinlogHistoryRepoImpl) Upsert(_ context.Context, history *entity.D
old := &entity.DbBinlogHistory{}
err := db.Where("db_instance_id = ?", history.DbInstanceId).
Where("sequence = ?", history.Sequence).
Scopes(gormx.UndeleteScope).
First(old).Error
switch {
case err == nil:

View File

@@ -6,74 +6,32 @@ import (
"fmt"
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
"mayfly-go/pkg/global"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"reflect"
)
type dbJobBase[T entity.DbJob] struct {
var _ repository.DbJobBase = (*dbJobBaseImpl[entity.DbJob])(nil)
type dbJobBaseImpl[T entity.DbJob] struct {
base.RepoImpl[T]
}
func (d *dbJobBase[T]) GetById(e entity.DbJob, id uint64, cols ...string) error {
func (d *dbJobBaseImpl[T]) GetById(e entity.DbJob, id uint64, cols ...string) error {
return d.RepoImpl.GetById(e.(T), id, cols...)
}
func (d *dbJobBase[T]) UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error {
func (d *dbJobBaseImpl[T]) UpdateById(ctx context.Context, e entity.DbJob, columns ...string) error {
return d.RepoImpl.UpdateById(ctx, e.(T), columns...)
}
func (d *dbJobBase[T]) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error {
cond := map[string]any{
"id": jobId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}
func (d *dbJobBase[T]) UpdateLastStatus(ctx context.Context, job entity.DbJob) error {
func (d *dbJobBaseImpl[T]) UpdateLastStatus(ctx context.Context, job entity.DbJob) error {
return d.UpdateById(ctx, job.(T), "last_status", "last_result", "last_time")
}
func (d *dbJobBase[T]) ListToDo(jobs any) error {
db := global.Db.Model(d.GetModel())
err := db.Where("enabled = ?", true).
Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.DbJobSuccess)).
Scopes(gormx.UndeleteScope).
Find(jobs).Error
if err != nil {
return err
}
return nil
}
func (d *dbJobBase[T]) ListRepeating(jobs any) error {
cond := map[string]any{
"enabled": true,
"repeated": true,
}
if err := d.ListByCond(cond, jobs); err != nil {
return err
}
return nil
}
// GetPageList 分页获取数据库备份任务列表
func (d *dbJobBase[T]) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) {
d.GetModel()
qd := gormx.NewQuery(d.GetModel()).
Eq("id", condition.Id).
Eq0("db_instance_id", condition.DbInstanceId).
Eq0("repeated", condition.Repeated).
In0("db_name", condition.InDbNames).
Like("db_name", condition.DbName)
return gormx.PageQuery(qd, pageParam, toEntity)
}
func (d *dbJobBase[T]) AddJob(ctx context.Context, jobs any) error {
func addJob[T entity.DbJob](ctx context.Context, repo dbJobBaseImpl[T], jobs any) error {
// refactor and jobs from any to []T
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
var dbNames []string
@@ -93,26 +51,28 @@ func (d *dbJobBase[T]) AddJob(ctx context.Context, jobs any) error {
if jobBase.DbInstanceId != instanceId {
return errors.New("不支持同时为多个数据库实例添加数据库任务")
}
if jobBase.Interval == 0 {
if job.GetInterval() == 0 {
// 单次执行的数据库任务可重复创建
continue
}
dbNames = append(dbNames, jobBase.DbName)
dbNames = append(dbNames, job.GetDbName())
}
default:
jobBase := jobs.(entity.DbJob).GetJobBase()
job := jobs.(entity.DbJob)
jobBase := job.GetJobBase()
instanceId = jobBase.DbInstanceId
if jobBase.Interval > 0 {
dbNames = append(dbNames, jobBase.DbName)
if job.GetInterval() > 0 {
dbNames = append(dbNames, job.GetDbName())
}
}
var res []string
err := db.Model(d.GetModel()).Select("db_name").
err := db.Model(repo.GetModel()).Select("db_name").
Where("db_instance_id = ?", instanceId).
Where("db_name in ?", dbNames).
Where("repeated = true").
Scopes(gormx.UndeleteScope).Find(&res).Error
Scopes(gormx.UndeleteScope).
Find(&res).Error
if err != nil {
return err
}
@@ -120,8 +80,8 @@ func (d *dbJobBase[T]) AddJob(ctx context.Context, jobs any) error {
return errors.New(fmt.Sprintf("数据库任务已存在: %v", res))
}
if plural {
return d.BatchInsertWithDb(ctx, db, jobs)
return repo.BatchInsertWithDb(ctx, db, jobs.([]T))
}
return d.InsertWithDb(ctx, db, jobs.(T))
return repo.InsertWithDb(ctx, db, jobs.(T))
})
}

View File

@@ -1,16 +1,19 @@
package persistence
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/global"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices"
)
var _ repository.DbRestore = (*dbRestoreRepoImpl)(nil)
type dbRestoreRepoImpl struct {
dbJobBase[*entity.DbRestore]
dbJobBaseImpl[*entity.DbRestore]
}
func NewDbRestoreRepo() repository.DbRestore {
@@ -21,7 +24,8 @@ func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames
var dbNamesWithRestore []string
query := gormx.NewQuery(d.GetModel()).
Eq("db_instance_id", instanceId).
Eq("repeated", true)
Eq("repeated", true).
Undeleted()
if err := query.GenGdb().Pluck("db_name", &dbNamesWithRestore).Error; err != nil {
return nil, err
}
@@ -33,3 +37,41 @@ func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames
}
return result, nil
}
func (d *dbRestoreRepoImpl) ListToDo(jobs any) error {
db := global.Db.Model(d.GetModel())
err := db.Where("enabled = ?", true).
Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.DbJobSuccess)).
Scopes(gormx.UndeleteScope).
Find(jobs).Error
if err != nil {
return err
}
return nil
}
// GetPageList 分页获取数据库备份任务列表
func (d *dbRestoreRepoImpl) GetPageList(condition *entity.DbJobQuery, pageParam *model.PageParam, toEntity any, _ ...string) (*model.PageResult[any], error) {
d.GetModel()
qd := gormx.NewQuery(d.GetModel()).
Eq("id", condition.Id).
Eq0("db_instance_id", condition.DbInstanceId).
Eq0("repeated", condition.Repeated).
In0("db_name", condition.InDbNames).
Like("db_name", condition.DbName)
return gormx.PageQuery(qd, pageParam, toEntity)
}
// AddJob 添加数据库任务
func (d *dbRestoreRepoImpl) AddJob(ctx context.Context, jobs any) error {
return addJob[*entity.DbRestore](ctx, d.dbJobBaseImpl, jobs)
}
func (d *dbRestoreRepoImpl) UpdateEnabled(_ context.Context, jobId uint64, enabled bool) error {
cond := map[string]any{
"id": jobId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}

View File

@@ -23,8 +23,9 @@ type Team struct {
}
func (p *Team) GetTeams(rc *req.Ctx) {
queryCond, page := ginx.BindQueryAndPage(rc.GinCtx, new(entity.TeamQuery))
teams := &[]entity.Team{}
res, err := p.TeamApp.GetPageList(&entity.Team{}, ginx.GetPageParam(rc.GinCtx), teams)
res, err := p.TeamApp.GetPageList(queryCond, page, teams)
biz.ErrIsNil(err)
rc.ResData = res
}

View File

@@ -13,7 +13,7 @@ import (
type Team interface {
// 分页获取项目团队信息列表
GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
Save(ctx context.Context, team *entity.Team) error
@@ -55,7 +55,7 @@ type teamAppImpl struct {
tagTreeTeamRepo repository.TagTreeTeam
}
func (p *teamAppImpl) GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
func (p *teamAppImpl) GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return p.teamRepo.GetPageList(condition, pageParam, toEntity, orderBy...)
}

View File

@@ -26,3 +26,9 @@ type TagResourceQuery struct {
TagPathLike string // 标签路径模糊查询
TagPathLikes []string
}
type TeamQuery struct {
model.Model
Name string `json:"name" form:"name"` // 团队名称
}

View File

@@ -9,5 +9,5 @@ import (
type Team interface {
base.Repo[*entity.Team]
GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
}

View File

@@ -13,10 +13,12 @@ type teamRepoImpl struct {
}
func newTeamRepo() repository.Team {
return &teamRepoImpl{base.RepoImpl[*entity.Team]{M: new(entity.Team)}}
return &teamRepoImpl{}
}
func (p *teamRepoImpl) GetPageList(condition *entity.Team, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
qd := gormx.NewQuery(condition).WithCondModel(condition).WithOrderBy(orderBy...)
func (p *teamRepoImpl) GetPageList(condition *entity.TeamQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
qd := gormx.NewQuery(p.GetModel()).
Like("name", condition.Name).
WithOrderBy()
return gormx.PageQuery(qd, pageParam, toEntity)
}

View File

@@ -22,10 +22,10 @@ type Repo[T model.ModelI] interface {
InsertWithDb(ctx context.Context, db *gorm.DB, e T) error
// 批量新增实体
BatchInsert(ctx context.Context, models any) error
BatchInsert(ctx context.Context, models []T) error
// 使用指定gorm db执行主要用于事务执行
BatchInsertWithDb(ctx context.Context, db *gorm.DB, models any) error
BatchInsertWithDb(ctx context.Context, db *gorm.DB, models []T) error
// 根据实体id更新实体信息
UpdateById(ctx context.Context, e T, columns ...string) error
@@ -93,19 +93,19 @@ func (br *RepoImpl[T]) InsertWithDb(ctx context.Context, db *gorm.DB, e T) error
return gormx.InsertWithDb(db, br.fillBaseInfo(ctx, e))
}
func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es any) error {
func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es []T) error {
if db := contextx.GetDb(ctx); db != nil {
return br.BatchInsertWithDb(ctx, db, es)
}
for _, e := range es.([]T) {
for _, e := range es {
br.fillBaseInfo(ctx, e)
}
return gormx.BatchInsert[T](es)
}
// 使用指定gorm db执行主要用于事务执行
func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es any) error {
for _, e := range es.([]T) {
func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error {
for _, e := range es {
br.fillBaseInfo(ctx, e)
}
return gormx.BatchInsertWithDb[T](db, es)

View File

@@ -135,13 +135,13 @@ func InsertWithDb(db *gorm.DB, model any) error {
}
// 批量插入
func BatchInsert[T any](models any) error {
func BatchInsert[T any](models []T) error {
return BatchInsertWithDb[T](global.Db, models)
}
// 批量插入
func BatchInsertWithDb[T any](db *gorm.DB, models any) error {
return db.CreateInBatches(models, len(models.([]T))).Error
func BatchInsertWithDb[T any](db *gorm.DB, models []T) error {
return db.CreateInBatches(models, len(models)).Error
}
// 根据id更新model更新字段为model中不为空的值即int类型不为0ptr类型不为nil这类字段值

View File

@@ -25,6 +25,35 @@ type Delayable interface {
GetKey() string
}
var _ Delayable = (*wrapper[Job])(nil)
type wrapper[T Job] struct {
key string
deadline time.Time
removed bool
status JobStatus
job T
}
func newWrapper[T Job](job T) *wrapper[T] {
return &wrapper[T]{
key: job.GetKey(),
job: job,
}
}
func (d *wrapper[T]) GetDeadline() time.Time {
return d.deadline
}
func (d *wrapper[T]) GetKey() string {
return d.key
}
func (d *wrapper[T]) Payload() T {
return d.job
}
func NewDelayQueue[T Delayable](cap int) *DelayQueue[T] {
singleDequeue := make(chan struct{}, 1)
singleDequeue <- struct{}{}

View File

@@ -1,11 +1,5 @@
package runner
//var (
// false = errors.New("queue: 队列已满")
// false = errors.New("queue: 队列为空")
// false = errors.New("queue: 元素未找到")
//)
// PriorityQueue 是一个基于小顶堆的优先队列
// 当capacity <= 0时为无界队列切片容量会动态扩缩容
// 当capacity > 0 时,为有界队列,初始化后就固定容量,不会扩缩容

View File

@@ -2,43 +2,43 @@ package runner
import (
"context"
"errors"
"fmt"
"github.com/emirpasic/gods/maps/linkedhashmap"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/timex"
"sync"
"time"
)
var (
ErrNotFound = errors.New("job not found")
ErrExist = errors.New("job already exists")
ErrFinished = errors.New("job already finished")
)
type JobKey = string
type RunFunc func(ctx context.Context)
type NextFunc func() (Job, bool)
type RunnableFunc func(next NextFunc) bool
type RunJobFunc[T Job] func(ctx context.Context, job T)
type NextJobFunc[T Job] func() (T, bool)
type RunnableJobFunc[T Job] func(job T, next NextJobFunc[T]) bool
type ScheduleJobFunc[T Job] func(job T) (deadline time.Time, err error)
type JobStatus int
const (
JobUnknown JobStatus = iota
JobDelay
JobDelaying
JobWaiting
JobRunning
JobRemoved
)
type Job interface {
GetKey() JobKey
GetStatus() JobStatus
SetStatus(status JobStatus)
Run(ctx context.Context)
Runnable(next NextFunc) bool
GetDeadline() time.Time
Schedule() bool
Update(job Job)
}
type iterator[T Job] struct {
index int
data []T
data []*wrapper[T]
zero T
}
@@ -51,19 +51,18 @@ func (iter *iterator[T]) Next() (T, bool) {
return iter.zero, false
}
iter.index++
return iter.data[iter.index], true
return iter.data[iter.index].job, true
}
type array[T Job] struct {
size int
data []T
zero T
data []*wrapper[T]
}
func newArray[T Job](size int) *array[T] {
return &array[T]{
size: size,
data: make([]T, 0, size),
data: make([]*wrapper[T], 0, size),
}
}
@@ -78,7 +77,7 @@ func (a *array[T]) Full() bool {
return len(a.data) >= a.size
}
func (a *array[T]) Append(job T) bool {
func (a *array[T]) Append(job *wrapper[T]) bool {
if len(a.data) >= a.size {
return false
}
@@ -86,20 +85,20 @@ func (a *array[T]) Append(job T) bool {
return true
}
func (a *array[T]) Get(key JobKey) (T, bool) {
func (a *array[T]) Get(key JobKey) (*wrapper[T], bool) {
for _, job := range a.data {
if key == job.GetKey() {
return job, true
}
}
return a.zero, false
return nil, false
}
func (a *array[T]) Remove(key JobKey) {
length := len(a.data)
for i, elm := range a.data {
if key == elm.GetKey() {
a.data[i], a.data[length-1] = a.data[length-1], a.zero
a.data[i], a.data[length-1] = a.data[length-1], nil
a.data = a.data[:length-1]
return
}
@@ -107,47 +106,76 @@ func (a *array[T]) Remove(key JobKey) {
}
type Runner[T Job] struct {
maxRunning int
waiting *linkedhashmap.Map
running *array[T]
runnable func(job T, iterateRunning func() (T, bool)) bool
mutex sync.Mutex
wg sync.WaitGroup
context context.Context
cancel context.CancelFunc
zero T
signal chan struct{}
all map[string]T
delayQueue *DelayQueue[T]
maxRunning int
waiting *linkedhashmap.Map
running *array[T]
runJob RunJobFunc[T]
runnableJob RunnableJobFunc[T]
scheduleJob ScheduleJobFunc[T]
mutex sync.Mutex
wg sync.WaitGroup
context context.Context
cancel context.CancelFunc
zero T
signal chan struct{}
all map[JobKey]*wrapper[T]
delayQueue *DelayQueue[*wrapper[T]]
}
func NewRunner[T Job](maxRunning int) *Runner[T] {
type Opt[T Job] func(runner *Runner[T])
func WithRunnableJob[T Job](runnableJob RunnableJobFunc[T]) Opt[T] {
return func(runner *Runner[T]) {
runner.runnableJob = runnableJob
}
}
func WithScheduleJob[T Job](scheduleJob ScheduleJobFunc[T]) Opt[T] {
return func(runner *Runner[T]) {
runner.scheduleJob = scheduleJob
}
}
func NewRunner[T Job](maxRunning int, runJob RunJobFunc[T], opts ...Opt[T]) *Runner[T] {
ctx, cancel := context.WithCancel(context.Background())
runner := &Runner[T]{
maxRunning: maxRunning,
all: make(map[string]T, maxRunning),
all: make(map[string]*wrapper[T], maxRunning),
waiting: linkedhashmap.New(),
running: newArray[T](maxRunning),
context: ctx,
cancel: cancel,
signal: make(chan struct{}, 1),
delayQueue: NewDelayQueue[T](0),
delayQueue: NewDelayQueue[*wrapper[T]](0),
}
runner.runJob = runJob
for _, opt := range opts {
opt(runner)
}
if runner.runnableJob == nil {
runner.runnableJob = func(job T, _ NextJobFunc[T]) bool {
return true
}
}
runner.wg.Add(maxRunning + 1)
for i := 0; i < maxRunning; i++ {
go runner.run()
}
go func() {
defer runner.wg.Done()
timex.SleepWithContext(runner.context, time.Second*10)
for runner.context.Err() == nil {
job, ok := runner.delayQueue.Dequeue(ctx)
wrap, ok := runner.delayQueue.Dequeue(ctx)
if !ok {
continue
}
runner.mutex.Lock()
runner.waiting.Put(job.GetKey(), job)
job.SetStatus(JobWaiting)
if old, ok := runner.all[wrap.key]; !ok || wrap != old {
runner.mutex.Unlock()
continue
}
runner.waiting.Put(wrap.key, wrap)
wrap.status = JobWaiting
runner.trigger()
runner.mutex.Unlock()
}
@@ -166,144 +194,155 @@ func (r *Runner[T]) run() {
for r.context.Err() == nil {
select {
case <-r.signal:
job, ok := r.pickRunnable()
wrap, ok := r.pickRunnableJob()
if !ok {
continue
}
r.doRun(job)
r.afterRun(job)
r.doRun(wrap)
r.afterRun(wrap)
case <-r.context.Done():
}
}
}
func (r *Runner[T]) doRun(job T) {
func (r *Runner[T]) doRun(wrap *wrapper[T]) {
defer func() {
if err := recover(); err != nil {
logx.Error(fmt.Sprintf("failed to run job: %v", err))
}
}()
job.Run(r.context)
r.runJob(r.context, wrap.job)
}
func (r *Runner[T]) afterRun(job T) {
func (r *Runner[T]) afterRun(wrap *wrapper[T]) {
r.mutex.Lock()
defer r.mutex.Unlock()
key := job.GetKey()
r.running.Remove(key)
r.running.Remove(wrap.key)
delete(r.all, wrap.key)
wrap.status = JobUnknown
r.trigger()
switch job.GetStatus() {
case JobRunning:
r.schedule(r.context, job)
case JobRemoved:
delete(r.all, key)
default:
panic(fmt.Sprintf("invalid job status %v occurred after run", job.GetStatus()))
if wrap.removed {
return
}
deadline, err := r.doScheduleJob(wrap.job, true)
if err != nil {
return
}
_ = r.schedule(r.context, deadline, wrap.job)
}
func (r *Runner[T]) pickRunnable() (T, bool) {
func (r *Runner[T]) doScheduleJob(job T, finished bool) (time.Time, error) {
if r.scheduleJob == nil {
if finished {
return time.Time{}, ErrFinished
}
return time.Now(), nil
}
return r.scheduleJob(job)
}
func (r *Runner[T]) pickRunnableJob() (*wrapper[T], bool) {
r.mutex.Lock()
defer r.mutex.Unlock()
iter := r.running.Iterator()
var runnable T
var runnable *wrapper[T]
ok := r.waiting.Any(func(key interface{}, value interface{}) bool {
job := value.(T)
wrap := value.(*wrapper[T])
iter.Begin()
if job.Runnable(func() (Job, bool) { return iter.Next() }) {
if r.runnableJob(wrap.job, iter.Next) {
if r.running.Full() {
return false
}
r.waiting.Remove(key)
r.running.Append(job)
job.SetStatus(JobRunning)
r.running.Append(wrap)
wrap.status = JobRunning
if !r.running.Full() && !r.waiting.Empty() {
r.trigger()
}
runnable = job
runnable = wrap
return true
}
return false
})
if !ok {
return r.zero, false
return nil, false
}
return runnable, true
}
func (r *Runner[T]) schedule(ctx context.Context, job T) {
if !job.Schedule() {
delete(r.all, job.GetKey())
job.SetStatus(JobRemoved)
return
func (r *Runner[T]) schedule(ctx context.Context, deadline time.Time, job T) error {
wrap := newWrapper(job)
wrap.deadline = deadline
if wrap.deadline.After(time.Now()) {
r.delayQueue.Enqueue(ctx, wrap)
wrap.status = JobDelaying
} else {
r.waiting.Put(wrap.key, wrap)
wrap.status = JobWaiting
r.trigger()
}
r.delayQueue.Enqueue(ctx, job)
job.SetStatus(JobDelay)
r.all[wrap.key] = wrap
return nil
}
//func (r *Runner[T]) Schedule(ctx context.Context, job T) {
// r.mutex.Lock()
// defer r.mutex.Unlock()
//
// switch job.GetStatus() {
// case JobUnknown:
// case JobDelay:
// r.delayQueue.Remove(ctx, job.GetKey())
// case JobWaiting:
// r.waiting.Remove(job)
// case JobRunning:
// // 标记为 removed, 任务执行完成后再删除
// return
// case JobRemoved:
// return
// }
// r.schedule(ctx, job)
//}
func (r *Runner[T]) Add(ctx context.Context, job T) error {
r.mutex.Lock()
defer r.mutex.Unlock()
if _, ok := r.all[job.GetKey()]; ok {
return nil
return ErrExist
}
r.schedule(ctx, job)
return nil
deadline, err := r.doScheduleJob(job, false)
if err != nil {
return err
}
return r.schedule(ctx, deadline, job)
}
func (r *Runner[T]) UpdateOrAdd(ctx context.Context, job T) error {
r.mutex.Lock()
defer r.mutex.Unlock()
if old, ok := r.all[job.GetKey()]; ok {
old.Update(job)
job = old
wrap, ok := r.all[job.GetKey()]
if ok {
wrap.job.Update(job)
switch wrap.status {
case JobDelaying:
r.delayQueue.Remove(ctx, wrap.key)
delete(r.all, wrap.key)
case JobWaiting:
r.waiting.Remove(wrap.key)
delete(r.all, wrap.key)
case JobRunning:
return nil
default:
}
}
r.schedule(ctx, job)
return nil
deadline, err := r.doScheduleJob(job, false)
if err != nil {
return err
}
return r.schedule(ctx, deadline, wrap.job)
}
func (r *Runner[T]) StartNow(ctx context.Context, job T) error {
r.mutex.Lock()
defer r.mutex.Unlock()
key := job.GetKey()
if old, ok := r.all[key]; ok {
job = old
if job.GetStatus() == JobDelay {
r.delayQueue.Remove(ctx, key)
r.waiting.Put(key, job)
r.trigger()
if wrap, ok := r.all[job.GetKey()]; ok {
switch wrap.status {
case JobDelaying:
r.delayQueue.Remove(ctx, wrap.key)
delete(r.all, wrap.key)
case JobWaiting, JobRunning:
return nil
default:
}
return nil
}
r.all[key] = job
r.waiting.Put(key, job)
r.trigger()
return nil
return r.schedule(ctx, time.Now(), job)
}
func (r *Runner[T]) trigger() {
@@ -317,23 +356,21 @@ func (r *Runner[T]) Remove(ctx context.Context, key JobKey) error {
r.mutex.Lock()
defer r.mutex.Unlock()
job, ok := r.all[key]
wrap, ok := r.all[key]
if !ok {
return nil
return ErrNotFound
}
switch job.GetStatus() {
case JobUnknown:
panic(fmt.Sprintf("invalid job status %v occurred after added", job.GetStatus()))
case JobDelay:
switch wrap.status {
case JobDelaying:
r.delayQueue.Remove(ctx, key)
delete(r.all, key)
case JobWaiting:
r.waiting.Remove(key)
delete(r.all, key)
case JobRunning:
// 标记为 removed, 任务执行完成后再删除
case JobRemoved:
return nil
// 统一标记为 removed, 任务执行完成后再删除
wrap.removed = true
default:
}
delete(r.all, key)
job.SetStatus(JobRemoved)
return nil
}

View File

@@ -2,6 +2,7 @@ package runner
import (
"context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mayfly-go/pkg/utils/timex"
"sync"
@@ -11,72 +12,35 @@ import (
var _ Job = &testJob{}
func newTestJob(key string, runTime time.Duration) *testJob {
func newTestJob(key string) *testJob {
return &testJob{
deadline: time.Now(),
Key: key,
run: func(ctx context.Context) {
timex.SleepWithContext(ctx, runTime)
},
Key: key,
}
}
type testJob struct {
run RunFunc
Key JobKey
status JobStatus
ran bool
deadline time.Time
Key JobKey
status int
}
func (t *testJob) Update(_ Job) {
}
func (t *testJob) GetDeadline() time.Time {
return t.deadline
}
func (t *testJob) Schedule() bool {
return !t.ran
}
func (t *testJob) Run(ctx context.Context) {
if t.run == nil {
return
}
t.run(ctx)
t.ran = true
}
func (t *testJob) Runnable(_ NextFunc) bool {
return true
}
func (t *testJob) Update(_ Job) {}
func (t *testJob) GetKey() JobKey {
return t.Key
}
func (t *testJob) GetStatus() JobStatus {
return t.status
}
func (t *testJob) SetStatus(status JobStatus) {
t.status = status
}
func TestRunner_Close(t *testing.T) {
runner := NewRunner[*testJob](1)
signal := make(chan struct{}, 1)
waiting := sync.WaitGroup{}
waiting.Add(1)
runner := NewRunner[*testJob](1, func(ctx context.Context, job *testJob) {
waiting.Done()
timex.SleepWithContext(ctx, time.Hour)
signal <- struct{}{}
})
go func() {
job := &testJob{
Key: "close",
run: func(ctx context.Context) {
waiting.Done()
timex.SleepWithContext(ctx, time.Hour)
signal <- struct{}{}
},
}
_ = runner.Add(context.Background(), job)
}()
@@ -95,55 +59,65 @@ func TestRunner_AddJob(t *testing.T) {
type testCase struct {
name string
job *testJob
want bool
want error
}
testCases := []testCase{
{
name: "first job",
job: newTestJob("single", time.Hour),
want: true,
job: newTestJob("single"),
want: nil,
},
{
name: "second job",
job: newTestJob("dual", time.Hour),
want: true,
},
{
name: "non repetitive job",
job: newTestJob("single", time.Hour),
want: true,
job: newTestJob("dual"),
want: nil,
},
{
name: "repetitive job",
job: newTestJob("dual", time.Hour),
want: true,
job: newTestJob("dual"),
want: ErrExist,
},
}
runner := NewRunner[*testJob](1)
runner := NewRunner[*testJob](1, func(ctx context.Context, job *testJob) {
timex.SleepWithContext(ctx, time.Hour)
})
defer runner.Close()
for _, tc := range testCases {
err := runner.Add(context.Background(), tc.job)
if tc.want != nil {
require.ErrorIs(t, err, tc.want)
continue
}
require.NoError(t, err)
}
}
func TestJob_UpdateStatus(t *testing.T) {
const d = time.Millisecond * 20
runner := NewRunner[*testJob](1)
running := newTestJob("running", d*2)
waiting := newTestJob("waiting", d*2)
_ = runner.Add(context.Background(), running)
_ = runner.Add(context.Background(), waiting)
const (
unknown = iota
running
finished
)
runner := NewRunner[*testJob](1, func(ctx context.Context, job *testJob) {
job.status = running
timex.SleepWithContext(ctx, d*2)
job.status = finished
})
first := newTestJob("first")
second := newTestJob("second")
_ = runner.Add(context.Background(), first)
_ = runner.Add(context.Background(), second)
time.Sleep(d)
require.Equal(t, JobRunning, running.status)
require.Equal(t, JobWaiting, waiting.status)
assert.Equal(t, running, first.status)
assert.Equal(t, unknown, second.status)
time.Sleep(d * 2)
require.Equal(t, JobRemoved, running.status)
require.Equal(t, JobRunning, waiting.status)
assert.Equal(t, finished, first.status)
assert.Equal(t, running, second.status)
time.Sleep(d * 2)
require.Equal(t, JobRemoved, running.status)
require.Equal(t, JobRemoved, waiting.status)
assert.Equal(t, finished, first.status)
assert.Equal(t, finished, second.status)
}