mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 23:40:24 +08:00
* fix: 保存 LastResult 时截断字符串过长部分,以避免数据库报错 * refactor: 新增 entity.DbTaskBase 和 persistence.dbTaskBase, 用于实现数据库备份和恢复任务处理相关部分 * fix: aeskey变更后,解密密码出现数组越界访问错误 * fix: 时间属性为零值时,保存到 mysql 数据库报错 * refactor db.infrastructure.service.scheduler * feat: 实现立即备份功能 * refactor db.infrastructure.service.db_instance * refactor: 从数据库中获取数据库备份目录、mysql文件路径等配置信息 * fix: 数据库备份和恢复问题 * fix: 修改 .gitignore 文件,忽略数据库备份目录和数据库程序目录
236 lines
5.2 KiB
Go
236 lines
5.2 KiB
Go
package application
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"mayfly-go/internal/db/domain/entity"
|
|
"mayfly-go/internal/db/domain/repository"
|
|
"mayfly-go/pkg/queue"
|
|
"mayfly-go/pkg/utils/anyx"
|
|
"mayfly-go/pkg/utils/stringx"
|
|
"mayfly-go/pkg/utils/timex"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const sleepAfterError = time.Minute
|
|
|
|
type dbScheduler[T entity.DbTask] struct {
|
|
mutex sync.Mutex
|
|
waitGroup sync.WaitGroup
|
|
queue *queue.DelayQueue[T]
|
|
context context.Context
|
|
cancel context.CancelFunc
|
|
RunTask func(ctx context.Context, task T) error
|
|
taskRepo repository.DbTask[T]
|
|
}
|
|
|
|
type dbSchedulerOption[T entity.DbTask] func(*dbScheduler[T])
|
|
|
|
func newDbScheduler[T entity.DbTask](taskRepo repository.DbTask[T], opts ...dbSchedulerOption[T]) (*dbScheduler[T], error) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
scheduler := &dbScheduler[T]{
|
|
taskRepo: taskRepo,
|
|
queue: queue.NewDelayQueue[T](0),
|
|
context: ctx,
|
|
cancel: cancel,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(scheduler)
|
|
}
|
|
if scheduler.RunTask == nil {
|
|
return nil, errors.New("数据库任务调度器没有设置 RunTask")
|
|
}
|
|
if err := scheduler.loadTask(context.Background()); err != nil {
|
|
return nil, err
|
|
}
|
|
scheduler.waitGroup.Add(1)
|
|
go scheduler.run()
|
|
return scheduler, nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) updateTaskStatus(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error {
|
|
base := task.GetTaskBase()
|
|
base.LastStatus = status
|
|
var result = task.MessageWithStatus(status)
|
|
if lastErr != nil {
|
|
result = fmt.Sprintf("%v: %v", result, lastErr)
|
|
}
|
|
base.LastResult = stringx.TruncateStr(result, entity.LastResultSize)
|
|
base.LastTime = timex.NewNullTime(time.Now())
|
|
return s.taskRepo.UpdateTaskStatus(ctx, task)
|
|
}
|
|
|
|
func (s *dbScheduler[T]) UpdateTask(ctx context.Context, task T) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if err := s.taskRepo.UpdateById(ctx, task); err != nil {
|
|
return err
|
|
}
|
|
|
|
oldTask, ok := s.queue.Remove(ctx, task.GetId())
|
|
if !ok {
|
|
return errors.New("任务不存在")
|
|
}
|
|
oldTask.Update(task)
|
|
if !oldTask.Schedule() {
|
|
return nil
|
|
}
|
|
if !s.queue.Enqueue(ctx, oldTask) {
|
|
return errors.New("任务入队失败")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) run() {
|
|
defer s.waitGroup.Done()
|
|
|
|
for !s.closed() {
|
|
time.Sleep(time.Second)
|
|
|
|
s.mutex.Lock()
|
|
task, ok := s.queue.TryDequeue()
|
|
if !ok {
|
|
s.mutex.Unlock()
|
|
continue
|
|
}
|
|
if err := s.updateTaskStatus(s.context, entity.TaskReserved, nil, task); err != nil {
|
|
s.mutex.Unlock()
|
|
timex.SleepWithContext(s.context, sleepAfterError)
|
|
continue
|
|
}
|
|
s.mutex.Unlock()
|
|
|
|
errRun := s.RunTask(s.context, task)
|
|
taskStatus := entity.TaskSuccess
|
|
if errRun != nil {
|
|
taskStatus = entity.TaskFailed
|
|
}
|
|
s.mutex.Lock()
|
|
if err := s.updateTaskStatus(s.context, taskStatus, errRun, task); err != nil {
|
|
s.mutex.Unlock()
|
|
timex.SleepWithContext(s.context, sleepAfterError)
|
|
continue
|
|
}
|
|
task.Schedule()
|
|
if !task.IsFinished() {
|
|
s.queue.Enqueue(s.context, task)
|
|
}
|
|
s.mutex.Unlock()
|
|
}
|
|
}
|
|
|
|
func (s *dbScheduler[T]) Close() {
|
|
s.cancel()
|
|
s.waitGroup.Wait()
|
|
}
|
|
|
|
func (s *dbScheduler[T]) closed() bool {
|
|
return s.context.Err() != nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) loadTask(ctx context.Context) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
tasks, err := s.taskRepo.ListToDo()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, task := range tasks {
|
|
if !task.Schedule() {
|
|
continue
|
|
}
|
|
s.queue.Enqueue(ctx, task)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) AddTask(ctx context.Context, tasks ...T) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
for _, task := range tasks {
|
|
if err := s.taskRepo.AddTask(ctx, task); err != nil {
|
|
return err
|
|
}
|
|
if !task.Schedule() {
|
|
continue
|
|
}
|
|
s.queue.Enqueue(ctx, task)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) DeleteTask(ctx context.Context, taskId uint64) error {
|
|
// todo: 删除数据库备份历史文件
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if err := s.taskRepo.DeleteById(ctx, taskId); err != nil {
|
|
return err
|
|
}
|
|
s.queue.Remove(ctx, taskId)
|
|
return nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) EnableTask(ctx context.Context, taskId uint64) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
task := anyx.DeepZero[T]()
|
|
if err := s.taskRepo.GetById(task, taskId); err != nil {
|
|
return err
|
|
}
|
|
if task.IsEnabled() {
|
|
return nil
|
|
}
|
|
task.GetTaskBase().Enabled = true
|
|
if err := s.taskRepo.UpdateEnabled(ctx, taskId, true); err != nil {
|
|
return err
|
|
}
|
|
s.queue.Remove(ctx, taskId)
|
|
if !task.Schedule() {
|
|
return nil
|
|
}
|
|
s.queue.Enqueue(ctx, task)
|
|
return nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) DisableTask(ctx context.Context, taskId uint64) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
task := anyx.DeepZero[T]()
|
|
if err := s.taskRepo.GetById(task, taskId); err != nil {
|
|
return err
|
|
}
|
|
if !task.IsEnabled() {
|
|
return nil
|
|
}
|
|
if err := s.taskRepo.UpdateEnabled(ctx, taskId, false); err != nil {
|
|
return err
|
|
}
|
|
s.queue.Remove(ctx, taskId)
|
|
return nil
|
|
}
|
|
|
|
func (s *dbScheduler[T]) StartTask(ctx context.Context, taskId uint64) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
task := anyx.DeepZero[T]()
|
|
if err := s.taskRepo.GetById(task, taskId); err != nil {
|
|
return err
|
|
}
|
|
if !task.IsEnabled() {
|
|
return errors.New("任务未启用")
|
|
}
|
|
s.queue.Remove(ctx, taskId)
|
|
task.GetTaskBase().Deadline = time.Now()
|
|
s.queue.Enqueue(ctx, task)
|
|
return nil
|
|
}
|