重构数据库备份与恢复模块 (#80)

* fix: 保存 LastResult 时截断字符串过长部分,以避免数据库报错

* refactor: 新增 entity.DbTaskBase 和 persistence.dbTaskBase, 用于实现数据库备份和恢复任务处理相关部分

* fix: aeskey变更后,解密密码出现数组越界访问错误

* fix: 时间属性为零值时,保存到 mysql 数据库报错

* refactor db.infrastructure.service.scheduler

* feat: 实现立即备份功能

* refactor db.infrastructure.service.db_instance

* refactor: 从数据库中获取数据库备份目录、mysql文件路径等配置信息

* fix: 数据库备份和恢复问题

* fix: 修改 .gitignore 文件,忽略数据库备份目录和数据库程序目录
This commit is contained in:
kanzihuang
2024-01-05 08:55:34 +08:00
committed by GitHub
parent 76fd6675b5
commit ae3d2659aa
83 changed files with 1819 additions and 1688 deletions

View File

@@ -24,7 +24,7 @@ export function dateStrFormat(fmt: string, dateStr: string) {
}
export function dateFormat(dateStr: string) {
if (dateStr?.startsWith('0001-01-01', 0)) {
if (!dateStr) {
return '';
}
return dateFormat2('yyyy-MM-dd HH:mm:ss', new Date(dateStr));

View File

@@ -24,9 +24,12 @@
</template>
<template #action="{ data }">
<el-button @click="editDbBackup(data)" type="primary" link>编辑</el-button>
<el-button @click="enableDbBackup(data)" v-if="!data.enabled" type="success" link>启用</el-button>
<el-button @click="disableDbBackup(data)" v-if="data.enabled" type="warning" link></el-button>
<div style="text-align: left">
<el-button @click="editDbBackup(data)" type="primary" link>编辑</el-button>
<el-button v-if="!data.enabled" @click="enableDbBackup(data)" type="primary" link></el-button>
<el-button v-if="data.enabled" @click="disableDbBackup(data)" type="primary" link>禁用</el-button>
<el-button v-if="data.enabled" @click="startDbBackup(data)" type="primary" link>立即备份</el-button>
</div>
</template>
</page-table>
@@ -150,5 +153,20 @@ const disableDbBackup = async (data: any) => {
await search();
ElMessage.success('禁用成功');
};
const startDbBackup = async (data: any) => {
let backupId: String;
if (data) {
backupId = data.id;
} else if (state.selectedData.length > 0) {
backupId = state.selectedData.map((x: any) => x.id).join(' ');
} else {
ElMessage.error('请选择需要启用的备份任务');
return;
}
await dbApi.startDbBackup.request({ dbId: props.dbId, backupId: backupId });
await search();
ElMessage.success('备份任务启动成功');
};
</script>
<style lang="scss"></style>

View File

@@ -199,10 +199,10 @@ const init = async (data: any) => {
state.form.dbBackupId = data.dbBackupId;
state.form.dbBackupHistoryId = data.dbBackupHistoryId;
state.form.dbBackupHistoryName = data.dbBackupHistoryName;
if (data.dbBackupHistoryId > 0) {
state.restoreMode = 'backup-history';
} else {
if (data.pointInTime) {
state.restoreMode = 'point-in-time';
} else {
state.restoreMode = 'backup-history';
}
state.history = {
dbBackupId: data.dbBackupId,
@@ -232,35 +232,33 @@ const getDbNamesWithoutRestore = async () => {
const btnOk = async () => {
restoreForm.value.validate(async (valid: any) => {
if (!valid) {
if (valid) {
if (state.restoreMode == 'point-in-time') {
state.form.dbBackupId = 0;
state.form.dbBackupHistoryId = 0;
state.form.dbBackupHistoryName = '';
} else {
state.form.pointInTime = null;
}
state.form.repeated = false;
const reqForm = { ...state.form };
let api = dbApi.createDbRestore;
if (props.data) {
api = dbApi.saveDbRestore;
}
api.request(reqForm).then(() => {
ElMessage.success('保存成功');
emit('val-change', state.form);
state.btnLoading = true;
setTimeout(() => {
state.btnLoading = false;
}, 1000);
cancel();
});
} else {
ElMessage.error('请正确填写信息');
return false;
}
if (state.restoreMode == 'point-in-time') {
state.form.dbBackupId = 0;
state.form.dbBackupHistoryId = 0;
state.form.dbBackupHistoryName = '';
} else {
state.form.pointInTime = '0001-01-01T00:00:00Z';
}
state.form.repeated = false;
const reqForm = { ...state.form };
let api = dbApi.createDbRestore;
if (props.data) {
api = dbApi.saveDbRestore;
}
try {
state.btnLoading = true;
await api.request(reqForm);
ElMessage.success('保存成功');
emit('val-change', state.form);
cancel();
} finally {
state.btnLoading = false;
}
});
};

View File

@@ -25,8 +25,8 @@
<template #action="{ data }">
<el-button @click="showDbRestore(data)" type="primary" link>详情</el-button>
<el-button @click="enableDbRestore(data)" type="primary" link>启用</el-button>
<el-button @click="disableDbRestore(data)" type="primary" link>禁用</el-button>
<el-button @click="enableDbRestore(data)" v-if="!data.enabled" type="primary" link>启用</el-button>
<el-button @click="disableDbRestore(data)" v-if="data.enabled" type="primary" link>禁用</el-button>
</template>
</page-table>
@@ -42,10 +42,10 @@
<el-dialog v-model="infoDialog.visible" title="数据库恢复">
<el-descriptions :column="1" border>
<el-descriptions-item :span="1" label="数据库名称">{{ infoDialog.data.dbName }}</el-descriptions-item>
<el-descriptions-item v-if="!infoDialog.data.dbBackupHistoryName" :span="1" label="恢复时间点">{{
<el-descriptions-item v-if="infoDialog.data.pointInTime" :span="1" label="恢复时间点">{{
dateFormat(infoDialog.data.pointInTime)
}}</el-descriptions-item>
<el-descriptions-item v-if="infoDialog.data.dbBackupHistoryName" :span="1" label="数据库备份">{{
<el-descriptions-item v-if="!infoDialog.data.pointInTime" :span="1" label="数据库备份">{{
infoDialog.data.dbBackupHistoryName
}}</el-descriptions-item>
<el-descriptions-item :span="1" label="开始时间">{{ dateFormat(infoDialog.data.startTime) }}</el-descriptions-item>

View File

@@ -47,6 +47,7 @@ export const dbApi = {
getDbNamesWithoutBackup: Api.newGet('/dbs/{dbId}/db-names-without-backup'),
enableDbBackup: Api.newPut('/dbs/{dbId}/backups/{backupId}/enable'),
disableDbBackup: Api.newPut('/dbs/{dbId}/backups/{backupId}/disable'),
startDbBackup: Api.newPut('/dbs/{dbId}/backups/{backupId}/start'),
saveDbBackup: Api.newPut('/dbs/{dbId}/backups/{id}'),
getDbBackupHistories: Api.newGet('/dbs/{dbId}/backup-histories'),

9
server/.gitignore vendored
View File

@@ -1,4 +1,11 @@
static/static
/static/static/
config.yml
mayfly_rsa
mayfly_rsa.pub
# 数据库备份目录
/db/backup/
# mysql 程序目录
/db/mysql/
# mariadb 程序目录
/db/mariadb/

View File

@@ -107,7 +107,7 @@ func (a *AccountLogin) OtpVerify(rc *req.Ctx) {
if otpStatus == OtpStatusNoReg {
update := &sysentity.Account{OtpSecret: otpSecret}
update.Id = accountId
update.OtpSecretEncrypt()
biz.ErrIsNil(update.OtpSecretEncrypt())
biz.ErrIsNil(a.AccountApp.Update(context.Background(), update))
}

View File

@@ -65,7 +65,7 @@ func LastLoginCheck(account *sysentity.Account, accountLoginSecurity *config.Acc
}
func useOtp(account *sysentity.Account, otpIssuer, accessToken string) (*OtpVerifyInfo, string, string) {
account.OtpSecretDecrypt()
biz.ErrIsNil(account.OtpSecretDecrypt())
otpSecret := account.OtpSecret
// 修改状态为已注册
otpStatus := OtpStatusReg

View File

@@ -1,7 +1,6 @@
package utils
import (
"mayfly-go/pkg/biz"
"mayfly-go/pkg/config"
"regexp"
)
@@ -27,30 +26,34 @@ func CheckAccountPasswordLever(ps string) bool {
}
// 使用config.yml的aes.key进行密码加密
func PwdAesEncrypt(password string) string {
func PwdAesEncrypt(password string) (string, error) {
if password == "" {
return ""
return "", nil
}
aes := config.Conf.Aes
if aes.Key == "" {
return password
return password, nil
}
encryptPwd, err := aes.EncryptBase64([]byte(password))
biz.ErrIsNilAppendErr(err, "密码加密失败: %s")
return encryptPwd
if err != nil {
return "", err
}
return encryptPwd, nil
}
// 使用config.yml的aes.key进行密码解密
func PwdAesDecrypt(encryptPwd string) string {
func PwdAesDecrypt(encryptPwd string) (string, error) {
if encryptPwd == "" {
return ""
return "", nil
}
aes := config.Conf.Aes
if aes.Key == "" {
return encryptPwd
return encryptPwd, nil
}
decryptPwd, err := aes.DecryptBase64(encryptPwd)
biz.ErrIsNilAppendErr(err, "密码解密失败: %s")
if err != nil {
return "", err
}
// 解密后的密码
return string(decryptPwd)
return string(decryptPwd), nil
}

View File

@@ -39,11 +39,11 @@ func (d *DbBackup) GetPageList(rc *req.Ctx) {
// Create 保存数据库备份任务
// @router /api/dbs/:dbId/backups [POST]
func (d *DbBackup) Create(rc *req.Ctx) {
form := &form.DbBackupForm{}
ginx.BindJsonAndValid(rc.GinCtx, form)
rc.ReqParam = form
backupForm := &form.DbBackupForm{}
ginx.BindJsonAndValid(rc.GinCtx, backupForm)
rc.ReqParam = backupForm
dbNames := strings.Fields(form.DbNames)
dbNames := strings.Fields(backupForm.DbNames)
biz.IsTrue(len(dbNames) > 0, "解析数据库备份任务失败:数据库名称未定义")
dbId := uint64(ginx.PathParamInt(rc.GinCtx, "dbId"))
@@ -54,14 +54,10 @@ func (d *DbBackup) Create(rc *req.Ctx) {
tasks := make([]*entity.DbBackup, 0, len(dbNames))
for _, dbName := range dbNames {
task := &entity.DbBackup{
DbTaskBase: entity.NewDbBTaskBase(true, backupForm.Repeated, backupForm.StartTime, backupForm.Interval),
DbName: dbName,
Name: form.Name,
StartTime: form.StartTime,
Interval: form.Interval,
Enabled: true,
Repeated: form.Repeated,
Name: backupForm.Name,
DbInstanceId: db.InstanceId,
LastTime: form.StartTime,
}
tasks = append(tasks, task)
}
@@ -71,17 +67,15 @@ func (d *DbBackup) Create(rc *req.Ctx) {
// Save 保存数据库备份任务
// @router /api/dbs/:dbId/backups/:backupId [PUT]
func (d *DbBackup) Save(rc *req.Ctx) {
form := &form.DbBackupForm{}
ginx.BindJsonAndValid(rc.GinCtx, form)
rc.ReqParam = form
backupForm := &form.DbBackupForm{}
ginx.BindJsonAndValid(rc.GinCtx, backupForm)
rc.ReqParam = backupForm
task := &entity.DbBackup{
Name: form.Name,
StartTime: form.StartTime,
Interval: form.Interval,
LastTime: form.StartTime,
}
task.Id = form.Id
task := &entity.DbBackup{}
task.Id = backupForm.Id
task.Name = backupForm.Name
task.StartTime = backupForm.StartTime
task.Interval = backupForm.Interval
biz.ErrIsNilAppendErr(d.DbBackupApp.Save(rc.MetaCtx, task), "保存数据库备份任务失败: %v")
}
@@ -125,6 +119,13 @@ func (d *DbBackup) Disable(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "禁用数据库备份任务失败: %v")
}
// Start 禁用数据库备份任务
// @router /api/dbs/:dbId/backups/:taskId/start [PUT]
func (d *DbBackup) Start(rc *req.Ctx) {
err := d.walk(rc, d.DbBackupApp.Start)
biz.ErrIsNilAppendErr(err, "运行数据库备份任务失败: %v")
}
// GetDbNamesWithoutBackup 获取未配置定时备份的数据库名称
// @router /api/dbs/:dbId/db-names-without-backup [GET]
func (d *DbBackup) GetDbNamesWithoutBackup(rc *req.Ctx) {

View File

@@ -38,9 +38,9 @@ func (d *DbRestore) GetPageList(rc *req.Ctx) {
// Create 保存数据库恢复任务
// @router /api/dbs/:dbId/restores [POST]
func (d *DbRestore) Create(rc *req.Ctx) {
form := &form.DbRestoreForm{}
ginx.BindJsonAndValid(rc.GinCtx, form)
rc.ReqParam = form
restoreForm := &form.DbRestoreForm{}
ginx.BindJsonAndValid(rc.GinCtx, restoreForm)
rc.ReqParam = restoreForm
dbId := uint64(ginx.PathParamInt(rc.GinCtx, "dbId"))
biz.IsTrue(dbId > 0, "无效的 dbId: %v", dbId)
@@ -48,16 +48,13 @@ func (d *DbRestore) Create(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "获取数据库信息失败: %v")
task := &entity.DbRestore{
DbName: form.DbName,
StartTime: form.StartTime,
Interval: form.Interval,
Enabled: true,
Repeated: form.Repeated,
DbTaskBase: entity.NewDbBTaskBase(true, restoreForm.Repeated, restoreForm.StartTime, restoreForm.Interval),
DbName: restoreForm.DbName,
DbInstanceId: db.InstanceId,
PointInTime: form.PointInTime,
DbBackupId: form.DbBackupId,
DbBackupHistoryId: form.DbBackupHistoryId,
DbBackupHistoryName: form.DbBackupHistoryName,
PointInTime: restoreForm.PointInTime,
DbBackupId: restoreForm.DbBackupId,
DbBackupHistoryId: restoreForm.DbBackupHistoryId,
DbBackupHistoryName: restoreForm.DbBackupHistoryName,
}
biz.ErrIsNilAppendErr(d.DbRestoreApp.Create(rc.MetaCtx, task), "添加数据库恢复任务失败: %v")
}
@@ -65,15 +62,14 @@ func (d *DbRestore) Create(rc *req.Ctx) {
// Save 保存数据库恢复任务
// @router /api/dbs/:dbId/restores/:restoreId [PUT]
func (d *DbRestore) Save(rc *req.Ctx) {
form := &form.DbRestoreForm{}
ginx.BindJsonAndValid(rc.GinCtx, form)
rc.ReqParam = form
restoreForm := &form.DbRestoreForm{}
ginx.BindJsonAndValid(rc.GinCtx, restoreForm)
rc.ReqParam = restoreForm
task := &entity.DbRestore{
StartTime: form.StartTime,
Interval: form.Interval,
}
task.Id = form.Id
task := &entity.DbRestore{}
task.Id = restoreForm.Id
task.StartTime = restoreForm.StartTime
task.Interval = restoreForm.Interval
biz.ErrIsNilAppendErr(d.DbRestoreApp.Save(rc.MetaCtx, task), "保存数据库恢复任务失败: %v")
}

View File

@@ -2,21 +2,22 @@ package form
import (
"encoding/json"
"mayfly-go/pkg/utils/timex"
"time"
)
// DbRestoreForm 数据库备份表单
type DbRestoreForm struct {
Id uint64 `json:"id"`
DbName string `binding:"required" json:"dbName"` // 数据库名
StartTime time.Time `binding:"required" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
PointInTime time.Time `json:"PointInTime"` // 指定时间
DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
IntervalDay uint64 `json:"intervalDay"` // 间隔天数: 为零表示单次执行,为正表示反复执行
Repeated bool `json:"repeated"` // 是否重复执行
Id uint64 `json:"id"`
DbName string `binding:"required" json:"dbName"` // 数据库名
StartTime time.Time `binding:"required" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
PointInTime timex.NullTime `json:"pointInTime"` // 指定时间
DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
IntervalDay uint64 `json:"intervalDay"` // 间隔天数: 为零表示单次执行,为正表示反复执行
Repeated bool `json:"repeated"` // 是否重复执行
}
func (restore *DbRestoreForm) UnmarshalJSON(data []byte) error {

View File

@@ -74,7 +74,7 @@ func (d *Instance) GetInstancePwd(rc *req.Ctx) {
instanceId := getInstanceId(rc.GinCtx)
instanceEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
biz.ErrIsNil(err, "获取数据库实例错误")
instanceEntity.PwdDecrypt()
biz.ErrIsNil(instanceEntity.PwdDecrypt())
rc.ResData = instanceEntity.Password
}
@@ -105,7 +105,7 @@ func (d *Instance) GetDatabaseNames(rc *req.Ctx) {
instanceId := getInstanceId(rc.GinCtx)
instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
biz.ErrIsNil(err, "获取数据库实例错误")
instance.PwdDecrypt()
biz.ErrIsNil(instance.PwdDecrypt())
res, err := d.InstanceApp.GetDatabases(instance)
biz.ErrIsNil(err)
rc.ResData = res

View File

@@ -2,27 +2,28 @@ package vo
import (
"encoding/json"
"mayfly-go/pkg/utils/timex"
"time"
)
// DbBackupHistory 数据库备份任务
// DbBackup 数据库备份任务
type DbBackup struct {
Id uint64 `json:"id"`
DbName string `json:"dbName"` // 数据库名
CreateTime time.Time `json:"createTime"` // 创建时间: 2023-11-08 02:00:00
StartTime time.Time `json:"startTime"` // 开始时间: 2023-11-08 02:00:00
Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数: 为零表示单次执行,为正表示反复执行
Enabled bool `json:"enabled"` // 是否启用
LastTime time.Time `json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
LastStatus string `json:"lastStatus"` // 最近一次执行状态
LastResult string `json:"lastResult"` // 最近一次执行结果
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
Name string `json:"name"` // 备份任务名称
Id uint64 `json:"id"`
DbName string `json:"dbName"` // 数据库名
CreateTime time.Time `json:"createTime"` // 创建时间
StartTime time.Time `json:"startTime"` // 开始时间
Interval time.Duration `json:"-"` // 间隔时间
IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数
Enabled bool `json:"enabled"` // 是否启用
LastTime timex.NullTime `json:"lastTime"` // 最近一次执行时间
LastStatus string `json:"lastStatus"` // 最近一次执行状态
LastResult string `json:"lastResult"` // 最近一次执行结果
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
Name string `json:"name"` // 备份任务名称
}
func (restore *DbBackup) MarshalJSON() ([]byte, error) {
func (backup *DbBackup) MarshalJSON() ([]byte, error) {
type dbBackup DbBackup
restore.IntervalDay = uint64(restore.Interval / time.Hour / 24)
return json.Marshal((*dbBackup)(restore))
backup.IntervalDay = uint64(backup.Interval / time.Hour / 24)
return json.Marshal((*dbBackup)(backup))
}

View File

@@ -2,25 +2,26 @@ package vo
import (
"encoding/json"
"mayfly-go/pkg/utils/timex"
"time"
)
// DbRestore 数据库备份任务
type DbRestore struct {
Id uint64 `json:"id"`
DbName string `json:"dbName"` // 数据库名
StartTime time.Time `json:"startTime"` // 开始时间: 2023-11-08 02:00:00
Interval time.Duration `json:"-"` // 间隔时间: 为零表示单次执行,为正表示反复执行
IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数: 为零表示单次执行,为正表示反复执行
Enabled bool `json:"enabled"` // 是否启用
LastTime time.Time `json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
LastStatus string `json:"lastStatus"` // 最近一次执行状态
LastResult string `json:"lastResult"` // 最近一次执行结果
PointInTime time.Time `json:"pointInTime"` // 指定数据库恢复的时间点
DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
Id uint64 `json:"id"`
DbName string `json:"dbName"` // 数据库名
StartTime time.Time `json:"startTime"` // 开始时间
Interval time.Duration `json:"-"` // 间隔时间
IntervalDay uint64 `json:"intervalDay" gorm:"-"` // 间隔天数
Enabled bool `json:"enabled"` // 是否启用
LastTime timex.NullTime `json:"lastTime"` // 最近一次执行时间
LastStatus string `json:"lastStatus"` // 最近一次执行状态
LastResult string `json:"lastResult"` // 最近一次执行结果
PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间点
DbBackupId uint64 `json:"dbBackupId"` // 数据库备份任务ID
DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 数据库备份历史ID
DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库备份历史名称
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
func (restore *DbRestore) MarshalJSON() ([]byte, error) {

View File

@@ -17,6 +17,7 @@ var (
dbBackupHistoryApp *DbBackupHistoryApp
dbRestoreApp *DbRestoreApp
dbRestoreHistoryApp *DbRestoreHistoryApp
dbBinlogApp *DbBinlogApp
)
var repositories *repository.Repositories
@@ -39,16 +40,26 @@ func Init() {
dbSqlExecApp = newDbSqlExecApp(persistence.GetDbSqlExecRepo())
dbSqlApp = newDbSqlApp(persistence.GetDbSqlRepo())
dbBackupApp, err = newDbBackupApp(repositories)
dbBackupApp, err = newDbBackupApp(repositories, dbApp)
if err != nil {
panic(fmt.Sprintf("初始化 dbBackupApp 失败: %v", err))
}
dbRestoreApp, err = newDbRestoreApp(repositories)
dbRestoreApp, err = newDbRestoreApp(repositories, dbApp)
if err != nil {
panic(fmt.Sprintf("初始化 dbRestoreApp 失败: %v", err))
}
dbBackupHistoryApp, err = newDbBackupHistoryApp(repositories)
if err != nil {
panic(fmt.Sprintf("初始化 dbBackupHistoryApp 失败: %v", err))
}
dbRestoreHistoryApp, err = newDbRestoreHistoryApp(repositories)
if err != nil {
panic(fmt.Sprintf("初始化 dbRestoreHistoryApp 失败: %v", err))
}
dbBinlogApp, err = newDbBinlogApp(repositories, dbApp)
if err != nil {
panic(fmt.Sprintf("初始化 dbBinlogApp 失败: %v", err))
}
})()
}
@@ -83,3 +94,7 @@ func GetDbRestoreApp() *DbRestoreApp {
func GetDbRestoreHistoryApp() *DbRestoreHistoryApp {
return dbRestoreHistoryApp
}
func GetDbBinlogApp() *DbBinlogApp {
return dbBinlogApp
}

View File

@@ -156,7 +156,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
checkDb := dbName
// 兼容pgsql/dm db/schema模式
if instance.Type == dbm.DbTypePostgres || instance.Type == dbm.DM {
if dbm.DbTypePostgres.Equal(instance.Type) || dbm.DM.Equal(instance.Type) {
ss := strings.Split(dbName, "/")
if len(ss) > 1 {
checkDb = ss[0]
@@ -167,7 +167,9 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
}
// 密码解密
instance.PwdDecrypt()
if err := instance.PwdDecrypt(); err != nil {
return nil, errorx.NewBiz(err.Error())
}
return toDbInfo(instance, dbId, dbName, d.tagApp.ListTagPathByResource(consts.TagResourceTypeDb, db.Code)...), nil
})
}

View File

@@ -2,62 +2,155 @@ package application
import (
"context"
"encoding/binary"
"fmt"
"github.com/google/uuid"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/domain/service"
service2 "mayfly-go/internal/db/infrastructure/service"
"mayfly-go/pkg/model"
"time"
)
func newDbBackupApp(repositories *repository.Repositories) (*DbBackupApp, error) {
binlogSvc, err := service2.NewDbBinlogSvc(repositories)
if err != nil {
return nil, err
}
dbBackupSvc, err := service2.NewDbBackupSvc(repositories, binlogSvc)
if err != nil {
return nil, err
}
func newDbBackupApp(repositories *repository.Repositories, dbApp Db) (*DbBackupApp, error) {
app := &DbBackupApp{
repo: repositories.Backup,
dbBackupSvc: dbBackupSvc,
backupRepo: repositories.Backup,
instanceRepo: repositories.Instance,
backupHistoryRepo: repositories.BackupHistory,
dbApp: dbApp,
}
scheduler, err := newDbScheduler[*entity.DbBackup](
repositories.Backup,
withRunBackupTask(app))
if err != nil {
return nil, err
}
app.scheduler = scheduler
return app, nil
}
type DbBackupApp struct {
repo repository.DbBackup
dbBackupSvc service.DbBackupSvc
backupRepo repository.DbBackup
instanceRepo repository.Instance
backupHistoryRepo repository.DbBackupHistory
dbApp Db
scheduler *dbScheduler[*entity.DbBackup]
}
func (app *DbBackupApp) Close() {
app.scheduler.Close()
}
func (app *DbBackupApp) Create(ctx context.Context, tasks ...*entity.DbBackup) error {
return app.dbBackupSvc.AddTask(ctx, tasks...)
return app.scheduler.AddTask(ctx, tasks...)
}
func (app *DbBackupApp) Save(ctx context.Context, task *entity.DbBackup) error {
return app.dbBackupSvc.UpdateTask(ctx, task)
return app.scheduler.UpdateTask(ctx, task)
}
func (app *DbBackupApp) Delete(ctx context.Context, taskId uint64) error {
// todo: 删除数据库备份历史文件
return app.dbBackupSvc.DeleteTask(ctx, taskId)
return app.scheduler.DeleteTask(ctx, taskId)
}
func (app *DbBackupApp) Enable(ctx context.Context, taskId uint64) error {
return app.dbBackupSvc.EnableTask(ctx, taskId)
return app.scheduler.EnableTask(ctx, taskId)
}
func (app *DbBackupApp) Disable(ctx context.Context, taskId uint64) error {
return app.dbBackupSvc.DisableTask(ctx, taskId)
return app.scheduler.DisableTask(ctx, taskId)
}
func (app *DbBackupApp) Start(ctx context.Context, taskId uint64) error {
return app.scheduler.StartTask(ctx, taskId)
}
// GetPageList 分页获取数据库备份任务
func (app *DbBackupApp) GetPageList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.repo.GetDbBackupList(condition, pageParam, toEntity, orderBy...)
return app.backupRepo.GetDbBackupList(condition, pageParam, toEntity, orderBy...)
}
// GetDbNamesWithoutBackup 获取未配置定时备份的数据库名称
func (app *DbBackupApp) GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) {
return app.repo.GetDbNamesWithoutBackup(instanceId, dbNames)
return app.backupRepo.GetDbNamesWithoutBackup(instanceId, dbNames)
}
func withRunBackupTask(app *DbBackupApp) dbSchedulerOption[*entity.DbBackup] {
return func(scheduler *dbScheduler[*entity.DbBackup]) {
scheduler.RunTask = app.runTask
}
}
func (app *DbBackupApp) runTask(ctx context.Context, task *entity.DbBackup) error {
id, err := NewIncUUID()
if err != nil {
return err
}
history := &entity.DbBackupHistory{
Uuid: id.String(),
DbBackupId: task.Id,
DbInstanceId: task.DbInstanceId,
DbName: task.DbName,
}
conn, err := app.dbApp.GetDbConnByInstanceId(task.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
binlogInfo, err := dbProgram.Backup(ctx, history)
if err != nil {
return err
}
now := time.Now()
name := task.Name
if len(name) == 0 {
name = task.DbName
}
history.Name = fmt.Sprintf("%s[%s]", name, now.Format(time.DateTime))
history.CreateTime = now
history.BinlogFileName = binlogInfo.FileName
history.BinlogSequence = binlogInfo.Sequence
history.BinlogPosition = binlogInfo.Position
if err := app.backupHistoryRepo.Insert(ctx, history); err != nil {
return err
}
return nil
}
func NewIncUUID() (uuid.UUID, error) {
var uid uuid.UUID
now, seq, err := uuid.GetTime()
if err != nil {
return uid, err
}
timeHi := uint32((now >> 28) & 0xffffffff)
timeMid := uint16((now >> 12) & 0xffff)
timeLow := uint16(now & 0x0fff)
timeLow |= 0x1000 // Version 1
binary.BigEndian.PutUint32(uid[0:], timeHi)
binary.BigEndian.PutUint16(uid[4:], timeMid)
binary.BigEndian.PutUint16(uid[6:], timeLow)
binary.BigEndian.PutUint16(uid[8:], seq)
copy(uid[10:], uuid.NodeID())
return uid, nil
}
func newDbBackupHistoryApp(repositories *repository.Repositories) (*DbBackupHistoryApp, error) {
app := &DbBackupHistoryApp{
repo: repositories.BackupHistory,
}
return app, nil
}
type DbBackupHistoryApp struct {
repo repository.DbBackupHistory
}
// GetPageList 分页获取数据库备份历史
func (app *DbBackupHistoryApp) GetPageList(condition *entity.DbBackupHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.repo.GetHistories(condition, pageParam, toEntity, orderBy...)
}

View File

@@ -1,23 +0,0 @@
package application
import (
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/model"
)
func newDbBackupHistoryApp(repositories *repository.Repositories) (*DbBackupHistoryApp, error) {
app := &DbBackupHistoryApp{
repo: repositories.BackupHistory,
}
return app, nil
}
type DbBackupHistoryApp struct {
repo repository.DbBackupHistory
}
// GetPageList 分页获取数据库备份历史
func (app *DbBackupHistoryApp) GetPageList(condition *entity.DbBackupHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.repo.GetHistories(condition, pageParam, toEntity, orderBy...)
}

View File

@@ -0,0 +1,154 @@
package application
import (
"context"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/utils/timex"
"sync"
"time"
)
const (
binlogDownloadInterval = time.Minute * 15
)
type DbBinlogApp struct {
binlogRepo repository.DbBinlog
binlogHistoryRepo repository.DbBinlogHistory
backupRepo repository.DbBackup
backupHistoryRepo repository.DbBackupHistory
dbApp Db
context context.Context
cancel context.CancelFunc
waitGroup sync.WaitGroup
}
var (
binlogResult = map[entity.TaskStatus]string{
entity.TaskDelay: "等待备份BINLOG",
entity.TaskReady: "准备备份BINLOG",
entity.TaskReserved: "BINLOG备份中",
entity.TaskSuccess: "BINLOG备份成功",
entity.TaskFailed: "BINLOG备份失败",
}
)
func newDbBinlogApp(repositories *repository.Repositories, dbApp Db) (*DbBinlogApp, error) {
ctx, cancel := context.WithCancel(context.Background())
svc := &DbBinlogApp{
binlogRepo: repositories.Binlog,
binlogHistoryRepo: repositories.BinlogHistory,
backupRepo: repositories.Backup,
backupHistoryRepo: repositories.BackupHistory,
dbApp: dbApp,
context: ctx,
cancel: cancel,
}
svc.waitGroup.Add(1)
go svc.run()
return svc, nil
}
func (app *DbBinlogApp) runTask(ctx context.Context, backup *entity.DbBackup) error {
if err := app.AddTaskIfNotExists(ctx, entity.NewDbBinlog(backup.DbInstanceId)); err != nil {
return err
}
latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
binlogHistory, ok, err := app.binlogHistoryRepo.GetLatestHistory(backup.DbInstanceId)
if err != nil {
return err
}
if ok {
latestBinlogSequence = binlogHistory.Sequence
} else {
backupHistory, err := app.backupHistoryRepo.GetEarliestHistory(backup.DbInstanceId)
if err != nil {
return err
}
earliestBackupSequence = backupHistory.BinlogSequence
}
conn, err := app.dbApp.GetDbConnByInstanceId(backup.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
binlogFiles, err := dbProgram.FetchBinlogs(ctx, false, earliestBackupSequence, latestBinlogSequence)
if err == nil {
err = app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, backup.DbInstanceId, binlogFiles)
}
taskStatus := entity.TaskSuccess
if err != nil {
taskStatus = entity.TaskFailed
}
task := &entity.DbBinlog{}
task.Id = backup.DbInstanceId
return app.updateCurTask(ctx, taskStatus, err, task)
}
func (app *DbBinlogApp) run() {
defer app.waitGroup.Done()
for !app.closed() {
app.fetchFromAllInstances()
timex.SleepWithContext(app.context, binlogDownloadInterval)
}
}
func (app *DbBinlogApp) fetchFromAllInstances() {
tasks, err := app.backupRepo.ListRepeating()
if err != nil {
logx.Errorf("DbBinlogApp: 获取数据库备份任务失败: %s", err.Error())
return
}
for _, task := range tasks {
if app.closed() {
break
}
if err := app.runTask(app.context, task); err != nil {
logx.Errorf("DbBinlogApp: 下载 binlog 文件失败: %s", err.Error())
return
}
}
}
func (app *DbBinlogApp) Close() {
app.cancel()
app.waitGroup.Wait()
}
func (app *DbBinlogApp) closed() bool {
return app.context.Err() != nil
}
func (app *DbBinlogApp) AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error {
if err := app.binlogRepo.AddTaskIfNotExists(ctx, task); err != nil {
return err
}
if task.Id == 0 {
return nil
}
return nil
}
func (app *DbBinlogApp) DeleteTask(ctx context.Context, taskId uint64) error {
// todo: 删除 Binlog 历史文件
if err := app.binlogRepo.DeleteById(ctx, taskId); err != nil {
return err
}
return nil
}
func (app *DbBinlogApp) updateCurTask(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbBinlog) error {
task.LastStatus = status
var result = binlogResult[status]
if lastErr != nil {
result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr)
}
task.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
task.LastTime = timex.NewNullTime(time.Now())
return app.binlogRepo.UpdateById(ctx, task, "last_status", "last_result", "last_time")
}

View File

@@ -2,57 +2,190 @@ package application
import (
"context"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/domain/service"
serviceImpl "mayfly-go/internal/db/infrastructure/service"
"mayfly-go/pkg/model"
"time"
)
func newDbRestoreApp(repositories *repository.Repositories) (*DbRestoreApp, error) {
dbRestoreSvc, err := serviceImpl.NewDbRestoreSvc(repositories)
func newDbRestoreApp(repositories *repository.Repositories, dbApp Db) (*DbRestoreApp, error) {
app := &DbRestoreApp{
restoreRepo: repositories.Restore,
instanceRepo: repositories.Instance,
backupHistoryRepo: repositories.BackupHistory,
restoreHistoryRepo: repositories.RestoreHistory,
binlogHistoryRepo: repositories.BinlogHistory,
dbApp: dbApp,
}
scheduler, err := newDbScheduler[*entity.DbRestore](
repositories.Restore,
withRunRestoreTask(app))
if err != nil {
return nil, err
}
app := &DbRestoreApp{
repo: repositories.Restore,
dbRestoreSvc: dbRestoreSvc,
}
app.scheduler = scheduler
return app, nil
}
type DbRestoreApp struct {
repo repository.DbRestore
dbRestoreSvc service.DbRestoreSvc
restoreRepo repository.DbRestore
instanceRepo repository.Instance
backupHistoryRepo repository.DbBackupHistory
restoreHistoryRepo repository.DbRestoreHistory
binlogHistoryRepo repository.DbBinlogHistory
dbApp Db
scheduler *dbScheduler[*entity.DbRestore]
}
func (app *DbRestoreApp) Close() {
app.scheduler.Close()
}
func (app *DbRestoreApp) Create(ctx context.Context, tasks ...*entity.DbRestore) error {
return app.dbRestoreSvc.AddTask(ctx, tasks...)
return app.scheduler.AddTask(ctx, tasks...)
}
func (app *DbRestoreApp) Save(ctx context.Context, task *entity.DbRestore) error {
return app.dbRestoreSvc.UpdateTask(ctx, task)
return app.scheduler.UpdateTask(ctx, task)
}
func (app *DbRestoreApp) Delete(ctx context.Context, taskId uint64) error {
// todo: 删除数据库恢复历史文件
return app.dbRestoreSvc.DeleteTask(ctx, taskId)
return app.scheduler.DeleteTask(ctx, taskId)
}
func (app *DbRestoreApp) Enable(ctx context.Context, taskId uint64) error {
return app.dbRestoreSvc.EnableTask(ctx, taskId)
return app.scheduler.EnableTask(ctx, taskId)
}
func (app *DbRestoreApp) Disable(ctx context.Context, taskId uint64) error {
return app.dbRestoreSvc.DisableTask(ctx, taskId)
return app.scheduler.DisableTask(ctx, taskId)
}
// GetPageList 分页获取数据库恢复任务
func (app *DbRestoreApp) GetPageList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.repo.GetDbRestoreList(condition, pageParam, toEntity, orderBy...)
return app.restoreRepo.GetDbRestoreList(condition, pageParam, toEntity, orderBy...)
}
// GetDbNamesWithoutRestore 获取未配置定时恢复的数据库名称
func (app *DbRestoreApp) GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) {
return app.repo.GetDbNamesWithoutRestore(instanceId, dbNames)
return app.restoreRepo.GetDbNamesWithoutRestore(instanceId, dbNames)
}
func (app *DbRestoreApp) runTask(ctx context.Context, task *entity.DbRestore) error {
conn, err := app.dbApp.GetDbConnByInstanceId(task.DbInstanceId)
if err != nil {
return err
}
dbProgram := conn.GetDialect().GetDbProgram()
if task.PointInTime.Valid {
latestBinlogSequence, earliestBackupSequence := int64(-1), int64(-1)
binlogHistory, ok, err := app.binlogHistoryRepo.GetLatestHistory(task.DbInstanceId)
if err != nil {
return err
}
if ok {
latestBinlogSequence = binlogHistory.Sequence
} else {
backupHistory, err := app.backupHistoryRepo.GetEarliestHistory(task.DbInstanceId)
if err != nil {
return err
}
earliestBackupSequence = backupHistory.BinlogSequence
}
binlogFiles, err := dbProgram.FetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence)
if err != nil {
return err
}
if err := app.binlogHistoryRepo.InsertWithBinlogFiles(ctx, task.DbInstanceId, binlogFiles); err != nil {
return err
}
if err := app.restorePointInTime(ctx, dbProgram, task); err != nil {
return err
}
} else {
if err := app.restoreBackupHistory(ctx, dbProgram, task); err != nil {
return err
}
}
history := &entity.DbRestoreHistory{
CreateTime: time.Now(),
DbRestoreId: task.Id,
}
if err := app.restoreHistoryRepo.Insert(ctx, history); err != nil {
return err
}
return nil
}
func (app *DbRestoreApp) restorePointInTime(ctx context.Context, program dbm.DbProgram, task *entity.DbRestore) error {
binlogHistory, err := app.binlogHistoryRepo.GetHistoryByTime(task.DbInstanceId, task.PointInTime.Time)
if err != nil {
return err
}
position, err := program.GetBinlogEventPositionAtOrAfterTime(ctx, binlogHistory.FileName, task.PointInTime.Time)
if err != nil {
return err
}
target := &entity.BinlogInfo{
FileName: binlogHistory.FileName,
Sequence: binlogHistory.Sequence,
Position: position,
}
backupHistory, err := app.backupHistoryRepo.GetLatestHistory(task.DbInstanceId, task.DbName, target)
if err != nil {
return err
}
start := &entity.BinlogInfo{
FileName: backupHistory.BinlogFileName,
Sequence: backupHistory.BinlogSequence,
Position: backupHistory.BinlogPosition,
}
binlogHistories, err := app.binlogHistoryRepo.GetHistories(task.DbInstanceId, start, target)
if err != nil {
return err
}
restoreInfo := &dbm.RestoreInfo{
BackupHistory: backupHistory,
BinlogHistories: binlogHistories,
StartPosition: backupHistory.BinlogPosition,
TargetPosition: target.Position,
TargetTime: task.PointInTime.Time,
}
if err := program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid); err != nil {
return err
}
return program.ReplayBinlog(ctx, task.DbName, task.DbName, restoreInfo)
}
func (app *DbRestoreApp) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, task *entity.DbRestore) error {
backupHistory := &entity.DbBackupHistory{}
if err := app.backupHistoryRepo.GetById(backupHistory, task.DbBackupHistoryId); err != nil {
return err
}
return program.RestoreBackupHistory(ctx, backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid)
}
func withRunRestoreTask(app *DbRestoreApp) dbSchedulerOption[*entity.DbRestore] {
return func(scheduler *dbScheduler[*entity.DbRestore]) {
scheduler.RunTask = app.runTask
}
}
func newDbRestoreHistoryApp(repositories *repository.Repositories) (*DbRestoreHistoryApp, error) {
app := &DbRestoreHistoryApp{
repo: repositories.RestoreHistory,
}
return app, nil
}
type DbRestoreHistoryApp struct {
repo repository.DbRestoreHistory
}
// GetPageList 分页获取数据库备份历史
func (app *DbRestoreHistoryApp) GetPageList(condition *entity.DbRestoreHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.repo.GetDbRestoreHistories(condition, pageParam, toEntity, orderBy...)
}

View File

@@ -1,23 +0,0 @@
package application
import (
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/model"
)
func newDbRestoreHistoryApp(repositories *repository.Repositories) (*DbRestoreHistoryApp, error) {
app := &DbRestoreHistoryApp{
repo: repositories.RestoreHistory,
}
return app, nil
}
type DbRestoreHistoryApp struct {
repo repository.DbRestoreHistory
}
// GetPageList 分页获取数据库备份历史
func (app *DbRestoreHistoryApp) GetPageList(condition *entity.DbRestoreHistoryQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.repo.GetDbRestoreHistories(condition, pageParam, toEntity, orderBy...)
}

View File

@@ -0,0 +1,235 @@
package application
import (
"context"
"errors"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/queue"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/utils/timex"
"sync"
"time"
)
const sleepAfterError = time.Minute
type dbScheduler[T entity.DbTask] struct {
mutex sync.Mutex
waitGroup sync.WaitGroup
queue *queue.DelayQueue[T]
context context.Context
cancel context.CancelFunc
RunTask func(ctx context.Context, task T) error
taskRepo repository.DbTask[T]
}
type dbSchedulerOption[T entity.DbTask] func(*dbScheduler[T])
func newDbScheduler[T entity.DbTask](taskRepo repository.DbTask[T], opts ...dbSchedulerOption[T]) (*dbScheduler[T], error) {
ctx, cancel := context.WithCancel(context.Background())
scheduler := &dbScheduler[T]{
taskRepo: taskRepo,
queue: queue.NewDelayQueue[T](0),
context: ctx,
cancel: cancel,
}
for _, opt := range opts {
opt(scheduler)
}
if scheduler.RunTask == nil {
return nil, errors.New("数据库任务调度器没有设置 RunTask")
}
if err := scheduler.loadTask(context.Background()); err != nil {
return nil, err
}
scheduler.waitGroup.Add(1)
go scheduler.run()
return scheduler, nil
}
func (s *dbScheduler[T]) updateTaskStatus(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error {
base := task.GetTaskBase()
base.LastStatus = status
var result = task.MessageWithStatus(status)
if lastErr != nil {
result = fmt.Sprintf("%v: %v", result, lastErr)
}
base.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
base.LastTime = timex.NewNullTime(time.Now())
return s.taskRepo.UpdateTaskStatus(ctx, task)
}
func (s *dbScheduler[T]) UpdateTask(ctx context.Context, task T) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.taskRepo.UpdateById(ctx, task); err != nil {
return err
}
oldTask, ok := s.queue.Remove(ctx, task.GetId())
if !ok {
return errors.New("任务不存在")
}
oldTask.Update(task)
if !oldTask.Schedule() {
return nil
}
if !s.queue.Enqueue(ctx, oldTask) {
return errors.New("任务入队失败")
}
return nil
}
func (s *dbScheduler[T]) run() {
defer s.waitGroup.Done()
for !s.closed() {
time.Sleep(time.Second)
s.mutex.Lock()
task, ok := s.queue.TryDequeue()
if !ok {
s.mutex.Unlock()
continue
}
if err := s.updateTaskStatus(s.context, entity.TaskReserved, nil, task); err != nil {
s.mutex.Unlock()
timex.SleepWithContext(s.context, sleepAfterError)
continue
}
s.mutex.Unlock()
errRun := s.RunTask(s.context, task)
taskStatus := entity.TaskSuccess
if errRun != nil {
taskStatus = entity.TaskFailed
}
s.mutex.Lock()
if err := s.updateTaskStatus(s.context, taskStatus, errRun, task); err != nil {
s.mutex.Unlock()
timex.SleepWithContext(s.context, sleepAfterError)
continue
}
task.Schedule()
if !task.IsFinished() {
s.queue.Enqueue(s.context, task)
}
s.mutex.Unlock()
}
}
func (s *dbScheduler[T]) Close() {
s.cancel()
s.waitGroup.Wait()
}
func (s *dbScheduler[T]) closed() bool {
return s.context.Err() != nil
}
func (s *dbScheduler[T]) loadTask(ctx context.Context) error {
s.mutex.Lock()
defer s.mutex.Unlock()
tasks, err := s.taskRepo.ListToDo()
if err != nil {
return err
}
for _, task := range tasks {
if !task.Schedule() {
continue
}
s.queue.Enqueue(ctx, task)
}
return nil
}
func (s *dbScheduler[T]) AddTask(ctx context.Context, tasks ...T) error {
s.mutex.Lock()
defer s.mutex.Unlock()
for _, task := range tasks {
if err := s.taskRepo.AddTask(ctx, task); err != nil {
return err
}
if !task.Schedule() {
continue
}
s.queue.Enqueue(ctx, task)
}
return nil
}
func (s *dbScheduler[T]) DeleteTask(ctx context.Context, taskId uint64) error {
// todo: 删除数据库备份历史文件
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.taskRepo.DeleteById(ctx, taskId); err != nil {
return err
}
s.queue.Remove(ctx, taskId)
return nil
}
func (s *dbScheduler[T]) EnableTask(ctx context.Context, taskId uint64) error {
s.mutex.Lock()
defer s.mutex.Unlock()
task := anyx.DeepZero[T]()
if err := s.taskRepo.GetById(task, taskId); err != nil {
return err
}
if task.IsEnabled() {
return nil
}
task.GetTaskBase().Enabled = true
if err := s.taskRepo.UpdateEnabled(ctx, taskId, true); err != nil {
return err
}
s.queue.Remove(ctx, taskId)
if !task.Schedule() {
return nil
}
s.queue.Enqueue(ctx, task)
return nil
}
func (s *dbScheduler[T]) DisableTask(ctx context.Context, taskId uint64) error {
s.mutex.Lock()
defer s.mutex.Unlock()
task := anyx.DeepZero[T]()
if err := s.taskRepo.GetById(task, taskId); err != nil {
return err
}
if !task.IsEnabled() {
return nil
}
if err := s.taskRepo.UpdateEnabled(ctx, taskId, false); err != nil {
return err
}
s.queue.Remove(ctx, taskId)
return nil
}
func (s *dbScheduler[T]) StartTask(ctx context.Context, taskId uint64) error {
s.mutex.Lock()
defer s.mutex.Unlock()
task := anyx.DeepZero[T]()
if err := s.taskRepo.GetById(task, taskId); err != nil {
return err
}
if !task.IsEnabled() {
return errors.New("任务未启用")
}
s.queue.Remove(ctx, taskId)
task.GetTaskBase().Deadline = time.Now()
s.queue.Enqueue(ctx, task)
return nil
}

View File

@@ -2,6 +2,7 @@ package application
import (
"context"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
@@ -77,7 +78,9 @@ func (app *instanceAppImpl) Save(ctx context.Context, instanceEntity *entity.DbI
if err == nil {
return errorx.NewBiz("该数据库实例已存在")
}
instanceEntity.PwdEncrypt()
if err := instanceEntity.PwdEncrypt(); err != nil {
return errorx.NewBiz(err.Error())
}
return app.Insert(ctx, instanceEntity)
}
@@ -85,7 +88,9 @@ func (app *instanceAppImpl) Save(ctx context.Context, instanceEntity *entity.DbI
if err == nil && oldInstance.Id != instanceEntity.Id {
return errorx.NewBiz("该数据库实例已存在")
}
instanceEntity.PwdEncrypt()
if err := instanceEntity.PwdEncrypt(); err != nil {
return errorx.NewBiz(err.Error())
}
return app.UpdateById(ctx, instanceEntity)
}
@@ -95,7 +100,7 @@ func (app *instanceAppImpl) Delete(ctx context.Context, id uint64) error {
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) {
ed.Network = ed.GetNetwork()
metaDb := ed.Type.MetaDbName()
metaDb := dbm.ToDbType(ed.Type).MetaDbName()
dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn()
if err != nil {

View File

@@ -2,6 +2,8 @@ package config
import (
sysapp "mayfly-go/internal/sys/application"
"path/filepath"
"runtime"
)
const (
@@ -9,6 +11,7 @@ const (
ConfigKeyDbQueryMaxCount string = "DbQueryMaxCount" // 数据库查询的最大数量
ConfigKeyDbBackupRestore string = "DbBackupRestore" // 数据库备份
ConfigKeyDbMysqlBin string = "MysqlBin" // mysql可执行文件配置
ConfigKeyDbMariaDbBin string = "MariaDbBin" // mariadb可执行文件配置
)
// 获取数据库最大查询数量配置
@@ -36,7 +39,7 @@ func GetDbBackupRestore() *DbBackupRestore {
if backupPath == "" {
backupPath = "./db/backup"
}
dbrc.BackupPath = backupPath
dbrc.BackupPath = filepath.Join(backupPath)
return dbrc
}
@@ -50,35 +53,39 @@ type MysqlBin struct {
}
// 获取数据库备份配置
func GetMysqlBin() *MysqlBin {
c := sysapp.GetConfigApp().GetConfig(ConfigKeyDbMysqlBin)
func GetMysqlBin(configKey string) *MysqlBin {
c := sysapp.GetConfigApp().GetConfig(configKey)
jm := c.GetJsonMap()
mbc := new(MysqlBin)
path := jm["path"]
if path == "" {
path = "./db/backup"
path = "./db/mysql/bin"
}
mbc.Path = path
mbc.Path = filepath.Join(path)
var extName string
if runtime.GOOS == "windows" {
extName = ".exe"
}
mysqlPath := jm["mysql"]
if mysqlPath == "" {
mysqlPath = path + "mysql"
mysqlPath = filepath.Join(path, "mysql"+extName)
}
mbc.MysqlPath = mysqlPath
mbc.MysqlPath = filepath.Join(mysqlPath)
mysqldumpPath := jm["mysqldump"]
if mysqldumpPath == "" {
mysqldumpPath = path + "mysqldump"
mysqldumpPath = filepath.Join(path, "mysqldump"+extName)
}
mbc.MysqldumpPath = mysqldumpPath
mbc.MysqldumpPath = filepath.Join(mysqldumpPath)
mysqlbinlogPath := jm["mysqlbinlog"]
if mysqlbinlogPath == "" {
mysqlbinlogPath = path + "mysqlbinlog"
mysqlbinlogPath = filepath.Join(path, "mysqlbinlog"+extName)
}
mbc.MysqlbinlogPath = mysqlbinlogPath
mbc.MysqlbinlogPath = filepath.Join(mysqlbinlogPath)
return mbc
}

View File

@@ -0,0 +1,32 @@
package dbm
import (
"context"
"mayfly-go/internal/db/domain/entity"
"path/filepath"
"time"
)
type DbProgram interface {
Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error)
FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error)
ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) error
RestoreBackupHistory(ctx context.Context, dbName string, dbBackupId uint64, dbBackupHistoryUuid string) error
GetBinlogEventPositionAtOrAfterTime(ctx context.Context, binlogName string, targetTime time.Time) (position int64, parseErr error)
}
type RestoreInfo struct {
BackupHistory *entity.DbBackupHistory
BinlogHistories []*entity.DbBinlogHistory
StartPosition int64
TargetPosition int64
TargetTime time.Time
}
func (ri *RestoreInfo) GetBinlogPaths(binlogDir string) []string {
files := make([]string, 0, len(ri.BinlogHistories))
for _, history := range ri.BinlogHistories {
files = append(files, filepath.Join(binlogDir, history.FileName))
}
return files
}

View File

@@ -1,4 +1,4 @@
package service
package dbm
import (
"bufio"
@@ -19,112 +19,55 @@ import (
"golang.org/x/sync/singleflight"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/domain/service"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/structx"
)
// BinlogFile is the metadata of the MySQL binlog file.
type BinlogFile struct {
Name string
Size int64
var _ DbProgram = (*DbProgramMysql)(nil)
// Sequence is parsed from Name and is for the sorting purpose.
Sequence int64
FirstEventTime time.Time
Downloaded bool
type DbProgramMysql struct {
dbConn *DbConn
// mysqlBin 用于集成测试
mysqlBin *config.MysqlBin
// backupPath 用于集成测试
backupPath string
}
func newBinlogFile(name string, size int64) (*BinlogFile, error) {
_, seq, err := ParseBinlogName(name)
if err != nil {
return nil, err
}
return &BinlogFile{Name: name, Size: size, Sequence: seq}, nil
}
var _ service.DbInstanceSvc = (*DbInstanceSvcImpl)(nil)
type DbInstanceSvcImpl struct {
instanceId uint64
dbInfo *dbm.DbInfo
backupHistoryRepo repository.DbBackupHistory
binlogHistoryRepo repository.DbBinlogHistory
}
func NewDbInstanceSvc(instance *entity.DbInstance, repositories *repository.Repositories) *DbInstanceSvcImpl {
dbInfo := new(dbm.DbInfo)
_ = structx.Copy(dbInfo, instance)
return &DbInstanceSvcImpl{
instanceId: instance.Id,
dbInfo: dbInfo,
backupHistoryRepo: repositories.BackupHistory,
binlogHistoryRepo: repositories.BinlogHistory,
func NewDbProgramMysql(dbConn *DbConn) *DbProgramMysql {
return &DbProgramMysql{
dbConn: dbConn,
}
}
type RestoreInfo struct {
backupHistory *entity.DbBackupHistory
binlogHistories []*entity.DbBinlogHistory
startPosition int64
targetPosition int64
targetTime time.Time
func (svc *DbProgramMysql) dbInfo() *DbInfo {
return svc.dbConn.Info
}
func (ri *RestoreInfo) getBinlogFiles(binlogDir string) []string {
files := make([]string, 0, len(ri.binlogHistories))
for _, history := range ri.binlogHistories {
files = append(files, filepath.Join(binlogDir, history.FileName))
func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
if svc.mysqlBin != nil {
return svc.mysqlBin
}
return files
var mysqlBin *config.MysqlBin
switch svc.dbInfo().Type {
default:
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
}
return mysqlBin
}
func (svc *DbInstanceSvcImpl) getBinlogFilePath(fileName string) string {
return filepath.Join(getBinlogDir(svc.instanceId), fileName)
func (svc *DbProgramMysql) getBackupPath() string {
if len(svc.backupPath) > 0 {
return svc.backupPath
}
return config.GetDbBackupRestore().BackupPath
}
func (svc *DbInstanceSvcImpl) GetRestoreInfo(ctx context.Context, dbName string, targetTime time.Time) (*RestoreInfo, error) {
binlogHistory, err := svc.binlogHistoryRepo.GetHistoryByTime(svc.instanceId, targetTime)
if err != nil {
return nil, err
}
position, err := getBinlogEventPositionAtOrAfterTime(ctx, svc.getBinlogFilePath(binlogHistory.FileName), targetTime)
if err != nil {
return nil, err
}
target := &entity.BinlogInfo{
FileName: binlogHistory.FileName,
Sequence: binlogHistory.Sequence,
Position: position,
}
backupHistory, err := svc.backupHistoryRepo.GetLatestHistory(svc.instanceId, dbName, target)
if err != nil {
return nil, err
}
start := &entity.BinlogInfo{
FileName: backupHistory.BinlogFileName,
Sequence: backupHistory.BinlogSequence,
Position: backupHistory.BinlogPosition,
}
binlogHistories, err := svc.binlogHistoryRepo.GetHistories(svc.instanceId, start, target)
if err != nil {
return nil, err
}
return &RestoreInfo{
backupHistory: backupHistory,
binlogHistories: binlogHistories,
startPosition: backupHistory.BinlogPosition,
targetPosition: target.Position,
targetTime: targetTime,
}, nil
func (svc *DbProgramMysql) GetBinlogFilePath(fileName string) string {
return filepath.Join(svc.getBinlogDir(svc.dbInfo().InstanceId), fileName)
}
func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error) {
dir := getDbBackupDir(backupHistory.DbInstanceId, backupHistory.DbBackupId)
func (svc *DbProgramMysql) Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error) {
dir := svc.getDbBackupDir(backupHistory.DbInstanceId, backupHistory.DbBackupId)
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return nil, err
}
@@ -134,10 +77,10 @@ func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.
}()
args := []string{
"--host", svc.dbInfo.Host,
"--port", strconv.Itoa(svc.dbInfo.Port),
"--user", svc.dbInfo.Username,
"--password=" + svc.dbInfo.Password,
"--host", svc.dbInfo().Host,
"--port", strconv.Itoa(svc.dbInfo().Port),
"--user", svc.dbInfo().Username,
"--password=" + svc.dbInfo().Password,
"--add-drop-database",
"--result-file", tmpFile,
"--single-transaction",
@@ -145,7 +88,7 @@ func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.
"--databases", backupHistory.DbName,
}
cmd := exec.CommandContext(ctx, mysqldumpPath(), args...)
cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqldumpPath, args...)
logx.Debugf("backup database using mysqldump binary: %s", cmd.String())
if err := runCmd(cmd); err != nil {
logx.Errorf("运行 mysqldump 程序失败: %v", err)
@@ -174,15 +117,17 @@ func (svc *DbInstanceSvcImpl) Backup(ctx context.Context, backupHistory *entity.
return binlogInfo, nil
}
func (svc *DbInstanceSvcImpl) RestoreBackup(ctx context.Context, database, fileName string) error {
func (svc *DbProgramMysql) RestoreBackupHistory(ctx context.Context, dbName string, dbBackupId uint64, dbBackupHistoryUuid string) error {
args := []string{
"--host", svc.dbInfo.Host,
"--port", strconv.Itoa(svc.dbInfo.Port),
"--database", database,
"--user", svc.dbInfo.Username,
"--password=" + svc.dbInfo.Password,
"--host", svc.dbInfo().Host,
"--port", strconv.Itoa(svc.dbInfo().Port),
"--database", dbName,
"--user", svc.dbInfo().Username,
"--password=" + svc.dbInfo().Password,
}
fileName := filepath.Join(svc.getDbBackupDir(svc.dbInfo().InstanceId, dbBackupId),
fmt.Sprintf("%v.sql", dbBackupHistoryUuid))
file, err := os.Open(fileName)
if err != nil {
return errors.Wrap(err, "打开备份文件失败")
@@ -191,7 +136,7 @@ func (svc *DbInstanceSvcImpl) RestoreBackup(ctx context.Context, database, fileN
_ = file.Close()
}()
cmd := exec.CommandContext(ctx, mysqlPath(), args...)
cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlPath, args...)
cmd.Stdin = file
logx.Debug("恢复数据库: ", cmd.String())
if err := runCmd(cmd); err != nil {
@@ -201,42 +146,14 @@ func (svc *DbInstanceSvcImpl) RestoreBackup(ctx context.Context, database, fileN
return nil
}
func (svc *DbInstanceSvcImpl) Restore(ctx context.Context, task *entity.DbRestore) error {
if task.PointInTime.IsZero() {
backupHistory := &entity.DbBackupHistory{}
err := svc.backupHistoryRepo.GetById(backupHistory, task.DbBackupHistoryId)
if err != nil {
return err
}
fileName := filepath.Join(getDbBackupDir(backupHistory.DbInstanceId, backupHistory.DbBackupId),
fmt.Sprintf("%v.sql", backupHistory.Uuid))
return svc.RestoreBackup(ctx, task.DbName, fileName)
}
if err := svc.FetchBinlogs(ctx, true); err != nil {
return err
}
restoreInfo, err := svc.GetRestoreInfo(ctx, task.DbName, task.PointInTime)
if err != nil {
return err
}
fileName := filepath.Join(getDbBackupDir(restoreInfo.backupHistory.DbInstanceId, restoreInfo.backupHistory.DbBackupId),
fmt.Sprintf("%s.sql", restoreInfo.backupHistory.Uuid))
if err := svc.RestoreBackup(ctx, task.DbName, fileName); err != nil {
return err
}
return svc.ReplayBinlogToDatabase(ctx, task.DbName, task.DbName, restoreInfo)
}
// Download binlog files on server.
func (svc *DbInstanceSvcImpl) downloadBinlogFilesOnServer(ctx context.Context, binlogFilesOnServerSorted []*BinlogFile, downloadLatestBinlogFile bool) error {
func (svc *DbProgramMysql) downloadBinlogFilesOnServer(ctx context.Context, binlogFilesOnServerSorted []*entity.BinlogFile, downloadLatestBinlogFile bool) error {
if len(binlogFilesOnServerSorted) == 0 {
logx.Debug("No binlog file found on server to download")
return nil
}
if err := os.MkdirAll(getBinlogDir(svc.instanceId), os.ModePerm); err != nil {
return errors.Wrapf(err, "创建 binlog 目录失败: %q", getBinlogDir(svc.instanceId))
if err := os.MkdirAll(svc.getBinlogDir(svc.dbInfo().InstanceId), os.ModePerm); err != nil {
return errors.Wrapf(err, "创建 binlog 目录失败: %q", svc.getBinlogDir(svc.dbInfo().InstanceId))
}
latestBinlogFileOnServer := binlogFilesOnServerSorted[len(binlogFilesOnServerSorted)-1]
for _, fileOnServer := range binlogFilesOnServerSorted {
@@ -244,7 +161,7 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFilesOnServer(ctx context.Context, b
if isLatest && !downloadLatestBinlogFile {
continue
}
binlogFilePath := filepath.Join(getBinlogDir(svc.instanceId), fileOnServer.Name)
binlogFilePath := filepath.Join(svc.getBinlogDir(svc.dbInfo().InstanceId), fileOnServer.Name)
logx.Debug("Downloading binlog file from MySQL server.", logx.String("path", binlogFilePath), logx.Bool("isLatest", isLatest))
if err := svc.downloadBinlogFile(ctx, fileOnServer, isLatest); err != nil {
logx.Error("下载 binlog 文件失败", logx.String("path", binlogFilePath), logx.String("error", err.Error()))
@@ -255,7 +172,7 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFilesOnServer(ctx context.Context, b
}
// Parse the first binlog eventTs of a local binlog file.
func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) {
func (svc *DbProgramMysql) parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (eventTime time.Time, parseErr error) {
args := []string{
// Local binlog file path.
filePath,
@@ -264,7 +181,7 @@ func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (event
// Tell mysqlbinlog to suppress the BINLOG statements for row events, which reduces the unneeded output.
"--base64-output=DECODE-ROWS",
}
cmd := exec.CommandContext(ctx, mysqlbinlogPath(), args...)
cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, args...)
var stderr strings.Builder
cmd.Stderr = &stderr
pr, err := cmd.StdoutPipe()
@@ -282,7 +199,7 @@ func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (event
}
}()
for s := bufio.NewScanner(pr); ; s.Scan() {
for s := bufio.NewScanner(pr); s.Scan(); {
line := s.Text()
eventTimeParsed, found, err := parseBinlogEventTimeInLine(line)
if err != nil {
@@ -295,89 +212,58 @@ func parseLocalBinlogFirstEventTime(ctx context.Context, filePath string) (event
return time.Time{}, errors.New("解析 binlog 文件失败")
}
// getBinlogDir gets the binlogDir.
func getBinlogDir(instanceId uint64) string {
return filepath.Join(
config.GetDbBackupRestore().BackupPath,
fmt.Sprintf("instance-%d", instanceId),
"binlog")
}
func getDbInstanceBackupRoot(instanceId uint64) string {
return filepath.Join(
config.GetDbBackupRestore().BackupPath,
fmt.Sprintf("instance-%d", instanceId))
}
func getDbBackupDir(instanceId, backupId uint64) string {
return filepath.Join(
config.GetDbBackupRestore().BackupPath,
fmt.Sprintf("instance-%d", instanceId),
fmt.Sprintf("backup-%d", backupId))
}
var singleFlightGroup singleflight.Group
// FetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`.
func (svc *DbInstanceSvcImpl) FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool) error {
latestDownloaded := false
_, err, _ := singleFlightGroup.Do(strconv.FormatUint(svc.instanceId, 10), func() (interface{}, error) {
latestDownloaded = downloadLatestBinlogFile
err := svc.fetchBinlogs(ctx, downloadLatestBinlogFile)
return nil, err
func (svc *DbProgramMysql) FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) {
var downloaded bool
key := strconv.FormatUint(svc.dbInfo().InstanceId, 16)
binlogFiles, err, _ := singleFlightGroup.Do(key, func() (interface{}, error) {
downloaded = true
return svc.fetchBinlogs(ctx, downloadLatestBinlogFile, earliestBackupSequence, latestBinlogSequence)
})
if downloadLatestBinlogFile && !latestDownloaded {
_, err, _ = singleFlightGroup.Do(strconv.FormatUint(svc.instanceId, 10), func() (interface{}, error) {
err := svc.fetchBinlogs(ctx, true)
return nil, err
})
if err != nil {
return nil, err
}
return err
if downloaded {
return binlogFiles.([]*entity.BinlogFile), nil
}
if !downloadLatestBinlogFile {
return nil, nil
}
binlogFiles, err, _ = singleFlightGroup.Do(key, func() (interface{}, error) {
return svc.fetchBinlogs(ctx, true, earliestBackupSequence, latestBinlogSequence)
})
if err != nil {
return nil, err
}
return binlogFiles.([]*entity.BinlogFile), err
}
// fetchBinlogs downloads binlog files from startingFileName on server to `binlogDir`.
func (svc *DbInstanceSvcImpl) fetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool) error {
func (svc *DbProgramMysql) fetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool, earliestBackupSequence, latestBinlogSequence int64) ([]*entity.BinlogFile, error) {
// Read binlog files list on server.
binlogFilesOnServerSorted, err := svc.GetSortedBinlogFilesOnServer(ctx)
if err != nil {
return err
return nil, err
}
if len(binlogFilesOnServerSorted) == 0 {
logx.Debug("No binlog file found on server to download")
return nil
}
latest, ok, err := svc.binlogHistoryRepo.GetLatestHistory(svc.instanceId)
if err != nil {
return err
}
binlogFileName := ""
latestSequence := int64(-1)
earliestSequence := int64(-1)
if ok {
latestSequence = latest.Sequence
binlogFileName = latest.FileName
} else {
earliest, err := svc.backupHistoryRepo.GetEarliestHistory(svc.instanceId)
if err != nil {
return err
}
earliestSequence = earliest.BinlogSequence
binlogFileName = earliest.BinlogFileName
return nil, nil
}
indexHistory := -1
for i, file := range binlogFilesOnServerSorted {
if latestSequence == file.Sequence {
if latestBinlogSequence == file.Sequence {
indexHistory = i + 1
break
}
if earliestSequence == file.Sequence {
if earliestBackupSequence == file.Sequence {
indexHistory = i
break
}
}
if indexHistory < 0 {
return errors.New(fmt.Sprintf("在数据库服务器上未找到 binlog 文件 %q", binlogFileName))
return nil, errors.New(fmt.Sprintf("在数据库服务器上未找到 binlog 文件: %d, %d", earliestBackupSequence, latestBinlogSequence))
}
if indexHistory > len(binlogFilesOnServerSorted)-1 {
indexHistory = len(binlogFilesOnServerSorted) - 1
@@ -385,58 +271,36 @@ func (svc *DbInstanceSvcImpl) fetchBinlogs(ctx context.Context, downloadLatestBi
binlogFilesOnServerSorted = binlogFilesOnServerSorted[indexHistory:]
if err := svc.downloadBinlogFilesOnServer(ctx, binlogFilesOnServerSorted, downloadLatestBinlogFile); err != nil {
return err
}
for i, fileOnServer := range binlogFilesOnServerSorted {
if !fileOnServer.Downloaded {
break
}
history := &entity.DbBinlogHistory{
CreateTime: time.Now(),
FileName: fileOnServer.Name,
FileSize: fileOnServer.Size,
Sequence: fileOnServer.Sequence,
FirstEventTime: fileOnServer.FirstEventTime,
DbInstanceId: svc.instanceId,
}
if i == len(binlogFilesOnServerSorted)-1 {
if err := svc.binlogHistoryRepo.Upsert(ctx, history); err != nil {
return err
}
} else {
if err := svc.binlogHistoryRepo.Insert(ctx, history); err != nil {
return err
}
}
return nil, err
}
return nil
return binlogFilesOnServerSorted, nil
}
// Syncs the binlog specified by `meta` between the instance and local.
// If isLast is true, it means that this is the last binlog file containing the targetTs event.
// It may keep growing as there are ongoing writes to the database. So we just need to check that
// the file size is larger or equal to the binlog file size we queried from the MySQL server earlier.
func (svc *DbInstanceSvcImpl) downloadBinlogFile(ctx context.Context, binlogFileToDownload *BinlogFile, isLast bool) error {
tempBinlogPrefix := filepath.Join(getBinlogDir(svc.instanceId), "tmp-")
func (svc *DbProgramMysql) downloadBinlogFile(ctx context.Context, binlogFileToDownload *entity.BinlogFile, isLast bool) error {
tempBinlogPrefix := filepath.Join(svc.getBinlogDir(svc.dbInfo().InstanceId), "tmp-")
args := []string{
binlogFileToDownload.Name,
"--read-from-remote-server",
// Verify checksum binlog events.
"--verify-binlog-checksum",
"--host", svc.dbInfo.Host,
"--port", strconv.Itoa(svc.dbInfo.Port),
"--user", svc.dbInfo.Username,
"--host", svc.dbInfo().Host,
"--port", strconv.Itoa(svc.dbInfo().Port),
"--user", svc.dbInfo().Username,
"--raw",
// With --raw this is a prefix for the file names.
"--result-file", tempBinlogPrefix,
}
cmd := exec.CommandContext(ctx, mysqlbinlogPath(), args...)
cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, args...)
// We cannot set password as a flag. Otherwise, there is warning message
// "mysqlbinlog: [Warning] Using a password on the command line interface can be insecure."
if svc.dbInfo.Password != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("MYSQL_PWD=%s", svc.dbInfo.Password))
if svc.dbInfo().Password != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("MYSQL_PWD=%s", svc.dbInfo().Password))
}
logx.Debug("Downloading binlog files using mysqlbinlog:", cmd.String())
@@ -464,11 +328,11 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFile(ctx context.Context, binlogFile
return errors.Errorf("下载的 binlog 文件 %q 与服务上的文件大小不一致 %d != %d", binlogFilePathTemp, binlogFileTempInfo.Size(), binlogFileToDownload.Size)
}
binlogFilePath := svc.getBinlogFilePath(binlogFileToDownload.Name)
binlogFilePath := svc.GetBinlogFilePath(binlogFileToDownload.Name)
if err := os.Rename(binlogFilePathTemp, binlogFilePath); err != nil {
return errors.Wrapf(err, "binlog 文件更名失败: %q -> %q", binlogFilePathTemp, binlogFilePath)
}
firstEventTime, err := parseLocalBinlogFirstEventTime(ctx, binlogFilePath)
firstEventTime, err := svc.parseLocalBinlogFirstEventTime(ctx, binlogFilePath)
if err != nil {
return err
}
@@ -479,14 +343,9 @@ func (svc *DbInstanceSvcImpl) downloadBinlogFile(ctx context.Context, binlogFile
}
// GetSortedBinlogFilesOnServer returns the information of binlog files in ascending order by their numeric extension.
func (svc *DbInstanceSvcImpl) GetSortedBinlogFilesOnServer(_ context.Context) ([]*BinlogFile, error) {
conn, err := svc.dbInfo.Conn()
if err != nil {
return nil, err
}
defer conn.Close()
func (svc *DbProgramMysql) GetSortedBinlogFilesOnServer(_ context.Context) ([]*entity.BinlogFile, error) {
query := "SHOW BINARY LOGS"
columns, rows, err := conn.Query(query)
columns, rows, err := svc.dbConn.Query(query)
if err != nil {
return nil, errors.Wrapf(err, "SQL 语句 %q 执行失败", query)
}
@@ -504,7 +363,7 @@ func (svc *DbInstanceSvcImpl) GetSortedBinlogFilesOnServer(_ context.Context) ([
return nil, errors.Errorf("SQL 语句 %q 执行结果解析失败", query)
}
var binlogFiles []*BinlogFile
var binlogFiles []*entity.BinlogFile
for _, row := range rows {
name, nameOk := row["Log_name"].(string)
@@ -512,11 +371,15 @@ func (svc *DbInstanceSvcImpl) GetSortedBinlogFilesOnServer(_ context.Context) ([
if !nameOk || !sizeOk {
return nil, errors.Errorf("SQL 语句 %q 执行结果解析失败", query)
}
binlogFile, err := newBinlogFile(name, int64(size))
_, seq, err := ParseBinlogName(name)
if err != nil {
return nil, errors.Wrapf(err, "SQL 语句 %q 执行结果解析失败", query)
}
binlogFile := &entity.BinlogFile{
Name: name,
Size: int64(size),
Sequence: seq,
}
binlogFiles = append(binlogFiles, binlogFile)
}
@@ -566,18 +429,19 @@ func readBinlogInfoFromBackup(reader io.Reader) (*entity.BinlogInfo, error) {
}
// Use command like mysqlbinlog --start-datetime=targetTs binlog.000001 to parse the first binlog event position with timestamp equal or after targetTs.
func getBinlogEventPositionAtOrAfterTime(ctx context.Context, filePath string, targetTime time.Time) (position int64, parseErr error) {
func (svc *DbProgramMysql) GetBinlogEventPositionAtOrAfterTime(ctx context.Context, binlogName string, targetTime time.Time) (position int64, parseErr error) {
binlogPath := svc.GetBinlogFilePath(binlogName)
args := []string{
// Local binlog file path.
filePath,
binlogPath,
// Verify checksum binlog events.
"--verify-binlog-checksum",
// Tell mysqlbinlog to suppress the BINLOG statements for row events, which reduces the unneeded output.
"--base64-output=DECODE-ROWS",
// Instruct mysqlbinlog to start output only after encountering the first binlog event with timestamp equal or after targetTime.
"--start-datetime", formatDateTime(targetTime),
"--start-datetime", targetTime.Local().Format(time.DateTime),
}
cmd := exec.CommandContext(ctx, mysqlbinlogPath(), args...)
cmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, args...)
var stderr strings.Builder
cmd.Stderr = &stderr
pr, err := cmd.StdoutPipe()
@@ -594,7 +458,7 @@ func getBinlogEventPositionAtOrAfterTime(ctx context.Context, filePath string, t
}
}()
for s := bufio.NewScanner(pr); ; s.Scan() {
for s := bufio.NewScanner(pr); s.Scan(); {
line := s.Text()
posParsed, found, err := parseBinlogEventPosInLine(line)
if err != nil {
@@ -605,11 +469,11 @@ func getBinlogEventPositionAtOrAfterTime(ctx context.Context, filePath string, t
return posParsed, nil
}
}
return 0, errors.Errorf("在 %v 之后没有 binlog 事件", targetTime)
return 0, errors.Errorf("在 %s 之后没有 binlog 事件", targetTime.Format(time.DateTime))
}
// replayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`.
func (svc *DbInstanceSvcImpl) replayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) {
// ReplayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`.
func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) {
const (
// Variable lower_case_table_names related.
@@ -657,27 +521,27 @@ func (svc *DbInstanceSvcImpl) replayBinlog(ctx context.Context, originalDatabase
// List entries for just this database. It's applied after the --rewrite-db option, so we should provide the rewritten database, i.e., pitrDatabase.
"--database", targetDatabase,
// Decode binary log from first event with position equal to or greater than argument.
"--start-position", fmt.Sprintf("%d", restoreInfo.startPosition),
"--start-position", fmt.Sprintf("%d", restoreInfo.StartPosition),
// Stop decoding binary log at first event with position equal to or greater than argument.
"--stop-position", fmt.Sprintf("%d", restoreInfo.targetPosition),
"--stop-position", fmt.Sprintf("%d", restoreInfo.TargetPosition),
}
mysqlbinlogArgs = append(mysqlbinlogArgs, restoreInfo.getBinlogFiles(getBinlogDir(svc.instanceId))...)
mysqlbinlogArgs = append(mysqlbinlogArgs, restoreInfo.GetBinlogPaths(svc.getBinlogDir(svc.dbInfo().InstanceId))...)
mysqlArgs := []string{
"--host", svc.dbInfo.Host,
"--port", strconv.Itoa(svc.dbInfo.Port),
"--user", svc.dbInfo.Username,
"--host", svc.dbInfo().Host,
"--port", strconv.Itoa(svc.dbInfo().Port),
"--user", svc.dbInfo().Username,
}
if svc.dbInfo.Password != "" {
if svc.dbInfo().Password != "" {
// The --password parameter of mysql/mysqlbinlog does not support the "--password PASSWORD" format (split by space).
// If provided like that, the program will hang.
mysqlArgs = append(mysqlArgs, fmt.Sprintf("--password=%s", svc.dbInfo.Password))
mysqlArgs = append(mysqlArgs, fmt.Sprintf("--password=%s", svc.dbInfo().Password))
}
mysqlbinlogCmd := exec.CommandContext(ctx, mysqlbinlogPath(), mysqlbinlogArgs...)
mysqlCmd := exec.CommandContext(ctx, mysqlPath(), mysqlArgs...)
mysqlbinlogCmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlbinlogPath, mysqlbinlogArgs...)
mysqlCmd := exec.CommandContext(ctx, svc.getMysqlBin().MysqlPath, mysqlArgs...)
logx.Debug("Start replay binlog commands.",
logx.String("mysqlbinlog", mysqlbinlogCmd.String()),
logx.String("mysql", mysqlCmd.String()))
@@ -717,26 +581,15 @@ func (svc *DbInstanceSvcImpl) replayBinlog(ctx context.Context, originalDatabase
return errors.Wrap(err, "启动 mysql 程序失败")
}
if err := mysqlCmd.Wait(); err != nil {
return errors.Errorf("运行 mysql 程序失败: %s", mysqlbinlogErr.String())
return errors.Errorf("运行 mysql 程序失败: %s", mysqlErr.String())
}
return nil
}
// ReplayBinlogToDatabase replays the binlog of originDatabaseName to the targetDatabaseName.
func (svc *DbInstanceSvcImpl) ReplayBinlogToDatabase(ctx context.Context, originDatabaseName, targetDatabaseName string, restoreInfo *RestoreInfo) error {
return svc.replayBinlog(ctx, originDatabaseName, targetDatabaseName, restoreInfo)
}
func (svc *DbInstanceSvcImpl) getServerVariable(_ context.Context, varName string) (string, error) {
conn, err := svc.dbInfo.Conn()
if err != nil {
return "", err
}
defer conn.Close()
func (svc *DbProgramMysql) getServerVariable(_ context.Context, varName string) (string, error) {
query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", varName)
_, rows, err := conn.Query(query)
_, rows, err := svc.dbConn.Query(query)
if err != nil {
return "", err
}
@@ -754,7 +607,7 @@ func (svc *DbInstanceSvcImpl) getServerVariable(_ context.Context, varName strin
}
// CheckBinlogEnabled checks whether binlog is enabled for the current instance.
func (svc *DbInstanceSvcImpl) CheckBinlogEnabled(ctx context.Context) error {
func (svc *DbProgramMysql) CheckBinlogEnabled(ctx context.Context) error {
value, err := svc.getServerVariable(ctx, "log_bin")
if err != nil {
return err
@@ -766,7 +619,7 @@ func (svc *DbInstanceSvcImpl) CheckBinlogEnabled(ctx context.Context) error {
}
// CheckBinlogRowFormat checks whether the binlog format is ROW.
func (svc *DbInstanceSvcImpl) CheckBinlogRowFormat(ctx context.Context) error {
func (svc *DbProgramMysql) CheckBinlogRowFormat(ctx context.Context) error {
value, err := svc.getServerVariable(ctx, "binlog_format")
if err != nil {
return err
@@ -790,19 +643,19 @@ func runCmd(cmd *exec.Cmd) error {
return nil
}
func (svc *DbInstanceSvcImpl) execute(database string, sql string) error {
func (svc *DbProgramMysql) execute(database string, sql string) error {
args := []string{
"--host", svc.dbInfo.Host,
"--port", strconv.Itoa(svc.dbInfo.Port),
"--user", svc.dbInfo.Username,
"--password=" + svc.dbInfo.Password,
"--host", svc.dbInfo().Host,
"--port", strconv.Itoa(svc.dbInfo().Port),
"--user", svc.dbInfo().Username,
"--password=" + svc.dbInfo().Password,
"--execute", sql,
}
if len(database) > 0 {
args = append(args, database)
}
cmd := exec.Command(mysqlPath(), args...)
cmd := exec.Command(svc.getMysqlBin().MysqlPath, args...)
logx.Debug("execute sql using mysql binary: ", cmd.String())
if err := runCmd(cmd); err != nil {
logx.Errorf("运行 mysql 程序失败: %v", err)
@@ -814,8 +667,8 @@ func (svc *DbInstanceSvcImpl) execute(database string, sql string) error {
// sortBinlogFiles will sort binlog files in ascending order by their numeric extension.
// For mysql binlog, after the serial number reaches 999999, the next serial number will not return to 000000, but 1000000,
// so we cannot directly use string to compare lexicographical order.
func sortBinlogFiles(binlogFiles []*BinlogFile) []*BinlogFile {
var sorted []*BinlogFile
func sortBinlogFiles(binlogFiles []*entity.BinlogFile) []*entity.BinlogFile {
var sorted []*entity.BinlogFile
sorted = append(sorted, binlogFiles...)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Sequence < sorted[j].Sequence
@@ -879,20 +732,23 @@ func ParseBinlogName(name string) (string, int64, error) {
return s[0], seq, nil
}
// formatDateTime formats the timestamp to the local time string.
func formatDateTime(t time.Time) string {
t = t.Local()
return fmt.Sprintf("%d-%d-%d %d:%d:%d", t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second())
// getBinlogDir gets the binlogDir.
func (svc *DbProgramMysql) getBinlogDir(instanceId uint64) string {
return filepath.Join(
svc.getBackupPath(),
fmt.Sprintf("instance-%d", instanceId),
"binlog")
}
func mysqlPath() string {
return config.GetMysqlBin().MysqlPath
func (svc *DbProgramMysql) getDbInstanceBackupRoot(instanceId uint64) string {
return filepath.Join(
svc.getBackupPath(),
fmt.Sprintf("instance-%d", instanceId))
}
func mysqldumpPath() string {
return config.GetMysqlBin().MysqldumpPath
}
func mysqlbinlogPath() string {
return config.GetMysqlBin().MysqlbinlogPath
func (svc *DbProgramMysql) getDbBackupDir(instanceId, backupId uint64) string {
return filepath.Join(
svc.getBackupPath(),
fmt.Sprintf("instance-%d", instanceId),
fmt.Sprintf("backup-%d", backupId))
}

View File

@@ -1,19 +1,19 @@
//go:build e2e
package service
package dbm
import (
"context"
"errors"
"fmt"
"github.com/stretchr/testify/suite"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/infrastructure/persistence"
"mayfly-go/pkg/config"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
@@ -30,23 +30,25 @@ const (
type DbInstanceSuite struct {
suite.Suite
instance *entity.DbInstance
repositories *repository.Repositories
instanceSvc *DbInstanceSvcImpl
instanceSvc *DbProgramMysql
dbConn *DbConn
}
func (s *DbInstanceSuite) SetupSuite() {
if err := chdir("mayfly-go", "server"); err != nil {
panic(err)
}
config.Init()
s.instance = &entity.DbInstance{
Type: dbm.DbTypeMysql,
dbInfo := DbInfo{
Type: DbTypeMysql,
Host: "localhost",
Port: 3306,
Username: "test",
Password: "123456",
Password: "test",
}
dbConn, err := dbInfo.Conn()
s.Require().NoError(err)
s.dbConn = dbConn
s.repositories = &repository.Repositories{
Instance: persistence.GetInstanceRepo(),
Backup: persistence.NewDbBackupRepo(),
@@ -56,7 +58,26 @@ func (s *DbInstanceSuite) SetupSuite() {
Binlog: persistence.NewDbBinlogRepo(),
BinlogHistory: persistence.NewDbBinlogHistoryRepo(),
}
s.instanceSvc = NewDbInstanceSvc(s.instance, s.repositories)
s.instanceSvc = NewDbProgramMysql(s.dbConn)
var extName string
if runtime.GOOS == "windows" {
extName = ".exe"
}
path := "db/mysql/bin"
s.instanceSvc.mysqlBin = &config.MysqlBin{
Path: filepath.Join(path),
MysqlPath: filepath.Join(path, "mysql"+extName),
MysqldumpPath: filepath.Join(path, "mysqldump"+extName),
MysqlbinlogPath: filepath.Join(path, "mysqlbinlog"+extName),
}
s.instanceSvc.backupPath = "db/backup"
}
func (s *DbInstanceSuite) TearDownSuite() {
if s.dbConn != nil {
s.dbConn.Close()
s.dbConn = nil
}
}
func (s *DbInstanceSuite) SetupTest() {
@@ -72,7 +93,7 @@ func (s *DbInstanceSuite) TearDownTest() {
sql := fmt.Sprintf("drop database if exists `%s`", dbNameBackupTest)
require.NoError(s.instanceSvc.execute("", sql))
_ = os.RemoveAll(getDbInstanceBackupRoot(instanceIdTest))
_ = os.RemoveAll(s.instanceSvc.getDbInstanceBackupRoot(instanceIdTest))
}
func (s *DbInstanceSuite) TestBackup() {
@@ -89,7 +110,7 @@ func (s *DbInstanceSuite) testBackup(backupHistory *entity.DbBackupHistory) {
binlogInfo, err := s.instanceSvc.Backup(context.Background(), backupHistory)
require.NoError(err)
fileName := filepath.Join(getDbBackupDir(s.instance.Id, backupHistory.Id), dbNameBackupTest+".sql")
fileName := filepath.Join(s.instanceSvc.getDbBackupDir(s.dbConn.Info.InstanceId, backupHistory.Id), dbNameBackupTest+".sql")
_, err = os.Stat(fileName)
require.NoError(err)
@@ -165,7 +186,7 @@ func (s *DbInstanceSuite) testReplayBinlog(backupHistory *entity.DbBackupHistory
require.NoError(err)
binlogFileLast := binlogFilesOnServerSorted[len(binlogFilesOnServerSorted)-1]
position, err := getBinlogEventPositionAtOrAfterTime(context.Background(), s.instanceSvc.getBinlogFilePath(binlogFileLast.Name), targetTime)
position, err := s.instanceSvc.GetBinlogEventPositionAtOrAfterTime(context.Background(), binlogFileLast.Name, targetTime)
require.NoError(err)
binlogHistories := make([]*entity.DbBinlogHistory, 0, 2)
binlogHistoryBackup := &entity.DbBinlogHistory{
@@ -183,21 +204,19 @@ func (s *DbInstanceSuite) testReplayBinlog(backupHistory *entity.DbBackupHistory
}
restoreInfo := &RestoreInfo{
backupHistory: backupHistory,
binlogHistories: binlogHistories,
startPosition: backupHistory.BinlogPosition,
targetPosition: position,
targetTime: targetTime,
BackupHistory: backupHistory,
BinlogHistories: binlogHistories,
StartPosition: backupHistory.BinlogPosition,
TargetPosition: position,
TargetTime: targetTime,
}
err = s.instanceSvc.ReplayBinlogToDatabase(context.Background(), dbNameBackupTest, dbNameBackupTest, restoreInfo)
err = s.instanceSvc.ReplayBinlog(context.Background(), dbNameBackupTest, dbNameBackupTest, restoreInfo)
require.NoError(err)
}
func (s *DbInstanceSuite) testRestore(backupHistory *entity.DbBackupHistory) {
require := s.Require()
fileName := filepath.Join(getDbBackupDir(instanceIdTest, backupIdTest),
fmt.Sprintf("%v.sql", dbNameBackupTest))
err := s.instanceSvc.RestoreBackup(context.Background(), dbNameBackupTest, fileName)
err := s.instanceSvc.RestoreBackupHistory(context.Background(), backupHistory.DbName, backupHistory.DbBackupId, backupHistory.Uuid)
require.NoError(err)
}

View File

@@ -1,4 +1,4 @@
package service
package dbm
import (
"github.com/stretchr/testify/require"

View File

@@ -16,6 +16,14 @@ const (
DM DbType = "dm"
)
func ToDbType(dbType string) DbType {
return DbType(dbType)
}
func (dbType DbType) Equal(typ string) bool {
return ToDbType(typ) == dbType
}
func (dbType DbType) MetaDbName() string {
switch dbType {
case DbTypeMysql:

View File

@@ -79,6 +79,9 @@ type DbDialect interface {
WalkTableRecord(tableName string, walk func(record map[string]any, columns []*QueryColumn)) error
GetSchemas() ([]string, error)
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
GetDbProgram() DbProgram
}
// ------------------------- 元数据sql操作 -------------------------

View File

@@ -68,8 +68,8 @@ func (dd *DMDialect) GetDbServer() (*DbServer, error) {
return ds, nil
}
func (pd *DMDialect) GetDbNames() ([]string, error) {
_, res, err := pd.dc.Query("SELECT name AS DBNAME FROM v$database")
func (dd *DMDialect) GetDbNames() ([]string, error) {
_, res, err := dd.dc.Query("SELECT name AS DBNAME FROM v$database")
if err != nil {
return nil, err
}
@@ -83,13 +83,13 @@ func (pd *DMDialect) GetDbNames() ([]string, error) {
}
// 获取表基础元信息, 如表名等
func (pd *DMDialect) GetTables() ([]Table, error) {
func (dd *DMDialect) GetTables() ([]Table, error) {
// 首先执行更新统计信息sql 这个统计信息在数据量比较大的时候就比较耗时,所以最好定时执行
// _, _, err := pd.dc.Query("dbms_stats.GATHER_SCHEMA_stats(SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))")
// 查询表信息
_, res, err := pd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
_, res, err := dd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
@@ -109,7 +109,7 @@ func (pd *DMDialect) GetTables() ([]Table, error) {
}
// 获取列元信息, 如列名等
func (pd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
tableName := ""
for i := 0; i < len(tableNames); i++ {
if i != 0 {
@@ -118,7 +118,7 @@ func (pd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
tableName = tableName + "'" + tableNames[i] + "'"
}
_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
@@ -139,8 +139,8 @@ func (pd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
return columns, nil
}
func (pd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
columns, err := pd.GetColumns(tablename)
func (dd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
columns, err := dd.GetColumns(tablename)
if err != nil {
return "", err
}
@@ -157,8 +157,8 @@ func (pd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
}
// 获取表索引信息
func (pd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
}
@@ -194,9 +194,9 @@ func (pd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
}
// 获取建表ddl
func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
func (dd *DMDialect) GetTableDDL(tableName string) (string, error) {
ddlSql := fmt.Sprintf("CALL SP_TABLEDEF((SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID)), '%s')", tableName)
_, res, err := pd.dc.Query(ddlSql)
_, res, err := dd.dc.Query(ddlSql)
if err != nil {
return "", err
}
@@ -207,7 +207,7 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
}
// 表注释
_, res, err = pd.dc.Query(fmt.Sprintf(`
_, res, err = dd.dc.Query(fmt.Sprintf(`
select OWNER, COMMENTS from DBA_TAB_COMMENTS where TABLE_TYPE='TABLE' and TABLE_NAME = '%s'
and owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))
`, tableName))
@@ -229,7 +229,7 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
WHERE OWNER = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))
AND TABLE_NAME = '%s'
`, tableName)
_, res, err = pd.dc.Query(fieldSql)
_, res, err = dd.dc.Query(fieldSql)
if err != nil {
return "", err
}
@@ -251,7 +251,7 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
and a.table_name = '%s'
and indexdef(b.object_id,1) != '禁止查看系统定义的索引信息'
`, tableName)
_, res, err = pd.dc.Query(indexSql)
_, res, err = dd.dc.Query(indexSql)
if err != nil {
return "", err
}
@@ -262,18 +262,18 @@ func (pd *DMDialect) GetTableDDL(tableName string) (string, error) {
return builder.String(), nil
}
func (pd *DMDialect) GetTableRecord(tableName string, pageNum, pageSize int) ([]*QueryColumn, []map[string]any, error) {
return pd.dc.Query(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
func (dd *DMDialect) GetTableRecord(tableName string, pageNum, pageSize int) ([]*QueryColumn, []map[string]any, error) {
return dd.dc.Query(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
}
func (pd *DMDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []*QueryColumn)) error {
return pd.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
func (dd *DMDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []*QueryColumn)) error {
return dd.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
}
// 获取DM当前连接的库可访问的schemaNames
func (pd *DMDialect) GetSchemas() ([]string, error) {
func (dd *DMDialect) GetSchemas() ([]string, error) {
sql := GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
_, res, err := pd.dc.Query(sql)
_, res, err := dd.dc.Query(sql)
if err != nil {
return nil, err
}
@@ -283,3 +283,8 @@ func (pd *DMDialect) GetSchemas() ([]string, error) {
}
return schemaNames, nil
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (dd *DMDialect) GetDbProgram() DbProgram {
panic("implement me")
}

View File

@@ -194,6 +194,11 @@ func (md *MysqlDialect) WalkTableRecord(tableName string, walk func(record map[s
return md.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
}
func (pd *MysqlDialect) GetSchemas() ([]string, error) {
func (md *MysqlDialect) GetSchemas() ([]string, error) {
return nil, nil
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (md *MysqlDialect) GetDbProgram() DbProgram {
return NewDbProgramMysql(md.dc)
}

View File

@@ -277,3 +277,8 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
}
return schemaNames, nil
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (pd *PgsqlDialect) GetDbProgram() DbProgram {
panic("implement me")
}

View File

@@ -1,77 +1,39 @@
package entity
import (
"mayfly-go/pkg/model"
"time"
)
var _ DbTask = (*DbBackup)(nil)
// DbBackup 数据库备份任务
type DbBackup struct {
model.Model
*DbTaskBase
Name string `gorm:"column(db_name)" json:"name"` // 备份任务名称
DbName string `gorm:"column(db_name)" json:"dbName"` // 数据库名
StartTime time.Time `gorm:"column(start_time)" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
Interval time.Duration `gorm:"column(interval)" json:"interval"` // 间隔时间: 为零表示单次执行,为正表示反复执行
Enabled bool `gorm:"column(enabled)" json:"enabled"` // 是否启用
Finished bool `gorm:"column(finished)" json:"finished"` // 是否完成
Repeated bool `gorm:"column(repeated)" json:"repeated"` // 是否重复执行
LastStatus TaskStatus `gorm:"column(last_status)" json:"lastStatus"` // 最近一次执行状态
LastResult string `gorm:"column(last_result)" json:"lastResult"` // 最近一次执行结果
LastTime time.Time `gorm:"column(last_time)" json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
DbInstanceId uint64 `gorm:"column(db_instance_id)" json:"dbInstanceId"`
Deadline time.Time `gorm:"-" json:"-"`
Name string `json:"name"` // 备份任务名称
DbName string `json:"dbName"` // 数据库名
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
func (d *DbBackup) TableName() string {
return "t_db_backup"
}
func (d *DbBackup) GetId() uint64 {
if d == nil {
return 0
var (
backupResult = map[TaskStatus]string{
TaskDelay: "等待备份数据库",
TaskReady: "准备备份数据库",
TaskReserved: "数据库备份中",
TaskSuccess: "数据库备份成功",
TaskFailed: "数据库备份失败",
}
return d.Id
}
)
func (d *DbBackup) GetDeadline() time.Time {
return d.Deadline
}
func (d *DbBackup) Schedule() bool {
if d.Finished || !d.Enabled {
return false
}
switch d.LastStatus {
func (*DbBackup) MessageWithStatus(status TaskStatus) string {
var result string
switch status {
case TaskDelay:
result = "等待备份数据库"
case TaskReady:
result = "准备备份数据库"
case TaskReserved:
result = "数据库备份中"
case TaskSuccess:
if d.Interval == 0 {
return false
}
lastTime := d.LastTime
if d.LastTime.Sub(d.StartTime) < 0 {
lastTime = d.StartTime.Add(-d.Interval)
}
d.Deadline = lastTime.Add(d.Interval - d.LastTime.Sub(d.StartTime)%d.Interval)
result = "数据库备份成功"
case TaskFailed:
d.Deadline = time.Now().Add(time.Minute)
default:
d.Deadline = d.StartTime
result = "数据库备份失败"
}
return true
}
func (d *DbBackup) IsFinished() bool {
return !d.Repeated && d.LastStatus == TaskSuccess
}
func (d *DbBackup) Update(task DbTask) bool {
switch t := task.(type) {
case *DbBackup:
d.StartTime = t.StartTime
d.Interval = t.Interval
return true
}
return false
return result
}

View File

@@ -2,86 +2,34 @@ package entity
import (
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/timex"
"time"
)
const BinlogDownloadInterval = time.Minute * 15
var _ DbTask = (*DbBinlog)(nil)
// DbBinlog 数据库备份任务
type DbBinlog struct {
model.Model
StartTime time.Time `gorm:"column(start_time)" json:"startTime"` // 开始时间: 2023-11-08 02:00:00
Interval time.Duration `gorm:"column(interval)" json:"interval"` // 间隔时间: 为零表示单次执行,为正表示反复执行
Enabled bool `gorm:"column(enabled)" json:"enabled"` // 是否启用
LastStatus TaskStatus `gorm:"column(last_status)" json:"lastStatus"` // 最近一次执行状态
LastResult string `gorm:"column(last_result)" json:"lastResult"` // 最近一次执行结果
LastTime time.Time `gorm:"column(last_time)" json:"lastTime"` // 最近一次执行时间: 2023-11-08 02:00:00
DbInstanceId uint64 `gorm:"column(db_instance_id)" json:"dbInstanceId"`
Deadline time.Time `gorm:"-" json:"-"`
LastStatus TaskStatus // 最近一次执行状态
LastResult string // 最近一次执行结果
LastTime timex.NullTime // 最近一次执行时间
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
func NewDbBinlog(history *DbBackupHistory) *DbBinlog {
binlogTask := &DbBinlog{
StartTime: time.Now(),
Enabled: true,
Interval: BinlogDownloadInterval,
DbInstanceId: history.DbInstanceId,
LastTime: time.Now(),
}
binlogTask.Id = binlogTask.DbInstanceId
func NewDbBinlog(instanceId uint64) *DbBinlog {
binlogTask := &DbBinlog{}
binlogTask.Id = instanceId
binlogTask.DbInstanceId = instanceId
return binlogTask
}
func (d *DbBinlog) TableName() string {
return "t_db_binlog"
}
// BinlogFile is the metadata of the MySQL binlog file.
type BinlogFile struct {
Name string
Size int64
func (d *DbBinlog) GetId() uint64 {
if d == nil {
return 0
}
return d.Id
}
func (d *DbBinlog) GetDeadline() time.Time {
return d.Deadline
}
func (d *DbBinlog) Schedule() bool {
if !d.Enabled {
return false
}
switch d.LastStatus {
case TaskSuccess:
if d.Interval == 0 {
return false
}
lastTime := d.LastTime
if d.LastTime.Sub(d.StartTime) < 0 {
lastTime = d.StartTime.Add(-d.Interval)
}
d.Deadline = lastTime.Add(d.Interval - d.LastTime.Sub(d.StartTime)%d.Interval)
case TaskFailed:
d.Deadline = time.Now().Add(time.Minute)
default:
d.Deadline = d.StartTime
}
return true
}
func (d *DbBinlog) IsFinished() bool {
return false
}
func (d *DbBinlog) Update(task DbTask) bool {
switch t := task.(type) {
case *DbBinlog:
d.StartTime = t.StartTime
d.Interval = t.Interval
return true
}
return false
// Sequence is parsed from Name and is for the sorting purpose.
Sequence int64
FirstEventTime time.Time
Downloaded bool
}

View File

@@ -1,25 +1,25 @@
package entity
import (
"errors"
"fmt"
"mayfly-go/internal/common/utils"
"mayfly-go/internal/db/dbm"
"mayfly-go/pkg/model"
)
type DbInstance struct {
model.Model
Name string `orm:"column(name)" json:"name"`
Type dbm.DbType `orm:"column(type)" json:"type"` // 类型mysql oracle等
Host string `orm:"column(host)" json:"host"`
Port int `orm:"column(port)" json:"port"`
Network string `orm:"column(network)" json:"network"`
Username string `orm:"column(username)" json:"username"`
Password string `orm:"column(password)" json:"-"`
Params string `orm:"column(params)" json:"params"`
Remark string `orm:"column(remark)" json:"remark"`
SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id
Name string `json:"name"`
Type string `json:"type"` // 类型mysql oracle等
Host string `json:"host"`
Port int `json:"port"`
Network string `json:"network"`
Username string `json:"username"`
Password string `json:"-"`
Params string `json:"params"`
Remark string `json:"remark"`
SshTunnelMachineId int `json:"sshTunnelMachineId"` // ssh隧道机器id
}
func (d *DbInstance) TableName() string {
@@ -39,12 +39,22 @@ func (d *DbInstance) GetNetwork() string {
return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId)
}
func (d *DbInstance) PwdEncrypt() {
func (d *DbInstance) PwdEncrypt() error {
// 密码替换为加密后的密码
d.Password = utils.PwdAesEncrypt(d.Password)
password, err := utils.PwdAesEncrypt(d.Password)
if err != nil {
return errors.New("加密数据库密码失败")
}
d.Password = password
return nil
}
func (d *DbInstance) PwdDecrypt() {
func (d *DbInstance) PwdDecrypt() error {
// 密码替换为解密后的密码
d.Password = utils.PwdAesDecrypt(d.Password)
password, err := utils.PwdAesDecrypt(d.Password)
if err != nil {
return errors.New("解密数据库密码失败")
}
d.Password = password
return nil
}

View File

@@ -1,80 +1,36 @@
package entity
import (
"mayfly-go/pkg/model"
"time"
"mayfly-go/pkg/utils/timex"
)
var _ DbTask = (*DbRestore)(nil)
// DbRestore 数据库恢复任务
type DbRestore struct {
model.Model
*DbTaskBase
DbName string `gorm:"column(db_name)" json:"dbName"` // 数据库名
StartTime time.Time `gorm:"column(start_time)" json:"startTime"` // 开始时间
Interval time.Duration `gorm:"column(interval)" json:"interval"` // 间隔时间: 为零表示单次执行,为正表示反复执行
Enabled bool `gorm:"column(enabled)" json:"enabled"` // 是否启用
Finished bool `gorm:"column(finished)" json:"finished"` // 是否完成
Repeated bool `gorm:"column(repeated)" json:"repeated"` // 是否重复执行
LastStatus TaskStatus `gorm:"column(last_status)" json:"lastStatus"` // 最近一次执行状态
LastResult string `gorm:"column(last_result)" json:"lastResult"` // 最近一次执行结果
LastTime time.Time `gorm:"column(last_time)" json:"lastTime"` // 最近一次执行时间
PointInTime time.Time `gorm:"column(point_in_time)" json:"pointInTime"` // 指定数据库恢复的时间点
DbBackupId uint64 `gorm:"column(db_backup_id)" json:"dbBackupId"` // 用于恢复的数据库备份任务ID
DbBackupHistoryId uint64 `gorm:"column(db_backup_history_id)" json:"dbBackupHistoryId"` // 用于恢复的数据库备份历史ID
DbBackupHistoryName string `gorm:"column(db_backup_history_name) json:"dbBackupHistoryName"` // 数据库备份历史名称
DbInstanceId uint64 `gorm:"column(db_instance_id)" json:"dbInstanceId"`
Deadline time.Time `gorm:"-" json:"-"`
DbName string `json:"dbName"` // 数据库名
PointInTime timex.NullTime `json:"pointInTime"` // 指定数据库恢复的时间
DbBackupId uint64 `json:"dbBackupId"` // 用于恢复的数据库恢复任务ID
DbBackupHistoryId uint64 `json:"dbBackupHistoryId"` // 用于恢复的数据库恢复历史ID
DbBackupHistoryName string `json:"dbBackupHistoryName"` // 数据库恢复历史名称
DbInstanceId uint64 `json:"dbInstanceId"` // 数据库实例ID
}
func (d *DbRestore) TableName() string {
return "t_db_restore"
}
func (d *DbRestore) GetId() uint64 {
if d == nil {
return 0
}
return d.Id
}
func (d *DbRestore) GetDeadline() time.Time {
return d.Deadline
}
func (d *DbRestore) Schedule() bool {
if d.Finished || !d.Enabled {
return false
}
switch d.LastStatus {
func (*DbRestore) MessageWithStatus(status TaskStatus) string {
var result string
switch status {
case TaskDelay:
result = "等待恢复数据库"
case TaskReady:
result = "准备恢复数据库"
case TaskReserved:
result = "数据库恢复中"
case TaskSuccess:
if d.Interval == 0 {
return false
}
lastTime := d.LastTime
if d.LastTime.Sub(d.StartTime) < 0 {
lastTime = d.StartTime.Add(-d.Interval)
}
d.Deadline = lastTime.Add(d.Interval - d.LastTime.Sub(d.StartTime)%d.Interval)
result = "数据库恢复成功"
case TaskFailed:
d.Deadline = time.Now().Add(time.Minute)
default:
d.Deadline = d.StartTime
result = "数据库恢复失败"
}
return true
}
func (d *DbRestore) IsFinished() bool {
return !d.Repeated && d.LastStatus == TaskSuccess
}
func (d *DbRestore) Update(task DbTask) bool {
switch backup := task.(type) {
case *DbRestore:
d.StartTime = backup.StartTime
d.Interval = backup.Interval
return true
}
return false
return result
}

View File

@@ -0,0 +1,109 @@
package entity
import (
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/timex"
"time"
)
type TaskStatus int
const (
TaskDelay TaskStatus = iota
TaskReady
TaskReserved
TaskSuccess
TaskFailed
)
const LastResultSize = 256
type DbTask interface {
model.ModelI
GetId() uint64
GetDeadline() time.Time
IsFinished() bool
Schedule() bool
Update(task DbTask)
GetTaskBase() *DbTaskBase
MessageWithStatus(status TaskStatus) string
IsEnabled() bool
}
func NewDbBTaskBase(enabled bool, repeated bool, startTime time.Time, interval time.Duration) *DbTaskBase {
return &DbTaskBase{
Enabled: enabled,
Repeated: repeated,
StartTime: startTime,
Interval: interval,
}
}
type DbTaskBase struct {
model.Model
Enabled bool // 是否启用
StartTime time.Time // 开始时间
Interval time.Duration // 间隔时间
Repeated bool // 是否重复执行
LastStatus TaskStatus // 最近一次执行状态
LastResult string // 最近一次执行结果
LastTime timex.NullTime // 最近一次执行时间
Deadline time.Time `gorm:"-" json:"-"` // 计划执行时间
}
func (d *DbTaskBase) GetId() uint64 {
if d == nil {
return 0
}
return d.Id
}
func (d *DbTaskBase) GetDeadline() time.Time {
return d.Deadline
}
func (d *DbTaskBase) Schedule() bool {
if d.IsFinished() || !d.Enabled {
return false
}
switch d.LastStatus {
case TaskSuccess:
if d.Interval == 0 {
return false
}
lastTime := d.LastTime.Time
if lastTime.Sub(d.StartTime) < 0 {
lastTime = d.StartTime.Add(-d.Interval)
}
d.Deadline = lastTime.Add(d.Interval - lastTime.Sub(d.StartTime)%d.Interval)
case TaskFailed:
d.Deadline = time.Now().Add(time.Minute)
default:
d.Deadline = d.StartTime
}
return true
}
func (d *DbTaskBase) IsFinished() bool {
return !d.Repeated && d.LastStatus == TaskSuccess
}
func (d *DbTaskBase) Update(task DbTask) {
t := task.GetTaskBase()
d.StartTime = t.StartTime
d.Interval = t.Interval
}
func (d *DbTaskBase) GetTaskBase() *DbTaskBase {
return d
}
func (*DbTaskBase) MessageWithStatus(_ TaskStatus) string {
return ""
}
func (d *DbTaskBase) IsEnabled() bool {
return d.Enabled
}

View File

@@ -1,21 +0,0 @@
package entity
import "time"
type TaskStatus int
const (
TaskDelay TaskStatus = iota
TaskReady
TaskReserved
TaskSuccess
TaskFailed
)
type DbTask interface {
GetId() uint64
GetDeadline() time.Time
IsFinished() bool
Schedule() bool
Update(task DbTask) bool
}

View File

@@ -3,17 +3,14 @@ package repository
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/base"
"mayfly-go/pkg/model"
)
type DbBackup interface {
base.Repo[*entity.DbBackup]
DbTask[*entity.DbBackup]
// GetDbBackupList 分页获取数据信息列表
GetDbBackupList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
AddTask(ctx context.Context, tasks ...*entity.DbBackup) error
UpdateTaskStatus(ctx context.Context, task *entity.DbBackup) error
GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error)
UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
}

View File

@@ -10,6 +10,4 @@ type DbBinlog interface {
base.Repo[*entity.DbBinlog]
AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error
UpdateTaskStatus(ctx context.Context, task *entity.DbBinlog) error
UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
}

View File

@@ -12,5 +12,6 @@ type DbBinlogHistory interface {
GetHistories(instanceId uint64, start, target *entity.BinlogInfo) ([]*entity.DbBinlogHistory, error)
GetHistoryByTime(instanceId uint64, targetTime time.Time) (*entity.DbBinlogHistory, error)
GetLatestHistory(instanceId uint64) (*entity.DbBinlogHistory, bool, error)
InsertWithBinlogFiles(ctx context.Context, instanceId uint64, binlogFiles []*entity.BinlogFile) error
Upsert(ctx context.Context, history *entity.DbBinlogHistory) error
}

View File

@@ -3,17 +3,14 @@ package repository
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/base"
"mayfly-go/pkg/model"
)
type DbRestore interface {
base.Repo[*entity.DbRestore]
DbTask[*entity.DbRestore]
// GetDbRestoreList 分页获取数据信息列表
GetDbRestoreList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
AddTask(ctx context.Context, tasks ...*entity.DbRestore) error
UpdateTaskStatus(ctx context.Context, task *entity.DbRestore) error
GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error)
UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
}

View File

@@ -0,0 +1,17 @@
package repository
import (
"context"
"mayfly-go/pkg/base"
"mayfly-go/pkg/model"
)
type DbTask[T model.ModelI] interface {
base.Repo[T]
UpdateTaskStatus(ctx context.Context, task T) error
UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error
ListToDo() ([]T, error)
ListRepeating() ([]T, error)
AddTask(ctx context.Context, tasks ...T) error
}

View File

@@ -1,14 +0,0 @@
package service
import (
"context"
"mayfly-go/internal/db/domain/entity"
)
type DbBackupSvc interface {
AddTask(ctx context.Context, tasks ...*entity.DbBackup) error
UpdateTask(ctx context.Context, task *entity.DbBackup) error
DeleteTask(ctx context.Context, taskId uint64) error
EnableTask(ctx context.Context, taskId uint64) error
DisableTask(ctx context.Context, taskId uint64) error
}

View File

@@ -1,14 +0,0 @@
package service
import (
"context"
"mayfly-go/internal/db/domain/entity"
)
type DbBinlogSvc interface {
AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error
UpdateTask(ctx context.Context, task *entity.DbBinlog) error
DeleteTask(ctx context.Context, taskId uint64) error
EnableTask(ctx context.Context, taskId uint64) error
DisableTask(ctx context.Context, taskId uint64) error
}

View File

@@ -1,12 +0,0 @@
package service
import (
"context"
"mayfly-go/internal/db/domain/entity"
)
type DbInstanceSvc interface {
Backup(ctx context.Context, backupHistory *entity.DbBackupHistory) (*entity.BinlogInfo, error)
Restore(ctx context.Context, task *entity.DbRestore) error
FetchBinlogs(ctx context.Context, downloadLatestBinlogFile bool) error
}

View File

@@ -1,14 +0,0 @@
package service
import (
"context"
"mayfly-go/internal/db/domain/entity"
)
type DbRestoreSvc interface {
AddTask(ctx context.Context, tasks ...*entity.DbRestore) error
UpdateTask(ctx context.Context, task *entity.DbRestore) error
DeleteTask(ctx context.Context, taskId uint64) error
EnableTask(ctx context.Context, taskId uint64) error
DisableTask(ctx context.Context, taskId uint64) error
}

View File

@@ -7,7 +7,6 @@ import (
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices"
@@ -16,7 +15,8 @@ import (
var _ repository.DbBackup = (*dbBackupRepoImpl)(nil)
type dbBackupRepoImpl struct {
base.RepoImpl[*entity.DbBackup]
//base.RepoImpl[*entity.DbBackup]
dbTaskBase[*entity.DbBackup]
}
func NewDbBackupRepo() repository.DbBackup {
@@ -34,30 +34,6 @@ func (d *dbBackupRepoImpl) GetDbBackupList(condition *entity.DbBackupQuery, page
return gormx.PageQuery(qd, pageParam, toEntity)
}
func (d *dbBackupRepoImpl) UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error {
cond := map[string]any{
"id": taskId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}
func (d *dbBackupRepoImpl) UpdateTaskStatus(ctx context.Context, task *entity.DbBackup) error {
task = &entity.DbBackup{
Model: model.Model{
DeletedModel: model.DeletedModel{
Id: task.Id,
},
},
Finished: task.Finished,
LastStatus: task.LastStatus,
LastResult: task.LastResult,
LastTime: task.LastTime,
}
return d.UpdateById(ctx, task)
}
func (d *dbBackupRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbBackup) error {
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
@@ -94,7 +70,7 @@ func (d *dbBackupRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbBacku
func (d *dbBackupRepoImpl) GetDbNamesWithoutBackup(instanceId uint64, dbNames []string) ([]string, error) {
var dbNamesWithBackup []string
query := gormx.NewQuery(d.M).
query := gormx.NewQuery(d.GetModel()).
Eq("db_instance_id", instanceId).
Eq("repeated", true)
if err := query.GenGdb().Pluck("db_name", &dbNamesWithBackup).Error; err != nil {

View File

@@ -57,5 +57,5 @@ func (repo *dbBackupHistoryRepoImpl) GetEarliestHistory(instanceId uint64) (*ent
if err != nil {
return nil, err
}
return history, err
return history, nil
}

View File

@@ -2,12 +2,12 @@ package persistence
import (
"context"
"fmt"
"gorm.io/gorm/clause"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
"mayfly-go/pkg/global"
"mayfly-go/pkg/model"
)
var _ repository.DbBinlog = (*dbBinlogRepoImpl)(nil)
@@ -20,29 +20,9 @@ func NewDbBinlogRepo() repository.DbBinlog {
return &dbBinlogRepoImpl{}
}
func (d *dbBinlogRepoImpl) UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error {
cond := map[string]any{
"id": taskId,
func (d *dbBinlogRepoImpl) AddTaskIfNotExists(_ context.Context, task *entity.DbBinlog) error {
if err := global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(task).Error; err != nil {
return fmt.Errorf("启动 binlog 下载失败: %w", err)
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}
func (d *dbBinlogRepoImpl) UpdateTaskStatus(ctx context.Context, task *entity.DbBinlog) error {
task = &entity.DbBinlog{
Model: model.Model{
DeletedModel: model.DeletedModel{
Id: task.Id,
},
},
LastStatus: task.LastStatus,
LastResult: task.LastResult,
LastTime: task.LastTime,
}
return d.UpdateById(ctx, task)
}
func (d *dbBinlogRepoImpl) AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error {
return global.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(task).Error
return nil
}

View File

@@ -56,14 +56,14 @@ func (repo *dbBinlogHistoryRepoImpl) GetHistories(instanceId uint64, start, targ
}
func (repo *dbBinlogHistoryRepoImpl) GetLatestHistory(instanceId uint64) (*entity.DbBinlogHistory, bool, error) {
gdb := gormx.NewQuery(repo.GetModel()).
history := &entity.DbBinlogHistory{}
err := gormx.NewQuery(repo.GetModel()).
Eq("db_instance_id", instanceId).
Undeleted().
OrderByDesc("sequence").
GenGdb()
history := &entity.DbBinlogHistory{}
switch err := gdb.First(history).Error; {
GenGdb().
First(history).Error
switch {
case err == nil:
return history, true, nil
case errors.Is(err, gorm.ErrRecordNotFound):
@@ -89,3 +89,35 @@ func (repo *dbBinlogHistoryRepoImpl) Upsert(_ context.Context, history *entity.D
}
})
}
func (repo *dbBinlogHistoryRepoImpl) InsertWithBinlogFiles(ctx context.Context, instanceId uint64, binlogFiles []*entity.BinlogFile) error {
if len(binlogFiles) == 0 {
return nil
}
histories := make([]*entity.DbBinlogHistory, 0, len(binlogFiles))
for _, fileOnServer := range binlogFiles {
if !fileOnServer.Downloaded {
break
}
history := &entity.DbBinlogHistory{
CreateTime: time.Now(),
FileName: fileOnServer.Name,
FileSize: fileOnServer.Size,
Sequence: fileOnServer.Sequence,
FirstEventTime: fileOnServer.FirstEventTime,
DbInstanceId: instanceId,
}
histories = append(histories, history)
}
if len(histories) > 1 {
if err := repo.BatchInsert(ctx, histories[:len(histories)-1]); err != nil {
return err
}
}
if len(histories) > 0 {
if err := repo.Upsert(ctx, histories[len(histories)-1]); err != nil {
return err
}
}
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"gorm.io/gorm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"slices"
@@ -16,7 +15,7 @@ import (
var _ repository.DbRestore = (*dbRestoreRepoImpl)(nil)
type dbRestoreRepoImpl struct {
base.RepoImpl[*entity.DbRestore]
dbTaskBase[*entity.DbRestore]
}
func NewDbRestoreRepo() repository.DbRestore {
@@ -34,21 +33,6 @@ func (d *dbRestoreRepoImpl) GetDbRestoreList(condition *entity.DbRestoreQuery, p
return gormx.PageQuery(qd, pageParam, toEntity)
}
func (d *dbRestoreRepoImpl) UpdateTaskStatus(ctx context.Context, task *entity.DbRestore) error {
task = &entity.DbRestore{
Model: model.Model{
DeletedModel: model.DeletedModel{
Id: task.Id,
},
},
Finished: task.Finished,
LastStatus: task.LastStatus,
LastResult: task.LastResult,
LastTime: task.LastTime,
}
return d.UpdateById(ctx, task)
}
func (d *dbRestoreRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbRestore) error {
return gormx.Tx(func(db *gorm.DB) error {
var instanceId uint64
@@ -85,7 +69,7 @@ func (d *dbRestoreRepoImpl) AddTask(ctx context.Context, tasks ...*entity.DbRest
func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames []string) ([]string, error) {
var dbNamesWithRestore []string
query := gormx.NewQuery(d.M).
query := gormx.NewQuery(d.GetModel()).
Eq("db_instance_id", instanceId).
Eq("repeated", true)
if err := query.GenGdb().Pluck("db_name", &dbNamesWithRestore).Error; err != nil {
@@ -99,12 +83,3 @@ func (d *dbRestoreRepoImpl) GetDbNamesWithoutRestore(instanceId uint64, dbNames
}
return result, nil
}
func (d *dbRestoreRepoImpl) UpdateEnabled(ctx context.Context, taskId uint64, enabled bool) error {
cond := map[string]any{
"id": taskId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}

View File

@@ -0,0 +1,52 @@
package persistence
import (
"context"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/base"
"mayfly-go/pkg/global"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
)
type dbTaskBase[T model.ModelI] struct {
base.RepoImpl[T]
}
func (d *dbTaskBase[T]) UpdateEnabled(_ context.Context, taskId uint64, enabled bool) error {
cond := map[string]any{
"id": taskId,
}
return d.Updates(cond, map[string]any{
"enabled": enabled,
})
}
func (d *dbTaskBase[T]) UpdateTaskStatus(ctx context.Context, task T) error {
return d.UpdateById(ctx, task, "last_status", "last_result", "last_time")
}
func (d *dbTaskBase[T]) ListToDo() ([]T, error) {
var tasks []T
db := global.Db.Model(d.GetModel())
err := db.Where("enabled = ?", true).
Where(db.Where("repeated = ?", true).Or("last_status <> ?", entity.TaskSuccess)).
Scopes(gormx.UndeleteScope).
Find(&tasks).Error
if err != nil {
return nil, err
}
return tasks, nil
}
func (d *dbTaskBase[T]) ListRepeating() ([]T, error) {
cond := map[string]any{
"enabled": true,
"repeated": true,
}
var tasks []T
if err := d.ListByCond(cond, &tasks); err != nil {
return nil, err
}
return tasks, nil
}

View File

@@ -1,212 +0,0 @@
package service
import (
"context"
"encoding/binary"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/domain/service"
"mayfly-go/pkg/model"
"time"
"github.com/google/uuid"
)
var _ service.DbBackupSvc = (*DbBackupSvcImpl)(nil)
type DbBackupSvcImpl struct {
repo repository.DbBackup
instanceRepo repository.Instance
scheduler *Scheduler[*entity.DbBackup]
binlogSvc service.DbBinlogSvc
}
func NewIncUUID() (uuid.UUID, error) {
var uid uuid.UUID
now, seq, err := uuid.GetTime()
if err != nil {
return uid, err
}
timeHi := uint32((now >> 28) & 0xffffffff)
timeMid := uint16((now >> 12) & 0xffff)
timeLow := uint16(now & 0x0fff)
timeLow |= 0x1000 // Version 1
binary.BigEndian.PutUint32(uid[0:], timeHi)
binary.BigEndian.PutUint16(uid[4:], timeMid)
binary.BigEndian.PutUint16(uid[6:], timeLow)
binary.BigEndian.PutUint16(uid[8:], seq)
copy(uid[10:], uuid.NodeID())
return uid, nil
}
func withRunBackupTask(repositories *repository.Repositories, binlogSvc service.DbBinlogSvc) SchedulerOption[*entity.DbBackup] {
return func(scheduler *Scheduler[*entity.DbBackup]) {
scheduler.RunTask = func(ctx context.Context, task *entity.DbBackup) error {
instance := new(entity.DbInstance)
if err := repositories.Instance.GetById(instance, task.DbInstanceId); err != nil {
return err
}
instance.PwdDecrypt()
id, err := NewIncUUID()
if err != nil {
return err
}
history := &entity.DbBackupHistory{
Uuid: id.String(),
DbBackupId: task.Id,
DbInstanceId: task.DbInstanceId,
DbName: task.DbName,
}
binlogInfo, err := NewDbInstanceSvc(instance, repositories).Backup(ctx, history)
if err != nil {
return err
}
now := time.Now()
name := task.Name
if len(name) == 0 {
name = task.DbName
}
history.Name = fmt.Sprintf("%s[%s]", name, now.Format(time.DateTime))
history.CreateTime = now
history.BinlogFileName = binlogInfo.FileName
history.BinlogSequence = binlogInfo.Sequence
history.BinlogPosition = binlogInfo.Position
if err := repositories.BackupHistory.Insert(ctx, history); err != nil {
return err
}
if err := binlogSvc.AddTaskIfNotExists(ctx, entity.NewDbBinlog(history)); err != nil {
return err
}
return nil
}
}
}
var (
backupResult = map[entity.TaskStatus]string{
entity.TaskDelay: "等待备份数据库",
entity.TaskReady: "准备备份数据库",
entity.TaskReserved: "数据库备份中",
entity.TaskSuccess: "数据库备份成功",
entity.TaskFailed: "数据库备份失败",
}
)
func withUpdateBackupStatus(repositories *repository.Repositories) SchedulerOption[*entity.DbBackup] {
return func(scheduler *Scheduler[*entity.DbBackup]) {
scheduler.UpdateTaskStatus = func(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbBackup) error {
task.Finished = !task.Repeated && status == entity.TaskSuccess
task.LastStatus = status
var result = backupResult[status]
if lastErr != nil {
result = fmt.Sprintf("%v: %v", backupResult[status], lastErr)
}
task.LastResult = result
task.LastTime = time.Now()
return repositories.Backup.UpdateTaskStatus(ctx, task)
}
}
}
func NewDbBackupSvc(repositories *repository.Repositories, binlogSvc service.DbBinlogSvc) (service.DbBackupSvc, error) {
scheduler, err := NewScheduler[*entity.DbBackup](
withRunBackupTask(repositories, binlogSvc),
withUpdateBackupStatus(repositories))
if err != nil {
return nil, err
}
svc := &DbBackupSvcImpl{
repo: repositories.Backup,
instanceRepo: repositories.Instance,
scheduler: scheduler,
binlogSvc: binlogSvc,
}
err = svc.loadTasks(context.Background())
if err != nil {
return nil, err
}
return svc, nil
}
func (svc *DbBackupSvcImpl) loadTasks(ctx context.Context) error {
tasks := make([]*entity.DbBackup, 0, 64)
cond := map[string]any{
"Enabled": true,
"Finished": false,
}
if err := svc.repo.ListByCond(cond, &tasks); err != nil {
return err
}
for _, task := range tasks {
svc.scheduler.PushTask(ctx, task)
}
return nil
}
func (svc *DbBackupSvcImpl) AddTask(ctx context.Context, tasks ...*entity.DbBackup) error {
for _, task := range tasks {
if err := svc.repo.AddTask(ctx, task); err != nil {
return err
}
svc.scheduler.PushTask(ctx, task)
}
return nil
}
func (svc *DbBackupSvcImpl) UpdateTask(ctx context.Context, task *entity.DbBackup) error {
if err := svc.repo.UpdateById(ctx, task); err != nil {
return err
}
svc.scheduler.UpdateTask(ctx, task)
return nil
}
func (svc *DbBackupSvcImpl) DeleteTask(ctx context.Context, taskId uint64) error {
// todo: 删除数据库备份历史文件
task := new(entity.DbBackup)
if err := svc.repo.GetById(task, taskId); err != nil {
return err
}
if err := svc.binlogSvc.DeleteTask(ctx, task.DbInstanceId); err != nil {
return err
}
if err := svc.repo.DeleteById(ctx, taskId); err != nil {
return err
}
svc.scheduler.RemoveTask(taskId)
return nil
}
func (svc *DbBackupSvcImpl) EnableTask(ctx context.Context, taskId uint64) error {
if err := svc.repo.UpdateEnabled(ctx, taskId, true); err != nil {
return err
}
task := new(entity.DbBackup)
if err := svc.repo.GetById(task, taskId); err != nil {
return err
}
svc.scheduler.UpdateTask(ctx, task)
return nil
}
func (svc *DbBackupSvcImpl) DisableTask(ctx context.Context, taskId uint64) error {
if err := svc.repo.UpdateEnabled(ctx, taskId, false); err != nil {
return err
}
task := new(entity.DbBackup)
if err := svc.repo.GetById(task, taskId); err != nil {
return err
}
svc.scheduler.RemoveTask(taskId)
return nil
}
// GetPageList 分页获取数据库备份任务
func (svc *DbBackupSvcImpl) GetPageList(condition *entity.DbBackupQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return svc.repo.GetDbBackupList(condition, pageParam, toEntity, orderBy...)
}

View File

@@ -1,140 +0,0 @@
package service
import (
"context"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/domain/service"
"time"
)
var _ service.DbBinlogSvc = (*DbBinlogSvcImpl)(nil)
type DbBinlogSvcImpl struct {
repo repository.DbBinlog
instanceRepo repository.Instance
scheduler *Scheduler[*entity.DbBinlog]
}
func withDownloadBinlog(repositories *repository.Repositories) SchedulerOption[*entity.DbBinlog] {
return func(scheduler *Scheduler[*entity.DbBinlog]) {
scheduler.RunTask = func(ctx context.Context, task *entity.DbBinlog) error {
instance := new(entity.DbInstance)
if err := repositories.Instance.GetById(instance, task.DbInstanceId); err != nil {
return err
}
instance.PwdDecrypt()
svc := NewDbInstanceSvc(instance, repositories)
err := svc.FetchBinlogs(ctx, false)
if err != nil {
return err
}
return nil
}
}
}
var (
binlogResult = map[entity.TaskStatus]string{
entity.TaskDelay: "等待备份BINLOG",
entity.TaskReady: "准备备份BINLOG",
entity.TaskReserved: "BINLOG备份中",
entity.TaskSuccess: "BINLOG备份成功",
entity.TaskFailed: "BINLOG备份失败",
}
)
func withUpdateBinlogStatus(repositories *repository.Repositories) SchedulerOption[*entity.DbBinlog] {
return func(scheduler *Scheduler[*entity.DbBinlog]) {
scheduler.UpdateTaskStatus = func(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbBinlog) error {
task.LastStatus = status
var result = backupResult[status]
if lastErr != nil {
result = fmt.Sprintf("%v: %v", binlogResult[status], lastErr)
}
task.LastResult = result
task.LastTime = time.Now()
return repositories.Binlog.UpdateTaskStatus(ctx, task)
}
}
}
func NewDbBinlogSvc(repositories *repository.Repositories) (service.DbBinlogSvc, error) {
scheduler, err := NewScheduler[*entity.DbBinlog](withDownloadBinlog(repositories), withUpdateBinlogStatus(repositories))
if err != nil {
return nil, err
}
svc := &DbBinlogSvcImpl{
repo: repositories.Binlog,
instanceRepo: repositories.Instance,
scheduler: scheduler,
}
err = svc.loadTasks(context.Background())
if err != nil {
return nil, err
}
return svc, nil
}
func (svc *DbBinlogSvcImpl) loadTasks(ctx context.Context) error {
tasks := make([]*entity.DbBinlog, 0, 64)
cond := map[string]any{
"Enabled": true,
}
if err := svc.repo.ListByCond(cond, &tasks); err != nil {
return err
}
for _, task := range tasks {
svc.scheduler.PushTask(ctx, task)
}
return nil
}
func (svc *DbBinlogSvcImpl) AddTaskIfNotExists(ctx context.Context, task *entity.DbBinlog) error {
if err := svc.repo.AddTaskIfNotExists(ctx, task); err != nil {
return err
}
if task.GetId() == 0 {
return nil
}
svc.scheduler.PushTask(ctx, task)
return nil
}
func (svc *DbBinlogSvcImpl) UpdateTask(ctx context.Context, task *entity.DbBinlog) error {
if err := svc.repo.UpdateById(ctx, task); err != nil {
return err
}
svc.scheduler.UpdateTask(ctx, task)
return nil
}
func (svc *DbBinlogSvcImpl) DeleteTask(ctx context.Context, taskId uint64) error {
// todo: 删除 Binlog 历史文件
if err := svc.repo.DeleteById(ctx, taskId); err != nil {
return err
}
svc.scheduler.RemoveTask(taskId)
return nil
}
func (svc *DbBinlogSvcImpl) EnableTask(ctx context.Context, taskId uint64) error {
if err := svc.repo.UpdateEnabled(ctx, taskId, true); err != nil {
return err
}
task := new(entity.DbBinlog)
if err := svc.repo.GetById(task, taskId); err != nil {
return err
}
svc.scheduler.UpdateTask(ctx, task)
return nil
}
func (svc *DbBinlogSvcImpl) DisableTask(ctx context.Context, taskId uint64) error {
if err := svc.repo.UpdateEnabled(ctx, taskId, false); err != nil {
return err
}
svc.scheduler.RemoveTask(taskId)
return nil
}

View File

@@ -1,155 +0,0 @@
package service
import (
"context"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/db/domain/service"
"mayfly-go/pkg/model"
"time"
)
var _ service.DbRestoreSvc = (*DbRestoreSvcImpl)(nil)
type DbRestoreSvcImpl struct {
repo repository.DbRestore
instanceRepo repository.Instance
scheduler *Scheduler[*entity.DbRestore]
}
func withRunRestoreTask(repositories *repository.Repositories) SchedulerOption[*entity.DbRestore] {
return func(scheduler *Scheduler[*entity.DbRestore]) {
scheduler.RunTask = func(ctx context.Context, task *entity.DbRestore) error {
instance := new(entity.DbInstance)
if err := repositories.Instance.GetById(instance, task.DbInstanceId); err != nil {
return err
}
instance.PwdDecrypt()
if err := NewDbInstanceSvc(instance, repositories).Restore(ctx, task); err != nil {
return err
}
history := &entity.DbRestoreHistory{
CreateTime: time.Now(),
DbRestoreId: task.Id,
}
if err := repositories.RestoreHistory.Insert(ctx, history); err != nil {
return err
}
return nil
}
}
}
var (
restoreResult = map[entity.TaskStatus]string{
entity.TaskDelay: "等待恢复数据库",
entity.TaskReady: "准备恢复数据库",
entity.TaskReserved: "数据库恢复中",
entity.TaskSuccess: "数据库恢复成功",
entity.TaskFailed: "数据库恢复失败",
}
)
func withUpdateRestoreStatus(repositories *repository.Repositories) SchedulerOption[*entity.DbRestore] {
return func(scheduler *Scheduler[*entity.DbRestore]) {
scheduler.UpdateTaskStatus = func(ctx context.Context, status entity.TaskStatus, lastErr error, task *entity.DbRestore) error {
task.Finished = !task.Repeated && status == entity.TaskSuccess
task.LastStatus = status
var result = restoreResult[status]
if lastErr != nil {
result = fmt.Sprintf("%v: %v", restoreResult[status], lastErr)
}
task.LastResult = result
task.LastTime = time.Now()
return repositories.Restore.UpdateTaskStatus(ctx, task)
}
}
}
func NewDbRestoreSvc(repositories *repository.Repositories) (service.DbRestoreSvc, error) {
scheduler, err := NewScheduler[*entity.DbRestore](
withRunRestoreTask(repositories),
withUpdateRestoreStatus(repositories))
if err != nil {
return nil, err
}
svc := &DbRestoreSvcImpl{
repo: repositories.Restore,
instanceRepo: repositories.Instance,
scheduler: scheduler,
}
if err := svc.loadTasks(context.Background()); err != nil {
return nil, err
}
return svc, nil
}
func (svc *DbRestoreSvcImpl) loadTasks(ctx context.Context) error {
tasks := make([]*entity.DbRestore, 0, 64)
cond := map[string]any{
"Enabled": true,
"Finished": false,
}
if err := svc.repo.ListByCond(cond, &tasks); err != nil {
return err
}
for _, task := range tasks {
svc.scheduler.PushTask(ctx, task)
}
return nil
}
func (svc *DbRestoreSvcImpl) AddTask(ctx context.Context, tasks ...*entity.DbRestore) error {
for _, task := range tasks {
if err := svc.repo.AddTask(ctx, task); err != nil {
return err
}
svc.scheduler.PushTask(ctx, task)
}
return nil
}
func (svc *DbRestoreSvcImpl) UpdateTask(ctx context.Context, task *entity.DbRestore) error {
if err := svc.repo.UpdateById(ctx, task); err != nil {
return err
}
svc.scheduler.UpdateTask(ctx, task)
return nil
}
func (svc *DbRestoreSvcImpl) DeleteTask(ctx context.Context, taskId uint64) error {
// todo: 删除数据库恢复历史文件
if err := svc.repo.DeleteById(ctx, taskId); err != nil {
return err
}
svc.scheduler.RemoveTask(taskId)
return nil
}
func (svc *DbRestoreSvcImpl) EnableTask(ctx context.Context, taskId uint64) error {
if err := svc.repo.UpdateEnabled(ctx, taskId, true); err != nil {
return err
}
task := new(entity.DbRestore)
if err := svc.repo.GetById(task, taskId); err != nil {
return err
}
svc.scheduler.UpdateTask(ctx, task)
return nil
}
func (svc *DbRestoreSvcImpl) DisableTask(ctx context.Context, taskId uint64) error {
if err := svc.repo.UpdateEnabled(ctx, taskId, false); err != nil {
return err
}
svc.scheduler.RemoveTask(taskId)
return nil
}
// GetPageList 分页获取数据库恢复任务
func (svc *DbRestoreSvcImpl) GetPageList(condition *entity.DbRestoreQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return svc.repo.GetDbRestoreList(condition, pageParam, toEntity, orderBy...)
}

View File

@@ -1,160 +0,0 @@
package service
import (
"context"
"errors"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/queue"
"sync"
"time"
)
type Scheduler[T entity.DbTask] struct {
mutex sync.Mutex
wg sync.WaitGroup
queue *queue.DelayQueue[T]
closed bool
curTask T
curTaskContext context.Context
curTaskCancel context.CancelFunc
UpdateTaskStatus func(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error
RunTask func(ctx context.Context, task T) error
}
type SchedulerOption[T entity.DbTask] func(*Scheduler[T])
func NewScheduler[T entity.DbTask](opts ...SchedulerOption[T]) (*Scheduler[T], error) {
scheduler := &Scheduler[T]{
queue: queue.NewDelayQueue[T](0),
}
for _, opt := range opts {
opt(scheduler)
}
if scheduler.RunTask == nil || scheduler.UpdateTaskStatus == nil {
return nil, errors.New("调度器没有设置 RunTask 或 UpdateTaskStatus")
}
scheduler.wg.Add(1)
go scheduler.run()
return scheduler, nil
}
func (m *Scheduler[T]) PushTask(ctx context.Context, task T) bool {
if !task.Schedule() {
return false
}
m.mutex.Lock()
defer m.mutex.Unlock()
return m.queue.Enqueue(ctx, task)
}
func (m *Scheduler[T]) UpdateTask(ctx context.Context, task T) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
if task.GetId() == m.curTask.GetId() {
return m.curTask.Update(task)
}
oldTask, ok := m.queue.Remove(ctx, task.GetId())
if ok {
if !oldTask.Update(task) {
return false
}
} else {
oldTask = task
}
if !oldTask.Schedule() {
return false
}
return m.queue.Enqueue(ctx, oldTask)
}
func (m *Scheduler[T]) updateCurTask(status entity.TaskStatus, lastErr error, task T) bool {
seconds := []time.Duration{time.Second * 1, time.Second * 8, time.Second * 64}
for _, second := range seconds {
if m.closed {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), second)
err := m.UpdateTaskStatus(ctx, status, lastErr, task)
cancel()
if err != nil {
logx.Errorf("保存任务失败: %v", err)
time.Sleep(second)
continue
}
return true
}
return false
}
func (m *Scheduler[T]) run() {
defer m.wg.Done()
var ctx context.Context
var cancel context.CancelFunc
for !m.closed {
m.mutex.Lock()
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
task, ok := m.queue.Dequeue(ctx)
cancel()
if !ok {
m.mutex.Unlock()
time.Sleep(time.Second)
continue
}
m.curTask = task
m.updateCurTask(entity.TaskReserved, nil, task)
m.curTaskContext, m.curTaskCancel = context.WithCancel(context.Background())
m.mutex.Unlock()
err := m.RunTask(m.curTaskContext, task)
m.mutex.Lock()
taskStatus := entity.TaskSuccess
if err != nil {
taskStatus = entity.TaskFailed
}
m.updateCurTask(taskStatus, err, task)
m.cancelCurTask()
task.Schedule()
if !task.IsFinished() {
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
m.queue.Enqueue(ctx, task)
cancel()
}
m.mutex.Unlock()
}
}
func (m *Scheduler[T]) Close() {
if m.closed {
return
}
m.mutex.Lock()
m.cancelCurTask()
m.closed = true
m.mutex.Unlock()
m.wg.Wait()
}
func (m *Scheduler[T]) RemoveTask(taskId uint64) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
m.queue.Remove(context.Background(), taskId)
if taskId == m.curTask.GetId() {
m.cancelCurTask()
}
return true
}
func (m *Scheduler[T]) cancelCurTask() {
if m.curTaskCancel != nil {
m.curTaskCancel()
m.curTaskCancel = nil
}
}

View File

@@ -26,6 +26,8 @@ func InitDbBackupRouter(router *gin.RouterGroup) {
req.NewPut(":dbId/backups/:backupId/enable", d.Enable).Log(req.NewLogSave("db-启用数据库备份任务")),
// 禁用数据库备份任务
req.NewPut(":dbId/backups/:backupId/disable", d.Disable).Log(req.NewLogSave("db-禁用数据库备份任务")),
// 立即启动数据库备份任务
req.NewPut(":dbId/backups/:backupId/start", d.Start).Log(req.NewLogSave("db-立即启动数据库备份任务")),
// 删除数据库备份任务
req.NewDelete(":dbId/backups/:backupId", d.Delete),
// 获取未配置定时备份的数据库名称

View File

@@ -84,7 +84,9 @@ func (m *machineAppImpl) Save(ctx context.Context, me *entity.Machine, tagIds ..
err := m.GetBy(oldMachine)
me.PwdEncrypt()
if errEnc := me.PwdEncrypt(); errEnc != nil {
return errorx.NewBiz(errEnc.Error())
}
if me.Id == 0 {
if err == nil {
return errorx.NewBiz("该机器信息已存在")
@@ -242,13 +244,17 @@ func (m *machineAppImpl) toMachineInfo(me *entity.Machine) (*mcm.MachineInfo, er
return nil, errorx.NewBiz("授权凭证信息已不存在,请重新关联")
}
mi.AuthMethod = ac.AuthMethod
ac.PwdDecrypt()
if err := ac.PwdDecrypt(); err != nil {
return nil, errorx.NewBiz(err.Error())
}
mi.Password = ac.Password
mi.Passphrase = ac.Passphrase
} else {
mi.AuthMethod = entity.AuthCertAuthMethodPassword
if me.Id != 0 {
me.PwdDecrypt()
if err := me.PwdDecrypt(); err != nil {
return nil, errorx.NewBiz(err.Error())
}
}
mi.Password = me.Password
}

View File

@@ -1,6 +1,7 @@
package entity
import (
"errors"
"mayfly-go/internal/common/utils"
"mayfly-go/pkg/model"
)
@@ -16,7 +17,7 @@ type AuthCert struct {
Remark string `json:"remark"`
}
func (a *AuthCert) TableName() string {
func (ac *AuthCert) TableName() string {
return "t_auth_cert"
}
@@ -28,14 +29,32 @@ const (
AuthCertTypePublic int8 = 2
)
// 密码加密
func (ac *AuthCert) PwdEncrypt() {
ac.Password = utils.PwdAesEncrypt(ac.Password)
ac.Passphrase = utils.PwdAesEncrypt(ac.Passphrase)
// PwdEncrypt 密码加密
func (ac *AuthCert) PwdEncrypt() error {
password, err := utils.PwdAesEncrypt(ac.Password)
if err != nil {
return errors.New("加密授权凭证密码失败")
}
passphrase, err := utils.PwdAesEncrypt(ac.Passphrase)
if err != nil {
return errors.New("加密授权凭证私钥失败")
}
ac.Password = password
ac.Passphrase = passphrase
return nil
}
// 密码解密
func (ac *AuthCert) PwdDecrypt() {
ac.Password = utils.PwdAesDecrypt(ac.Password)
ac.Passphrase = utils.PwdAesDecrypt(ac.Passphrase)
// PwdDecrypt 密码解密
func (ac *AuthCert) PwdDecrypt() error {
password, err := utils.PwdAesDecrypt(ac.Password)
if err != nil {
return errors.New("解密授权凭证密码失败")
}
passphrase, err := utils.PwdAesDecrypt(ac.Passphrase)
if err != nil {
return errors.New("解密授权凭证私钥失败")
}
ac.Password = password
ac.Passphrase = passphrase
return nil
}

View File

@@ -1,6 +1,7 @@
package entity
import (
"errors"
"mayfly-go/internal/common/utils"
"mayfly-go/pkg/model"
)
@@ -26,14 +27,24 @@ const (
MachineStatusDisable int8 = -1 // 禁用状态
)
func (m *Machine) PwdEncrypt() {
func (m *Machine) PwdEncrypt() error {
// 密码替换为加密后的密码
m.Password = utils.PwdAesEncrypt(m.Password)
password, err := utils.PwdAesEncrypt(m.Password)
if err != nil {
return errors.New("加密主机密码失败")
}
m.Password = password
return nil
}
func (m *Machine) PwdDecrypt() {
func (m *Machine) PwdDecrypt() error {
// 密码替换为解密后的密码
m.Password = utils.PwdAesDecrypt(m.Password)
password, err := utils.PwdAesDecrypt(m.Password)
if err != nil {
return errors.New("解密主机密码失败")
}
m.Password = password
return nil
}
func (m *Machine) UseAuthCert() bool {

View File

@@ -77,7 +77,9 @@ func (r *Redis) GetRedisPwd(rc *req.Ctx) {
rid := uint64(ginx.PathParamInt(rc.GinCtx, "id"))
re, err := r.RedisApp.GetById(new(entity.Redis), rid, "Password")
biz.ErrIsNil(err, "redis信息不存在")
re.PwdDecrypt()
if err := re.PwdDecrypt(); err != nil {
biz.ErrIsNil(err)
}
rc.ResData = re.Password
}

View File

@@ -80,7 +80,9 @@ func (r *redisAppImpl) Save(ctx context.Context, re *entity.Redis, tagIds ...uin
if err == nil {
return errorx.NewBiz("该实例已存在")
}
re.PwdEncrypt()
if errEnc := re.PwdEncrypt(); errEnc != nil {
return errorx.NewBiz(errEnc.Error())
}
resouceCode := stringx.Rand(16)
re.Code = resouceCode
@@ -108,7 +110,9 @@ func (r *redisAppImpl) Save(ctx context.Context, re *entity.Redis, tagIds ...uin
oldRedis, _ = r.GetById(new(entity.Redis), re.Id)
}
re.PwdEncrypt()
if errEnc := re.PwdEncrypt(); errEnc != nil {
return errorx.NewBiz(errEnc.Error())
}
return r.Tx(ctx, func(ctx context.Context) error {
return r.UpdateById(ctx, re)
}, func(ctx context.Context) error {
@@ -144,8 +148,9 @@ func (r *redisAppImpl) GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) {
if err != nil {
return nil, errorx.NewBiz("redis信息不存在")
}
re.PwdDecrypt()
if err := re.PwdDecrypt(); err != nil {
return nil, errorx.NewBiz(err.Error())
}
return re.ToRedisInfo(db, r.tagApp.ListTagPathByResource(consts.TagResourceTypeRedis, re.Code)...), nil
})
}

View File

@@ -1,6 +1,7 @@
package entity
import (
"errors"
"mayfly-go/internal/common/utils"
"mayfly-go/internal/redis/rdm"
"mayfly-go/pkg/model"
@@ -21,20 +22,30 @@ type Redis struct {
Remark string
}
func (r *Redis) PwdEncrypt() {
func (r *Redis) PwdEncrypt() error {
// 密码替换为加密后的密码
r.Password = utils.PwdAesEncrypt(r.Password)
password, err := utils.PwdAesEncrypt(r.Password)
if err != nil {
return errors.New("加密 Redis 密码失败")
}
r.Password = password
return nil
}
func (r *Redis) PwdDecrypt() {
func (r *Redis) PwdDecrypt() error {
// 密码替换为解密后的密码
r.Password = utils.PwdAesDecrypt(r.Password)
password, err := utils.PwdAesDecrypt(r.Password)
if err != nil {
return errors.New("解密 Redis 密码失败")
}
r.Password = password
return nil
}
// 转换为redisInfo进行连接
func (re *Redis) ToRedisInfo(db int, tagPath ...string) *rdm.RedisInfo {
// ToRedisInfo 转换为redisInfo进行连接
func (r *Redis) ToRedisInfo(db int, tagPath ...string) *rdm.RedisInfo {
redisInfo := new(rdm.RedisInfo)
structx.Copy(redisInfo, re)
_ = structx.Copy(redisInfo, r)
redisInfo.Db = db
redisInfo.TagPath = tagPath
return redisInfo

View File

@@ -1,6 +1,7 @@
package entity
import (
"errors"
"mayfly-go/internal/common/utils"
"mayfly-go/pkg/model"
"time"
@@ -27,15 +28,25 @@ func (a *Account) IsEnable() bool {
return a.Status == AccountEnableStatus
}
func (a *Account) OtpSecretEncrypt() {
a.OtpSecret = utils.PwdAesEncrypt(a.OtpSecret)
func (a *Account) OtpSecretEncrypt() error {
secret, err := utils.PwdAesEncrypt(a.OtpSecret)
if err != nil {
return errors.New("加密账户密码失败")
}
a.OtpSecret = secret
return nil
}
func (a *Account) OtpSecretDecrypt() {
func (a *Account) OtpSecretDecrypt() error {
if a.OtpSecret == "-" {
return
return nil
}
a.OtpSecret = utils.PwdAesDecrypt(a.OtpSecret)
secret, err := utils.PwdAesDecrypt(a.OtpSecret)
if err != nil {
return errors.New("解密账户密码失败")
}
a.OtpSecret = secret
return nil
}
const (

View File

@@ -11,6 +11,9 @@ import (
// 基础repo接口
type Repo[T model.ModelI] interface {
// GetModel 获取表的模型实例
GetModel() T
// 新增一个实体
Insert(ctx context.Context, e T) error
@@ -24,10 +27,10 @@ type Repo[T model.ModelI] interface {
BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error
// 根据实体id更新实体信息
UpdateById(ctx context.Context, e T) error
UpdateById(ctx context.Context, e T, columns ...string) error
// 使用指定gorm db执行主要用于事务执行
UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T) error
UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T, columns ...string) error
// 根据实体主键删除实体
DeleteById(ctx context.Context, id uint64) error
@@ -101,16 +104,16 @@ func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []
return gormx.BatchInsertWithDb(db, es)
}
func (br *RepoImpl[T]) UpdateById(ctx context.Context, e T) error {
func (br *RepoImpl[T]) UpdateById(ctx context.Context, e T, columns ...string) error {
if db := contextx.GetDb(ctx); db != nil {
return br.UpdateByIdWithDb(ctx, db, e)
return br.UpdateByIdWithDb(ctx, db, e, columns...)
}
return gormx.UpdateById(br.setBaseInfo(ctx, e))
return gormx.UpdateById(br.setBaseInfo(ctx, e), columns...)
}
func (br *RepoImpl[T]) UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T) error {
return gormx.UpdateByIdWithDb(db, br.setBaseInfo(ctx, e))
func (br *RepoImpl[T]) UpdateByIdWithDb(ctx context.Context, db *gorm.DB, e T, columns ...string) error {
return gormx.UpdateByIdWithDb(db, br.setBaseInfo(ctx, e), columns...)
}
func (br *RepoImpl[T]) Updates(cond any, udpateFields map[string]any) error {

View File

@@ -17,7 +17,7 @@ import (
func ErrIsNil(err error, msgAndParams ...any) {
if err != nil {
if len(msgAndParams) == 0 {
panic(err)
panic(errorx.NewBiz(err.Error()))
}
panic(errorx.NewBiz(fmt.Sprintf(msgAndParams[0].(string), msgAndParams[1:]...)))

View File

@@ -20,11 +20,11 @@ func (a *Aes) DecryptBase64(data string) ([]byte, error) {
return cryptox.AesDecryptBase64(data, []byte(a.Key))
}
func (j *Aes) Valid() {
if j.Key == "" {
func (a *Aes) Valid() {
if a.Key == "" {
return
}
aesKeyLen := len(j.Key)
aesKeyLen := len(a.Key)
assert.IsTrue(aesKeyLen == 16 || aesKeyLen == 24 || aesKeyLen == 32,
fmt.Sprintf("config.yml之 [aes.key] 长度需为16、24、32位长度, 当前为%d位", aesKeyLen))
}

View File

@@ -146,12 +146,12 @@ func BatchInsertWithDb[T any](db *gorm.DB, models []T) error {
// 根据id更新model更新字段为model中不为空的值即int类型不为0ptr类型不为nil这类字段值
// @param model 数据库映射实体模型
func UpdateById(model any) error {
return UpdateByIdWithDb(global.Db, model)
func UpdateById(model any, columns ...string) error {
return UpdateByIdWithDb(global.Db, model, columns...)
}
func UpdateByIdWithDb(db *gorm.DB, model any) error {
return db.Model(model).Updates(model).Error
func UpdateByIdWithDb(db *gorm.DB, model any, columns ...string) error {
return db.Model(model).Select(columns).Updates(model).Error
}
// 根据实体条件更新参数udpateFields指定字段

View File

@@ -42,6 +42,28 @@ func NewDelayQueue[T Delayable](cap int) *DelayQueue[T] {
}
}
func (s *DelayQueue[T]) TryDequeue() (T, bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if elm, ok := s.priorityQueue.Peek(0); ok {
delay := elm.GetDeadline().Sub(time.Now())
if delay < minTimerDelay {
// 无需延迟,头部元素出队后直接返回
_, _ = s.dequeue()
return elm, true
}
}
return s.zero, false
}
func (s *DelayQueue[T]) TryEnqueue(val T) bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.enqueue(val)
}
func (s *DelayQueue[T]) Dequeue(ctx context.Context) (T, bool) {
// 出队锁:避免因重复获取队列头部同一元素降低性能
select {
@@ -64,7 +86,6 @@ func (s *DelayQueue[T]) Dequeue(ctx context.Context) (T, bool) {
// 接收直接转发的不需要延迟的新元素
select {
case elm := <-s.transferChan:
delete(s.elmMap, elm.GetId())
s.mutex.Unlock()
return elm, true
default:
@@ -78,7 +99,6 @@ func (s *DelayQueue[T]) Dequeue(ctx context.Context) (T, bool) {
if delay < minTimerDelay {
// 无需延迟,头部元素出队后直接返回
_, _ = s.dequeue()
delete(s.elmMap, elm.GetId())
s.mutex.Unlock()
return elm, ok
}
@@ -122,6 +142,7 @@ func (s *DelayQueue[T]) dequeue() (T, bool) {
if !ok {
return s.zero, false
}
delete(s.elmMap, elm.GetId())
select {
case s.dequeuedSignal <- struct{}{}:
default:
@@ -133,6 +154,7 @@ func (s *DelayQueue[T]) enqueue(val T) bool {
if ok := s.priorityQueue.Enqueue(val); !ok {
return false
}
s.elmMap[val.GetId()] = val
select {
case s.enqueuedSignal <- struct{}{}:
default:
@@ -156,7 +178,6 @@ func (s *DelayQueue[T]) Enqueue(ctx context.Context, val T) bool {
// 如果队列未满,入队后直接返回
if !s.priorityQueue.IsFull() {
s.elmMap[val.GetId()] = val
s.enqueue(val)
s.mutex.Unlock()
return true

View File

@@ -10,7 +10,6 @@ import (
"mayfly-go/pkg/validatorx"
"os"
"os/signal"
"sync"
"syscall"
)
@@ -25,8 +24,6 @@ func RunWebServer() {
cancel()
}()
runnerWG := &sync.WaitGroup{}
// 初始化config.yml配置文件映射信息或使用环境变量。并初始化系统日志相关配置
config.Init()
@@ -52,6 +49,4 @@ func RunWebServer() {
// 运行web服务
runWebServer(ctx)
runnerWG.Wait()
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"github.com/gin-gonic/gin"
"mayfly-go/initialize"
"mayfly-go/internal/db/application"
"mayfly-go/pkg/config"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/req"
@@ -40,6 +41,8 @@ func runWebServer(ctx context.Context) {
if err != nil {
logx.Errorf("Failed to Shutdown HTTP Server: %v", err)
}
closeDbTasks()
// todo: close backupApp and restoreApp
}()
confSrv := config.Conf.Server
@@ -56,3 +59,18 @@ func runWebServer(ctx context.Context) {
logx.Errorf("Failed to Start HTTP Server: %v", err)
}
}
func closeDbTasks() {
restoreApp := application.GetDbRestoreApp()
if restoreApp != nil {
restoreApp.Close()
}
binlogApp := application.GetDbBinlogApp()
if binlogApp != nil {
binlogApp.Close()
}
backupApp := application.GetDbBackupApp()
if backupApp != nil {
backupApp.Close()
}
}

View File

@@ -120,3 +120,16 @@ func ToString(value any) string {
return string(newValue)
}
}
// DeepZero 初始化对象
// 如 T 为基本类型或结构体,则返回零值
// 如 T 为指向基本类型或结构体的指针,则返回指向零值的指针
func DeepZero[T any]() T {
var data T
typ := reflect.TypeOf(data)
kind := typ.Kind()
if kind == reflect.Pointer {
return reflect.New(typ.Elem()).Interface().(T)
}
return data
}

View File

@@ -0,0 +1,14 @@
package anyx
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestDeepZero(t *testing.T) {
assert.Zero(t, DeepZero[int]())
assert.Zero(t, *DeepZero[*int]())
assert.Zero(t, DeepZero[time.Time]())
assert.Zero(t, *DeepZero[*time.Time]())
}

View File

@@ -272,5 +272,9 @@ func pkcs7UnPadding(data []byte) ([]byte, error) {
}
//获取填充的个数
unPadding := int(data[length-1])
// todo fix: slice bounds out of range
if unPadding > length {
return nil, errors.New("解密字符串时去除填充个数超出字符串长度")
}
return data[:(length - unPadding)], nil
}

View File

@@ -115,3 +115,17 @@ func ReverStrTemplate(temp, str string, res map[string]any) {
ReverStrTemplate(next, Trim(SubString(str, UnicodeIndex(str, value)+Len(value), Len(str))), res)
}
}
func TruncateStr(s string, length int) string {
if length >= len(s) {
return s
}
var last int
for i := range s {
if i > length {
break
}
last = i
}
return s[:last]
}

View File

@@ -0,0 +1,32 @@
package stringx
import (
"github.com/stretchr/testify/require"
"strconv"
"testing"
)
func TestTruncateStr(t *testing.T) {
testCases := []struct {
data string
length int
want string
}{
{"123一二三", 0, ""},
{"123一二三", 1, "1"},
{"123一二三", 3, "123"},
{"123一二三", 4, "123"},
{"123一二三", 5, "123"},
{"123一二三", 6, "123一"},
{"123一二三", 7, "123一"},
{"123一二三", 11, "123一二"},
{"123一二三", 12, "123一二三"},
{"123一二三", 13, "123一二三"},
}
for _, tc := range testCases {
t.Run(strconv.Itoa(tc.length), func(t *testing.T) {
got := TruncateStr(tc.data, tc.length)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -1,9 +1,60 @@
package timex
import "time"
import (
"context"
"database/sql"
"encoding/json"
"time"
)
const DefaultDateTimeFormat = "2006-01-02 15:04:05"
func DefaultFormat(time time.Time) string {
return time.Format(DefaultDateTimeFormat)
}
func NewNullTime(t time.Time) NullTime {
return NullTime{
NullTime: sql.NullTime{
Time: t,
Valid: !t.IsZero(),
},
}
}
type NullTime struct {
sql.NullTime
}
func (nt *NullTime) UnmarshalJSON(bytes []byte) error {
if len(bytes) == 0 {
nt.NullTime = sql.NullTime{}
return nil
}
var t time.Time
if err := json.Unmarshal(bytes, &t); err != nil {
return err
}
if t.IsZero() {
nt.NullTime = sql.NullTime{}
return nil
}
nt.NullTime = sql.NullTime{
Valid: true,
Time: t,
}
return nil
}
func (nt *NullTime) MarshalJSON() ([]byte, error) {
if !nt.Valid || nt.Time.IsZero() {
return json.Marshal(nil)
}
return json.Marshal(nt.Time)
}
func SleepWithContext(ctx context.Context, d time.Duration) {
ctx, cancel := context.WithTimeout(ctx, d)
<-ctx.Done()
cancel()
}

View File

@@ -0,0 +1,77 @@
package timex
import (
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestNullTime_UnmarshalJSON(t *testing.T) {
zero := time.Time{}
now := time.Now()
bytesNow, err := now.MarshalJSON()
require.NoError(t, err)
tests := []struct {
name string
want NullTime
bytes []byte
wantErr error
}{
{
name: "zero",
want: NewNullTime(zero),
bytes: nil,
},
{
name: "now",
want: NewNullTime(now),
bytes: bytesNow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &NullTime{}
err := got.UnmarshalJSON(tt.bytes)
require.ErrorIs(t, err, tt.wantErr)
if err != nil {
return
}
require.Equal(t, tt.want.Valid, got.Valid)
require.True(t, got.Time.Equal(tt.want.Time))
})
}
}
func TestNullTime_MarshalJSON(t *testing.T) {
zero := time.Time{}
now := time.Now()
bytes, err := now.MarshalJSON()
require.NoError(t, err)
tests := []struct {
name string
nullTime NullTime
want []byte
wantErr error
}{
{
name: "zero",
nullTime: NewNullTime(zero),
want: []byte("null"),
},
{
name: "now",
nullTime: NewNullTime(now),
want: bytes,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.nullTime.MarshalJSON()
require.ErrorIs(t, err, tt.wantErr)
if err != nil {
return
}
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -479,6 +479,8 @@ INSERT INTO `t_sys_config` (name, `key`, params, value, remark, permission, crea
INSERT INTO `t_sys_config` (name, `key`, params, value, remark, create_time, creator_id, creator, update_time, modifier_id, modifier)VALUES ('数据库查询最大结果集', 'DbQueryMaxCount', '[]', '200', '允许sql查询的最大结果集数。注: 0=不限制', '2023-02-11 14:29:03', 1, 'admin', '2023-02-11 14:40:56', 1, 'admin');
INSERT INTO `t_sys_config` (name, `key`, params, value, remark, create_time, creator_id, creator, update_time, modifier_id, modifier)VALUES ('数据库是否记录查询SQL', 'DbSaveQuerySQL', '[]', '0', '1: 记录、0:不记录', '2023-02-11 16:07:14', 1, 'admin', '2023-02-11 16:44:17', 1, 'admin');
INSERT INTO `t_sys_config` (name, `key`, params, value, remark, permission, create_time, creator_id, creator, update_time, modifier_id, modifier, is_deleted, delete_time) VALUES('机器相关配置', 'MachineConfig', '[{"name":"终端回放存储路径","model":"terminalRecPath","placeholder":"终端回放存储路径"},{"name":"uploadMaxFileSize","model":"uploadMaxFileSize","placeholder":"允许上传的最大文件大小(1MB\\\\2GB等)"}]', '{"terminalRecPath":"./rec","uploadMaxFileSize":"1GB"}', '机器相关配置,如终端回放路径等', 'admin,', '2023-07-13 16:26:44', 1, 'admin', '2023-11-09 22:01:31', 1, 'admin', 0, NULL);
INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('Mysql可执行文件', 'MysqlBin', '[{"model":"path","name":"路径","placeholder":"可执行文件路径","required":true},{"model":"mysql","name":"mysql","placeholder":"mysql命令路径(空则为 路径/mysql)","required":false},{"model":"mysqldump","name":"mysqldump","placeholder":"mysqldump命令路径(空则为 路径/mysqldump)","required":false},{"model":"mysqlbinlog","name":"mysqlbinlog","placeholder":"mysqlbinlog命令路径(空则为 路径/mysqlbinlog)","required":false}]', '{"mysql":"","mysqldump":"","mysqlbinlog":"","path":""}', '', 'admin,', '2023-12-29 10:01:33', 1, 'admin', '2023-12-29 13:34:40', 1, 'admin', 0, NULL);
INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('数据库备份恢复', 'DbBackupRestore', '[{"model":"backupPath","name":"备份路径","placeholder":"备份文件存储路径"}]', '{"backupPath":"./db/backup"}', '', 'admin,', '2023-12-29 09:55:26', 1, 'admin', '2023-12-29 15:45:24', 1, 'admin', 0, NULL);
COMMIT;
-- ----------------------------
@@ -865,7 +867,6 @@ CREATE TABLE `t_db_backup` (
`interval` bigint(20) DEFAULT NULL COMMENT '备份周期',
`start_time` datetime DEFAULT NULL COMMENT '首次备份时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
`finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次备份状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次备份结果',
`last_time` datetime DEFAULT NULL COMMENT '上次备份时间',
@@ -917,7 +918,6 @@ CREATE TABLE `t_db_restore` (
`interval` bigint(20) DEFAULT NULL COMMENT '恢复周期',
`start_time` datetime DEFAULT NULL COMMENT '首次恢复时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
`finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次恢复状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次恢复结果',
`last_time` datetime DEFAULT NULL COMMENT '上次恢复时间',
@@ -959,9 +959,6 @@ DROP TABLE IF EXISTS `t_db_binlog`;
CREATE TABLE `t_db_binlog` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
`db_instance_id` bigint(20) unsigned NOT NULL COMMENT '数据库实例ID',
`interval` bigint(20) DEFAULT NULL COMMENT '下载周期',
`start_time` datetime DEFAULT NULL COMMENT '首次下载时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '会否启用',
`last_status` bigint(20) DEFAULT NULL COMMENT '上次下载状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次下载结果',
`last_time` datetime DEFAULT NULL COMMENT '上次下载时间',

View File

@@ -11,7 +11,6 @@ CREATE TABLE `t_db_backup` (
`interval` bigint(20) DEFAULT NULL COMMENT '备份周期',
`start_time` datetime DEFAULT NULL COMMENT '首次备份时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
`finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次备份状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次备份结果',
`last_time` datetime DEFAULT NULL COMMENT '上次备份时间',
@@ -63,7 +62,6 @@ CREATE TABLE `t_db_restore` (
`interval` bigint(20) DEFAULT NULL COMMENT '恢复周期',
`start_time` datetime DEFAULT NULL COMMENT '首次恢复时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '是否启用',
`finished` tinyint(1) DEFAULT NULL COMMENT '是否完成',
`last_status` tinyint(4) DEFAULT NULL COMMENT '上次恢复状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次恢复结果',
`last_time` datetime DEFAULT NULL COMMENT '上次恢复时间',
@@ -105,9 +103,6 @@ DROP TABLE IF EXISTS `t_db_binlog`;
CREATE TABLE `t_db_binlog` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
`db_instance_id` bigint(20) unsigned NOT NULL COMMENT '数据库实例ID',
`interval` bigint(20) DEFAULT NULL COMMENT '下载周期',
`start_time` datetime DEFAULT NULL COMMENT '首次下载时间',
`enabled` tinyint(1) DEFAULT NULL COMMENT '会否启用',
`last_status` bigint(20) DEFAULT NULL COMMENT '上次下载状态',
`last_result` varchar(256) DEFAULT NULL COMMENT '上次下载结果',
`last_time` datetime DEFAULT NULL COMMENT '上次下载时间',
@@ -140,3 +135,7 @@ CREATE TABLE `t_db_binlog_history` (
PRIMARY KEY (`id`),
KEY `idx_db_instance_id` (`db_instance_id`) USING BTREE
) ENGINE=InnoDB AUTO_INCREMENT=17 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;
INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('Mysql可执行文件', 'MysqlBin', '[{"model":"path","name":"路径","placeholder":"可执行文件路径","required":true},{"model":"mysql","name":"mysql","placeholder":"mysql命令路径(空则为 路径/mysql)","required":false},{"model":"mysqldump","name":"mysqldump","placeholder":"mysqldump命令路径(空则为 路径/mysqldump)","required":false},{"model":"mysqlbinlog","name":"mysqlbinlog","placeholder":"mysqlbinlog命令路径(空则为 路径/mysqlbinlog)","required":false}]', '{"mysql":"","mysqldump":"","mysqlbinlog":"","path":""}', '', 'admin,', '2023-12-29 10:01:33', 1, 'admin', '2023-12-29 13:34:40', 1, 'admin', 0, NULL);
INSERT INTO `t_sys_config` (`name`, `key`, `params`, `value`, `remark`, `permission`, `create_time`, `creator_id`, `creator`, `update_time`, `modifier_id`, `modifier`, `is_deleted`, `delete_time`) VALUES('数据库备份恢复', 'DbBackupRestore', '[{"model":"backupPath","name":"备份路径","placeholder":"备份文件存储路径"}]', '{"backupPath":"./db/backup"}', '', 'admin,', '2023-12-29 09:55:26', 1, 'admin', '2023-12-29 15:45:24', 1, 'admin', 0, NULL);