diff --git a/mayfly_go_web/src/common/utils/date.ts b/mayfly_go_web/src/common/utils/date.ts
index 2de40c3e..48ed6ad3 100644
--- a/mayfly_go_web/src/common/utils/date.ts
+++ b/mayfly_go_web/src/common/utils/date.ts
@@ -24,7 +24,7 @@ export function dateStrFormat(fmt: string, dateStr: string) {
}
export function dateFormat(dateStr: string) {
- if (dateStr?.startsWith('0001-01-01', 0)) {
+ if (!dateStr) {
return '';
}
return dateFormat2('yyyy-MM-dd HH:mm:ss', new Date(dateStr));
diff --git a/mayfly_go_web/src/views/ops/db/DbBackupList.vue b/mayfly_go_web/src/views/ops/db/DbBackupList.vue
index 69fa6634..51f3c43a 100644
--- a/mayfly_go_web/src/views/ops/db/DbBackupList.vue
+++ b/mayfly_go_web/src/views/ops/db/DbBackupList.vue
@@ -24,9 +24,12 @@
- 编辑
- 启用
- 禁用
+
+ 编辑
+ 启用
+ 禁用
+ 立即备份
+
@@ -150,5 +153,20 @@ const disableDbBackup = async (data: any) => {
await search();
ElMessage.success('禁用成功');
};
+
+const startDbBackup = async (data: any) => {
+ let backupId: String;
+ if (data) {
+ backupId = data.id;
+ } else if (state.selectedData.length > 0) {
+ backupId = state.selectedData.map((x: any) => x.id).join(' ');
+ } else {
+ ElMessage.error('请选择需要启用的备份任务');
+ return;
+ }
+ await dbApi.startDbBackup.request({ dbId: props.dbId, backupId: backupId });
+ await search();
+ ElMessage.success('备份任务启动成功');
+};
diff --git a/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue b/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue
index afd4a164..1b887b76 100644
--- a/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue
+++ b/mayfly_go_web/src/views/ops/db/DbRestoreEdit.vue
@@ -199,10 +199,10 @@ const init = async (data: any) => {
state.form.dbBackupId = data.dbBackupId;
state.form.dbBackupHistoryId = data.dbBackupHistoryId;
state.form.dbBackupHistoryName = data.dbBackupHistoryName;
- if (data.dbBackupHistoryId > 0) {
- state.restoreMode = 'backup-history';
- } else {
+ if (data.pointInTime) {
state.restoreMode = 'point-in-time';
+ } else {
+ state.restoreMode = 'backup-history';
}
state.history = {
dbBackupId: data.dbBackupId,
@@ -232,35 +232,33 @@ const getDbNamesWithoutRestore = async () => {
const btnOk = async () => {
restoreForm.value.validate(async (valid: any) => {
- if (!valid) {
+ if (valid) {
+ if (state.restoreMode == 'point-in-time') {
+ state.form.dbBackupId = 0;
+ state.form.dbBackupHistoryId = 0;
+ state.form.dbBackupHistoryName = '';
+ } else {
+ state.form.pointInTime = null;
+ }
+ state.form.repeated = false;
+ const reqForm = { ...state.form };
+ let api = dbApi.createDbRestore;
+ if (props.data) {
+ api = dbApi.saveDbRestore;
+ }
+ api.request(reqForm).then(() => {
+ ElMessage.success('保存成功');
+ emit('val-change', state.form);
+ state.btnLoading = true;
+ setTimeout(() => {
+ state.btnLoading = false;
+ }, 1000);
+ cancel();
+ });
+ } else {
ElMessage.error('请正确填写信息');
return false;
}
-
- if (state.restoreMode == 'point-in-time') {
- state.form.dbBackupId = 0;
- state.form.dbBackupHistoryId = 0;
- state.form.dbBackupHistoryName = '';
- } else {
- state.form.pointInTime = '0001-01-01T00:00:00Z';
- }
-
- state.form.repeated = false;
- const reqForm = { ...state.form };
- let api = dbApi.createDbRestore;
- if (props.data) {
- api = dbApi.saveDbRestore;
- }
-
- try {
- state.btnLoading = true;
- await api.request(reqForm);
- ElMessage.success('保存成功');
- emit('val-change', state.form);
- cancel();
- } finally {
- state.btnLoading = false;
- }
});
};
diff --git a/mayfly_go_web/src/views/ops/db/DbRestoreList.vue b/mayfly_go_web/src/views/ops/db/DbRestoreList.vue
index 137f79e6..bbbad0b4 100644
--- a/mayfly_go_web/src/views/ops/db/DbRestoreList.vue
+++ b/mayfly_go_web/src/views/ops/db/DbRestoreList.vue
@@ -25,8 +25,8 @@
详情
- 启用
- 禁用
+ 启用
+ 禁用
@@ -42,10 +42,10 @@
{{ infoDialog.data.dbName }}
- {{
+ {{
dateFormat(infoDialog.data.pointInTime)
}}
- {{
+ {{
infoDialog.data.dbBackupHistoryName
}}
{{ dateFormat(infoDialog.data.startTime) }}
diff --git a/mayfly_go_web/src/views/ops/db/api.ts b/mayfly_go_web/src/views/ops/db/api.ts
index 58a8a4e7..84b37197 100644
--- a/mayfly_go_web/src/views/ops/db/api.ts
+++ b/mayfly_go_web/src/views/ops/db/api.ts
@@ -47,6 +47,7 @@ export const dbApi = {
getDbNamesWithoutBackup: Api.newGet('/dbs/{dbId}/db-names-without-backup'),
enableDbBackup: Api.newPut('/dbs/{dbId}/backups/{backupId}/enable'),
disableDbBackup: Api.newPut('/dbs/{dbId}/backups/{backupId}/disable'),
+ startDbBackup: Api.newPut('/dbs/{dbId}/backups/{backupId}/start'),
saveDbBackup: Api.newPut('/dbs/{dbId}/backups/{id}'),
getDbBackupHistories: Api.newGet('/dbs/{dbId}/backup-histories'),
diff --git a/server/.gitignore b/server/.gitignore
index 900776b3..6ca34a6b 100644
--- a/server/.gitignore
+++ b/server/.gitignore
@@ -1,4 +1,11 @@
-static/static
+/static/static/
config.yml
mayfly_rsa
-mayfly_rsa.pub
\ No newline at end of file
+mayfly_rsa.pub
+
+# 数据库备份目录
+/db/backup/
+# mysql 程序目录
+/db/mysql/
+# mariadb 程序目录
+/db/mariadb/
\ No newline at end of file
diff --git a/server/internal/auth/api/account_login.go b/server/internal/auth/api/account_login.go
index dd8b23d4..6663d115 100644
--- a/server/internal/auth/api/account_login.go
+++ b/server/internal/auth/api/account_login.go
@@ -107,7 +107,7 @@ func (a *AccountLogin) OtpVerify(rc *req.Ctx) {
if otpStatus == OtpStatusNoReg {
update := &sysentity.Account{OtpSecret: otpSecret}
update.Id = accountId
- update.OtpSecretEncrypt()
+ biz.ErrIsNil(update.OtpSecretEncrypt())
biz.ErrIsNil(a.AccountApp.Update(context.Background(), update))
}
diff --git a/server/internal/auth/api/common.go b/server/internal/auth/api/common.go
index ba423a93..9457f8d6 100644
--- a/server/internal/auth/api/common.go
+++ b/server/internal/auth/api/common.go
@@ -65,7 +65,7 @@ func LastLoginCheck(account *sysentity.Account, accountLoginSecurity *config.Acc
}
func useOtp(account *sysentity.Account, otpIssuer, accessToken string) (*OtpVerifyInfo, string, string) {
- account.OtpSecretDecrypt()
+ biz.ErrIsNil(account.OtpSecretDecrypt())
otpSecret := account.OtpSecret
// 修改状态为已注册
otpStatus := OtpStatusReg
diff --git a/server/internal/common/utils/pwd.go b/server/internal/common/utils/pwd.go
index fd0289f9..6dd48614 100644
--- a/server/internal/common/utils/pwd.go
+++ b/server/internal/common/utils/pwd.go
@@ -1,7 +1,6 @@
package utils
import (
- "mayfly-go/pkg/biz"
"mayfly-go/pkg/config"
"regexp"
)
@@ -27,30 +26,34 @@ func CheckAccountPasswordLever(ps string) bool {
}
// 使用config.yml的aes.key进行密码加密
-func PwdAesEncrypt(password string) string {
+func PwdAesEncrypt(password string) (string, error) {
if password == "" {
- return ""
+ return "", nil
}
aes := config.Conf.Aes
if aes.Key == "" {
- return password
+ return password, nil
}
encryptPwd, err := aes.EncryptBase64([]byte(password))
- biz.ErrIsNilAppendErr(err, "密码加密失败: %s")
- return encryptPwd
+ if err != nil {
+ return "", err
+ }
+ return encryptPwd, nil
}
// 使用config.yml的aes.key进行密码解密
-func PwdAesDecrypt(encryptPwd string) string {
+func PwdAesDecrypt(encryptPwd string) (string, error) {
if encryptPwd == "" {
- return ""
+ return "", nil
}
aes := config.Conf.Aes
if aes.Key == "" {
- return encryptPwd
+ return encryptPwd, nil
}
decryptPwd, err := aes.DecryptBase64(encryptPwd)
- biz.ErrIsNilAppendErr(err, "密码解密失败: %s")
+ if err != nil {
+ return "", err
+ }
// 解密后的密码
- return string(decryptPwd)
+ return string(decryptPwd), nil
}
diff --git a/server/internal/db/api/db_backup.go b/server/internal/db/api/db_backup.go
index ac05c185..f5883cdf 100644
--- a/server/internal/db/api/db_backup.go
+++ b/server/internal/db/api/db_backup.go
@@ -39,11 +39,11 @@ func (d *DbBackup) GetPageList(rc *req.Ctx) {
// Create 保存数据库备份任务
// @router /api/dbs/:dbId/backups [POST]
func (d *DbBackup) Create(rc *req.Ctx) {
- form := &form.DbBackupForm{}
- ginx.BindJsonAndValid(rc.GinCtx, form)
- rc.ReqParam = form
+ backupForm := &form.DbBackupForm{}
+ ginx.BindJsonAndValid(rc.GinCtx, backupForm)
+ rc.ReqParam = backupForm
- dbNames := strings.Fields(form.DbNames)
+ dbNames := strings.Fields(backupForm.DbNames)
biz.IsTrue(len(dbNames) > 0, "解析数据库备份任务失败:数据库名称未定义")
dbId := uint64(ginx.PathParamInt(rc.GinCtx, "dbId"))
@@ -54,14 +54,10 @@ func (d *DbBackup) Create(rc *req.Ctx) {
tasks := make([]*entity.DbBackup, 0, len(dbNames))
for _, dbName := range dbNames {
task := &entity.DbBackup{
+ DbTaskBase: entity.NewDbBTaskBase(true, backupForm.Repeated, backupForm.StartTime, backupForm.Interval),
DbName: dbName,
- Name: form.Name,
- StartTime: form.StartTime,
- Interval: form.Interval,
- Enabled: true,
- Repeated: form.Repeated,
+ Name: backupForm.Name,
DbInstanceId: db.InstanceId,
- LastTime: form.StartTime,
}
tasks = append(tasks, task)
}
@@ -71,17 +67,15 @@ func (d *DbBackup) Create(rc *req.Ctx) {
// Save 保存数据库备份任务
// @router /api/dbs/:dbId/backups/:backupId [PUT]
func (d *DbBackup) Save(rc *req.Ctx) {
- form := &form.DbBackupForm{}
- ginx.BindJsonAndValid(rc.GinCtx, form)
- rc.ReqParam = form
+ backupForm := &form.DbBackupForm{}
+ ginx.BindJsonAndValid(rc.GinCtx, backupForm)
+ rc.ReqParam = backupForm
- task := &entity.DbBackup{
- Name: form.Name,
- StartTime: form.StartTime,
- Interval: form.Interval,
- LastTime: form.StartTime,
- }
- task.Id = form.Id
+ task := &entity.DbBackup{}
+ task.Id = backupForm.Id
+ task.Name = backupForm.Name
+ task.StartTime = backupForm.StartTime
+ task.Interval = backupForm.Interval
biz.ErrIsNilAppendErr(d.DbBackupApp.Save(rc.MetaCtx, task), "保存数据库备份任务失败: %v")
}
@@ -125,6 +119,13 @@ func (d *DbBackup) Disable(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "禁用数据库备份任务失败: %v")
}
+// Start 禁用数据库备份任务
+// @router /api/dbs/:dbId/backups/:taskId/start [PUT]
+func (d *DbBackup) Start(rc *req.Ctx) {
+ err := d.walk(rc, d.DbBackupApp.Start)
+ biz.ErrIsNilAppendErr(err, "运行数据库备份任务失败: %v")
+}
+
// GetDbNamesWithoutBackup 获取未配置定时备份的数据库名称
// @router /api/dbs/:dbId/db-names-without-backup [GET]
func (d *DbBackup) GetDbNamesWithoutBackup(rc *req.Ctx) {
diff --git a/server/internal/db/api/db_restore.go b/server/internal/db/api/db_restore.go
index 8af0660d..4146aafb 100644
--- a/server/internal/db/api/db_restore.go
+++ b/server/internal/db/api/db_restore.go
@@ -38,9 +38,9 @@ func (d *DbRestore) GetPageList(rc *req.Ctx) {
// Create 保存数据库恢复任务
// @router /api/dbs/:dbId/restores [POST]
func (d *DbRestore) Create(rc *req.Ctx) {
- form := &form.DbRestoreForm{}
- ginx.BindJsonAndValid(rc.GinCtx, form)
- rc.ReqParam = form
+ restoreForm := &form.DbRestoreForm{}
+ ginx.BindJsonAndValid(rc.GinCtx, restoreForm)
+ rc.ReqParam = restoreForm
dbId := uint64(ginx.PathParamInt(rc.GinCtx, "dbId"))
biz.IsTrue(dbId > 0, "无效的 dbId: %v", dbId)
@@ -48,16 +48,13 @@ func (d *DbRestore) Create(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v")
task := &entity.DbRestore{
- DbName: form.DbName,
- StartTime: form.StartTime,
- Interval: form.Interval,
- Enabled: true,
- Repeated: form.Repeated,
+ DbTaskBase: entity.NewDbBTaskBase(true, restoreForm.Repeated, restoreForm.StartTime, restoreForm.Interval),
+ DbName: restoreForm.DbName,
DbInstanceId: db.InstanceId,
- PointInTime: form.PointInTime,
- DbBackupId: form.DbBackupId,
- DbBackupHistoryId: form.DbBackupHistoryId,
- DbBackupHistoryName: form.DbBackupHistoryName,
+ PointInTime: restoreForm.PointInTime,
+ DbBackupId: restoreForm.DbBackupId,
+ DbBackupHistoryId: restoreForm.DbBackupHistoryId,
+ DbBackupHistoryName: restoreForm.DbBackupHistoryName,
}
biz.ErrIsNilAppendErr(d.DbRestoreApp.Create(rc.MetaCtx, task), "添加数据库恢复任务失败: %v")
}
@@ -65,15 +62,14 @@ func (d *DbRestore) Create(rc *req.Ctx) {
// Save 保存数据库恢复任务
// @router /api/dbs/:dbId/restores/:restoreId [PUT]
func (d *DbRestore) Save(rc *req.Ctx) {
- form := &form.DbRestoreForm{}
- ginx.BindJsonAndValid(rc.GinCtx, form)
- rc.ReqParam = form
+ restoreForm := &form.DbRestoreForm{}
+ ginx.BindJsonAndValid(rc.GinCtx, restoreForm)
+ rc.ReqParam = restoreForm
- task := &entity.DbRestore{
- StartTime: form.StartTime,
- Interval: form.Interval,
- }
- task.Id = form.Id
+ task := &entity.DbRestore{}
+ task.Id = restoreForm.Id
+ task.StartTime = restoreForm.StartTime
+ task.Interval = restoreForm.Interval
biz.ErrIsNilAppendErr(d.DbRestoreApp.Save(rc.MetaCtx, task), "保存数据库恢复任务失败: %v")
}
diff --git a/server/internal/db/api/form/db_restore.go b/server/internal/db/api/form/db_restore.go
index 6348bbbf..936f7583 100644
--- a/server/internal/db/api/form/db_restore.go
+++ b/server/internal/db/api/form/db_restore.go
@@ -2,21 +2,22 @@ package form
import (
"encoding/json"
+ "mayfly-go/pkg/utils/timex"
"time"
)
// DbRestoreForm 数据库备份表单
type DbRestoreForm struct {
- Id uint64 `json:"id"`
- DbName string `binding:"required" json:"dbName"` // 数据库名
- StartTime time.Time `binding:"required" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
- PointInTime time.Time `json:"PointInTime"` // 指定时间
- DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
- DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
- DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
- Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
- IntervalDay uint64 `json:"intervalDay"` // 间隔天数: 为零表示单次执行,为正表示反复执行
- Repeated bool `json:"repeated"` // 是否重复执行
+ Id uint64 `json:"id"`
+ DbName string `binding:"required" json:"dbName"` // 数据库名
+ StartTime time.Time `binding:"required" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
+ PointInTime timex.NullTime `json:"pointInTime"` // 指定时间
+ DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
+ DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
+ DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
+ Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
+ IntervalDay uint64 `json:"intervalDay"` // 间隔天数: 为零表示单次执行,为正表示反复执行
+ Repeated bool `json:"repeated"` // 是否重复执行
}
func (restore *DbRestoreForm) UnmarshalJSON(data []byte) error {
diff --git a/server/internal/db/api/instance.go b/server/internal/db/api/instance.go
index 6853abea..84af750c 100644
--- a/server/internal/db/api/instance.go
+++ b/server/internal/db/api/instance.go
@@ -74,7 +74,7 @@ func (d *Instance) GetInstancePwd(rc *req.Ctx) {
instanceId := getInstanceId(rc.GinCtx)
instanceEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
biz.ErrIsNil(err, "获取数据库实例错误")
- instanceEntity.PwdDecrypt()
+ biz.ErrIsNil(instanceEntity.PwdDecrypt())
rc.ResData = instanceEntity.Password
}
@@ -105,7 +105,7 @@ func (d *Instance) GetDatabaseNames(rc *req.Ctx) {
instanceId := getInstanceId(rc.GinCtx)
instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
biz.ErrIsNil(err, "获取数据库实例错误")
- instance.PwdDecrypt()
+ biz.ErrIsNil(instance.PwdDecrypt())
res, err := d.InstanceApp.GetDatabases(instance)
biz.ErrIsNil(err)
rc.ResData = res
diff --git a/server/internal/db/api/vo/db_backup.go b/server/internal/db/api/vo/db_backup.go
index 22052cb4..4c66d6cc 100644
--- a/server/internal/db/api/vo/db_backup.go
+++ b/server/internal/db/api/vo/db_backup.go
@@ -2,27 +2,28 @@ package vo
import (
"encoding/json"
+ "mayfly-go/pkg/utils/timex"
"time"
)
-// DbBackupHistory 数据库备份任务
+// DbBackup 数据库备份任务
type DbBackup struct {
- Id uint64 `json:"id"`
- DbName string `json:"dbName"` // 数据库名
- CreateTime time.Time `json:"createTime"` // 创建时间: 2023-11-08 02:00:00
- StartTime time.Time `json:"startTime"` // 开始时间: 2023-11-08 02:00:00
- Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
- IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数: 为零表示单次执行,为正表示反复执行
- Enabled bool `json:"enabled"` // 是否启用
- LastTime time.Time `json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
- LastStatus string `json:"lastStatus"` // 最近一次执行状态
- LastResult string `json:"lastResult"` // 最近一次执行结果
- DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
- Name string `json:"name"` // 备份任务名称
+ Id uint64 `json:"id"`
+ DbName string `json:"dbName"` // 数据库名
+ CreateTime time.Time `json:"createTime"` // 创建时间
+ StartTime time.Time `json:"startTime"` // 开始时间
+ Interval time.Duration `json:"-"` // 间隔时间
+ IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数
+ Enabled bool `json:"enabled"` // 是否启用
+ LastTime timex.NullTime `json:"lastTime"` // 最近一次执行时间
+ LastStatus string `json:"lastStatus"` // 最近一次执行状态
+ LastResult string `json:"lastResult"` // 最近一次执行结果
+ DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
+ Name string `json:"name"` // 备份任务名称
}
-func (restore *DbBackup) MarshalJSON() ([]byte, error) {
+func (backup *DbBackup) MarshalJSON() ([]byte, error) {
type dbBackup DbBackup
- restore.IntervalDay = uint64(restore.Interval / time.Hour / 24)
- return json.Marshal((*dbBackup)(restore))
+ backup.IntervalDay = uint64(backup.Interval / time.Hour / 24)
+ return json.Marshal((*dbBackup)(backup))
}
diff --git a/server/internal/db/api/vo/db_restore.go b/server/internal/db/api/vo/db_restore.go
index 4f50cf3b..c57bf2fe 100644
--- a/server/internal/db/api/vo/db_restore.go
+++ b/server/internal/db/api/vo/db_restore.go
@@ -2,25 +2,26 @@ package vo
import (
"encoding/json"
+ "mayfly-go/pkg/utils/timex"
"time"
)
// DbRestore 数据库备份任务
type DbRestore struct {
- Id uint64 `json:"id"`
- DbName string `json:"dbName"` // 数据库名
- StartTime time.Time `json:"startTime"` // 开始时间: 2023-11-08 02:00:00
- Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
- IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数: 为零表示单次执行,为正表示反复执行
- Enabled bool `json:"enabled"` // 是否启用
- LastTime time.Time `json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
- LastStatus string `json:"lastStatus"` // 最近一次执行状态
- LastResult string `json:"lastResult"` // 最近一次执行结果
- PointInTime time.Time `json:"pointInTime"` // 指定数据库恢复的时间点
- DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
- DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
- DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
- DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
+ Id uint64 `json:"id"`
+ DbName string `json:"dbName"` // 数据库名
+ StartTime time.Time `json:"startTime"` // 开始时间
+ Interval time.Duration `json:"-"` // 间隔时间
+ IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数
+ Enabled bool `json:"enabled"` // 是否启用
+ LastTime timex.NullTime `json:"lastTime"` // 最近一次执行时间
+ LastStatus string `json:"lastStatus"` // 最近一次执行状态
+ LastResult string `json:"lastResult"` // 最近一次执行结果
+ PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间点
+ DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
+ DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
+ DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
+ DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
func (restore *DbRestore) MarshalJSON() ([]byte, error) {
diff --git a/server/internal/db/application/application.go b/server/internal/db/application/application.go
index a6f103eb..241ae04b 100644
--- a/server/internal/db/application/application.go
+++ b/server/internal/db/application/application.go
@@ -17,6 +17,7 @@ var (
dbBackupHistoryApp *DbBackupHistoryApp
dbRestoreApp *DbRestoreApp
dbRestoreHistoryApp *DbRestoreHistoryApp
+ dbBinlogApp *DbBinlogApp
)
var repositories *repository.Repositories
@@ -39,16 +40,26 @@ func Init() {
dbSqlExecApp = newDbSqlExecApp(persistence.GetDbSqlExecRepo())
dbSqlApp = newDbSqlApp(persistence.GetDbSqlRepo())
- dbBackupApp, err = newDbBackupApp(repositories)
+ dbBackupApp, err = newDbBackupApp(repositories, dbApp)
if err != nil {
panic(fmt.Sprintf("初始化 dbBackupApp 失败: %v", err))
}
- dbRestoreApp, err = newDbRestoreApp(repositories)
+ dbRestoreApp, err = newDbRestoreApp(repositories, dbApp)
if err != nil {
panic(fmt.Sprintf("初始化 dbRestoreApp 失败: %v", err))
}
dbBackupHistoryApp, err = newDbBackupHistoryApp(repositories)
+ if err != nil {
+ panic(fmt.Sprintf("初始化 dbBackupHistoryApp 失败: %v", err))
+ }
dbRestoreHistoryApp, err = newDbRestoreHistoryApp(repositories)
+ if err != nil {
+ panic(fmt.Sprintf("初始化 dbRestoreHistoryApp 失败: %v", err))
+ }
+ dbBinlogApp, err = newDbBinlogApp(repositories, dbApp)
+ if err != nil {
+ panic(fmt.Sprintf("初始化 dbBinlogApp 失败: %v", err))
+ }
})()
}
@@ -83,3 +94,7 @@ func GetDbRestoreApp() *DbRestoreApp {
func GetDbRestoreHistoryApp() *DbRestoreHistoryApp {
return dbRestoreHistoryApp
}
+
+func GetDbBinlogApp() *DbBinlogApp {
+ return dbBinlogApp
+}
diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go
index ef711003..0af1a9cd 100644
--- a/server/internal/db/application/db.go
+++ b/server/internal/db/application/db.go
@@ -156,7 +156,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
checkDb := dbName
// 兼容pgsql/dm db/schema模式
- if instance.Type == dbm.DbTypePostgres || instance.Type == dbm.DM {
+ if dbm.DbTypePostgres.Equal(instance.Type) || dbm.DM.Equal(instance.Type) {
ss := strings.Split(dbName, "/")
if len(ss) > 1 {
checkDb = ss[0]
@@ -167,7 +167,9 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
}
// 密码解密
- instance.PwdDecrypt()
+ if err := instance.PwdDecrypt(); err != nil {
+ return nil, errorx.NewBiz(err.Error())
+ }
return toDbInfo(instance, dbId, dbName, d.tagApp.ListTagPathByResource(consts.TagResourceTypeDb, db.Code)...), nil
})
}
diff --git a/server/internal/db/application/db_backup.go b/server/internal/db/application/db_backup.go
index 48b4b7fd..85b2b701 100644
--- a/server/internal/db/application/db_backup.go
+++ b/server/internal/db/application/db_backup.go
@@ -2,62 +2,155 @@ package application
import (
"context"
+ "encoding/binary"
+ "fmt"
+ "github.com/google/uuid"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
- "mayfly-go/internal/db/domain/service"
- service2 "mayfly-go/internal/db/infrastructure/service"
"mayfly-go/pkg/model"
+ "time"
)
-func newDbBackupApp(repositories *repository.Repositories) (*DbBackupApp, error) {
- binlogSvc, err := service2.NewDbBinlogSvc(repositories)
- if err != nil {
- return nil, err
- }
- dbBackupSvc, err := service2.NewDbBackupSvc(repositories, binlogSvc)
- if err != nil {
- return nil, err
- }
-
+func newDbBackupApp(repositories *repository.Repositories, dbApp Db) (*DbBackupApp, error) {
app := &DbBackupApp{
- repo: repositories.Backup,
- dbBackupSvc: dbBackupSvc,
+ backupRepo: repositories.Backup,
+ instanceRepo: repositories.Instance,
+ backupHistoryRepo: repositories.BackupHistory,
+ dbApp: dbApp,
}
+ scheduler, err := newDbScheduler[*entity.DbBackup](
+ repositories.Backup,
+ withRunBackupTask(app))
+ if err != nil {
+ return nil, err
+ }
+ app.scheduler = scheduler
return app, nil
}
type DbBackupApp struct {
- repo repository.DbBackup
- dbBackupSvc service.DbBackupSvc
+ backupRepo repository.DbBackup
+ instanceRepo repository.Instance
+ backupHistoryRepo repository.DbBackupHistory
+ dbApp Db
+ scheduler *dbScheduler[*entity.DbBackup]
+}
+
+func (app *DbBackupApp) Close() {
+ app.scheduler.Close()
}
func (app *DbBackupApp) Create(ctx context.Context, tasks ...*entity.DbBackup) error {
- return app.dbBackupSvc.AddTask(ctx, tasks...)
+ return app.scheduler.AddTask(ctx, tasks...)
}
func (app *DbBackupApp) Save(ctx context.Context, task *entity.DbBackup) error {
- return app.dbBackupSvc.UpdateTask(ctx, task)
+ return app.scheduler.UpdateTask(ctx, task)
}
func (app *DbBackupApp) Delete(ctx context.Context, taskId uint64) error {
// todo: 删除数据库备份历史文件
- return app.dbBackupSvc.DeleteTask(ctx, taskId)
+ return app.scheduler.DeleteTask(ctx, taskId)
}
func (app *DbBackupApp) Enable(ctx context.Context, taskId uint64) error {
- return app.dbBackupSvc.EnableTask(ctx, taskId)
+ return app.scheduler.EnableTask(ctx, taskId)
}
func (app *DbBackupApp) Disable(ctx context.Context, taskId uint64) error {
- return app.dbBackupSvc.DisableTask(ctx, taskId)
+ return app.scheduler.DisableTask(ctx, taskId)
+}
+
+func (app *DbBackupApp) Start(ctx context.Context, taskId uint64) error {
+ return app.scheduler.StartTask(ctx, taskId)
}
// GetPageList 分页获取数据库备份任务
func (app *DbBackupApp) GetPageList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
- return app.repo.GetDbBackupList(condition, pageParam, toEntity, orderBy...)
+ return app.backupRepo.GetDbBackupList(condition, pageParam, toEntity, orderBy...)
}
// GetDbNamesWithoutBackup 获取未配置定时备份的数据库名称
func (app *DbBackupApp) GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) {
- return app.repo.GetDbNamesWithoutBackup(instanceId, dbNames)
+ return app.backupRepo.GetDbNamesWithoutBackup(instanceId, dbNames)
+}
+
+func withRunBackupTask(app *DbBackupApp) dbSchedulerOption[*entity.DbBackup] {
+ return func(scheduler *dbScheduler[*entity.DbBackup]) {
+ scheduler.RunTask = app.runTask
+ }
+}
+
+func (app *DbBackupApp) runTask(ctx context.Context, task *entity.DbBackup) error {
+ id, err := NewIncUUID()
+ if err != nil {
+ return err
+ }
+ history := &entity.DbBackupHistory{
+ Uuid: id.String(),
+ DbBackupId: task.Id,
+ DbInstanceId: task.DbInstanceId,
+ DbName: task.DbName,
+ }
+ conn, err := app.dbApp.GetDbConnByInstanceId(task.DbInstanceId)
+ if err != nil {
+ return err
+ }
+ dbProgram := conn.GetDialect().GetDbProgram()
+ binlogInfo, err := dbProgram.Backup(ctx, history)
+ if err != nil {
+ return err
+ }
+ now := time.Now()
+ name := task.Name
+ if len(name) == 0 {
+ name = task.DbName
+ }
+ history.Name = fmt.Sprintf("%s[%s]", name, now.Format(time.DateTime))
+ history.CreateTime = now
+ history.BinlogFileName = binlogInfo.FileName
+ history.BinlogSequence = binlogInfo.Sequence
+ history.BinlogPosition = binlogInfo.Position
+
+ if err := app.backupHistoryRepo.Insert(ctx, history); err != nil {
+ return err
+ }
+ return nil
+}
+
+func NewIncUUID() (uuid.UUID, error) {
+ var uid uuid.UUID
+ now, seq, err := uuid.GetTime()
+ if err != nil {
+ return uid, err
+ }
+ timeHi := uint32((now >> 28) & 0xffffffff)
+ timeMid := uint16((now >> 12) & 0xffff)
+ timeLow := uint16(now & 0x0fff)
+ timeLow |= 0x1000 // Version 1
+
+ binary.BigEndian.PutUint32(uid[0:], timeHi)
+ binary.BigEndian.PutUint16(uid[4:], timeMid)
+ binary.BigEndian.PutUint16(uid[6:], timeLow)
+ binary.BigEndian.PutUint16(uid[8:], seq)
+
+ copy(uid[10:], uuid.NodeID())
+
+ return uid, nil
+}
+
+func newDbBackupHistoryApp(repositories *repository.Repositories) (*DbBackupHistoryApp, error) {
+ app := &DbBackupHistoryApp{
+ repo: repositories.BackupHistory,
+ }
+ return app, nil
+}
+
+type DbBackupHistoryApp struct {
+ repo repository.DbBackupHistory
+}
+
+// GetPageList 分页获取数据库备份历史
+func (app *DbBackupHistoryApp) GetPageList(condition *entity.DbBackupHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
+ return app.repo.GetHistories(condition, pageParam, toEntity, orderBy...)
}
diff --git a/server/internal/db/application/db_backup_history.go b/server/internal/db/application/db_backup_history.go
deleted file mode 100644
index 3724d207..00000000
--- a/server/internal/db/application/db_backup_history.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package application
-
-import (
- "mayfly-go/internal/db/domain/entity"
- "mayfly-go/internal/db/domain/repository"
- "mayfly-go/pkg/model"
-)
-
-func newDbBackupHistoryApp(repositories *repository.Repositories) (*DbBackupHistoryApp, error) {
- app := &DbBackupHistoryApp{
- repo: repositories.BackupHistory,
- }
- return app, nil
-}
-
-type DbBackupHistoryApp struct {
- repo repository.DbBackupHistory
-}
-
-// GetPageList 分页获取数据库备份历史
-func (app *DbBackupHistoryApp) GetPageList(condition *entity.DbBackupHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
- return app.repo.GetHistories(condition, pageParam, toEntity, orderBy...)
-}
diff --git a/server/internal/db/application/db_binlog.go b/server/internal/db/application/db_binlog.go
new file mode 100644
index 00000000..310d1ed3
--- /dev/null
+++ b/server/internal/db/application/db_binlog.go
@@ -0,0 +1,154 @@
+package application
+
+import (
+ "context"
+ "fmt"
+ "mayfly-go/internal/db/domain/entity"
+ "mayfly-go/internal/db/domain/repository"
+ "mayfly-go/pkg/logx"
+ "mayfly-go/pkg/utils/stringx"
+ "mayfly-go/pkg/utils/timex"
+ "sync"
+ "time"
+)
+
+const (
+ binlogDownloadInterval = time.Minute * 15
+)
+
+type DbBinlogApp struct {
+ binlogRepo repository.DbBinlog
+ binlogHistoryRepo repository.DbBinlogHistory
+ backupRepo repository.DbBackup
+ backupHistoryRepo repository.DbBackupHistory
+ dbApp Db
+ context context.Context
+ cancel context.CancelFunc
+ waitGroup sync.WaitGroup
+}
+
+var (
+ binlogResult = map[entity.TaskStatus]string{
+ entity.TaskDelay: "等待备份BINLOG",
+ entity.TaskReady: "准备备份BINLOG",
+ entity.TaskReserved: "BINLOG备份中",
+ entity.TaskSuccess: "BINLOG备份成功",
+ entity.TaskFailed: "BINLOG备份失败",
+ }
+)
+
+func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogApp, error) {
+ ctx, cancel := context.WithCancel(context.Background())
+ svc := &DbBinlogApp{
+ binlogRepo: repositories.Binlog,
+ binlogHistoryRepo: repositories.BinlogHistory,
+ backupRepo: repositories.Backup,
+ backupHistoryRepo: repositories.BackupHistory,
+ dbApp: dbApp,
+ context: ctx,
+ cancel: cancel,
+ }
+ svc.waitGroup.Add(1)
+ go svc.run()
+ return svc, nil
+}
+
+func (app *DbBinlogApp) runTask(ctx context.Context, backup *entity.DbBackup) error {
+ if err := app.AddTaskIfNotExists(ctx, entity.NewDbBinlog(backup.DbInstanceId)); err != nil {
+ return err
+ }
+ latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
+ binlogHistory, ok, err := app.binlogHistoryRepo.GetLatestHistory(backup.DbInstanceId)
+ if err != nil {
+ return err
+ }
+ if ok {
+ latestBinlogSequence = binlogHistory.Sequence
+ } else {
+ backupHistory, err := app.backupHistoryRepo.GetEarliestHistory(backup.DbInstanceId)
+ if err != nil {
+ return err
+ }
+ earliestBackupSequence = backupHistory.BinlogSequence
+ }
+ conn, err := app.dbApp.GetDbConnByInstanceId(backup.DbInstanceId)
+ if err != nil {
+ return err
+ }
+ dbProgram := conn.GetDialect().GetDbProgram()
+ binlogFiles, err := dbProgram.FetchBinlogs(ctx, false, earliestBackupSequence, latestBinlogSequence)
+ if err == nil {
+ err = app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, backup.DbInstanceId, binlogFiles)
+ }
+ taskStatus := entity.TaskSuccess
+ if err != nil {
+ taskStatus = entity.TaskFailed
+ }
+ task := &entity.DbBinlog{}
+ task.Id = backup.DbInstanceId
+ return app.updateCurTask(ctx, taskStatus, err, task)
+}
+
+func (app *DbBinlogApp) run() {
+ defer app.waitGroup.Done()
+
+ for !app.closed() {
+ app.fetchFromAllInstances()
+ timex.SleepWithContext(app.context, binlogDownloadInterval)
+ }
+}
+
+func (app *DbBinlogApp) fetchFromAllInstances() {
+ tasks, err := app.backupRepo.ListRepeating()
+ if err != nil {
+ logx.Errorf("DbBinlogApp: 获取数据库备份任务失败: %s", err.Error())
+ return
+ }
+ for _, task := range tasks {
+ if app.closed() {
+ break
+ }
+ if err := app.runTask(app.context, task); err != nil {
+ logx.Errorf("DbBinlogApp: 下载 binlog 文件失败: %s", err.Error())
+ return
+ }
+ }
+}
+
+func (app *DbBinlogApp) Close() {
+ app.cancel()
+ app.waitGroup.Wait()
+}
+
+func (app *DbBinlogApp) closed() bool {
+ return app.context.Err() != nil
+}
+
+func (app *DbBinlogApp) AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error {
+ if err := app.binlogRepo.AddTaskIfNotExists(ctx, task); err != nil {
+ return err
+ }
+ if task.Id == 0 {
+ return nil
+ }
+ return nil
+}
+
+func (app *DbBinlogApp) DeleteTask(ctx context.Context, taskId uint64) error {
+ // todo: 删除 Binlog 历史文件
+ if err := app.binlogRepo.DeleteById(ctx, taskId); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (app *DbBinlogApp) updateCurTask(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbBinlog) error {
+ task.LastStatus = status
+ var result = binlogResult[status]
+ if lastErr != nil {
+ result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr)
+ }
+ task.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
+ task.LastTime = timex.NewNullTime(time.Now())
+ return app.binlogRepo.UpdateById(ctx, task, "last_status", "last_result", "last_time")
+}
diff --git a/server/internal/db/application/db_restore.go b/server/internal/db/application/db_restore.go
index c746fc4b..cc8b0357 100644
--- a/server/internal/db/application/db_restore.go
+++ b/server/internal/db/application/db_restore.go
@@ -2,57 +2,190 @@ package application
import (
"context"
+ "mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
- "mayfly-go/internal/db/domain/service"
- serviceImpl "mayfly-go/internal/db/infrastructure/service"
"mayfly-go/pkg/model"
+ "time"
)
-func newDbRestoreApp(repositories *repository.Repositories) (*DbRestoreApp, error) {
- dbRestoreSvc, err := serviceImpl.NewDbRestoreSvc(repositories)
+func newDbRestoreApp(repositories *repository.Repositories, dbApp Db) (*DbRestoreApp, error) {
+ app := &DbRestoreApp{
+ restoreRepo: repositories.Restore,
+ instanceRepo: repositories.Instance,
+ backupHistoryRepo: repositories.BackupHistory,
+ restoreHistoryRepo: repositories.RestoreHistory,
+ binlogHistoryRepo: repositories.BinlogHistory,
+ dbApp: dbApp,
+ }
+ scheduler, err := newDbScheduler[*entity.DbRestore](
+ repositories.Restore,
+ withRunRestoreTask(app))
if err != nil {
return nil, err
}
- app := &DbRestoreApp{
- repo: repositories.Restore,
- dbRestoreSvc: dbRestoreSvc,
- }
+ app.scheduler = scheduler
return app, nil
}
type DbRestoreApp struct {
- repo repository.DbRestore
- dbRestoreSvc service.DbRestoreSvc
+ restoreRepo repository.DbRestore
+ instanceRepo repository.Instance
+ backupHistoryRepo repository.DbBackupHistory
+ restoreHistoryRepo repository.DbRestoreHistory
+ binlogHistoryRepo repository.DbBinlogHistory
+ dbApp Db
+ scheduler *dbScheduler[*entity.DbRestore]
+}
+
+func (app *DbRestoreApp) Close() {
+ app.scheduler.Close()
}
func (app *DbRestoreApp) Create(ctx context.Context, tasks ...*entity.DbRestore) error {
- return app.dbRestoreSvc.AddTask(ctx, tasks...)
+ return app.scheduler.AddTask(ctx, tasks...)
}
func (app *DbRestoreApp) Save(ctx context.Context, task *entity.DbRestore) error {
- return app.dbRestoreSvc.UpdateTask(ctx, task)
+ return app.scheduler.UpdateTask(ctx, task)
}
func (app *DbRestoreApp) Delete(ctx context.Context, taskId uint64) error {
// todo: 删除数据库恢复历史文件
- return app.dbRestoreSvc.DeleteTask(ctx, taskId)
+ return app.scheduler.DeleteTask(ctx, taskId)
}
func (app *DbRestoreApp) Enable(ctx context.Context, taskId uint64) error {
- return app.dbRestoreSvc.EnableTask(ctx, taskId)
+ return app.scheduler.EnableTask(ctx, taskId)
}
func (app *DbRestoreApp) Disable(ctx context.Context, taskId uint64) error {
- return app.dbRestoreSvc.DisableTask(ctx, taskId)
+ return app.scheduler.DisableTask(ctx, taskId)
}
// GetPageList 分页获取数据库恢复任务
func (app *DbRestoreApp) GetPageList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
- return app.repo.GetDbRestoreList(condition, pageParam, toEntity, orderBy...)
+ return app.restoreRepo.GetDbRestoreList(condition, pageParam, toEntity, orderBy...)
}
// GetDbNamesWithoutRestore 获取未配置定时恢复的数据库名称
func (app *DbRestoreApp) GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) {
- return app.repo.GetDbNamesWithoutRestore(instanceId, dbNames)
+ return app.restoreRepo.GetDbNamesWithoutRestore(instanceId, dbNames)
+}
+
+func (app *DbRestoreApp) runTask(ctx context.Context, task *entity.DbRestore) error {
+ conn, err := app.dbApp.GetDbConnByInstanceId(task.DbInstanceId)
+ if err != nil {
+ return err
+ }
+ dbProgram := conn.GetDialect().GetDbProgram()
+ if task.PointInTime.Valid {
+ latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
+ binlogHistory, ok, err := app.binlogHistoryRepo.GetLatestHistory(task.DbInstanceId)
+ if err != nil {
+ return err
+ }
+ if ok {
+ latestBinlogSequence = binlogHistory.Sequence
+ } else {
+ backupHistory, err := app.backupHistoryRepo.GetEarliestHistory(task.DbInstanceId)
+ if err != nil {
+ return err
+ }
+ earliestBackupSequence = backupHistory.BinlogSequence
+ }
+ binlogFiles, err := dbProgram.FetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence)
+ if err != nil {
+ return err
+ }
+ if err := app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, task.DbInstanceId, binlogFiles); err != nil {
+ return err
+ }
+ if err := app.restorePointInTime(ctx, dbProgram, task); err != nil {
+ return err
+ }
+ } else {
+ if err := app.restoreBackupHistory(ctx, dbProgram, task); err != nil {
+ return err
+ }
+ }
+
+ history := &entity.DbRestoreHistory{
+ CreateTime: time.Now(),
+ DbRestoreId: task.Id,
+ }
+ if err := app.restoreHistoryRepo.Insert(ctx, history); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (app *DbRestoreApp) restorePointInTime(ctx context.Context, program dbm.DbProgram, task *entity.DbRestore) error {
+ binlogHistory, err := app.binlogHistoryRepo.GetHistoryByTime(task.DbInstanceId, task.PointInTime.Time)
+ if err != nil {
+ return err
+ }
+ position, err := program.GetBinlogEventPositionAtOrAfterTime(ctx, binlogHistory.FileName, task.PointInTime.Time)
+ if err != nil {
+ return err
+ }
+ target := &entity.BinlogInfo{
+ FileName: binlogHistory.FileName,
+ Sequence: binlogHistory.Sequence,
+ Position: position,
+ }
+ backupHistory, err := app.backupHistoryRepo.GetLatestHistory(task.DbInstanceId, task.DbName, target)
+ if err != nil {
+ return err
+ }
+ start := &entity.BinlogInfo{
+ FileName: backupHistory.BinlogFileName,
+ Sequence: backupHistory.BinlogSequence,
+ Position: backupHistory.BinlogPosition,
+ }
+ binlogHistories, err := app.binlogHistoryRepo.GetHistories(task.DbInstanceId, start, target)
+ if err != nil {
+ return err
+ }
+ restoreInfo := &dbm.RestoreInfo{
+ BackupHistory: backupHistory,
+ BinlogHistories: binlogHistories,
+ StartPosition: backupHistory.BinlogPosition,
+ TargetPosition: target.Position,
+ TargetTime: task.PointInTime.Time,
+ }
+ if err := program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid); err != nil {
+ return err
+ }
+ return program.ReplayBinlog(ctx, task.DbName, task.DbName, restoreInfo)
+}
+
+func (app *DbRestoreApp) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, task *entity.DbRestore) error {
+ backupHistory := &entity.DbBackupHistory{}
+ if err := app.backupHistoryRepo.GetById(backupHistory, task.DbBackupHistoryId); err != nil {
+ return err
+ }
+ return program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid)
+}
+
+func withRunRestoreTask(app *DbRestoreApp) dbSchedulerOption[*entity.DbRestore] {
+ return func(scheduler *dbScheduler[*entity.DbRestore]) {
+ scheduler.RunTask = app.runTask
+ }
+}
+
+func newDbRestoreHistoryApp(repositories *repository.Repositories) (*DbRestoreHistoryApp, error) {
+ app := &DbRestoreHistoryApp{
+ repo: repositories.RestoreHistory,
+ }
+ return app, nil
+}
+
+type DbRestoreHistoryApp struct {
+ repo repository.DbRestoreHistory
+}
+
+// GetPageList 分页获取数据库备份历史
+func (app *DbRestoreHistoryApp) GetPageList(condition *entity.DbRestoreHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
+ return app.repo.GetDbRestoreHistories(condition, pageParam, toEntity, orderBy...)
}
diff --git a/server/internal/db/application/db_restore_history.go b/server/internal/db/application/db_restore_history.go
deleted file mode 100644
index 30b0c914..00000000
--- a/server/internal/db/application/db_restore_history.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package application
-
-import (
- "mayfly-go/internal/db/domain/entity"
- "mayfly-go/internal/db/domain/repository"
- "mayfly-go/pkg/model"
-)
-
-func newDbRestoreHistoryApp(repositories *repository.Repositories) (*DbRestoreHistoryApp, error) {
- app := &DbRestoreHistoryApp{
- repo: repositories.RestoreHistory,
- }
- return app, nil
-}
-
-type DbRestoreHistoryApp struct {
- repo repository.DbRestoreHistory
-}
-
-// GetPageList 分页获取数据库备份历史
-func (app *DbRestoreHistoryApp) GetPageList(condition *entity.DbRestoreHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
- return app.repo.GetDbRestoreHistories(condition, pageParam, toEntity, orderBy...)
-}
diff --git a/server/internal/db/application/db_scheduler.go b/server/internal/db/application/db_scheduler.go
new file mode 100644
index 00000000..fca76fae
--- /dev/null
+++ b/server/internal/db/application/db_scheduler.go
@@ -0,0 +1,235 @@
+package application
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "mayfly-go/internal/db/domain/entity"
+ "mayfly-go/internal/db/domain/repository"
+ "mayfly-go/pkg/queue"
+ "mayfly-go/pkg/utils/anyx"
+ "mayfly-go/pkg/utils/stringx"
+ "mayfly-go/pkg/utils/timex"
+ "sync"
+ "time"
+)
+
+const sleepAfterError = time.Minute
+
+type dbScheduler[T entity.DbTask] struct {
+ mutex sync.Mutex
+ waitGroup sync.WaitGroup
+ queue *queue.DelayQueue[T]
+ context context.Context
+ cancel context.CancelFunc
+ RunTask func(ctx context.Context, task T) error
+ taskRepo repository.DbTask[T]
+}
+
+type dbSchedulerOption[T entity.DbTask] func(*dbScheduler[T])
+
+func newDbScheduler[T entity.DbTask](taskRepo repository.DbTask[T], opts ...dbSchedulerOption[T]) (*dbScheduler[T], error) {
+ ctx, cancel := context.WithCancel(context.Background())
+ scheduler := &dbScheduler[T]{
+ taskRepo: taskRepo,
+ queue: queue.NewDelayQueue[T](0),
+ context: ctx,
+ cancel: cancel,
+ }
+ for _, opt := range opts {
+ opt(scheduler)
+ }
+ if scheduler.RunTask == nil {
+ return nil, errors.New("数据库任务调度器没有设置 RunTask")
+ }
+ if err := scheduler.loadTask(context.Background()); err != nil {
+ return nil, err
+ }
+ scheduler.waitGroup.Add(1)
+ go scheduler.run()
+ return scheduler, nil
+}
+
+func (s *dbScheduler[T]) updateTaskStatus(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error {
+ base := task.GetTaskBase()
+ base.LastStatus = status
+ var result = task.MessageWithStatus(status)
+ if lastErr != nil {
+ result = fmt.Sprintf("%v: %v", result, lastErr)
+ }
+ base.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
+ base.LastTime = timex.NewNullTime(time.Now())
+ return s.taskRepo.UpdateTaskStatus(ctx, task)
+}
+
+func (s *dbScheduler[T]) UpdateTask(ctx context.Context, task T) error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if err := s.taskRepo.UpdateById(ctx, task); err != nil {
+ return err
+ }
+
+ oldTask, ok := s.queue.Remove(ctx, task.GetId())
+ if !ok {
+ return errors.New("任务不存在")
+ }
+ oldTask.Update(task)
+ if !oldTask.Schedule() {
+ return nil
+ }
+ if !s.queue.Enqueue(ctx, oldTask) {
+ return errors.New("任务入队失败")
+ }
+ return nil
+}
+
+func (s *dbScheduler[T]) run() {
+ defer s.waitGroup.Done()
+
+ for !s.closed() {
+ time.Sleep(time.Second)
+
+ s.mutex.Lock()
+ task, ok := s.queue.TryDequeue()
+ if !ok {
+ s.mutex.Unlock()
+ continue
+ }
+ if err := s.updateTaskStatus(s.context, entity.TaskReserved, nil, task); err != nil {
+ s.mutex.Unlock()
+ timex.SleepWithContext(s.context, sleepAfterError)
+ continue
+ }
+ s.mutex.Unlock()
+
+ errRun := s.RunTask(s.context, task)
+ taskStatus := entity.TaskSuccess
+ if errRun != nil {
+ taskStatus = entity.TaskFailed
+ }
+ s.mutex.Lock()
+ if err := s.updateTaskStatus(s.context, taskStatus, errRun, task); err != nil {
+ s.mutex.Unlock()
+ timex.SleepWithContext(s.context, sleepAfterError)
+ continue
+ }
+ task.Schedule()
+ if !task.IsFinished() {
+ s.queue.Enqueue(s.context, task)
+ }
+ s.mutex.Unlock()
+ }
+}
+
+func (s *dbScheduler[T]) Close() {
+ s.cancel()
+ s.waitGroup.Wait()
+}
+
+func (s *dbScheduler[T]) closed() bool {
+ return s.context.Err() != nil
+}
+
+func (s *dbScheduler[T]) loadTask(ctx context.Context) error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ tasks, err := s.taskRepo.ListToDo()
+ if err != nil {
+ return err
+ }
+ for _, task := range tasks {
+ if !task.Schedule() {
+ continue
+ }
+ s.queue.Enqueue(ctx, task)
+ }
+ return nil
+}
+
+func (s *dbScheduler[T]) AddTask(ctx context.Context, tasks ...T) error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ for _, task := range tasks {
+ if err := s.taskRepo.AddTask(ctx, task); err != nil {
+ return err
+ }
+ if !task.Schedule() {
+ continue
+ }
+ s.queue.Enqueue(ctx, task)
+ }
+ return nil
+}
+
+func (s *dbScheduler[T]) DeleteTask(ctx context.Context, taskId uint64) error {
+ // todo: 删除数据库备份历史文件
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if err := s.taskRepo.DeleteById(ctx, taskId); err != nil {
+ return err
+ }
+ s.queue.Remove(ctx, taskId)
+ return nil
+}
+
+func (s *dbScheduler[T]) EnableTask(ctx context.Context, taskId uint64) error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ task := anyx.DeepZero[T]()
+ if err := s.taskRepo.GetById(task, taskId); err != nil {
+ return err
+ }
+ if task.IsEnabled() {
+ return nil
+ }
+ task.GetTaskBase().Enabled = true
+ if err := s.taskRepo.UpdateEnabled(ctx, taskId, true); err != nil {
+ return err
+ }
+ s.queue.Remove(ctx, taskId)
+ if !task.Schedule() {
+ return nil
+ }
+ s.queue.Enqueue(ctx, task)
+ return nil
+}
+
+func (s *dbScheduler[T]) DisableTask(ctx context.Context, taskId uint64) error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ task := anyx.DeepZero[T]()
+ if err := s.taskRepo.GetById(task, taskId); err != nil {
+ return err
+ }
+ if !task.IsEnabled() {
+ return nil
+ }
+ if err := s.taskRepo.UpdateEnabled(ctx, taskId, false); err != nil {
+ return err
+ }
+ s.queue.Remove(ctx, taskId)
+ return nil
+}
+
+func (s *dbScheduler[T]) StartTask(ctx context.Context, taskId uint64) error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ task := anyx.DeepZero[T]()
+ if err := s.taskRepo.GetById(task, taskId); err != nil {
+ return err
+ }
+ if !task.IsEnabled() {
+ return errors.New("任务未启用")
+ }
+ s.queue.Remove(ctx, taskId)
+ task.GetTaskBase().Deadline = time.Now()
+ s.queue.Enqueue(ctx, task)
+ return nil
+}
diff --git a/server/internal/db/application/instance.go b/server/internal/db/application/instance.go
index 7508e158..40dde07a 100644
--- a/server/internal/db/application/instance.go
+++ b/server/internal/db/application/instance.go
@@ -2,6 +2,7 @@ package application
import (
"context"
+ "mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
@@ -77,7 +78,9 @@ func (app *instanceAppImpl) Save(ctx context.Context, instanceEntity *entity.DbI
if err == nil {
return errorx.NewBiz("该数据库实例已存在")
}
- instanceEntity.PwdEncrypt()
+ if err := instanceEntity.PwdEncrypt(); err != nil {
+ return errorx.NewBiz(err.Error())
+ }
return app.Insert(ctx, instanceEntity)
}
@@ -85,7 +88,9 @@ func (app *instanceAppImpl) Save(ctx context.Context, instanceEntity *entity.DbI
if err == nil && oldInstance.Id != instanceEntity.Id {
return errorx.NewBiz("该数据库实例已存在")
}
- instanceEntity.PwdEncrypt()
+ if err := instanceEntity.PwdEncrypt(); err != nil {
+ return errorx.NewBiz(err.Error())
+ }
return app.UpdateById(ctx, instanceEntity)
}
@@ -95,7 +100,7 @@ func (app *instanceAppImpl) Delete(ctx context.Context, id uint64) error {
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) {
ed.Network = ed.GetNetwork()
- metaDb := ed.Type.MetaDbName()
+ metaDb := dbm.ToDbType(ed.Type).MetaDbName()
dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn()
if err != nil {
diff --git a/server/internal/db/config/config.go b/server/internal/db/config/config.go
index 555205e6..94a7c298 100644
--- a/server/internal/db/config/config.go
+++ b/server/internal/db/config/config.go
@@ -2,6 +2,8 @@ package config
import (
sysapp "mayfly-go/internal/sys/application"
+ "path/filepath"
+ "runtime"
)
const (
@@ -9,6 +11,7 @@ const (
ConfigKeyDbQueryMaxCount string = "DbQueryMaxCount" // 数据库查询的最大数量
ConfigKeyDbBackupRestore string = "DbBackupRestore" // 数据库备份
ConfigKeyDbMysqlBin string = "MysqlBin" // mysql可执行文件配置
+ ConfigKeyDbMariaDbBin string = "MariaDbBin" // mariadb可执行文件配置
)
// 获取数据库最大查询数量配置
@@ -36,7 +39,7 @@ func GetDbBackupRestore() *DbBackupRestore {
if backupPath == "" {
backupPath = "./db/backup"
}
- dbrc.BackupPath = backupPath
+ dbrc.BackupPath = filepath.Join(backupPath)
return dbrc
}
@@ -50,35 +53,39 @@ type MysqlBin struct {
}
// 获取数据库备份配置
-func GetMysqlBin() *MysqlBin {
- c := sysapp.GetConfigApp().GetConfig(ConfigKeyDbMysqlBin)
+func GetMysqlBin(configKey string) *MysqlBin {
+ c := sysapp.GetConfigApp().GetConfig(configKey)
jm := c.GetJsonMap()
mbc := new(MysqlBin)
path := jm["path"]
if path == "" {
- path = "./db/backup"
+ path = "./db/mysql/bin"
}
- mbc.Path = path
+ mbc.Path = filepath.Join(path)
+ var extName string
+ if runtime.GOOS == "windows" {
+ extName = ".exe"
+ }
mysqlPath := jm["mysql"]
if mysqlPath == "" {
- mysqlPath = path + "mysql"
+ mysqlPath = filepath.Join(path, "mysql"+extName)
}
- mbc.MysqlPath = mysqlPath
+ mbc.MysqlPath = filepath.Join(mysqlPath)
mysqldumpPath := jm["mysqldump"]
if mysqldumpPath == "" {
- mysqldumpPath = path + "mysqldump"
+ mysqldumpPath = filepath.Join(path, "mysqldump"+extName)
}
- mbc.MysqldumpPath = mysqldumpPath
+ mbc.MysqldumpPath = filepath.Join(mysqldumpPath)
mysqlbinlogPath := jm["mysqlbinlog"]
if mysqlbinlogPath == "" {
- mysqlbinlogPath = path + "mysqlbinlog"
+ mysqlbinlogPath = filepath.Join(path, "mysqlbinlog"+extName)
}
- mbc.MysqlbinlogPath = mysqlbinlogPath
+ mbc.MysqlbinlogPath = filepath.Join(mysqlbinlogPath)
return mbc
}
diff --git a/server/internal/db/dbm/db_program.go b/server/internal/db/dbm/db_program.go
new file mode 100644
index 00000000..ce06736a
--- /dev/null
+++ b/server/internal/db/dbm/db_program.go
@@ -0,0 +1,32 @@
+package dbm
+
+import (
+ "context"
+ "mayfly-go/internal/db/domain/entity"
+ "path/filepath"
+ "time"
+)
+
+type DbProgram interface {
+ Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error)
+ FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error)
+ ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) error
+ RestoreBackupHistory(ctx context.Context, dbName string, dbBackupId uint64, dbBackupHistoryUuid string) error
+ GetBinlogEventPositionAtOrAfterTime(ctx context.Context, binlogName string, targetTime time.Time) (position int64, parseErr error)
+}
+
+type RestoreInfo struct {
+ BackupHistory *entity.DbBackupHistory
+ BinlogHistories []*entity.DbBinlogHistory
+ StartPosition int64
+ TargetPosition int64
+ TargetTime time.Time
+}
+
+func (ri *RestoreInfo) GetBinlogPaths(binlogDir string) []string {
+ files := make([]string, 0, len(ri.BinlogHistories))
+ for _, history := range ri.BinlogHistories {
+ files = append(files, filepath.Join(binlogDir, history.FileName))
+ }
+ return files
+}
diff --git a/server/internal/db/infrastructure/service/db_instance.go b/server/internal/db/dbm/db_program_mysql.go
similarity index 63%
rename from server/internal/db/infrastructure/service/db_instance.go
rename to server/internal/db/dbm/db_program_mysql.go
index 9e312062..7282613f 100644
--- a/server/internal/db/infrastructure/service/db_instance.go
+++ b/server/internal/db/dbm/db_program_mysql.go
@@ -1,4 +1,4 @@
-package service
+package dbm
import (
"bufio"
@@ -19,112 +19,55 @@ import (
"golang.org/x/sync/singleflight"
"mayfly-go/internal/db/config"
- "mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
- "mayfly-go/internal/db/domain/repository"
- "mayfly-go/internal/db/domain/service"
-
"mayfly-go/pkg/logx"
- "mayfly-go/pkg/utils/structx"
)
-// BinlogFile is the metadata of the MySQL binlog file.
-type BinlogFile struct {
- Name string
- Size int64
+var _ DbProgram = (*DbProgramMysql)(nil)
- // Sequence is parsed from Name and is for the sorting purpose.
- Sequence int64
- FirstEventTime time.Time
- Downloaded bool
+type DbProgramMysql struct {
+ dbConn *DbConn
+ // mysqlBin 用于集成测试
+ mysqlBin *config.MysqlBin
+ // backupPath 用于集成测试
+ backupPath string
}
-func newBinlogFile(name string, size int64) (*BinlogFile, error) {
- _, seq, err := ParseBinlogName(name)
- if err != nil {
- return nil, err
- }
- return &BinlogFile{Name: name, Size: size, Sequence: seq}, nil
-}
-
-var _ service.DbInstanceSvc = (*DbInstanceSvcImpl)(nil)
-
-type DbInstanceSvcImpl struct {
- instanceId uint64
- dbInfo *dbm.DbInfo
- backupHistoryRepo repository.DbBackupHistory
- binlogHistoryRepo repository.DbBinlogHistory
-}
-
-func NewDbInstanceSvc(instance *entity.DbInstance, repositories *repository.Repositories) *DbInstanceSvcImpl {
- dbInfo := new(dbm.DbInfo)
- _ = structx.Copy(dbInfo, instance)
- return &DbInstanceSvcImpl{
- instanceId: instance.Id,
- dbInfo: dbInfo,
- backupHistoryRepo: repositories.BackupHistory,
- binlogHistoryRepo: repositories.BinlogHistory,
+func NewDbProgramMysql(dbConn *DbConn) *DbProgramMysql {
+ return &DbProgramMysql{
+ dbConn: dbConn,
}
}
-type RestoreInfo struct {
- backupHistory *entity.DbBackupHistory
- binlogHistories []*entity.DbBinlogHistory
- startPosition int64
- targetPosition int64
- targetTime time.Time
+func (svc *DbProgramMysql) dbInfo() *DbInfo {
+ return svc.dbConn.Info
}
-func (ri *RestoreInfo) getBinlogFiles(binlogDir string) []string {
- files := make([]string, 0, len(ri.binlogHistories))
- for _, history := range ri.binlogHistories {
- files = append(files, filepath.Join(binlogDir, history.FileName))
+func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
+ if svc.mysqlBin != nil {
+ return svc.mysqlBin
}
- return files
+ var mysqlBin *config.MysqlBin
+ switch svc.dbInfo().Type {
+ default:
+ mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
+ }
+ return mysqlBin
}
-func (svc *DbInstanceSvcImpl) getBinlogFilePath(fileName string) string {
- return filepath.Join(getBinlogDir(svc.instanceId), fileName)
+func (svc *DbProgramMysql) getBackupPath() string {
+ if len(svc.backupPath) > 0 {
+ return svc.backupPath
+ }
+ return config.GetDbBackupRestore().BackupPath
}
-func (svc *DbInstanceSvcImpl) GetRestoreInfo(ctx context.Context, dbName string, targetTime time.Time) (*RestoreInfo, error) {
- binlogHistory, err := svc.binlogHistoryRepo.GetHistoryByTime(svc.instanceId, targetTime)
- if err != nil {
- return nil, err
- }
- position, err := getBinlogEventPositionAtOrAfterTime(ctx, svc.getBinlogFilePath(binlogHistory.FileName), targetTime)
- if err != nil {
- return nil, err
- }
- target := &entity.BinlogInfo{
- FileName: binlogHistory.FileName,
- Sequence: binlogHistory.Sequence,
- Position: position,
- }
- backupHistory, err := svc.backupHistoryRepo.GetLatestHistory(svc.instanceId, dbName, target)
- if err != nil {
- return nil, err
- }
- start := &entity.BinlogInfo{
- FileName: backupHistory.BinlogFileName,
- Sequence: backupHistory.BinlogSequence,
- Position: backupHistory.BinlogPosition,
- }
- binlogHistories, err := svc.binlogHistoryRepo.GetHistories(svc.instanceId, start, target)
- if err != nil {
- return nil, err
- }
- return &RestoreInfo{
- backupHistory: backupHistory,
- binlogHistories: binlogHistories,
- startPosition: backupHistory.BinlogPosition,
- targetPosition: target.Position,
- targetTime: targetTime,
- }, nil
+func (svc *DbProgramMysql) GetBinlogFilePath(fileName string) string {
+ return filepath.Join(svc.getBinlogDir(svc.dbInfo().InstanceId), fileName)
}
-func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error) {
- dir := getDbBackupDir(backupHistory.DbInstanceId, backupHistory.DbBackupId)
+func (svc *DbProgramMysql) Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error) {
+ dir := svc.getDbBackupDir(backupHistory.DbInstanceId, backupHistory.DbBackupId)
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return nil, err
}
@@ -134,10 +77,10 @@ func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.
}()
args := []string{
- "--host", svc.dbInfo.Host,
- "--port", strconv.Itoa(svc.dbInfo.Port),
- "--user", svc.dbInfo.Username,
- "--password=" + svc.dbInfo.Password,
+ "--host", svc.dbInfo().Host,
+ "--port", strconv.Itoa(svc.dbInfo().Port),
+ "--user", svc.dbInfo().Username,
+ "--password=" + svc.dbInfo().Password,
"--add-drop-database",
"--result-file", tmpFile,
"--single-transaction",
@@ -145,7 +88,7 @@ func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.
"--databases", backupHistory.DbName,
}
- cmd := exec.CommandContext(ctx, mysqldumpPath(), args...)
+ cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqldumpPath, args...)
logx.Debugf("backup database using mysqldump binary: %s", cmd.String())
if err := runCmd(cmd); err != nil {
logx.Errorf("运行 mysqldump 程序失败: %v", err)
@@ -174,15 +117,17 @@ func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.
return binlogInfo, nil
}
-func (svc *DbInstanceSvcImpl) RestoreBackup(ctx context.Context, database, fileName string) error {
+func (svc *DbProgramMysql) RestoreBackupHistory(ctx context.Context, dbName string, dbBackupId uint64, dbBackupHistoryUuid string) error {
args := []string{
- "--host", svc.dbInfo.Host,
- "--port", strconv.Itoa(svc.dbInfo.Port),
- "--database", database,
- "--user", svc.dbInfo.Username,
- "--password=" + svc.dbInfo.Password,
+ "--host", svc.dbInfo().Host,
+ "--port", strconv.Itoa(svc.dbInfo().Port),
+ "--database", dbName,
+ "--user", svc.dbInfo().Username,
+ "--password=" + svc.dbInfo().Password,
}
+ fileName := filepath.Join(svc.getDbBackupDir(svc.dbInfo().InstanceId, dbBackupId),
+ fmt.Sprintf("%v.sql", dbBackupHistoryUuid))
file, err := os.Open(fileName)
if err != nil {
return errors.Wrap(err, "打开备份文件失败")
@@ -191,7 +136,7 @@ func (svc *DbInstanceSvcImpl) RestoreBackup(ctx context.Context, database, fileN
_ = file.Close()
}()
- cmd := exec.CommandContext(ctx, mysqlPath(), args...)
+ cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlPath, args...)
cmd.Stdin = file
logx.Debug("恢复数据库: ", cmd.String())
if err := runCmd(cmd); err != nil {
@@ -201,42 +146,14 @@ func (svc *DbInstanceSvcImpl) RestoreBackup(ctx context.Context, database, fileN
return nil
}
-func (svc *DbInstanceSvcImpl) Restore(ctx context.Context, task *entity.DbRestore) error {
- if task.PointInTime.IsZero() {
- backupHistory := &entity.DbBackupHistory{}
- err := svc.backupHistoryRepo.GetById(backupHistory, task.DbBackupHistoryId)
- if err != nil {
- return err
- }
- fileName := filepath.Join(getDbBackupDir(backupHistory.DbInstanceId, backupHistory.DbBackupId),
- fmt.Sprintf("%v.sql", backupHistory.Uuid))
- return svc.RestoreBackup(ctx, task.DbName, fileName)
- }
-
- if err := svc.FetchBinlogs(ctx, true); err != nil {
- return err
- }
- restoreInfo, err := svc.GetRestoreInfo(ctx, task.DbName, task.PointInTime)
- if err != nil {
- return err
- }
- fileName := filepath.Join(getDbBackupDir(restoreInfo.backupHistory.DbInstanceId, restoreInfo.backupHistory.DbBackupId),
- fmt.Sprintf("%s.sql", restoreInfo.backupHistory.Uuid))
-
- if err := svc.RestoreBackup(ctx, task.DbName, fileName); err != nil {
- return err
- }
- return svc.ReplayBinlogToDatabase(ctx, task.DbName, task.DbName, restoreInfo)
-}
-
// Download binlog files on server.
-func (svc *DbInstanceSvcImpl) downloadBinlogFilesOnServer(ctx context.Context, binlogFilesOnServerSorted []*BinlogFile, downloadLatestBinlogFile bool) error {
+func (svc *DbProgramMysql) downloadBinlogFilesOnServer(ctx context.Context, binlogFilesOnServerSorted []*entity.BinlogFile, downloadLatestBinlogFile bool) error {
if len(binlogFilesOnServerSorted) == 0 {
logx.Debug("No binlog file found on server to download")
return nil
}
- if err := os.MkdirAll(getBinlogDir(svc.instanceId), os.ModePerm); err != nil {
- return errors.Wrapf(err, "创建 binlog 目录失败: %q", getBinlogDir(svc.instanceId))
+ if err := os.MkdirAll(svc.getBinlogDir(svc.dbInfo().InstanceId), os.ModePerm); err != nil {
+ return errors.Wrapf(err, "创建 binlog 目录失败: %q", svc.getBinlogDir(svc.dbInfo().InstanceId))
}
latestBinlogFileOnServer := binlogFilesOnServerSorted[len(binlogFilesOnServerSorted)-1]
for _, fileOnServer := range binlogFilesOnServerSorted {
@@ -244,7 +161,7 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFilesOnServer(ctx context.Context, b
if isLatest && !downloadLatestBinlogFile {
continue
}
- binlogFilePath := filepath.Join(getBinlogDir(svc.instanceId), fileOnServer.Name)
+ binlogFilePath := filepath.Join(svc.getBinlogDir(svc.dbInfo().InstanceId), fileOnServer.Name)
logx.Debug("Downloading binlog file from MySQL server.", logx.String("path", binlogFilePath), logx.Bool("isLatest", isLatest))
if err := svc.downloadBinlogFile(ctx, fileOnServer, isLatest); err != nil {
logx.Error("下载 binlog 文件失败", logx.String("path", binlogFilePath), logx.String("error", err.Error()))
@@ -255,7 +172,7 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFilesOnServer(ctx context.Context, b
}
// Parse the first binlog eventTs of a local binlog file.
-func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) {
+func (svc *DbProgramMysql) parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) {
args := []string{
// Local binlog file path.
filePath,
@@ -264,7 +181,7 @@ func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (event
// Tell mysqlbinlog to suppress the BINLOG statements for row events, which reduces the unneeded output.
"--base64-output=DECODE-ROWS",
}
- cmd := exec.CommandContext(ctx, mysqlbinlogPath(), args...)
+ cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, args...)
var stderr strings.Builder
cmd.Stderr = &stderr
pr, err := cmd.StdoutPipe()
@@ -282,7 +199,7 @@ func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (event
}
}()
- for s := bufio.NewScanner(pr); ; s.Scan() {
+ for s := bufio.NewScanner(pr); s.Scan(); {
line := s.Text()
eventTimeParsed, found, err := parseBinlogEventTimeInLine(line)
if err != nil {
@@ -295,89 +212,58 @@ func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (event
return time.Time{}, errors.New("解析 binlog 文件失败")
}
-// getBinlogDir gets the binlogDir.
-func getBinlogDir(instanceId uint64) string {
- return filepath.Join(
- config.GetDbBackupRestore().BackupPath,
- fmt.Sprintf("instance-%d", instanceId),
- "binlog")
-}
-
-func getDbInstanceBackupRoot(instanceId uint64) string {
- return filepath.Join(
- config.GetDbBackupRestore().BackupPath,
- fmt.Sprintf("instance-%d", instanceId))
-}
-
-func getDbBackupDir(instanceId, backupId uint64) string {
- return filepath.Join(
- config.GetDbBackupRestore().BackupPath,
- fmt.Sprintf("instance-%d", instanceId),
- fmt.Sprintf("backup-%d", backupId))
-}
-
var singleFlightGroup singleflight.Group
// FetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`.
-func (svc *DbInstanceSvcImpl) FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool) error {
- latestDownloaded := false
- _, err, _ := singleFlightGroup.Do(strconv.FormatUint(svc.instanceId, 10), func() (interface{}, error) {
- latestDownloaded = downloadLatestBinlogFile
- err := svc.fetchBinlogs(ctx, downloadLatestBinlogFile)
- return nil, err
+func (svc *DbProgramMysql) FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) {
+ var downloaded bool
+ key := strconv.FormatUint(svc.dbInfo().InstanceId, 16)
+ binlogFiles, err, _ := singleFlightGroup.Do(key, func() (interface{}, error) {
+ downloaded = true
+ return svc.fetchBinlogs(ctx, downloadLatestBinlogFile, earliestBackupSequence, latestBinlogSequence)
})
-
- if downloadLatestBinlogFile && !latestDownloaded {
- _, err, _ = singleFlightGroup.Do(strconv.FormatUint(svc.instanceId, 10), func() (interface{}, error) {
- err := svc.fetchBinlogs(ctx, true)
- return nil, err
- })
+ if err != nil {
+ return nil, err
}
- return err
+ if downloaded {
+ return binlogFiles.([]*entity.BinlogFile), nil
+ }
+ if !downloadLatestBinlogFile {
+ return nil, nil
+ }
+ binlogFiles, err, _ = singleFlightGroup.Do(key, func() (interface{}, error) {
+ return svc.fetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return binlogFiles.([]*entity.BinlogFile), err
}
// fetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`.
-func (svc *DbInstanceSvcImpl) fetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool) error {
+func (svc *DbProgramMysql) fetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) {
// Read binlog files list on server.
binlogFilesOnServerSorted, err := svc.GetSortedBinlogFilesOnServer(ctx)
if err != nil {
- return err
+ return nil, err
}
if len(binlogFilesOnServerSorted) == 0 {
logx.Debug("No binlog file found on server to download")
- return nil
- }
- latest, ok, err := svc.binlogHistoryRepo.GetLatestHistory(svc.instanceId)
- if err != nil {
- return err
- }
- binlogFileName := ""
- latestSequence := int64(-1)
- earliestSequence := int64(-1)
- if ok {
- latestSequence = latest.Sequence
- binlogFileName = latest.FileName
- } else {
- earliest, err := svc.backupHistoryRepo.GetEarliestHistory(svc.instanceId)
- if err != nil {
- return err
- }
- earliestSequence = earliest.BinlogSequence
- binlogFileName = earliest.BinlogFileName
+ return nil, nil
}
indexHistory := -1
for i, file := range binlogFilesOnServerSorted {
- if latestSequence == file.Sequence {
+ if latestBinlogSequence == file.Sequence {
indexHistory = i + 1
break
}
- if earliestSequence == file.Sequence {
+ if earliestBackupSequence == file.Sequence {
indexHistory = i
break
}
}
if indexHistory < 0 {
- return errors.New(fmt.Sprintf("在数据库服务器上未找到 binlog 文件 %q", binlogFileName))
+ return nil, errors.New(fmt.Sprintf("在数据库服务器上未找到 binlog 文件: %d, %d", earliestBackupSequence, latestBinlogSequence))
}
if indexHistory > len(binlogFilesOnServerSorted)-1 {
indexHistory = len(binlogFilesOnServerSorted) - 1
@@ -385,58 +271,36 @@ func (svc *DbInstanceSvcImpl) fetchBinlogs(ctx context.Context, downloadLatestBi
binlogFilesOnServerSorted = binlogFilesOnServerSorted[indexHistory:]
if err := svc.downloadBinlogFilesOnServer(ctx, binlogFilesOnServerSorted, downloadLatestBinlogFile); err != nil {
- return err
- }
- for i, fileOnServer := range binlogFilesOnServerSorted {
- if !fileOnServer.Downloaded {
- break
- }
- history := &entity.DbBinlogHistory{
- CreateTime: time.Now(),
- FileName: fileOnServer.Name,
- FileSize: fileOnServer.Size,
- Sequence: fileOnServer.Sequence,
- FirstEventTime: fileOnServer.FirstEventTime,
- DbInstanceId: svc.instanceId,
- }
- if i == len(binlogFilesOnServerSorted)-1 {
- if err := svc.binlogHistoryRepo.Upsert(ctx, history); err != nil {
- return err
- }
- } else {
- if err := svc.binlogHistoryRepo.Insert(ctx, history); err != nil {
- return err
- }
- }
+ return nil, err
}
- return nil
+ return binlogFilesOnServerSorted, nil
}
// Syncs the binlog specified by `meta` between the instance and local.
// If isLast is true, it means that this is the last binlog file containing the targetTs event.
// It may keep growing as there are ongoing writes to the database. So we just need to check that
// the file size is larger or equal to the binlog file size we queried from the MySQL server earlier.
-func (svc *DbInstanceSvcImpl) downloadBinlogFile(ctx context.Context, binlogFileToDownload *BinlogFile, isLast bool) error {
- tempBinlogPrefix := filepath.Join(getBinlogDir(svc.instanceId), "tmp-")
+func (svc *DbProgramMysql) downloadBinlogFile(ctx context.Context, binlogFileToDownload *entity.BinlogFile, isLast bool) error {
+ tempBinlogPrefix := filepath.Join(svc.getBinlogDir(svc.dbInfo().InstanceId), "tmp-")
args := []string{
binlogFileToDownload.Name,
"--read-from-remote-server",
// Verify checksum binlog events.
"--verify-binlog-checksum",
- "--host", svc.dbInfo.Host,
- "--port", strconv.Itoa(svc.dbInfo.Port),
- "--user", svc.dbInfo.Username,
+ "--host", svc.dbInfo().Host,
+ "--port", strconv.Itoa(svc.dbInfo().Port),
+ "--user", svc.dbInfo().Username,
"--raw",
// With --raw this is a prefix for the file names.
"--result-file", tempBinlogPrefix,
}
- cmd := exec.CommandContext(ctx, mysqlbinlogPath(), args...)
+ cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, args...)
// We cannot set password as a flag. Otherwise, there is warning message
// "mysqlbinlog: [Warning] Using a password on the command line interface can be insecure."
- if svc.dbInfo.Password != "" {
- cmd.Env = append(cmd.Env, fmt.Sprintf("MYSQL_PWD=%s", svc.dbInfo.Password))
+ if svc.dbInfo().Password != "" {
+ cmd.Env = append(cmd.Env, fmt.Sprintf("MYSQL_PWD=%s", svc.dbInfo().Password))
}
logx.Debug("Downloading binlog files using mysqlbinlog:", cmd.String())
@@ -464,11 +328,11 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFile(ctx context.Context, binlogFile
return errors.Errorf("下载的 binlog 文件 %q 与服务上的文件大小不一致 %d != %d", binlogFilePathTemp, binlogFileTempInfo.Size(), binlogFileToDownload.Size)
}
- binlogFilePath := svc.getBinlogFilePath(binlogFileToDownload.Name)
+ binlogFilePath := svc.GetBinlogFilePath(binlogFileToDownload.Name)
if err := os.Rename(binlogFilePathTemp, binlogFilePath); err != nil {
return errors.Wrapf(err, "binlog 文件更名失败: %q -> %q", binlogFilePathTemp, binlogFilePath)
}
- firstEventTime, err := parseLocalBinlogFirstEventTime(ctx, binlogFilePath)
+ firstEventTime, err := svc.parseLocalBinlogFirstEventTime(ctx, binlogFilePath)
if err != nil {
return err
}
@@ -479,14 +343,9 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFile(ctx context.Context, binlogFile
}
// GetSortedBinlogFilesOnServer returns the information of binlog files in ascending order by their numeric extension.
-func (svc *DbInstanceSvcImpl) GetSortedBinlogFilesOnServer(_ context.Context) ([]*BinlogFile, error) {
- conn, err := svc.dbInfo.Conn()
- if err != nil {
- return nil, err
- }
- defer conn.Close()
+func (svc *DbProgramMysql) GetSortedBinlogFilesOnServer(_ context.Context) ([]*entity.BinlogFile, error) {
query := "SHOW BINARY LOGS"
- columns, rows, err := conn.Query(query)
+ columns, rows, err := svc.dbConn.Query(query)
if err != nil {
return nil, errors.Wrapf(err, "SQL 语句 %q 执行失败", query)
}
@@ -504,7 +363,7 @@ func (svc *DbInstanceSvcImpl) GetSortedBinlogFilesOnServer(_ context.Context) ([
return nil, errors.Errorf("SQL 语句 %q 执行结果解析失败", query)
}
- var binlogFiles []*BinlogFile
+ var binlogFiles []*entity.BinlogFile
for _, row := range rows {
name, nameOk := row["Log_name"].(string)
@@ -512,11 +371,15 @@ func (svc *DbInstanceSvcImpl) GetSortedBinlogFilesOnServer(_ context.Context) ([
if !nameOk || !sizeOk {
return nil, errors.Errorf("SQL 语句 %q 执行结果解析失败", query)
}
-
- binlogFile, err := newBinlogFile(name, int64(size))
+ _, seq, err := ParseBinlogName(name)
if err != nil {
return nil, errors.Wrapf(err, "SQL 语句 %q 执行结果解析失败", query)
}
+ binlogFile := &entity.BinlogFile{
+ Name: name,
+ Size: int64(size),
+ Sequence: seq,
+ }
binlogFiles = append(binlogFiles, binlogFile)
}
@@ -566,18 +429,19 @@ func readBinlogInfoFromBackup(reader io.Reader) (*entity.BinlogInfo, error) {
}
// Use command like mysqlbinlog --start-datetime=targetTs binlog.000001 to parse the first binlog event position with timestamp equal or after targetTs.
-func getBinlogEventPositionAtOrAfterTime(ctx context.Context, filePath string, targetTime time.Time) (position int64, parseErr error) {
+func (svc *DbProgramMysql) GetBinlogEventPositionAtOrAfterTime(ctx context.Context, binlogName string, targetTime time.Time) (position int64, parseErr error) {
+ binlogPath := svc.GetBinlogFilePath(binlogName)
args := []string{
// Local binlog file path.
- filePath,
+ binlogPath,
// Verify checksum binlog events.
"--verify-binlog-checksum",
// Tell mysqlbinlog to suppress the BINLOG statements for row events, which reduces the unneeded output.
"--base64-output=DECODE-ROWS",
// Instruct mysqlbinlog to start output only after encountering the first binlog event with timestamp equal or after targetTime.
- "--start-datetime", formatDateTime(targetTime),
+ "--start-datetime", targetTime.Local().Format(time.DateTime),
}
- cmd := exec.CommandContext(ctx, mysqlbinlogPath(), args...)
+ cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, args...)
var stderr strings.Builder
cmd.Stderr = &stderr
pr, err := cmd.StdoutPipe()
@@ -594,7 +458,7 @@ func getBinlogEventPositionAtOrAfterTime(ctx context.Context, filePath string, t
}
}()
- for s := bufio.NewScanner(pr); ; s.Scan() {
+ for s := bufio.NewScanner(pr); s.Scan(); {
line := s.Text()
posParsed, found, err := parseBinlogEventPosInLine(line)
if err != nil {
@@ -605,11 +469,11 @@ func getBinlogEventPositionAtOrAfterTime(ctx context.Context, filePath string, t
return posParsed, nil
}
}
- return 0, errors.Errorf("在 %v 之后没有 binlog 事件", targetTime)
+ return 0, errors.Errorf("在 %s 之后没有 binlog 事件", targetTime.Format(time.DateTime))
}
-// replayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`.
-func (svc *DbInstanceSvcImpl) replayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) {
+// ReplayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`.
+func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) {
const (
// Variable lower_case_table_names related.
@@ -657,27 +521,27 @@ func (svc *DbInstanceSvcImpl) replayBinlog(ctx context.Context, originalDatabase
// List entries for just this database. It's applied after the --rewrite-db option, so we should provide the rewritten database, i.e., pitrDatabase.
"--database", targetDatabase,
// Decode binary log from first event with position equal to or greater than argument.
- "--start-position", fmt.Sprintf("%d", restoreInfo.startPosition),
+ "--start-position", fmt.Sprintf("%d", restoreInfo.StartPosition),
// Stop decoding binary log at first event with position equal to or greater than argument.
- "--stop-position", fmt.Sprintf("%d", restoreInfo.targetPosition),
+ "--stop-position", fmt.Sprintf("%d", restoreInfo.TargetPosition),
}
- mysqlbinlogArgs = append(mysqlbinlogArgs, restoreInfo.getBinlogFiles(getBinlogDir(svc.instanceId))...)
+ mysqlbinlogArgs = append(mysqlbinlogArgs, restoreInfo.GetBinlogPaths(svc.getBinlogDir(svc.dbInfo().InstanceId))...)
mysqlArgs := []string{
- "--host", svc.dbInfo.Host,
- "--port", strconv.Itoa(svc.dbInfo.Port),
- "--user", svc.dbInfo.Username,
+ "--host", svc.dbInfo().Host,
+ "--port", strconv.Itoa(svc.dbInfo().Port),
+ "--user", svc.dbInfo().Username,
}
- if svc.dbInfo.Password != "" {
+ if svc.dbInfo().Password != "" {
// The --password parameter of mysql/mysqlbinlog does not support the "--password PASSWORD" format (split by space).
// If provided like that, the program will hang.
- mysqlArgs = append(mysqlArgs, fmt.Sprintf("--password=%s", svc.dbInfo.Password))
+ mysqlArgs = append(mysqlArgs, fmt.Sprintf("--password=%s", svc.dbInfo().Password))
}
- mysqlbinlogCmd := exec.CommandContext(ctx, mysqlbinlogPath(), mysqlbinlogArgs...)
- mysqlCmd := exec.CommandContext(ctx, mysqlPath(), mysqlArgs...)
+ mysqlbinlogCmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, mysqlbinlogArgs...)
+ mysqlCmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlPath, mysqlArgs...)
logx.Debug("Start replay binlog commands.",
logx.String("mysqlbinlog", mysqlbinlogCmd.String()),
logx.String("mysql", mysqlCmd.String()))
@@ -717,26 +581,15 @@ func (svc *DbInstanceSvcImpl) replayBinlog(ctx context.Context, originalDatabase
return errors.Wrap(err, "启动 mysql 程序失败")
}
if err := mysqlCmd.Wait(); err != nil {
- return errors.Errorf("运行 mysql 程序失败: %s", mysqlbinlogErr.String())
+ return errors.Errorf("运行 mysql 程序失败: %s", mysqlErr.String())
}
return nil
}
-// ReplayBinlogToDatabase replays the binlog of originDatabaseName to the targetDatabaseName.
-func (svc *DbInstanceSvcImpl) ReplayBinlogToDatabase(ctx context.Context, originDatabaseName, targetDatabaseName string, restoreInfo *RestoreInfo) error {
- return svc.replayBinlog(ctx, originDatabaseName, targetDatabaseName, restoreInfo)
-}
-
-func (svc *DbInstanceSvcImpl) getServerVariable(_ context.Context, varName string) (string, error) {
- conn, err := svc.dbInfo.Conn()
- if err != nil {
- return "", err
- }
- defer conn.Close()
-
+func (svc *DbProgramMysql) getServerVariable(_ context.Context, varName string) (string, error) {
query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", varName)
- _, rows, err := conn.Query(query)
+ _, rows, err := svc.dbConn.Query(query)
if err != nil {
return "", err
}
@@ -754,7 +607,7 @@ func (svc *DbInstanceSvcImpl) getServerVariable(_ context.Context, varName strin
}
// CheckBinlogEnabled checks whether binlog is enabled for the current instance.
-func (svc *DbInstanceSvcImpl) CheckBinlogEnabled(ctx context.Context) error {
+func (svc *DbProgramMysql) CheckBinlogEnabled(ctx context.Context) error {
value, err := svc.getServerVariable(ctx, "log_bin")
if err != nil {
return err
@@ -766,7 +619,7 @@ func (svc *DbInstanceSvcImpl) CheckBinlogEnabled(ctx context.Context) error {
}
// CheckBinlogRowFormat checks whether the binlog format is ROW.
-func (svc *DbInstanceSvcImpl) CheckBinlogRowFormat(ctx context.Context) error {
+func (svc *DbProgramMysql) CheckBinlogRowFormat(ctx context.Context) error {
value, err := svc.getServerVariable(ctx, "binlog_format")
if err != nil {
return err
@@ -790,19 +643,19 @@ func runCmd(cmd *exec.Cmd) error {
return nil
}
-func (svc *DbInstanceSvcImpl) execute(database string, sql string) error {
+func (svc *DbProgramMysql) execute(database string, sql string) error {
args := []string{
- "--host", svc.dbInfo.Host,
- "--port", strconv.Itoa(svc.dbInfo.Port),
- "--user", svc.dbInfo.Username,
- "--password=" + svc.dbInfo.Password,
+ "--host", svc.dbInfo().Host,
+ "--port", strconv.Itoa(svc.dbInfo().Port),
+ "--user", svc.dbInfo().Username,
+ "--password=" + svc.dbInfo().Password,
"--execute", sql,
}
if len(database) > 0 {
args = append(args, database)
}
- cmd := exec.Command(mysqlPath(), args...)
+ cmd := exec.Command(svc.getMysqlBin().MysqlPath, args...)
logx.Debug("execute sql using mysql binary: ", cmd.String())
if err := runCmd(cmd); err != nil {
logx.Errorf("运行 mysql 程序失败: %v", err)
@@ -814,8 +667,8 @@ func (svc *DbInstanceSvcImpl) execute(database string, sql string) error {
// sortBinlogFiles will sort binlog files in ascending order by their numeric extension.
// For mysql binlog, after the serial number reaches 999999, the next serial number will not return to 000000, but 1000000,
// so we cannot directly use string to compare lexicographical order.
-func sortBinlogFiles(binlogFiles []*BinlogFile) []*BinlogFile {
- var sorted []*BinlogFile
+func sortBinlogFiles(binlogFiles []*entity.BinlogFile) []*entity.BinlogFile {
+ var sorted []*entity.BinlogFile
sorted = append(sorted, binlogFiles...)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Sequence < sorted[j].Sequence
@@ -879,20 +732,23 @@ func ParseBinlogName(name string) (string, int64, error) {
return s[0], seq, nil
}
-// formatDateTime formats the timestamp to the local time string.
-func formatDateTime(t time.Time) string {
- t = t.Local()
- return fmt.Sprintf("%d-%d-%d %d:%d:%d", t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second())
+// getBinlogDir gets the binlogDir.
+func (svc *DbProgramMysql) getBinlogDir(instanceId uint64) string {
+ return filepath.Join(
+ svc.getBackupPath(),
+ fmt.Sprintf("instance-%d", instanceId),
+ "binlog")
}
-func mysqlPath() string {
- return config.GetMysqlBin().MysqlPath
+func (svc *DbProgramMysql) getDbInstanceBackupRoot(instanceId uint64) string {
+ return filepath.Join(
+ svc.getBackupPath(),
+ fmt.Sprintf("instance-%d", instanceId))
}
-func mysqldumpPath() string {
- return config.GetMysqlBin().MysqldumpPath
-}
-
-func mysqlbinlogPath() string {
- return config.GetMysqlBin().MysqlbinlogPath
+func (svc *DbProgramMysql) getDbBackupDir(instanceId, backupId uint64) string {
+ return filepath.Join(
+ svc.getBackupPath(),
+ fmt.Sprintf("instance-%d", instanceId),
+ fmt.Sprintf("backup-%d", backupId))
}
diff --git a/server/internal/db/infrastructure/service/db_instance_e2e_test.go b/server/internal/db/dbm/db_program_mysql_e2e_test.go
similarity index 82%
rename from server/internal/db/infrastructure/service/db_instance_e2e_test.go
rename to server/internal/db/dbm/db_program_mysql_e2e_test.go
index cdf083a9..eee5015d 100644
--- a/server/internal/db/infrastructure/service/db_instance_e2e_test.go
+++ b/server/internal/db/dbm/db_program_mysql_e2e_test.go
@@ -1,19 +1,19 @@
//go:build e2e
-package service
+package dbm
import (
"context"
"errors"
"fmt"
"github.com/stretchr/testify/suite"
- "mayfly-go/internal/db/dbm"
+ "mayfly-go/internal/db/config"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/infrastructure/persistence"
- "mayfly-go/pkg/config"
"os"
"path/filepath"
+ "runtime"
"strings"
"testing"
"time"
@@ -30,23 +30,25 @@ const (
type DbInstanceSuite struct {
suite.Suite
- instance *entity.DbInstance
repositories *repository.Repositories
- instanceSvc *DbInstanceSvcImpl
+ instanceSvc *DbProgramMysql
+ dbConn *DbConn
}
func (s *DbInstanceSuite) SetupSuite() {
if err := chdir("mayfly-go", "server"); err != nil {
panic(err)
}
- config.Init()
- s.instance = &entity.DbInstance{
- Type: dbm.DbTypeMysql,
+ dbInfo := DbInfo{
+ Type: DbTypeMysql,
Host: "localhost",
Port: 3306,
Username: "test",
- Password: "123456",
+ Password: "test",
}
+ dbConn, err := dbInfo.Conn()
+ s.Require().NoError(err)
+ s.dbConn = dbConn
s.repositories = &repository.Repositories{
Instance: persistence.GetInstanceRepo(),
Backup: persistence.NewDbBackupRepo(),
@@ -56,7 +58,26 @@ func (s *DbInstanceSuite) SetupSuite() {
Binlog: persistence.NewDbBinlogRepo(),
BinlogHistory: persistence.NewDbBinlogHistoryRepo(),
}
- s.instanceSvc = NewDbInstanceSvc(s.instance, s.repositories)
+ s.instanceSvc = NewDbProgramMysql(s.dbConn)
+ var extName string
+ if runtime.GOOS == "windows" {
+ extName = ".exe"
+ }
+ path := "db/mysql/bin"
+ s.instanceSvc.mysqlBin = &config.MysqlBin{
+ Path: filepath.Join(path),
+ MysqlPath: filepath.Join(path, "mysql"+extName),
+ MysqldumpPath: filepath.Join(path, "mysqldump"+extName),
+ MysqlbinlogPath: filepath.Join(path, "mysqlbinlog"+extName),
+ }
+ s.instanceSvc.backupPath = "db/backup"
+}
+
+func (s *DbInstanceSuite) TearDownSuite() {
+ if s.dbConn != nil {
+ s.dbConn.Close()
+ s.dbConn = nil
+ }
}
func (s *DbInstanceSuite) SetupTest() {
@@ -72,7 +93,7 @@ func (s *DbInstanceSuite) TearDownTest() {
sql := fmt.Sprintf("drop database if exists `%s`", dbNameBackupTest)
require.NoError(s.instanceSvc.execute("", sql))
- _ = os.RemoveAll(getDbInstanceBackupRoot(instanceIdTest))
+ _ = os.RemoveAll(s.instanceSvc.getDbInstanceBackupRoot(instanceIdTest))
}
func (s *DbInstanceSuite) TestBackup() {
@@ -89,7 +110,7 @@ func (s *DbInstanceSuite) testBackup(backupHistory *entity.DbBackupHistory) {
binlogInfo, err := s.instanceSvc.Backup(context.Background(), backupHistory)
require.NoError(err)
- fileName := filepath.Join(getDbBackupDir(s.instance.Id, backupHistory.Id), dbNameBackupTest+".sql")
+ fileName := filepath.Join(s.instanceSvc.getDbBackupDir(s.dbConn.Info.InstanceId, backupHistory.Id), dbNameBackupTest+".sql")
_, err = os.Stat(fileName)
require.NoError(err)
@@ -165,7 +186,7 @@ func (s *DbInstanceSuite) testReplayBinlog(backupHistory *entity.DbBackupHistory
require.NoError(err)
binlogFileLast := binlogFilesOnServerSorted[len(binlogFilesOnServerSorted)-1]
- position, err := getBinlogEventPositionAtOrAfterTime(context.Background(), s.instanceSvc.getBinlogFilePath(binlogFileLast.Name), targetTime)
+ position, err := s.instanceSvc.GetBinlogEventPositionAtOrAfterTime(context.Background(), binlogFileLast.Name, targetTime)
require.NoError(err)
binlogHistories := make([]*entity.DbBinlogHistory, 0, 2)
binlogHistoryBackup := &entity.DbBinlogHistory{
@@ -183,21 +204,19 @@ func (s *DbInstanceSuite) testReplayBinlog(backupHistory *entity.DbBackupHistory
}
restoreInfo := &RestoreInfo{
- backupHistory: backupHistory,
- binlogHistories: binlogHistories,
- startPosition: backupHistory.BinlogPosition,
- targetPosition: position,
- targetTime: targetTime,
+ BackupHistory: backupHistory,
+ BinlogHistories: binlogHistories,
+ StartPosition: backupHistory.BinlogPosition,
+ TargetPosition: position,
+ TargetTime: targetTime,
}
- err = s.instanceSvc.ReplayBinlogToDatabase(context.Background(), dbNameBackupTest, dbNameBackupTest, restoreInfo)
+ err = s.instanceSvc.ReplayBinlog(context.Background(), dbNameBackupTest, dbNameBackupTest, restoreInfo)
require.NoError(err)
}
func (s *DbInstanceSuite) testRestore(backupHistory *entity.DbBackupHistory) {
require := s.Require()
- fileName := filepath.Join(getDbBackupDir(instanceIdTest, backupIdTest),
- fmt.Sprintf("%v.sql", dbNameBackupTest))
- err := s.instanceSvc.RestoreBackup(context.Background(), dbNameBackupTest, fileName)
+ err := s.instanceSvc.RestoreBackupHistory(context.Background(), backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid)
require.NoError(err)
}
diff --git a/server/internal/db/infrastructure/service/db_instance_test.go b/server/internal/db/dbm/db_program_mysql_test.go
similarity index 97%
rename from server/internal/db/infrastructure/service/db_instance_test.go
rename to server/internal/db/dbm/db_program_mysql_test.go
index bbeb45f2..c03a8757 100644
--- a/server/internal/db/infrastructure/service/db_instance_test.go
+++ b/server/internal/db/dbm/db_program_mysql_test.go
@@ -1,4 +1,4 @@
-package service
+package dbm
import (
"github.com/stretchr/testify/require"
diff --git a/server/internal/db/dbm/db_type.go b/server/internal/db/dbm/db_type.go
index e129d4a4..2fde7a14 100644
--- a/server/internal/db/dbm/db_type.go
+++ b/server/internal/db/dbm/db_type.go
@@ -16,6 +16,14 @@ const (
DM DbType = "dm"
)
+func ToDbType(dbType string) DbType {
+ return DbType(dbType)
+}
+
+func (dbType DbType) Equal(typ string) bool {
+ return ToDbType(typ) == dbType
+}
+
func (dbType DbType) MetaDbName() string {
switch dbType {
case DbTypeMysql:
diff --git a/server/internal/db/dbm/dialect.go b/server/internal/db/dbm/dialect.go
index 087af53f..d32bdc7f 100644
--- a/server/internal/db/dbm/dialect.go
+++ b/server/internal/db/dbm/dialect.go
@@ -79,6 +79,9 @@ type DbDialect interface {
WalkTableRecord(tableName string, walk func(record map[string]any, columns []*QueryColumn)) error
GetSchemas() ([]string, error)
+
+ // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
+ GetDbProgram() DbProgram
}
// ------------------------- 元数据sql操作 -------------------------
diff --git a/server/internal/db/dbm/dialect_dm.go b/server/internal/db/dbm/dialect_dm.go
index 52e0b922..e9fb5d28 100644
--- a/server/internal/db/dbm/dialect_dm.go
+++ b/server/internal/db/dbm/dialect_dm.go
@@ -68,8 +68,8 @@ func (dd *DMDialect) GetDbServer() (*DbServer, error) {
return ds, nil
}
-func (pd *DMDialect) GetDbNames() ([]string, error) {
- _, res, err := pd.dc.Query("SELECT name AS DBNAME FROM v$database")
+func (dd *DMDialect) GetDbNames() ([]string, error) {
+ _, res, err := dd.dc.Query("SELECT name AS DBNAME FROM v$database")
if err != nil {
return nil, err
}
@@ -83,13 +83,13 @@ func (pd *DMDialect) GetDbNames() ([]string, error) {
}
// 获取表基础元信息, 如表名等
-func (pd *DMDialect) GetTables() ([]Table, error) {
+func (dd *DMDialect) GetTables() ([]Table, error) {
// 首先执行更新统计信息sql 这个统计信息在数据量比较大的时候就比较耗时,所以最好定时执行
// _, _, err := pd.dc.Query("dbms_stats.GATHER_SCHEMA_stats(SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))")
// 查询表信息
- _, res, err := pd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
+ _, res, err := dd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
@@ -109,7 +109,7 @@ func (pd *DMDialect) GetTables() ([]Table, error) {
}
// 获取列元信息, 如列名等
-func (pd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
+func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
tableName := ""
for i := 0; i < len(tableNames); i++ {
if i != 0 {
@@ -118,7 +118,7 @@ func (pd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
tableName = tableName + "'" + tableNames[i] + "'"
}
- _, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
+ _, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
@@ -139,8 +139,8 @@ func (pd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
return columns, nil
}
-func (pd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
- columns, err := pd.GetColumns(tablename)
+func (dd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
+ columns, err := dd.GetColumns(tablename)
if err != nil {
return "", err
}
@@ -157,8 +157,8 @@ func (pd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
}
// 获取表索引信息
-func (pd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
- _, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
+func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
+ _, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
}
@@ -194,9 +194,9 @@ func (pd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
}
// 获取建表ddl
-func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
+func (dd *DMDialect) GetTableDDL(tableName string) (string, error) {
ddlSql := fmt.Sprintf("CALL SP_TABLEDEF((SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID)), '%s')", tableName)
- _, res, err := pd.dc.Query(ddlSql)
+ _, res, err := dd.dc.Query(ddlSql)
if err != nil {
return "", err
}
@@ -207,7 +207,7 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
}
// 表注释
- _, res, err = pd.dc.Query(fmt.Sprintf(`
+ _, res, err = dd.dc.Query(fmt.Sprintf(`
select OWNER, COMMENTS from DBA_TAB_COMMENTS where TABLE_TYPE='TABLE' and TABLE_NAME = '%s'
and owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))
`, tableName))
@@ -229,7 +229,7 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
WHERE OWNER = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))
AND TABLE_NAME = '%s'
`, tableName)
- _, res, err = pd.dc.Query(fieldSql)
+ _, res, err = dd.dc.Query(fieldSql)
if err != nil {
return "", err
}
@@ -251,7 +251,7 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
and a.table_name = '%s'
and indexdef(b.object_id,1) != '禁止查看系统定义的索引信息'
`, tableName)
- _, res, err = pd.dc.Query(indexSql)
+ _, res, err = dd.dc.Query(indexSql)
if err != nil {
return "", err
}
@@ -262,18 +262,18 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
return builder.String(), nil
}
-func (pd *DMDialect) GetTableRecord(tableName string, pageNum, pageSize int) ([]*QueryColumn, []map[string]any, error) {
- return pd.dc.Query(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
+func (dd *DMDialect) GetTableRecord(tableName string, pageNum, pageSize int) ([]*QueryColumn, []map[string]any, error) {
+ return dd.dc.Query(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
}
-func (pd *DMDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []*QueryColumn)) error {
- return pd.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
+func (dd *DMDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []*QueryColumn)) error {
+ return dd.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
}
// 获取DM当前连接的库可访问的schemaNames
-func (pd *DMDialect) GetSchemas() ([]string, error) {
+func (dd *DMDialect) GetSchemas() ([]string, error) {
sql := GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
- _, res, err := pd.dc.Query(sql)
+ _, res, err := dd.dc.Query(sql)
if err != nil {
return nil, err
}
@@ -283,3 +283,8 @@ func (pd *DMDialect) GetSchemas() ([]string, error) {
}
return schemaNames, nil
}
+
+// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
+func (dd *DMDialect) GetDbProgram() DbProgram {
+ panic("implement me")
+}
diff --git a/server/internal/db/dbm/dialect_mysql.go b/server/internal/db/dbm/dialect_mysql.go
index 37122fdd..898bf1fb 100644
--- a/server/internal/db/dbm/dialect_mysql.go
+++ b/server/internal/db/dbm/dialect_mysql.go
@@ -194,6 +194,11 @@ func (md *MysqlDialect) WalkTableRecord(tableName string, walk func(record map[s
return md.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
}
-func (pd *MysqlDialect) GetSchemas() ([]string, error) {
+func (md *MysqlDialect) GetSchemas() ([]string, error) {
return nil, nil
}
+
+// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
+func (md *MysqlDialect) GetDbProgram() DbProgram {
+ return NewDbProgramMysql(md.dc)
+}
diff --git a/server/internal/db/dbm/dialect_pgsql.go b/server/internal/db/dbm/dialect_pgsql.go
index d786f54e..d1850d6b 100644
--- a/server/internal/db/dbm/dialect_pgsql.go
+++ b/server/internal/db/dbm/dialect_pgsql.go
@@ -277,3 +277,8 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
}
return schemaNames, nil
}
+
+// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
+func (pd *PgsqlDialect) GetDbProgram() DbProgram {
+ panic("implement me")
+}
diff --git a/server/internal/db/domain/entity/db_backup.go b/server/internal/db/domain/entity/db_backup.go
index e571ea42..c80a1b2d 100644
--- a/server/internal/db/domain/entity/db_backup.go
+++ b/server/internal/db/domain/entity/db_backup.go
@@ -1,77 +1,39 @@
package entity
-import (
- "mayfly-go/pkg/model"
- "time"
-)
-
var _ DbTask = (*DbBackup)(nil)
// DbBackup 数据库备份任务
type DbBackup struct {
- model.Model
+ *DbTaskBase
- Name string `gorm:"column(db_name)" json:"name"` // 备份任务名称
- DbName string `gorm:"column(db_name)" json:"dbName"` // 数据库名
- StartTime time.Time `gorm:"column(start_time)" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
- Interval time.Duration `gorm:"column(interval)" json:"interval"` // 间隔时间: 为零表示单次执行,为正表示反复执行
- Enabled bool `gorm:"column(enabled)" json:"enabled"` // 是否启用
- Finished bool `gorm:"column(finished)" json:"finished"` // 是否完成
- Repeated bool `gorm:"column(repeated)" json:"repeated"` // 是否重复执行
- LastStatus TaskStatus `gorm:"column(last_status)" json:"lastStatus"` // 最近一次执行状态
- LastResult string `gorm:"column(last_result)" json:"lastResult"` // 最近一次执行结果
- LastTime time.Time `gorm:"column(last_time)" json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
- DbInstanceId uint64 `gorm:"column(db_instance_id)" json:"dbInstanceId"`
- Deadline time.Time `gorm:"-" json:"-"`
+ Name string `json:"name"` // 备份任务名称
+ DbName string `json:"dbName"` // 数据库名
+ DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
-func (d *DbBackup) TableName() string {
- return "t_db_backup"
-}
-
-func (d *DbBackup) GetId() uint64 {
- if d == nil {
- return 0
+var (
+ backupResult = map[TaskStatus]string{
+ TaskDelay: "等待备份数据库",
+ TaskReady: "准备备份数据库",
+ TaskReserved: "数据库备份中",
+ TaskSuccess: "数据库备份成功",
+ TaskFailed: "数据库备份失败",
}
- return d.Id
-}
+)
-func (d *DbBackup) GetDeadline() time.Time {
- return d.Deadline
-}
-
-func (d *DbBackup) Schedule() bool {
- if d.Finished || !d.Enabled {
- return false
- }
- switch d.LastStatus {
+func (*DbBackup) MessageWithStatus(status TaskStatus) string {
+ var result string
+ switch status {
+ case TaskDelay:
+ result = "等待备份数据库"
+ case TaskReady:
+ result = "准备备份数据库"
+ case TaskReserved:
+ result = "数据库备份中"
case TaskSuccess:
- if d.Interval == 0 {
- return false
- }
- lastTime := d.LastTime
- if d.LastTime.Sub(d.StartTime) < 0 {
- lastTime = d.StartTime.Add(-d.Interval)
- }
- d.Deadline = lastTime.Add(d.Interval - d.LastTime.Sub(d.StartTime)%d.Interval)
+ result = "数据库备份成功"
case TaskFailed:
- d.Deadline = time.Now().Add(time.Minute)
- default:
- d.Deadline = d.StartTime
+ result = "数据库备份失败"
}
- return true
-}
-
-func (d *DbBackup) IsFinished() bool {
- return !d.Repeated && d.LastStatus == TaskSuccess
-}
-
-func (d *DbBackup) Update(task DbTask) bool {
- switch t := task.(type) {
- case *DbBackup:
- d.StartTime = t.StartTime
- d.Interval = t.Interval
- return true
- }
- return false
+ return result
}
diff --git a/server/internal/db/domain/entity/db_binlog.go b/server/internal/db/domain/entity/db_binlog.go
index 9e28748b..3a4e2e16 100644
--- a/server/internal/db/domain/entity/db_binlog.go
+++ b/server/internal/db/domain/entity/db_binlog.go
@@ -2,86 +2,34 @@ package entity
import (
"mayfly-go/pkg/model"
+ "mayfly-go/pkg/utils/timex"
"time"
)
-const BinlogDownloadInterval = time.Minute * 15
-
-var _ DbTask = (*DbBinlog)(nil)
-
// DbBinlog 数据库备份任务
type DbBinlog struct {
model.Model
- StartTime time.Time `gorm:"column(start_time)" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
- Interval time.Duration `gorm:"column(interval)" json:"interval"` // 间隔时间: 为零表示单次执行,为正表示反复执行
- Enabled bool `gorm:"column(enabled)" json:"enabled"` // 是否启用
- LastStatus TaskStatus `gorm:"column(last_status)" json:"lastStatus"` // 最近一次执行状态
- LastResult string `gorm:"column(last_result)" json:"lastResult"` // 最近一次执行结果
- LastTime time.Time `gorm:"column(last_time)" json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
- DbInstanceId uint64 `gorm:"column(db_instance_id)" json:"dbInstanceId"`
- Deadline time.Time `gorm:"-" json:"-"`
+ LastStatus TaskStatus // 最近一次执行状态
+ LastResult string // 最近一次执行结果
+ LastTime timex.NullTime // 最近一次执行时间
+ DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
-func NewDbBinlog(history *DbBackupHistory) *DbBinlog {
- binlogTask := &DbBinlog{
- StartTime: time.Now(),
- Enabled: true,
- Interval: BinlogDownloadInterval,
- DbInstanceId: history.DbInstanceId,
- LastTime: time.Now(),
- }
- binlogTask.Id = binlogTask.DbInstanceId
+func NewDbBinlog(instanceId uint64) *DbBinlog {
+ binlogTask := &DbBinlog{}
+ binlogTask.Id = instanceId
+ binlogTask.DbInstanceId = instanceId
return binlogTask
}
-func (d *DbBinlog) TableName() string {
- return "t_db_binlog"
-}
+// BinlogFile is the metadata of the MySQL binlog file.
+type BinlogFile struct {
+ Name string
+ Size int64
-func (d *DbBinlog) GetId() uint64 {
- if d == nil {
- return 0
- }
- return d.Id
-}
-
-func (d *DbBinlog) GetDeadline() time.Time {
- return d.Deadline
-}
-
-func (d *DbBinlog) Schedule() bool {
- if !d.Enabled {
- return false
- }
- switch d.LastStatus {
- case TaskSuccess:
- if d.Interval == 0 {
- return false
- }
- lastTime := d.LastTime
- if d.LastTime.Sub(d.StartTime) < 0 {
- lastTime = d.StartTime.Add(-d.Interval)
- }
- d.Deadline = lastTime.Add(d.Interval - d.LastTime.Sub(d.StartTime)%d.Interval)
- case TaskFailed:
- d.Deadline = time.Now().Add(time.Minute)
- default:
- d.Deadline = d.StartTime
- }
- return true
-}
-
-func (d *DbBinlog) IsFinished() bool {
- return false
-}
-
-func (d *DbBinlog) Update(task DbTask) bool {
- switch t := task.(type) {
- case *DbBinlog:
- d.StartTime = t.StartTime
- d.Interval = t.Interval
- return true
- }
- return false
+ // Sequence is parsed from Name and is for the sorting purpose.
+ Sequence int64
+ FirstEventTime time.Time
+ Downloaded bool
}
diff --git a/server/internal/db/domain/entity/db_instance.go b/server/internal/db/domain/entity/db_instance.go
index 2f1e03ec..212da345 100644
--- a/server/internal/db/domain/entity/db_instance.go
+++ b/server/internal/db/domain/entity/db_instance.go
@@ -1,25 +1,25 @@
package entity
import (
+ "errors"
"fmt"
"mayfly-go/internal/common/utils"
- "mayfly-go/internal/db/dbm"
"mayfly-go/pkg/model"
)
type DbInstance struct {
model.Model
- Name string `orm:"column(name)" json:"name"`
- Type dbm.DbType `orm:"column(type)" json:"type"` // 类型,mysql oracle等
- Host string `orm:"column(host)" json:"host"`
- Port int `orm:"column(port)" json:"port"`
- Network string `orm:"column(network)" json:"network"`
- Username string `orm:"column(username)" json:"username"`
- Password string `orm:"column(password)" json:"-"`
- Params string `orm:"column(params)" json:"params"`
- Remark string `orm:"column(remark)" json:"remark"`
- SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id
+ Name string `json:"name"`
+ Type string `json:"type"` // 类型,mysql oracle等
+ Host string `json:"host"`
+ Port int `json:"port"`
+ Network string `json:"network"`
+ Username string `json:"username"`
+ Password string `json:"-"`
+ Params string `json:"params"`
+ Remark string `json:"remark"`
+ SshTunnelMachineId int `json:"sshTunnelMachineId"` // ssh隧道机器id
}
func (d *DbInstance) TableName() string {
@@ -39,12 +39,22 @@ func (d *DbInstance) GetNetwork() string {
return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId)
}
-func (d *DbInstance) PwdEncrypt() {
+func (d *DbInstance) PwdEncrypt() error {
// 密码替换为加密后的密码
- d.Password = utils.PwdAesEncrypt(d.Password)
+ password, err := utils.PwdAesEncrypt(d.Password)
+ if err != nil {
+ return errors.New("加密数据库密码失败")
+ }
+ d.Password = password
+ return nil
}
-func (d *DbInstance) PwdDecrypt() {
+func (d *DbInstance) PwdDecrypt() error {
// 密码替换为解密后的密码
- d.Password = utils.PwdAesDecrypt(d.Password)
+ password, err := utils.PwdAesDecrypt(d.Password)
+ if err != nil {
+ return errors.New("解密数据库密码失败")
+ }
+ d.Password = password
+ return nil
}
diff --git a/server/internal/db/domain/entity/db_restore.go b/server/internal/db/domain/entity/db_restore.go
index adde6772..80fce949 100644
--- a/server/internal/db/domain/entity/db_restore.go
+++ b/server/internal/db/domain/entity/db_restore.go
@@ -1,80 +1,36 @@
package entity
import (
- "mayfly-go/pkg/model"
- "time"
+ "mayfly-go/pkg/utils/timex"
)
var _ DbTask = (*DbRestore)(nil)
// DbRestore 数据库恢复任务
type DbRestore struct {
- model.Model
+ *DbTaskBase
- DbName string `gorm:"column(db_name)" json:"dbName"` // 数据库名
- StartTime time.Time `gorm:"column(start_time)" json:"startTime"` // 开始时间
- Interval time.Duration `gorm:"column(interval)" json:"interval"` // 间隔时间: 为零表示单次执行,为正表示反复执行
- Enabled bool `gorm:"column(enabled)" json:"enabled"` // 是否启用
- Finished bool `gorm:"column(finished)" json:"finished"` // 是否完成
- Repeated bool `gorm:"column(repeated)" json:"repeated"` // 是否重复执行
- LastStatus TaskStatus `gorm:"column(last_status)" json:"lastStatus"` // 最近一次执行状态
- LastResult string `gorm:"column(last_result)" json:"lastResult"` // 最近一次执行结果
- LastTime time.Time `gorm:"column(last_time)" json:"lastTime"` // 最近一次执行时间
- PointInTime time.Time `gorm:"column(point_in_time)" json:"pointInTime"` // 指定数据库恢复的时间点
- DbBackupId uint64 `gorm:"column(db_backup_id)" json:"dbBackupId"` // 用于恢复的数据库备份任务ID
- DbBackupHistoryId uint64 `gorm:"column(db_backup_history_id)" json:"dbBackupHistoryId"` // 用于恢复的数据库备份历史ID
- DbBackupHistoryName string `gorm:"column(db_backup_history_name) json:"dbBackupHistoryName"` // 数据库备份历史名称
- DbInstanceId uint64 `gorm:"column(db_instance_id)" json:"dbInstanceId"`
- Deadline time.Time `gorm:"-" json:"-"`
+ DbName string `json:"dbName"` // 数据库名
+ PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间点
+ DbBackupId uint64 `json:"dbBackupId"` // 用于恢复的数据库恢复任务ID
+ DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 用于恢复的数据库恢复历史ID
+ DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库恢复历史名称
+ DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
-func (d *DbRestore) TableName() string {
- return "t_db_restore"
-}
-
-func (d *DbRestore) GetId() uint64 {
- if d == nil {
- return 0
- }
- return d.Id
-}
-
-func (d *DbRestore) GetDeadline() time.Time {
- return d.Deadline
-}
-
-func (d *DbRestore) Schedule() bool {
- if d.Finished || !d.Enabled {
- return false
- }
- switch d.LastStatus {
+func (*DbRestore) MessageWithStatus(status TaskStatus) string {
+ var result string
+ switch status {
+ case TaskDelay:
+ result = "等待恢复数据库"
+ case TaskReady:
+ result = "准备恢复数据库"
+ case TaskReserved:
+ result = "数据库恢复中"
case TaskSuccess:
- if d.Interval == 0 {
- return false
- }
- lastTime := d.LastTime
- if d.LastTime.Sub(d.StartTime) < 0 {
- lastTime = d.StartTime.Add(-d.Interval)
- }
- d.Deadline = lastTime.Add(d.Interval - d.LastTime.Sub(d.StartTime)%d.Interval)
+ result = "数据库恢复成功"
case TaskFailed:
- d.Deadline = time.Now().Add(time.Minute)
- default:
- d.Deadline = d.StartTime
+ result = "数据库恢复失败"
}
- return true
-}
-
-func (d *DbRestore) IsFinished() bool {
- return !d.Repeated && d.LastStatus == TaskSuccess
-}
-
-func (d *DbRestore) Update(task DbTask) bool {
- switch backup := task.(type) {
- case *DbRestore:
- d.StartTime = backup.StartTime
- d.Interval = backup.Interval
- return true
- }
- return false
+ return result
}
diff --git a/server/internal/db/domain/entity/db_task_base.go b/server/internal/db/domain/entity/db_task_base.go
new file mode 100644
index 00000000..aa6f8928
--- /dev/null
+++ b/server/internal/db/domain/entity/db_task_base.go
@@ -0,0 +1,109 @@
+package entity
+
+import (
+ "mayfly-go/pkg/model"
+ "mayfly-go/pkg/utils/timex"
+ "time"
+)
+
+type TaskStatus int
+
+const (
+ TaskDelay TaskStatus = iota
+ TaskReady
+ TaskReserved
+ TaskSuccess
+ TaskFailed
+)
+
+const LastResultSize = 256
+
+type DbTask interface {
+ model.ModelI
+
+ GetId() uint64
+ GetDeadline() time.Time
+ IsFinished() bool
+ Schedule() bool
+ Update(task DbTask)
+ GetTaskBase() *DbTaskBase
+ MessageWithStatus(status TaskStatus) string
+ IsEnabled() bool
+}
+
+func NewDbBTaskBase(enabled bool, repeated bool, startTime time.Time, interval time.Duration) *DbTaskBase {
+ return &DbTaskBase{
+ Enabled: enabled,
+ Repeated: repeated,
+ StartTime: startTime,
+ Interval: interval,
+ }
+}
+
+type DbTaskBase struct {
+ model.Model
+
+ Enabled bool // 是否启用
+ StartTime time.Time // 开始时间
+ Interval time.Duration // 间隔时间
+ Repeated bool // 是否重复执行
+ LastStatus TaskStatus // 最近一次执行状态
+ LastResult string // 最近一次执行结果
+ LastTime timex.NullTime // 最近一次执行时间
+ Deadline time.Time `gorm:"-" json:"-"` // 计划执行时间
+}
+
+func (d *DbTaskBase) GetId() uint64 {
+ if d == nil {
+ return 0
+ }
+ return d.Id
+}
+
+func (d *DbTaskBase) GetDeadline() time.Time {
+ return d.Deadline
+}
+
+func (d *DbTaskBase) Schedule() bool {
+ if d.IsFinished() || !d.Enabled {
+ return false
+ }
+ switch d.LastStatus {
+ case TaskSuccess:
+ if d.Interval == 0 {
+ return false
+ }
+ lastTime := d.LastTime.Time
+ if lastTime.Sub(d.StartTime) < 0 {
+ lastTime = d.StartTime.Add(-d.Interval)
+ }
+ d.Deadline = lastTime.Add(d.Interval - lastTime.Sub(d.StartTime)%d.Interval)
+ case TaskFailed:
+ d.Deadline = time.Now().Add(time.Minute)
+ default:
+ d.Deadline = d.StartTime
+ }
+ return true
+}
+
+func (d *DbTaskBase) IsFinished() bool {
+ return !d.Repeated && d.LastStatus == TaskSuccess
+}
+
+func (d *DbTaskBase) Update(task DbTask) {
+ t := task.GetTaskBase()
+ d.StartTime = t.StartTime
+ d.Interval = t.Interval
+}
+
+func (d *DbTaskBase) GetTaskBase() *DbTaskBase {
+ return d
+}
+
+func (*DbTaskBase) MessageWithStatus(_ TaskStatus) string {
+ return ""
+}
+
+func (d *DbTaskBase) IsEnabled() bool {
+ return d.Enabled
+}
diff --git a/server/internal/db/domain/entity/types.go b/server/internal/db/domain/entity/types.go
deleted file mode 100644
index 9a1edb6d..00000000
--- a/server/internal/db/domain/entity/types.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package entity
-
-import "time"
-
-type TaskStatus int
-
-const (
- TaskDelay TaskStatus = iota
- TaskReady
- TaskReserved
- TaskSuccess
- TaskFailed
-)
-
-type DbTask interface {
- GetId() uint64
- GetDeadline() time.Time
- IsFinished() bool
- Schedule() bool
- Update(task DbTask) bool
-}
diff --git a/server/internal/db/domain/repository/db_backup.go b/server/internal/db/domain/repository/db_backup.go
index b41b938b..0f5891ef 100644
--- a/server/internal/db/domain/repository/db_backup.go
+++ b/server/internal/db/domain/repository/db_backup.go
@@ -3,17 +3,14 @@ package repository
import (
"context"
"mayfly-go/internal/db/domain/entity"
- "mayfly-go/pkg/base"
"mayfly-go/pkg/model"
)
type DbBackup interface {
- base.Repo[*entity.DbBackup]
+ DbTask[*entity.DbBackup]
// GetDbBackupList 分页获取数据信息列表
GetDbBackupList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
AddTask(ctx context.Context, tasks ...*entity.DbBackup) error
- UpdateTaskStatus(ctx context.Context, task *entity.DbBackup) error
GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error)
- UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
}
diff --git a/server/internal/db/domain/repository/db_binlog.go b/server/internal/db/domain/repository/db_binlog.go
index 267be23f..6a44336b 100644
--- a/server/internal/db/domain/repository/db_binlog.go
+++ b/server/internal/db/domain/repository/db_binlog.go
@@ -10,6 +10,4 @@ type DbBinlog interface {
base.Repo[*entity.DbBinlog]
AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error
- UpdateTaskStatus(ctx context.Context, task *entity.DbBinlog) error
- UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
}
diff --git a/server/internal/db/domain/repository/db_binlog_history.go b/server/internal/db/domain/repository/db_binlog_history.go
index 6fcc03a9..b2aa9b71 100644
--- a/server/internal/db/domain/repository/db_binlog_history.go
+++ b/server/internal/db/domain/repository/db_binlog_history.go
@@ -12,5 +12,6 @@ type DbBinlogHistory interface {
GetHistories(instanceId uint64, start, target *entity.BinlogInfo) ([]*entity.DbBinlogHistory, error)
GetHistoryByTime(instanceId uint64, targetTime time.Time) (*entity.DbBinlogHistory, error)
GetLatestHistory(instanceId uint64) (*entity.DbBinlogHistory, bool, error)
+ InsertWithBinlogFiles(ctx context.Context, instanceId uint64, binlogFiles []*entity.BinlogFile) error
Upsert(ctx context.Context, history *entity.DbBinlogHistory) error
}
diff --git a/server/internal/db/domain/repository/db_restore.go b/server/internal/db/domain/repository/db_restore.go
index 9d9f1e82..8a28da15 100644
--- a/server/internal/db/domain/repository/db_restore.go
+++ b/server/internal/db/domain/repository/db_restore.go
@@ -3,17 +3,14 @@ package repository
import (
"context"
"mayfly-go/internal/db/domain/entity"
- "mayfly-go/pkg/base"
"mayfly-go/pkg/model"
)
type DbRestore interface {
- base.Repo[*entity.DbRestore]
+ DbTask[*entity.DbRestore]
// GetDbRestoreList 分页获取数据信息列表
GetDbRestoreList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
AddTask(ctx context.Context, tasks ...*entity.DbRestore) error
- UpdateTaskStatus(ctx context.Context, task *entity.DbRestore) error
GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error)
- UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
}
diff --git a/server/internal/db/domain/repository/db_task.go b/server/internal/db/domain/repository/db_task.go
new file mode 100644
index 00000000..ca2ba6e8
--- /dev/null
+++ b/server/internal/db/domain/repository/db_task.go
@@ -0,0 +1,17 @@
+package repository
+
+import (
+ "context"
+ "mayfly-go/pkg/base"
+ "mayfly-go/pkg/model"
+)
+
+type DbTask[T model.ModelI] interface {
+ base.Repo[T]
+
+ UpdateTaskStatus(ctx context.Context, task T) error
+ UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
+ ListToDo() ([]T, error)
+ ListRepeating() ([]T, error)
+ AddTask(ctx context.Context, tasks ...T) error
+}
diff --git a/server/internal/db/domain/service/db_backup.go b/server/internal/db/domain/service/db_backup.go
deleted file mode 100644
index a7db3a57..00000000
--- a/server/internal/db/domain/service/db_backup.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package service
-
-import (
- "context"
- "mayfly-go/internal/db/domain/entity"
-)
-
-type DbBackupSvc interface {
- AddTask(ctx context.Context, tasks ...*entity.DbBackup) error
- UpdateTask(ctx context.Context, task *entity.DbBackup) error
- DeleteTask(ctx context.Context, taskId uint64) error
- EnableTask(ctx context.Context, taskId uint64) error
- DisableTask(ctx context.Context, taskId uint64) error
-}
diff --git a/server/internal/db/domain/service/db_binlog.go b/server/internal/db/domain/service/db_binlog.go
deleted file mode 100644
index 5358f202..00000000
--- a/server/internal/db/domain/service/db_binlog.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package service
-
-import (
- "context"
- "mayfly-go/internal/db/domain/entity"
-)
-
-type DbBinlogSvc interface {
- AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error
- UpdateTask(ctx context.Context, task *entity.DbBinlog) error
- DeleteTask(ctx context.Context, taskId uint64) error
- EnableTask(ctx context.Context, taskId uint64) error
- DisableTask(ctx context.Context, taskId uint64) error
-}
diff --git a/server/internal/db/domain/service/db_instance.go b/server/internal/db/domain/service/db_instance.go
deleted file mode 100644
index 1b1bc822..00000000
--- a/server/internal/db/domain/service/db_instance.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package service
-
-import (
- "context"
- "mayfly-go/internal/db/domain/entity"
-)
-
-type DbInstanceSvc interface {
- Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error)
- Restore(ctx context.Context, task *entity.DbRestore) error
- FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool) error
-}
diff --git a/server/internal/db/domain/service/db_restore.go b/server/internal/db/domain/service/db_restore.go
deleted file mode 100644
index 6d284a1b..00000000
--- a/server/internal/db/domain/service/db_restore.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package service
-
-import (
- "context"
- "mayfly-go/internal/db/domain/entity"
-)
-
-type DbRestoreSvc interface {
- AddTask(ctx context.Context, tasks ...*entity.DbRestore) error
- UpdateTask(ctx context.Context, task *entity.DbRestore) error
- DeleteTask(ctx context.Context, taskId uint64) error
- EnableTask(ctx context.Context, taskId uint64) error
- DisableTask(ctx context.Context, taskId uint64) error
-}
diff --git a/server/internal/db/infrastructure/persistence/db_backup.go b/server/internal/db/infrastructure/persistence/db_backup.go
index d84267ef..221f6735 100644
--- a/server/internal/db/infrastructure/persistence/db_backup.go
+++ b/server/internal/db/infrastructure/persistence/db_backup.go
@@ -7,7 +7,6 @@ import (
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
- "mayfly-go/pkg/base"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices"
@@ -16,7 +15,8 @@ import (
var _ repository.DbBackup = (*dbBackupRepoImpl)(nil)
type dbBackupRepoImpl struct {
- base.RepoImpl[*entity.DbBackup]
+ //base.RepoImpl[*entity.DbBackup]
+ dbTaskBase[*entity.DbBackup]
}
func NewDbBackupRepo() repository.DbBackup {
@@ -34,30 +34,6 @@ func (d *dbBackupRepoImpl) GetDbBackupList(condition *entity.DbBackupQuery, page
return gormx.PageQuery(qd, pageParam, toEntity)
}
-func (d *dbBackupRepoImpl) UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error {
- cond := map[string]any{
- "id": taskId,
- }
- return d.Updates(cond, map[string]any{
- "enabled": enabled,
- })
-}
-
-func (d *dbBackupRepoImpl) UpdateTaskStatus(ctx context.Context, task *entity.DbBackup) error {
- task = &entity.DbBackup{
- Model: model.Model{
- DeletedModel: model.DeletedModel{
- Id: task.Id,
- },
- },
- Finished: task.Finished,
- LastStatus: task.LastStatus,
- LastResult: task.LastResult,
- LastTime: task.LastTime,
- }
- return d.UpdateById(ctx, task)
-}
-
func (d *dbBackupRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbBackup) error {
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
@@ -94,7 +70,7 @@ func (d *dbBackupRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbBacku
func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) {
var dbNamesWithBackup []string
- query := gormx.NewQuery(d.M).
+ query := gormx.NewQuery(d.GetModel()).
Eq("db_instance_id", instanceId).
Eq("repeated", true)
if err := query.GenGdb().Pluck("db_name", &dbNamesWithBackup).Error; err != nil {
diff --git a/server/internal/db/infrastructure/persistence/db_backup_history.go b/server/internal/db/infrastructure/persistence/db_backup_history.go
index 46ebaae2..138fb2ad 100644
--- a/server/internal/db/infrastructure/persistence/db_backup_history.go
+++ b/server/internal/db/infrastructure/persistence/db_backup_history.go
@@ -57,5 +57,5 @@ func (repo *dbBackupHistoryRepoImpl) GetEarliestHistory(instanceId uint64) (*ent
if err != nil {
return nil, err
}
- return history, err
+ return history, nil
}
diff --git a/server/internal/db/infrastructure/persistence/db_binlog.go b/server/internal/db/infrastructure/persistence/db_binlog.go
index 9446818d..8bcbb4b3 100644
--- a/server/internal/db/infrastructure/persistence/db_binlog.go
+++ b/server/internal/db/infrastructure/persistence/db_binlog.go
@@ -2,12 +2,12 @@ package persistence
import (
"context"
+ "fmt"
"gorm.io/gorm/clause"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
"mayfly-go/pkg/global"
- "mayfly-go/pkg/model"
)
var _ repository.DbBinlog = (*dbBinlogRepoImpl)(nil)
@@ -20,29 +20,9 @@ func NewDbBinlogRepo() repository.DbBinlog {
return &dbBinlogRepoImpl{}
}
-func (d *dbBinlogRepoImpl) UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error {
- cond := map[string]any{
- "id": taskId,
+func (d *dbBinlogRepoImpl) AddTaskIfNotExists(_ context.Context, task *entity.DbBinlog) error {
+ if err := global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(task).Error; err != nil {
+ return fmt.Errorf("启动 binlog 下载失败: %w", err)
}
- return d.Updates(cond, map[string]any{
- "enabled": enabled,
- })
-}
-
-func (d *dbBinlogRepoImpl) UpdateTaskStatus(ctx context.Context, task *entity.DbBinlog) error {
- task = &entity.DbBinlog{
- Model: model.Model{
- DeletedModel: model.DeletedModel{
- Id: task.Id,
- },
- },
- LastStatus: task.LastStatus,
- LastResult: task.LastResult,
- LastTime: task.LastTime,
- }
- return d.UpdateById(ctx, task)
-}
-
-func (d *dbBinlogRepoImpl) AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error {
- return global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(task).Error
+ return nil
}
diff --git a/server/internal/db/infrastructure/persistence/db_binlog_history.go b/server/internal/db/infrastructure/persistence/db_binlog_history.go
index e2b5e38d..09a11444 100644
--- a/server/internal/db/infrastructure/persistence/db_binlog_history.go
+++ b/server/internal/db/infrastructure/persistence/db_binlog_history.go
@@ -56,14 +56,14 @@ func (repo *dbBinlogHistoryRepoImpl) GetHistories(instanceId uint64, start, targ
}
func (repo *dbBinlogHistoryRepoImpl) GetLatestHistory(instanceId uint64) (*entity.DbBinlogHistory, bool, error) {
- gdb := gormx.NewQuery(repo.GetModel()).
+ history := &entity.DbBinlogHistory{}
+ err := gormx.NewQuery(repo.GetModel()).
Eq("db_instance_id", instanceId).
Undeleted().
OrderByDesc("sequence").
- GenGdb()
- history := &entity.DbBinlogHistory{}
-
- switch err := gdb.First(history).Error; {
+ GenGdb().
+ First(history).Error
+ switch {
case err == nil:
return history, true, nil
case errors.Is(err, gorm.ErrRecordNotFound):
@@ -89,3 +89,35 @@ func (repo *dbBinlogHistoryRepoImpl) Upsert(_ context.Context, history *entity.D
}
})
}
+
+func (repo *dbBinlogHistoryRepoImpl) InsertWithBinlogFiles(ctx context.Context, instanceId uint64, binlogFiles []*entity.BinlogFile) error {
+ if len(binlogFiles) == 0 {
+ return nil
+ }
+ histories := make([]*entity.DbBinlogHistory, 0, len(binlogFiles))
+ for _, fileOnServer := range binlogFiles {
+ if !fileOnServer.Downloaded {
+ break
+ }
+ history := &entity.DbBinlogHistory{
+ CreateTime: time.Now(),
+ FileName: fileOnServer.Name,
+ FileSize: fileOnServer.Size,
+ Sequence: fileOnServer.Sequence,
+ FirstEventTime: fileOnServer.FirstEventTime,
+ DbInstanceId: instanceId,
+ }
+ histories = append(histories, history)
+ }
+ if len(histories) > 1 {
+ if err := repo.BatchInsert(ctx, histories[:len(histories)-1]); err != nil {
+ return err
+ }
+ }
+ if len(histories) > 0 {
+ if err := repo.Upsert(ctx, histories[len(histories)-1]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/server/internal/db/infrastructure/persistence/db_restore.go b/server/internal/db/infrastructure/persistence/db_restore.go
index 4352dc0e..ee8093d5 100644
--- a/server/internal/db/infrastructure/persistence/db_restore.go
+++ b/server/internal/db/infrastructure/persistence/db_restore.go
@@ -7,7 +7,6 @@ import (
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
- "mayfly-go/pkg/base"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices"
@@ -16,7 +15,7 @@ import (
var _ repository.DbRestore = (*dbRestoreRepoImpl)(nil)
type dbRestoreRepoImpl struct {
- base.RepoImpl[*entity.DbRestore]
+ dbTaskBase[*entity.DbRestore]
}
func NewDbRestoreRepo() repository.DbRestore {
@@ -34,21 +33,6 @@ func (d *dbRestoreRepoImpl) GetDbRestoreList(condition *entity.DbRestoreQuery, p
return gormx.PageQuery(qd, pageParam, toEntity)
}
-func (d *dbRestoreRepoImpl) UpdateTaskStatus(ctx context.Context, task *entity.DbRestore) error {
- task = &entity.DbRestore{
- Model: model.Model{
- DeletedModel: model.DeletedModel{
- Id: task.Id,
- },
- },
- Finished: task.Finished,
- LastStatus: task.LastStatus,
- LastResult: task.LastResult,
- LastTime: task.LastTime,
- }
- return d.UpdateById(ctx, task)
-}
-
func (d *dbRestoreRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbRestore) error {
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
@@ -85,7 +69,7 @@ func (d *dbRestoreRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbRest
func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) {
var dbNamesWithRestore []string
- query := gormx.NewQuery(d.M).
+ query := gormx.NewQuery(d.GetModel()).
Eq("db_instance_id", instanceId).
Eq("repeated", true)
if err := query.GenGdb().Pluck("db_name", &dbNamesWithRestore).Error; err != nil {
@@ -99,12 +83,3 @@ func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames
}
return result, nil
}
-
-func (d *dbRestoreRepoImpl) UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error {
- cond := map[string]any{
- "id": taskId,
- }
- return d.Updates(cond, map[string]any{
- "enabled": enabled,
- })
-}
diff --git a/server/internal/db/infrastructure/persistence/db_task_base.go b/server/internal/db/infrastructure/persistence/db_task_base.go
new file mode 100644
index 00000000..de4daf00
--- /dev/null
+++ b/server/internal/db/infrastructure/persistence/db_task_base.go
@@ -0,0 +1,52 @@
+package persistence
+
+import (
+ "context"
+ "mayfly-go/internal/db/domain/entity"
+ "mayfly-go/pkg/base"
+ "mayfly-go/pkg/global"
+ "mayfly-go/pkg/gormx"
+ "mayfly-go/pkg/model"
+)
+
+type dbTaskBase[T model.ModelI] struct {
+ base.RepoImpl[T]
+}
+
+func (d *dbTaskBase[T]) UpdateEnabled(_ context.Context, taskId uint64, enabled bool) error {
+ cond := map[string]any{
+ "id": taskId,
+ }
+ return d.Updates(cond, map[string]any{
+ "enabled": enabled,
+ })
+}
+
+func (d *dbTaskBase[T]) UpdateTaskStatus(ctx context.Context, task T) error {
+ return d.UpdateById(ctx, task, "last_status", "last_result", "last_time")
+}
+
+func (d *dbTaskBase[T]) ListToDo() ([]T, error) {
+ var tasks []T
+ db := global.Db.Model(d.GetModel())
+ err := db.Where("enabled = ?", true).
+ Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.TaskSuccess)).
+ Scopes(gormx.UndeleteScope).
+ Find(&tasks).Error
+ if err != nil {
+ return nil, err
+ }
+ return tasks, nil
+}
+
+func (d *dbTaskBase[T]) ListRepeating() ([]T, error) {
+ cond := map[string]any{
+ "enabled": true,
+ "repeated": true,
+ }
+ var tasks []T
+ if err := d.ListByCond(cond, &tasks); err != nil {
+ return nil, err
+ }
+ return tasks, nil
+}
diff --git a/server/internal/db/infrastructure/service/db_backup.go b/server/internal/db/infrastructure/service/db_backup.go
deleted file mode 100644
index a4a8a0af..00000000
--- a/server/internal/db/infrastructure/service/db_backup.go
+++ /dev/null
@@ -1,212 +0,0 @@
-package service
-
-import (
- "context"
- "encoding/binary"
- "fmt"
- "mayfly-go/internal/db/domain/entity"
- "mayfly-go/internal/db/domain/repository"
- "mayfly-go/internal/db/domain/service"
- "mayfly-go/pkg/model"
- "time"
-
- "github.com/google/uuid"
-)
-
-var _ service.DbBackupSvc = (*DbBackupSvcImpl)(nil)
-
-type DbBackupSvcImpl struct {
- repo repository.DbBackup
- instanceRepo repository.Instance
- scheduler *Scheduler[*entity.DbBackup]
- binlogSvc service.DbBinlogSvc
-}
-
-func NewIncUUID() (uuid.UUID, error) {
- var uid uuid.UUID
- now, seq, err := uuid.GetTime()
- if err != nil {
- return uid, err
- }
- timeHi := uint32((now >> 28) & 0xffffffff)
- timeMid := uint16((now >> 12) & 0xffff)
- timeLow := uint16(now & 0x0fff)
- timeLow |= 0x1000 // Version 1
-
- binary.BigEndian.PutUint32(uid[0:], timeHi)
- binary.BigEndian.PutUint16(uid[4:], timeMid)
- binary.BigEndian.PutUint16(uid[6:], timeLow)
- binary.BigEndian.PutUint16(uid[8:], seq)
-
- copy(uid[10:], uuid.NodeID())
-
- return uid, nil
-}
-
-func withRunBackupTask(repositories *repository.Repositories, binlogSvc service.DbBinlogSvc) SchedulerOption[*entity.DbBackup] {
- return func(scheduler *Scheduler[*entity.DbBackup]) {
- scheduler.RunTask = func(ctx context.Context, task *entity.DbBackup) error {
- instance := new(entity.DbInstance)
- if err := repositories.Instance.GetById(instance, task.DbInstanceId); err != nil {
- return err
- }
- instance.PwdDecrypt()
- id, err := NewIncUUID()
- if err != nil {
- return err
- }
- history := &entity.DbBackupHistory{
- Uuid: id.String(),
- DbBackupId: task.Id,
- DbInstanceId: task.DbInstanceId,
- DbName: task.DbName,
- }
- binlogInfo, err := NewDbInstanceSvc(instance, repositories).Backup(ctx, history)
- if err != nil {
- return err
- }
- now := time.Now()
- name := task.Name
- if len(name) == 0 {
- name = task.DbName
- }
- history.Name = fmt.Sprintf("%s[%s]", name, now.Format(time.DateTime))
- history.CreateTime = now
- history.BinlogFileName = binlogInfo.FileName
- history.BinlogSequence = binlogInfo.Sequence
- history.BinlogPosition = binlogInfo.Position
-
- if err := repositories.BackupHistory.Insert(ctx, history); err != nil {
- return err
- }
- if err := binlogSvc.AddTaskIfNotExists(ctx, entity.NewDbBinlog(history)); err != nil {
- return err
- }
- return nil
- }
- }
-}
-
-var (
- backupResult = map[entity.TaskStatus]string{
- entity.TaskDelay: "等待备份数据库",
- entity.TaskReady: "准备备份数据库",
- entity.TaskReserved: "数据库备份中",
- entity.TaskSuccess: "数据库备份成功",
- entity.TaskFailed: "数据库备份失败",
- }
-)
-
-func withUpdateBackupStatus(repositories *repository.Repositories) SchedulerOption[*entity.DbBackup] {
- return func(scheduler *Scheduler[*entity.DbBackup]) {
- scheduler.UpdateTaskStatus = func(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbBackup) error {
- task.Finished = !task.Repeated && status == entity.TaskSuccess
- task.LastStatus = status
- var result = backupResult[status]
- if lastErr != nil {
- result = fmt.Sprintf("%v: %v", backupResult[status], lastErr)
- }
- task.LastResult = result
- task.LastTime = time.Now()
- return repositories.Backup.UpdateTaskStatus(ctx, task)
- }
- }
-}
-
-func NewDbBackupSvc(repositories *repository.Repositories, binlogSvc service.DbBinlogSvc) (service.DbBackupSvc, error) {
- scheduler, err := NewScheduler[*entity.DbBackup](
- withRunBackupTask(repositories, binlogSvc),
- withUpdateBackupStatus(repositories))
- if err != nil {
- return nil, err
- }
- svc := &DbBackupSvcImpl{
- repo: repositories.Backup,
- instanceRepo: repositories.Instance,
- scheduler: scheduler,
- binlogSvc: binlogSvc,
- }
- err = svc.loadTasks(context.Background())
- if err != nil {
- return nil, err
- }
- return svc, nil
-}
-
-func (svc *DbBackupSvcImpl) loadTasks(ctx context.Context) error {
- tasks := make([]*entity.DbBackup, 0, 64)
- cond := map[string]any{
- "Enabled": true,
- "Finished": false,
- }
- if err := svc.repo.ListByCond(cond, &tasks); err != nil {
- return err
- }
- for _, task := range tasks {
- svc.scheduler.PushTask(ctx, task)
- }
- return nil
-}
-
-func (svc *DbBackupSvcImpl) AddTask(ctx context.Context, tasks ...*entity.DbBackup) error {
- for _, task := range tasks {
- if err := svc.repo.AddTask(ctx, task); err != nil {
- return err
- }
- svc.scheduler.PushTask(ctx, task)
- }
- return nil
-}
-
-func (svc *DbBackupSvcImpl) UpdateTask(ctx context.Context, task *entity.DbBackup) error {
- if err := svc.repo.UpdateById(ctx, task); err != nil {
- return err
- }
- svc.scheduler.UpdateTask(ctx, task)
- return nil
-}
-
-func (svc *DbBackupSvcImpl) DeleteTask(ctx context.Context, taskId uint64) error {
- // todo: 删除数据库备份历史文件
- task := new(entity.DbBackup)
- if err := svc.repo.GetById(task, taskId); err != nil {
- return err
- }
- if err := svc.binlogSvc.DeleteTask(ctx, task.DbInstanceId); err != nil {
- return err
- }
- if err := svc.repo.DeleteById(ctx, taskId); err != nil {
- return err
- }
- svc.scheduler.RemoveTask(taskId)
- return nil
-}
-
-func (svc *DbBackupSvcImpl) EnableTask(ctx context.Context, taskId uint64) error {
- if err := svc.repo.UpdateEnabled(ctx, taskId, true); err != nil {
- return err
- }
- task := new(entity.DbBackup)
- if err := svc.repo.GetById(task, taskId); err != nil {
- return err
- }
- svc.scheduler.UpdateTask(ctx, task)
- return nil
-}
-
-func (svc *DbBackupSvcImpl) DisableTask(ctx context.Context, taskId uint64) error {
- if err := svc.repo.UpdateEnabled(ctx, taskId, false); err != nil {
- return err
- }
- task := new(entity.DbBackup)
- if err := svc.repo.GetById(task, taskId); err != nil {
- return err
- }
- svc.scheduler.RemoveTask(taskId)
- return nil
-}
-
-// GetPageList 分页获取数据库备份任务
-func (svc *DbBackupSvcImpl) GetPageList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
- return svc.repo.GetDbBackupList(condition, pageParam, toEntity, orderBy...)
-}
diff --git a/server/internal/db/infrastructure/service/db_binlog.go b/server/internal/db/infrastructure/service/db_binlog.go
deleted file mode 100644
index 70fe9960..00000000
--- a/server/internal/db/infrastructure/service/db_binlog.go
+++ /dev/null
@@ -1,140 +0,0 @@
-package service
-
-import (
- "context"
- "fmt"
- "mayfly-go/internal/db/domain/entity"
- "mayfly-go/internal/db/domain/repository"
- "mayfly-go/internal/db/domain/service"
- "time"
-)
-
-var _ service.DbBinlogSvc = (*DbBinlogSvcImpl)(nil)
-
-type DbBinlogSvcImpl struct {
- repo repository.DbBinlog
- instanceRepo repository.Instance
- scheduler *Scheduler[*entity.DbBinlog]
-}
-
-func withDownloadBinlog(repositories *repository.Repositories) SchedulerOption[*entity.DbBinlog] {
- return func(scheduler *Scheduler[*entity.DbBinlog]) {
- scheduler.RunTask = func(ctx context.Context, task *entity.DbBinlog) error {
- instance := new(entity.DbInstance)
- if err := repositories.Instance.GetById(instance, task.DbInstanceId); err != nil {
- return err
- }
- instance.PwdDecrypt()
- svc := NewDbInstanceSvc(instance, repositories)
- err := svc.FetchBinlogs(ctx, false)
- if err != nil {
- return err
- }
- return nil
- }
- }
-}
-
-var (
- binlogResult = map[entity.TaskStatus]string{
- entity.TaskDelay: "等待备份BINLOG",
- entity.TaskReady: "准备备份BINLOG",
- entity.TaskReserved: "BINLOG备份中",
- entity.TaskSuccess: "BINLOG备份成功",
- entity.TaskFailed: "BINLOG备份失败",
- }
-)
-
-func withUpdateBinlogStatus(repositories *repository.Repositories) SchedulerOption[*entity.DbBinlog] {
- return func(scheduler *Scheduler[*entity.DbBinlog]) {
- scheduler.UpdateTaskStatus = func(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbBinlog) error {
- task.LastStatus = status
- var result = backupResult[status]
- if lastErr != nil {
- result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr)
- }
- task.LastResult = result
- task.LastTime = time.Now()
- return repositories.Binlog.UpdateTaskStatus(ctx, task)
- }
- }
-}
-
-func NewDbBinlogSvc(repositories *repository.Repositories) (service.DbBinlogSvc, error) {
- scheduler, err := NewScheduler[*entity.DbBinlog](withDownloadBinlog(repositories), withUpdateBinlogStatus(repositories))
- if err != nil {
- return nil, err
- }
- svc := &DbBinlogSvcImpl{
- repo: repositories.Binlog,
- instanceRepo: repositories.Instance,
- scheduler: scheduler,
- }
- err = svc.loadTasks(context.Background())
- if err != nil {
- return nil, err
- }
- return svc, nil
-}
-
-func (svc *DbBinlogSvcImpl) loadTasks(ctx context.Context) error {
- tasks := make([]*entity.DbBinlog, 0, 64)
- cond := map[string]any{
- "Enabled": true,
- }
- if err := svc.repo.ListByCond(cond, &tasks); err != nil {
- return err
- }
- for _, task := range tasks {
- svc.scheduler.PushTask(ctx, task)
- }
- return nil
-}
-
-func (svc *DbBinlogSvcImpl) AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error {
- if err := svc.repo.AddTaskIfNotExists(ctx, task); err != nil {
- return err
- }
- if task.GetId() == 0 {
- return nil
- }
- svc.scheduler.PushTask(ctx, task)
- return nil
-}
-
-func (svc *DbBinlogSvcImpl) UpdateTask(ctx context.Context, task *entity.DbBinlog) error {
- if err := svc.repo.UpdateById(ctx, task); err != nil {
- return err
- }
- svc.scheduler.UpdateTask(ctx, task)
- return nil
-}
-
-func (svc *DbBinlogSvcImpl) DeleteTask(ctx context.Context, taskId uint64) error {
- // todo: 删除 Binlog 历史文件
- if err := svc.repo.DeleteById(ctx, taskId); err != nil {
- return err
- }
- svc.scheduler.RemoveTask(taskId)
- return nil
-}
-
-func (svc *DbBinlogSvcImpl) EnableTask(ctx context.Context, taskId uint64) error {
- if err := svc.repo.UpdateEnabled(ctx, taskId, true); err != nil {
- return err
- }
- task := new(entity.DbBinlog)
- if err := svc.repo.GetById(task, taskId); err != nil {
- return err
- }
- svc.scheduler.UpdateTask(ctx, task)
- return nil
-}
-
-func (svc *DbBinlogSvcImpl) DisableTask(ctx context.Context, taskId uint64) error {
- if err := svc.repo.UpdateEnabled(ctx, taskId, false); err != nil {
- return err
- }
- svc.scheduler.RemoveTask(taskId)
- return nil
-}
diff --git a/server/internal/db/infrastructure/service/db_restore.go b/server/internal/db/infrastructure/service/db_restore.go
deleted file mode 100644
index 52ae06d3..00000000
--- a/server/internal/db/infrastructure/service/db_restore.go
+++ /dev/null
@@ -1,155 +0,0 @@
-package service
-
-import (
- "context"
- "fmt"
- "mayfly-go/internal/db/domain/entity"
- "mayfly-go/internal/db/domain/repository"
- "mayfly-go/internal/db/domain/service"
- "mayfly-go/pkg/model"
- "time"
-)
-
-var _ service.DbRestoreSvc = (*DbRestoreSvcImpl)(nil)
-
-type DbRestoreSvcImpl struct {
- repo repository.DbRestore
- instanceRepo repository.Instance
- scheduler *Scheduler[*entity.DbRestore]
-}
-
-func withRunRestoreTask(repositories *repository.Repositories) SchedulerOption[*entity.DbRestore] {
- return func(scheduler *Scheduler[*entity.DbRestore]) {
- scheduler.RunTask = func(ctx context.Context, task *entity.DbRestore) error {
- instance := new(entity.DbInstance)
- if err := repositories.Instance.GetById(instance, task.DbInstanceId); err != nil {
- return err
- }
- instance.PwdDecrypt()
- if err := NewDbInstanceSvc(instance, repositories).Restore(ctx, task); err != nil {
- return err
- }
-
- history := &entity.DbRestoreHistory{
- CreateTime: time.Now(),
- DbRestoreId: task.Id,
- }
- if err := repositories.RestoreHistory.Insert(ctx, history); err != nil {
- return err
- }
-
- return nil
- }
- }
-}
-
-var (
- restoreResult = map[entity.TaskStatus]string{
- entity.TaskDelay: "等待恢复数据库",
- entity.TaskReady: "准备恢复数据库",
- entity.TaskReserved: "数据库恢复中",
- entity.TaskSuccess: "数据库恢复成功",
- entity.TaskFailed: "数据库恢复失败",
- }
-)
-
-func withUpdateRestoreStatus(repositories *repository.Repositories) SchedulerOption[*entity.DbRestore] {
- return func(scheduler *Scheduler[*entity.DbRestore]) {
- scheduler.UpdateTaskStatus = func(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbRestore) error {
- task.Finished = !task.Repeated && status == entity.TaskSuccess
- task.LastStatus = status
- var result = restoreResult[status]
- if lastErr != nil {
- result = fmt.Sprintf("%v: %v", restoreResult[status], lastErr)
- }
- task.LastResult = result
- task.LastTime = time.Now()
- return repositories.Restore.UpdateTaskStatus(ctx, task)
- }
- }
-}
-
-func NewDbRestoreSvc(repositories *repository.Repositories) (service.DbRestoreSvc, error) {
- scheduler, err := NewScheduler[*entity.DbRestore](
- withRunRestoreTask(repositories),
- withUpdateRestoreStatus(repositories))
- if err != nil {
- return nil, err
- }
- svc := &DbRestoreSvcImpl{
- repo: repositories.Restore,
- instanceRepo: repositories.Instance,
- scheduler: scheduler,
- }
- if err := svc.loadTasks(context.Background()); err != nil {
- return nil, err
- }
- return svc, nil
-}
-
-func (svc *DbRestoreSvcImpl) loadTasks(ctx context.Context) error {
- tasks := make([]*entity.DbRestore, 0, 64)
- cond := map[string]any{
- "Enabled": true,
- "Finished": false,
- }
- if err := svc.repo.ListByCond(cond, &tasks); err != nil {
- return err
- }
- for _, task := range tasks {
- svc.scheduler.PushTask(ctx, task)
- }
- return nil
-}
-
-func (svc *DbRestoreSvcImpl) AddTask(ctx context.Context, tasks ...*entity.DbRestore) error {
- for _, task := range tasks {
- if err := svc.repo.AddTask(ctx, task); err != nil {
- return err
- }
- svc.scheduler.PushTask(ctx, task)
- }
- return nil
-}
-
-func (svc *DbRestoreSvcImpl) UpdateTask(ctx context.Context, task *entity.DbRestore) error {
- if err := svc.repo.UpdateById(ctx, task); err != nil {
- return err
- }
- svc.scheduler.UpdateTask(ctx, task)
- return nil
-}
-
-func (svc *DbRestoreSvcImpl) DeleteTask(ctx context.Context, taskId uint64) error {
- // todo: 删除数据库恢复历史文件
- if err := svc.repo.DeleteById(ctx, taskId); err != nil {
- return err
- }
- svc.scheduler.RemoveTask(taskId)
- return nil
-}
-
-func (svc *DbRestoreSvcImpl) EnableTask(ctx context.Context, taskId uint64) error {
- if err := svc.repo.UpdateEnabled(ctx, taskId, true); err != nil {
- return err
- }
- task := new(entity.DbRestore)
- if err := svc.repo.GetById(task, taskId); err != nil {
- return err
- }
- svc.scheduler.UpdateTask(ctx, task)
- return nil
-}
-
-func (svc *DbRestoreSvcImpl) DisableTask(ctx context.Context, taskId uint64) error {
- if err := svc.repo.UpdateEnabled(ctx, taskId, false); err != nil {
- return err
- }
- svc.scheduler.RemoveTask(taskId)
- return nil
-}
-
-// GetPageList 分页获取数据库恢复任务
-func (svc *DbRestoreSvcImpl) GetPageList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
- return svc.repo.GetDbRestoreList(condition, pageParam, toEntity, orderBy...)
-}
diff --git a/server/internal/db/infrastructure/service/scheduler.go b/server/internal/db/infrastructure/service/scheduler.go
deleted file mode 100644
index 9597a9de..00000000
--- a/server/internal/db/infrastructure/service/scheduler.go
+++ /dev/null
@@ -1,160 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "mayfly-go/internal/db/domain/entity"
- "mayfly-go/pkg/logx"
- "mayfly-go/pkg/queue"
- "sync"
- "time"
-)
-
-type Scheduler[T entity.DbTask] struct {
- mutex sync.Mutex
- wg sync.WaitGroup
- queue *queue.DelayQueue[T]
- closed bool
- curTask T
- curTaskContext context.Context
- curTaskCancel context.CancelFunc
- UpdateTaskStatus func(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error
- RunTask func(ctx context.Context, task T) error
-}
-
-type SchedulerOption[T entity.DbTask] func(*Scheduler[T])
-
-func NewScheduler[T entity.DbTask](opts ...SchedulerOption[T]) (*Scheduler[T], error) {
- scheduler := &Scheduler[T]{
- queue: queue.NewDelayQueue[T](0),
- }
- for _, opt := range opts {
- opt(scheduler)
- }
- if scheduler.RunTask == nil || scheduler.UpdateTaskStatus == nil {
- return nil, errors.New("调度器没有设置 RunTask 或 UpdateTaskStatus")
- }
- scheduler.wg.Add(1)
- go scheduler.run()
- return scheduler, nil
-}
-
-func (m *Scheduler[T]) PushTask(ctx context.Context, task T) bool {
- if !task.Schedule() {
- return false
- }
-
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- return m.queue.Enqueue(ctx, task)
-}
-
-func (m *Scheduler[T]) UpdateTask(ctx context.Context, task T) bool {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- if task.GetId() == m.curTask.GetId() {
- return m.curTask.Update(task)
- }
- oldTask, ok := m.queue.Remove(ctx, task.GetId())
- if ok {
- if !oldTask.Update(task) {
- return false
- }
- } else {
- oldTask = task
- }
- if !oldTask.Schedule() {
- return false
- }
- return m.queue.Enqueue(ctx, oldTask)
-}
-
-func (m *Scheduler[T]) updateCurTask(status entity.TaskStatus, lastErr error, task T) bool {
- seconds := []time.Duration{time.Second * 1, time.Second * 8, time.Second * 64}
- for _, second := range seconds {
- if m.closed {
- return false
- }
- ctx, cancel := context.WithTimeout(context.Background(), second)
- err := m.UpdateTaskStatus(ctx, status, lastErr, task)
- cancel()
- if err != nil {
- logx.Errorf("保存任务失败: %v", err)
- time.Sleep(second)
- continue
- }
- return true
- }
- return false
-}
-
-func (m *Scheduler[T]) run() {
- defer m.wg.Done()
-
- var ctx context.Context
- var cancel context.CancelFunc
- for !m.closed {
- m.mutex.Lock()
- ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
- task, ok := m.queue.Dequeue(ctx)
- cancel()
- if !ok {
- m.mutex.Unlock()
- time.Sleep(time.Second)
- continue
- }
- m.curTask = task
- m.updateCurTask(entity.TaskReserved, nil, task)
- m.curTaskContext, m.curTaskCancel = context.WithCancel(context.Background())
- m.mutex.Unlock()
-
- err := m.RunTask(m.curTaskContext, task)
-
- m.mutex.Lock()
- taskStatus := entity.TaskSuccess
- if err != nil {
- taskStatus = entity.TaskFailed
- }
- m.updateCurTask(taskStatus, err, task)
- m.cancelCurTask()
- task.Schedule()
- if !task.IsFinished() {
- ctx, cancel = context.WithTimeout(context.Background(), time.Second)
- m.queue.Enqueue(ctx, task)
- cancel()
- }
- m.mutex.Unlock()
- }
-}
-
-func (m *Scheduler[T]) Close() {
- if m.closed {
- return
- }
- m.mutex.Lock()
- m.cancelCurTask()
- m.closed = true
- m.mutex.Unlock()
-
- m.wg.Wait()
-}
-
-func (m *Scheduler[T]) RemoveTask(taskId uint64) bool {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.queue.Remove(context.Background(), taskId)
- if taskId == m.curTask.GetId() {
- m.cancelCurTask()
- }
- return true
-}
-
-func (m *Scheduler[T]) cancelCurTask() {
- if m.curTaskCancel != nil {
- m.curTaskCancel()
- m.curTaskCancel = nil
- }
-}
diff --git a/server/internal/db/router/db_backup.go b/server/internal/db/router/db_backup.go
index f472d15c..66c67dfc 100644
--- a/server/internal/db/router/db_backup.go
+++ b/server/internal/db/router/db_backup.go
@@ -26,6 +26,8 @@ func InitDbBackupRouter(router *gin.RouterGroup) {
req.NewPut(":dbId/backups/:backupId/enable", d.Enable).Log(req.NewLogSave("db-启用数据库备份任务")),
// 禁用数据库备份任务
req.NewPut(":dbId/backups/:backupId/disable", d.Disable).Log(req.NewLogSave("db-禁用数据库备份任务")),
+ // 立即启动数据库备份任务
+ req.NewPut(":dbId/backups/:backupId/start", d.Start).Log(req.NewLogSave("db-立即启动数据库备份任务")),
// 删除数据库备份任务
req.NewDelete(":dbId/backups/:backupId", d.Delete),
// 获取未配置定时备份的数据库名称
diff --git a/server/internal/machine/application/machine.go b/server/internal/machine/application/machine.go
index 6482d744..e1c7b716 100644
--- a/server/internal/machine/application/machine.go
+++ b/server/internal/machine/application/machine.go
@@ -84,7 +84,9 @@ func (m *machineAppImpl) Save(ctx context.Context, me *entity.Machine, tagIds ..
err := m.GetBy(oldMachine)
- me.PwdEncrypt()
+ if errEnc := me.PwdEncrypt(); errEnc != nil {
+ return errorx.NewBiz(errEnc.Error())
+ }
if me.Id == 0 {
if err == nil {
return errorx.NewBiz("该机器信息已存在")
@@ -242,13 +244,17 @@ func (m *machineAppImpl) toMachineInfo(me *entity.Machine) (*mcm.MachineInfo, er
return nil, errorx.NewBiz("授权凭证信息已不存在,请重新关联")
}
mi.AuthMethod = ac.AuthMethod
- ac.PwdDecrypt()
+ if err := ac.PwdDecrypt(); err != nil {
+ return nil, errorx.NewBiz(err.Error())
+ }
mi.Password = ac.Password
mi.Passphrase = ac.Passphrase
} else {
mi.AuthMethod = entity.AuthCertAuthMethodPassword
if me.Id != 0 {
- me.PwdDecrypt()
+ if err := me.PwdDecrypt(); err != nil {
+ return nil, errorx.NewBiz(err.Error())
+ }
}
mi.Password = me.Password
}
diff --git a/server/internal/machine/domain/entity/auth_cert.go b/server/internal/machine/domain/entity/auth_cert.go
index 36e42c02..cd191c4b 100644
--- a/server/internal/machine/domain/entity/auth_cert.go
+++ b/server/internal/machine/domain/entity/auth_cert.go
@@ -1,6 +1,7 @@
package entity
import (
+ "errors"
"mayfly-go/internal/common/utils"
"mayfly-go/pkg/model"
)
@@ -16,7 +17,7 @@ type AuthCert struct {
Remark string `json:"remark"`
}
-func (a *AuthCert) TableName() string {
+func (ac *AuthCert) TableName() string {
return "t_auth_cert"
}
@@ -28,14 +29,32 @@ const (
AuthCertTypePublic int8 = 2
)
-// 密码加密
-func (ac *AuthCert) PwdEncrypt() {
- ac.Password = utils.PwdAesEncrypt(ac.Password)
- ac.Passphrase = utils.PwdAesEncrypt(ac.Passphrase)
+// PwdEncrypt 密码加密
+func (ac *AuthCert) PwdEncrypt() error {
+ password, err := utils.PwdAesEncrypt(ac.Password)
+ if err != nil {
+ return errors.New("加密授权凭证密码失败")
+ }
+ passphrase, err := utils.PwdAesEncrypt(ac.Passphrase)
+ if err != nil {
+ return errors.New("加密授权凭证私钥失败")
+ }
+ ac.Password = password
+ ac.Passphrase = passphrase
+ return nil
}
-// 密码解密
-func (ac *AuthCert) PwdDecrypt() {
- ac.Password = utils.PwdAesDecrypt(ac.Password)
- ac.Passphrase = utils.PwdAesDecrypt(ac.Passphrase)
+// PwdDecrypt 密码解密
+func (ac *AuthCert) PwdDecrypt() error {
+ password, err := utils.PwdAesDecrypt(ac.Password)
+ if err != nil {
+ return errors.New("解密授权凭证密码失败")
+ }
+ passphrase, err := utils.PwdAesDecrypt(ac.Passphrase)
+ if err != nil {
+ return errors.New("解密授权凭证私钥失败")
+ }
+ ac.Password = password
+ ac.Passphrase = passphrase
+ return nil
}
diff --git a/server/internal/machine/domain/entity/machine.go b/server/internal/machine/domain/entity/machine.go
index d701f08c..21c96832 100644
--- a/server/internal/machine/domain/entity/machine.go
+++ b/server/internal/machine/domain/entity/machine.go
@@ -1,6 +1,7 @@
package entity
import (
+ "errors"
"mayfly-go/internal/common/utils"
"mayfly-go/pkg/model"
)
@@ -26,14 +27,24 @@ const (
MachineStatusDisable int8 = -1 // 禁用状态
)
-func (m *Machine) PwdEncrypt() {
+func (m *Machine) PwdEncrypt() error {
// 密码替换为加密后的密码
- m.Password = utils.PwdAesEncrypt(m.Password)
+ password, err := utils.PwdAesEncrypt(m.Password)
+ if err != nil {
+ return errors.New("加密主机密码失败")
+ }
+ m.Password = password
+ return nil
}
-func (m *Machine) PwdDecrypt() {
+func (m *Machine) PwdDecrypt() error {
// 密码替换为解密后的密码
- m.Password = utils.PwdAesDecrypt(m.Password)
+ password, err := utils.PwdAesDecrypt(m.Password)
+ if err != nil {
+ return errors.New("解密主机密码失败")
+ }
+ m.Password = password
+ return nil
}
func (m *Machine) UseAuthCert() bool {
diff --git a/server/internal/redis/api/redis.go b/server/internal/redis/api/redis.go
index 4fedee5d..84295670 100644
--- a/server/internal/redis/api/redis.go
+++ b/server/internal/redis/api/redis.go
@@ -77,7 +77,9 @@ func (r *Redis) GetRedisPwd(rc *req.Ctx) {
rid := uint64(ginx.PathParamInt(rc.GinCtx, "id"))
re, err := r.RedisApp.GetById(new(entity.Redis), rid, "Password")
biz.ErrIsNil(err, "redis信息不存在")
- re.PwdDecrypt()
+ if err := re.PwdDecrypt(); err != nil {
+ biz.ErrIsNil(err)
+ }
rc.ResData = re.Password
}
diff --git a/server/internal/redis/application/redis.go b/server/internal/redis/application/redis.go
index 1b7d74ad..c1403a35 100644
--- a/server/internal/redis/application/redis.go
+++ b/server/internal/redis/application/redis.go
@@ -80,7 +80,9 @@ func (r *redisAppImpl) Save(ctx context.Context, re *entity.Redis, tagIds ...uin
if err == nil {
return errorx.NewBiz("该实例已存在")
}
- re.PwdEncrypt()
+ if errEnc := re.PwdEncrypt(); errEnc != nil {
+ return errorx.NewBiz(errEnc.Error())
+ }
resouceCode := stringx.Rand(16)
re.Code = resouceCode
@@ -108,7 +110,9 @@ func (r *redisAppImpl) Save(ctx context.Context, re *entity.Redis, tagIds ...uin
oldRedis, _ = r.GetById(new(entity.Redis), re.Id)
}
- re.PwdEncrypt()
+ if errEnc := re.PwdEncrypt(); errEnc != nil {
+ return errorx.NewBiz(errEnc.Error())
+ }
return r.Tx(ctx, func(ctx context.Context) error {
return r.UpdateById(ctx, re)
}, func(ctx context.Context) error {
@@ -144,8 +148,9 @@ func (r *redisAppImpl) GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) {
if err != nil {
return nil, errorx.NewBiz("redis信息不存在")
}
- re.PwdDecrypt()
-
+ if err := re.PwdDecrypt(); err != nil {
+ return nil, errorx.NewBiz(err.Error())
+ }
return re.ToRedisInfo(db, r.tagApp.ListTagPathByResource(consts.TagResourceTypeRedis, re.Code)...), nil
})
}
diff --git a/server/internal/redis/domain/entity/redis.go b/server/internal/redis/domain/entity/redis.go
index c62ae209..cedc9c22 100644
--- a/server/internal/redis/domain/entity/redis.go
+++ b/server/internal/redis/domain/entity/redis.go
@@ -1,6 +1,7 @@
package entity
import (
+ "errors"
"mayfly-go/internal/common/utils"
"mayfly-go/internal/redis/rdm"
"mayfly-go/pkg/model"
@@ -21,20 +22,30 @@ type Redis struct {
Remark string
}
-func (r *Redis) PwdEncrypt() {
+func (r *Redis) PwdEncrypt() error {
// 密码替换为加密后的密码
- r.Password = utils.PwdAesEncrypt(r.Password)
+ password, err := utils.PwdAesEncrypt(r.Password)
+ if err != nil {
+ return errors.New("加密 Redis 密码失败")
+ }
+ r.Password = password
+ return nil
}
-func (r *Redis) PwdDecrypt() {
+func (r *Redis) PwdDecrypt() error {
// 密码替换为解密后的密码
- r.Password = utils.PwdAesDecrypt(r.Password)
+ password, err := utils.PwdAesDecrypt(r.Password)
+ if err != nil {
+ return errors.New("解密 Redis 密码失败")
+ }
+ r.Password = password
+ return nil
}
-// 转换为redisInfo进行连接
-func (re *Redis) ToRedisInfo(db int, tagPath ...string) *rdm.RedisInfo {
+// ToRedisInfo 转换为redisInfo进行连接
+func (r *Redis) ToRedisInfo(db int, tagPath ...string) *rdm.RedisInfo {
redisInfo := new(rdm.RedisInfo)
- structx.Copy(redisInfo, re)
+ _ = structx.Copy(redisInfo, r)
redisInfo.Db = db
redisInfo.TagPath = tagPath
return redisInfo
diff --git a/server/internal/sys/domain/entity/account.go b/server/internal/sys/domain/entity/account.go
index 4c4a6971..f7a9dd43 100644
--- a/server/internal/sys/domain/entity/account.go
+++ b/server/internal/sys/domain/entity/account.go
@@ -1,6 +1,7 @@
package entity
import (
+ "errors"
"mayfly-go/internal/common/utils"
"mayfly-go/pkg/model"
"time"
@@ -27,15 +28,25 @@ func (a *Account) IsEnable() bool {
return a.Status == AccountEnableStatus
}
-func (a *Account) OtpSecretEncrypt() {
- a.OtpSecret = utils.PwdAesEncrypt(a.OtpSecret)
+func (a *Account) OtpSecretEncrypt() error {
+ secret, err := utils.PwdAesEncrypt(a.OtpSecret)
+ if err != nil {
+ return errors.New("加密账户密码失败")
+ }
+ a.OtpSecret = secret
+ return nil
}
-func (a *Account) OtpSecretDecrypt() {
+func (a *Account) OtpSecretDecrypt() error {
if a.OtpSecret == "-" {
- return
+ return nil
}
- a.OtpSecret = utils.PwdAesDecrypt(a.OtpSecret)
+ secret, err := utils.PwdAesDecrypt(a.OtpSecret)
+ if err != nil {
+ return errors.New("解密账户密码失败")
+ }
+ a.OtpSecret = secret
+ return nil
}
const (
diff --git a/server/pkg/base/repo.go b/server/pkg/base/repo.go
index 9bef0067..a8867ac2 100644
--- a/server/pkg/base/repo.go
+++ b/server/pkg/base/repo.go
@@ -11,6 +11,9 @@ import (
// 基础repo接口
type Repo[T model.ModelI] interface {
+ // GetModel 获取表的模型实例
+ GetModel() T
+
// 新增一个实体
Insert(ctx context.Context, e T) error
@@ -24,10 +27,10 @@ type Repo[T model.ModelI] interface {
BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error
// 根据实体id更新实体信息
- UpdateById(ctx context.Context, e T) error
+ UpdateById(ctx context.Context, e T, columns ...string) error
// 使用指定gorm db执行,主要用于事务执行
- UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T) error
+ UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T, columns ...string) error
// 根据实体主键删除实体
DeleteById(ctx context.Context, id uint64) error
@@ -101,16 +104,16 @@ func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []
return gormx.BatchInsertWithDb(db, es)
}
-func (br *RepoImpl[T]) UpdateById(ctx context.Context, e T) error {
+func (br *RepoImpl[T]) UpdateById(ctx context.Context, e T, columns ...string) error {
if db := contextx.GetDb(ctx); db != nil {
- return br.UpdateByIdWithDb(ctx, db, e)
+ return br.UpdateByIdWithDb(ctx, db, e, columns...)
}
- return gormx.UpdateById(br.setBaseInfo(ctx, e))
+ return gormx.UpdateById(br.setBaseInfo(ctx, e), columns...)
}
-func (br *RepoImpl[T]) UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T) error {
- return gormx.UpdateByIdWithDb(db, br.setBaseInfo(ctx, e))
+func (br *RepoImpl[T]) UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T, columns ...string) error {
+ return gormx.UpdateByIdWithDb(db, br.setBaseInfo(ctx, e), columns...)
}
func (br *RepoImpl[T]) Updates(cond any, udpateFields map[string]any) error {
diff --git a/server/pkg/biz/assert.go b/server/pkg/biz/assert.go
index b85d7970..74a5e76f 100644
--- a/server/pkg/biz/assert.go
+++ b/server/pkg/biz/assert.go
@@ -17,7 +17,7 @@ import (
func ErrIsNil(err error, msgAndParams ...any) {
if err != nil {
if len(msgAndParams) == 0 {
- panic(err)
+ panic(errorx.NewBiz(err.Error()))
}
panic(errorx.NewBiz(fmt.Sprintf(msgAndParams[0].(string), msgAndParams[1:]...)))
diff --git a/server/pkg/config/aes.go b/server/pkg/config/aes.go
index 6f718ea0..4e823901 100644
--- a/server/pkg/config/aes.go
+++ b/server/pkg/config/aes.go
@@ -20,11 +20,11 @@ func (a *Aes) DecryptBase64(data string) ([]byte, error) {
return cryptox.AesDecryptBase64(data, []byte(a.Key))
}
-func (j *Aes) Valid() {
- if j.Key == "" {
+func (a *Aes) Valid() {
+ if a.Key == "" {
return
}
- aesKeyLen := len(j.Key)
+ aesKeyLen := len(a.Key)
assert.IsTrue(aesKeyLen == 16 || aesKeyLen == 24 || aesKeyLen == 32,
fmt.Sprintf("config.yml之 [aes.key] 长度需为16、24、32位长度, 当前为%d位", aesKeyLen))
}
diff --git a/server/pkg/gormx/gormx.go b/server/pkg/gormx/gormx.go
index 72e4e549..61a8eed5 100644
--- a/server/pkg/gormx/gormx.go
+++ b/server/pkg/gormx/gormx.go
@@ -146,12 +146,12 @@ func BatchInsertWithDb[T any](db *gorm.DB, models []T) error {
// 根据id更新model,更新字段为model中不为空的值,即int类型不为0,ptr类型不为nil这类字段值
// @param model 数据库映射实体模型
-func UpdateById(model any) error {
- return UpdateByIdWithDb(global.Db, model)
+func UpdateById(model any, columns ...string) error {
+ return UpdateByIdWithDb(global.Db, model, columns...)
}
-func UpdateByIdWithDb(db *gorm.DB, model any) error {
- return db.Model(model).Updates(model).Error
+func UpdateByIdWithDb(db *gorm.DB, model any, columns ...string) error {
+ return db.Model(model).Select(columns).Updates(model).Error
}
// 根据实体条件,更新参数udpateFields指定字段
diff --git a/server/pkg/queue/delay_queue.go b/server/pkg/queue/delay_queue.go
index 59278c15..2a34245c 100644
--- a/server/pkg/queue/delay_queue.go
+++ b/server/pkg/queue/delay_queue.go
@@ -42,6 +42,28 @@ func NewDelayQueue[T Delayable](cap int) *DelayQueue[T] {
}
}
+func (s *DelayQueue[T]) TryDequeue() (T, bool) {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if elm, ok := s.priorityQueue.Peek(0); ok {
+ delay := elm.GetDeadline().Sub(time.Now())
+ if delay < minTimerDelay {
+ // 无需延迟,头部元素出队后直接返回
+ _, _ = s.dequeue()
+ return elm, true
+ }
+ }
+ return s.zero, false
+}
+
+func (s *DelayQueue[T]) TryEnqueue(val T) bool {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ return s.enqueue(val)
+}
+
func (s *DelayQueue[T]) Dequeue(ctx context.Context) (T, bool) {
// 出队锁:避免因重复获取队列头部同一元素降低性能
select {
@@ -64,7 +86,6 @@ func (s *DelayQueue[T]) Dequeue(ctx context.Context) (T, bool) {
// 接收直接转发的不需要延迟的新元素
select {
case elm := <-s.transferChan:
- delete(s.elmMap, elm.GetId())
s.mutex.Unlock()
return elm, true
default:
@@ -78,7 +99,6 @@ func (s *DelayQueue[T]) Dequeue(ctx context.Context) (T, bool) {
if delay < minTimerDelay {
// 无需延迟,头部元素出队后直接返回
_, _ = s.dequeue()
- delete(s.elmMap, elm.GetId())
s.mutex.Unlock()
return elm, ok
}
@@ -122,6 +142,7 @@ func (s *DelayQueue[T]) dequeue() (T, bool) {
if !ok {
return s.zero, false
}
+ delete(s.elmMap, elm.GetId())
select {
case s.dequeuedSignal <- struct{}{}:
default:
@@ -133,6 +154,7 @@ func (s *DelayQueue[T]) enqueue(val T) bool {
if ok := s.priorityQueue.Enqueue(val); !ok {
return false
}
+ s.elmMap[val.GetId()] = val
select {
case s.enqueuedSignal <- struct{}{}:
default:
@@ -156,7 +178,6 @@ func (s *DelayQueue[T]) Enqueue(ctx context.Context, val T) bool {
// 如果队列未满,入队后直接返回
if !s.priorityQueue.IsFull() {
- s.elmMap[val.GetId()] = val
s.enqueue(val)
s.mutex.Unlock()
return true
diff --git a/server/pkg/starter/run.go b/server/pkg/starter/run.go
index bd8b3903..da3c9818 100644
--- a/server/pkg/starter/run.go
+++ b/server/pkg/starter/run.go
@@ -10,7 +10,6 @@ import (
"mayfly-go/pkg/validatorx"
"os"
"os/signal"
- "sync"
"syscall"
)
@@ -25,8 +24,6 @@ func RunWebServer() {
cancel()
}()
- runnerWG := &sync.WaitGroup{}
-
// 初始化config.yml配置文件映射信息或使用环境变量。并初始化系统日志相关配置
config.Init()
@@ -52,6 +49,4 @@ func RunWebServer() {
// 运行web服务
runWebServer(ctx)
-
- runnerWG.Wait()
}
diff --git a/server/pkg/starter/web-server.go b/server/pkg/starter/web-server.go
index 6213d58e..c8994e4f 100644
--- a/server/pkg/starter/web-server.go
+++ b/server/pkg/starter/web-server.go
@@ -5,6 +5,7 @@ import (
"errors"
"github.com/gin-gonic/gin"
"mayfly-go/initialize"
+ "mayfly-go/internal/db/application"
"mayfly-go/pkg/config"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/req"
@@ -40,6 +41,8 @@ func runWebServer(ctx context.Context) {
if err != nil {
logx.Errorf("Failed to Shutdown HTTP Server: %v", err)
}
+ closeDbTasks()
+ // todo: close backupApp and restoreApp
}()
confSrv := config.Conf.Server
@@ -56,3 +59,18 @@ func runWebServer(ctx context.Context) {
logx.Errorf("Failed to Start HTTP Server: %v", err)
}
}
+
+func closeDbTasks() {
+ restoreApp := application.GetDbRestoreApp()
+ if restoreApp != nil {
+ restoreApp.Close()
+ }
+ binlogApp := application.GetDbBinlogApp()
+ if binlogApp != nil {
+ binlogApp.Close()
+ }
+ backupApp := application.GetDbBackupApp()
+ if backupApp != nil {
+ backupApp.Close()
+ }
+}
diff --git a/server/pkg/utils/anyx/anyx.go b/server/pkg/utils/anyx/anyx.go
index 8857c54d..77a2d14d 100644
--- a/server/pkg/utils/anyx/anyx.go
+++ b/server/pkg/utils/anyx/anyx.go
@@ -120,3 +120,16 @@ func ToString(value any) string {
return string(newValue)
}
}
+
+// DeepZero 初始化对象
+// 如 T 为基本类型或结构体,则返回零值
+// 如 T 为指向基本类型或结构体的指针,则返回指向零值的指针
+func DeepZero[T any]() T {
+ var data T
+ typ := reflect.TypeOf(data)
+ kind := typ.Kind()
+ if kind == reflect.Pointer {
+ return reflect.New(typ.Elem()).Interface().(T)
+ }
+ return data
+}
diff --git a/server/pkg/utils/anyx/anyx_test.go b/server/pkg/utils/anyx/anyx_test.go
new file mode 100644
index 00000000..e3d5b45a
--- /dev/null
+++ b/server/pkg/utils/anyx/anyx_test.go
@@ -0,0 +1,14 @@
+package anyx
+
+import (
+ "github.com/stretchr/testify/assert"
+ "testing"
+ "time"
+)
+
+func TestDeepZero(t *testing.T) {
+ assert.Zero(t, DeepZero[int]())
+ assert.Zero(t, *DeepZero[*int]())
+ assert.Zero(t, DeepZero[time.Time]())
+ assert.Zero(t, *DeepZero[*time.Time]())
+}
diff --git a/server/pkg/utils/cryptox/cryptox.go b/server/pkg/utils/cryptox/cryptox.go
index 13db4091..412d001c 100644
--- a/server/pkg/utils/cryptox/cryptox.go
+++ b/server/pkg/utils/cryptox/cryptox.go
@@ -272,5 +272,9 @@ func pkcs7UnPadding(data []byte) ([]byte, error) {
}
//获取填充的个数
unPadding := int(data[length-1])
+ // todo fix: slice bounds out of range
+ if unPadding > length {
+ return nil, errors.New("解密字符串时去除填充个数超出字符串长度")
+ }
return data[:(length - unPadding)], nil
}
diff --git a/server/pkg/utils/stringx/stringx.go b/server/pkg/utils/stringx/stringx.go
index 7486c35e..4e231b3e 100644
--- a/server/pkg/utils/stringx/stringx.go
+++ b/server/pkg/utils/stringx/stringx.go
@@ -115,3 +115,17 @@ func ReverStrTemplate(temp, str string, res map[string]any) {
ReverStrTemplate(next, Trim(SubString(str, UnicodeIndex(str, value)+Len(value), Len(str))), res)
}
}
+
+func TruncateStr(s string, length int) string {
+ if length >= len(s) {
+ return s
+ }
+ var last int
+ for i := range s {
+ if i > length {
+ break
+ }
+ last = i
+ }
+ return s[:last]
+}
diff --git a/server/pkg/utils/stringx/stringx_test.go b/server/pkg/utils/stringx/stringx_test.go
new file mode 100644
index 00000000..36c1bb0e
--- /dev/null
+++ b/server/pkg/utils/stringx/stringx_test.go
@@ -0,0 +1,32 @@
+package stringx
+
+import (
+ "github.com/stretchr/testify/require"
+ "strconv"
+ "testing"
+)
+
+func TestTruncateStr(t *testing.T) {
+ testCases := []struct {
+ data string
+ length int
+ want string
+ }{
+ {"123一二三", 0, ""},
+ {"123一二三", 1, "1"},
+ {"123一二三", 3, "123"},
+ {"123一二三", 4, "123"},
+ {"123一二三", 5, "123"},
+ {"123一二三", 6, "123一"},
+ {"123一二三", 7, "123一"},
+ {"123一二三", 11, "123一二"},
+ {"123一二三", 12, "123一二三"},
+ {"123一二三", 13, "123一二三"},
+ }
+ for _, tc := range testCases {
+ t.Run(strconv.Itoa(tc.length), func(t *testing.T) {
+ got := TruncateStr(tc.data, tc.length)
+ require.Equal(t, tc.want, got)
+ })
+ }
+}
diff --git a/server/pkg/utils/timex/timex.go b/server/pkg/utils/timex/timex.go
index 2a940316..e787a2ac 100644
--- a/server/pkg/utils/timex/timex.go
+++ b/server/pkg/utils/timex/timex.go
@@ -1,9 +1,60 @@
package timex
-import "time"
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "time"
+)
const DefaultDateTimeFormat = "2006-01-02 15:04:05"
func DefaultFormat(time time.Time) string {
return time.Format(DefaultDateTimeFormat)
}
+
+func NewNullTime(t time.Time) NullTime {
+ return NullTime{
+ NullTime: sql.NullTime{
+ Time: t,
+ Valid: !t.IsZero(),
+ },
+ }
+}
+
+type NullTime struct {
+ sql.NullTime
+}
+
+func (nt *NullTime) UnmarshalJSON(bytes []byte) error {
+ if len(bytes) == 0 {
+ nt.NullTime = sql.NullTime{}
+ return nil
+ }
+ var t time.Time
+ if err := json.Unmarshal(bytes, &t); err != nil {
+ return err
+ }
+ if t.IsZero() {
+ nt.NullTime = sql.NullTime{}
+ return nil
+ }
+ nt.NullTime = sql.NullTime{
+ Valid: true,
+ Time: t,
+ }
+ return nil
+}
+
+func (nt *NullTime) MarshalJSON() ([]byte, error) {
+ if !nt.Valid || nt.Time.IsZero() {
+ return json.Marshal(nil)
+ }
+ return json.Marshal(nt.Time)
+}
+
+func SleepWithContext(ctx context.Context, d time.Duration) {
+ ctx, cancel := context.WithTimeout(ctx, d)
+ <-ctx.Done()
+ cancel()
+}
diff --git a/server/pkg/utils/timex/timex_test.go b/server/pkg/utils/timex/timex_test.go
new file mode 100644
index 00000000..e3238e48
--- /dev/null
+++ b/server/pkg/utils/timex/timex_test.go
@@ -0,0 +1,77 @@
+package timex
+
+import (
+ "github.com/stretchr/testify/require"
+ "testing"
+ "time"
+)
+
+func TestNullTime_UnmarshalJSON(t *testing.T) {
+ zero := time.Time{}
+ now := time.Now()
+ bytesNow, err := now.MarshalJSON()
+ require.NoError(t, err)
+ tests := []struct {
+ name string
+ want NullTime
+ bytes []byte
+ wantErr error
+ }{
+ {
+ name: "zero",
+ want: NewNullTime(zero),
+ bytes: nil,
+ },
+ {
+ name: "now",
+ want: NewNullTime(now),
+ bytes: bytesNow,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := &NullTime{}
+ err := got.UnmarshalJSON(tt.bytes)
+ require.ErrorIs(t, err, tt.wantErr)
+ if err != nil {
+ return
+ }
+ require.Equal(t, tt.want.Valid, got.Valid)
+ require.True(t, got.Time.Equal(tt.want.Time))
+ })
+ }
+}
+
+func TestNullTime_MarshalJSON(t *testing.T) {
+ zero := time.Time{}
+ now := time.Now()
+ bytes, err := now.MarshalJSON()
+ require.NoError(t, err)
+ tests := []struct {
+ name string
+ nullTime NullTime
+ want []byte
+ wantErr error
+ }{
+ {
+ name: "zero",
+ nullTime: NewNullTime(zero),
+ want: []byte("null"),
+ },
+ {
+ name: "now",
+ nullTime: NewNullTime(now),
+ want: bytes,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.nullTime.MarshalJSON()
+ require.ErrorIs(t, err, tt.wantErr)
+ if err != nil {
+ return
+ }
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
diff --git a/server/resources/script/sql/mayfly-go.sql b/server/resources/script/sql/mayfly-go.sql
index 06648777..174198c4 100644
--- a/server/resources/script/sql/mayfly-go.sql
+++ b/server/resources/script/sql/mayfly-go.sql
@@ -479,6 +479,8 @@ INSERT INTO `t_sys_config` (name, `key`, params, value, remark, permission, crea
INSERT INTO `t_sys_config` (name, `key`, params, value, remark, create_time, creator_id, creator, update_time, modifier_id, modifier)VALUES ('数据库查询最大结果集', 'DbQueryMaxCount', '[]', '200', '允许sql查询的最大结果集数。注: 0=不限制', '2023-02-11 14:29:03', 1, 'admin', '2023-02-11 14:40:56', 1, 'admin');
INSERT INTO `t_sys_config` (name, `key`, params, value, remark, create_time, creator_id, creator, update_time, modifier_id, modifier)VALUES ('数据库是否记录查询SQL', 'DbSaveQuerySQL', '[]', '0', '1: 记录、0:不记录', '2023-02-11 16:07:14', 1, 'admin', '2023-02-11 16:44:17', 1, 'admin');
INSERT INTO `t_sys_config` (name, `key`, params, value, remark, permission, create_time, creator_id, creator, update_time, modifier_id, modifier, is_deleted, delete_time) VALUES('机器相关配置', 'MachineConfig', '[{"name":"终端回放存储路径","model":"terminalRecPath","placeholder":"终端回放存储路径"},{"name":"uploadMaxFileSize","model":"uploadMaxFileSize","placeholder":"允许上传的最大文件大小(1MB\\\\2GB等)"}]', '{"terminalRecPath":"./rec","uploadMaxFileSize":"1GB"}', '机器相关配置,如终端回放路径等', 'admin,', '2023-07-13 16:26:44', 1, 'admin', '2023-11-09 22:01:31', 1, 'admin', 0, NULL);
+INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('Mysql可执行文件', 'MysqlBin', '[{"model":"path","name":"路径","placeholder":"可执行文件路径","required":true},{"model":"mysql","name":"mysql","placeholder":"mysql命令路径(空则为 路径/mysql)","required":false},{"model":"mysqldump","name":"mysqldump","placeholder":"mysqldump命令路径(空则为 路径/mysqldump)","required":false},{"model":"mysqlbinlog","name":"mysqlbinlog","placeholder":"mysqlbinlog命令路径(空则为 路径/mysqlbinlog)","required":false}]', '{"mysql":"","mysqldump":"","mysqlbinlog":"","path":""}', '', 'admin,', '2023-12-29 10:01:33', 1, 'admin', '2023-12-29 13:34:40', 1, 'admin', 0, NULL);
+INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('数据库备份恢复', 'DbBackupRestore', '[{"model":"backupPath","name":"备份路径","placeholder":"备份文件存储路径"}]', '{"backupPath":"./db/backup"}', '', 'admin,', '2023-12-29 09:55:26', 1, 'admin', '2023-12-29 15:45:24', 1, 'admin', 0, NULL);
COMMIT;
-- ----------------------------
@@ -865,7 +867,6 @@ CREATE TABLE `t_db_backup` (
`interval` bigint(20) DEFAULT NULL COMMENT '备份周期',
`start_time` datetime DEFAULT NULL COMMENT '首次备份时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
- `finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次备份状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次备份结果',
`last_time` datetime DEFAULT NULL COMMENT '上次备份时间',
@@ -917,7 +918,6 @@ CREATE TABLE `t_db_restore` (
`interval` bigint(20) DEFAULT NULL COMMENT '恢复周期',
`start_time` datetime DEFAULT NULL COMMENT '首次恢复时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
- `finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次恢复状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次恢复结果',
`last_time` datetime DEFAULT NULL COMMENT '上次恢复时间',
@@ -959,9 +959,6 @@ DROP TABLE IF EXISTS `t_db_binlog`;
CREATE TABLE `t_db_binlog` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
`db_instance_id` bigint(20) unsigned NOT NULL COMMENT '数据库实例ID',
- `interval` bigint(20) DEFAULT NULL COMMENT '下载周期',
- `start_time` datetime DEFAULT NULL COMMENT '首次下载时间',
- `enabled` tinyint(1) DEFAULT NULL COMMENT '会否启用',
`last_status` bigint(20) DEFAULT NULL COMMENT '上次下载状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次下载结果',
`last_time` datetime DEFAULT NULL COMMENT '上次下载时间',
diff --git a/server/resources/script/sql/v1.6.3.sql b/server/resources/script/sql/v1.6.3.sql
index a82a4524..7f974596 100644
--- a/server/resources/script/sql/v1.6.3.sql
+++ b/server/resources/script/sql/v1.6.3.sql
@@ -11,7 +11,6 @@ CREATE TABLE `t_db_backup` (
`interval` bigint(20) DEFAULT NULL COMMENT '备份周期',
`start_time` datetime DEFAULT NULL COMMENT '首次备份时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
- `finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次备份状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次备份结果',
`last_time` datetime DEFAULT NULL COMMENT '上次备份时间',
@@ -63,7 +62,6 @@ CREATE TABLE `t_db_restore` (
`interval` bigint(20) DEFAULT NULL COMMENT '恢复周期',
`start_time` datetime DEFAULT NULL COMMENT '首次恢复时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
- `finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次恢复状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次恢复结果',
`last_time` datetime DEFAULT NULL COMMENT '上次恢复时间',
@@ -105,9 +103,6 @@ DROP TABLE IF EXISTS `t_db_binlog`;
CREATE TABLE `t_db_binlog` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
`db_instance_id` bigint(20) unsigned NOT NULL COMMENT '数据库实例ID',
- `interval` bigint(20) DEFAULT NULL COMMENT '下载周期',
- `start_time` datetime DEFAULT NULL COMMENT '首次下载时间',
- `enabled` tinyint(1) DEFAULT NULL COMMENT '会否启用',
`last_status` bigint(20) DEFAULT NULL COMMENT '上次下载状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次下载结果',
`last_time` datetime DEFAULT NULL COMMENT '上次下载时间',
@@ -140,3 +135,7 @@ CREATE TABLE `t_db_binlog_history` (
PRIMARY KEY (`id`),
KEY `idx_db_instance_id` (`db_instance_id`) USING BTREE
) ENGINE=InnoDB AUTO_INCREMENT=17 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;
+
+
+INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('Mysql可执行文件', 'MysqlBin', '[{"model":"path","name":"路径","placeholder":"可执行文件路径","required":true},{"model":"mysql","name":"mysql","placeholder":"mysql命令路径(空则为 路径/mysql)","required":false},{"model":"mysqldump","name":"mysqldump","placeholder":"mysqldump命令路径(空则为 路径/mysqldump)","required":false},{"model":"mysqlbinlog","name":"mysqlbinlog","placeholder":"mysqlbinlog命令路径(空则为 路径/mysqlbinlog)","required":false}]', '{"mysql":"","mysqldump":"","mysqlbinlog":"","path":""}', '', 'admin,', '2023-12-29 10:01:33', 1, 'admin', '2023-12-29 13:34:40', 1, 'admin', 0, NULL);
+INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('数据库备份恢复', 'DbBackupRestore', '[{"model":"backupPath","name":"备份路径","placeholder":"备份文件存储路径"}]', '{"backupPath":"./db/backup"}', '', 'admin,', '2023-12-29 09:55:26', 1, 'admin', '2023-12-29 15:45:24', 1, 'admin', 0, NULL);
\ No newline at end of file