mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 23:40:24 +08:00
161 lines
3.5 KiB
Go
161 lines
3.5 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"mayfly-go/internal/db/domain/entity"
|
|
"mayfly-go/pkg/logx"
|
|
"mayfly-go/pkg/queue"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type Scheduler[T entity.DbTask] struct {
|
|
mutex sync.Mutex
|
|
wg sync.WaitGroup
|
|
queue *queue.DelayQueue[T]
|
|
closed bool
|
|
curTask T
|
|
curTaskContext context.Context
|
|
curTaskCancel context.CancelFunc
|
|
UpdateTaskStatus func(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error
|
|
RunTask func(ctx context.Context, task T) error
|
|
}
|
|
|
|
type SchedulerOption[T entity.DbTask] func(*Scheduler[T])
|
|
|
|
func NewScheduler[T entity.DbTask](opts ...SchedulerOption[T]) (*Scheduler[T], error) {
|
|
scheduler := &Scheduler[T]{
|
|
queue: queue.NewDelayQueue[T](0),
|
|
}
|
|
for _, opt := range opts {
|
|
opt(scheduler)
|
|
}
|
|
if scheduler.RunTask == nil || scheduler.UpdateTaskStatus == nil {
|
|
return nil, errors.New("调度器没有设置 RunTask 或 UpdateTaskStatus")
|
|
}
|
|
scheduler.wg.Add(1)
|
|
go scheduler.run()
|
|
return scheduler, nil
|
|
}
|
|
|
|
func (m *Scheduler[T]) PushTask(ctx context.Context, task T) bool {
|
|
if !task.Schedule() {
|
|
return false
|
|
}
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
return m.queue.Enqueue(ctx, task)
|
|
}
|
|
|
|
func (m *Scheduler[T]) UpdateTask(ctx context.Context, task T) bool {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
if task.GetId() == m.curTask.GetId() {
|
|
return m.curTask.Update(task)
|
|
}
|
|
oldTask, ok := m.queue.Remove(ctx, task.GetId())
|
|
if ok {
|
|
if !oldTask.Update(task) {
|
|
return false
|
|
}
|
|
} else {
|
|
oldTask = task
|
|
}
|
|
if !oldTask.Schedule() {
|
|
return false
|
|
}
|
|
return m.queue.Enqueue(ctx, oldTask)
|
|
}
|
|
|
|
func (m *Scheduler[T]) updateCurTask(status entity.TaskStatus, lastErr error, task T) bool {
|
|
seconds := []time.Duration{time.Second * 1, time.Second * 8, time.Second * 64}
|
|
for _, second := range seconds {
|
|
if m.closed {
|
|
return false
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), second)
|
|
err := m.UpdateTaskStatus(ctx, status, lastErr, task)
|
|
cancel()
|
|
if err != nil {
|
|
logx.Errorf("保存任务失败: %v", err)
|
|
time.Sleep(second)
|
|
continue
|
|
}
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Scheduler[T]) run() {
|
|
defer m.wg.Done()
|
|
|
|
var ctx context.Context
|
|
var cancel context.CancelFunc
|
|
for !m.closed {
|
|
m.mutex.Lock()
|
|
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
|
|
task, ok := m.queue.Dequeue(ctx)
|
|
cancel()
|
|
if !ok {
|
|
m.mutex.Unlock()
|
|
time.Sleep(time.Second)
|
|
continue
|
|
}
|
|
m.curTask = task
|
|
m.updateCurTask(entity.TaskReserved, nil, task)
|
|
m.curTaskContext, m.curTaskCancel = context.WithCancel(context.Background())
|
|
m.mutex.Unlock()
|
|
|
|
err := m.RunTask(m.curTaskContext, task)
|
|
|
|
m.mutex.Lock()
|
|
taskStatus := entity.TaskSuccess
|
|
if err != nil {
|
|
taskStatus = entity.TaskFailed
|
|
}
|
|
m.updateCurTask(taskStatus, err, task)
|
|
m.cancelCurTask()
|
|
task.Schedule()
|
|
if !task.IsFinished() {
|
|
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
|
|
m.queue.Enqueue(ctx, task)
|
|
cancel()
|
|
}
|
|
m.mutex.Unlock()
|
|
}
|
|
}
|
|
|
|
func (m *Scheduler[T]) Close() {
|
|
if m.closed {
|
|
return
|
|
}
|
|
m.mutex.Lock()
|
|
m.cancelCurTask()
|
|
m.closed = true
|
|
m.mutex.Unlock()
|
|
|
|
m.wg.Wait()
|
|
}
|
|
|
|
func (m *Scheduler[T]) RemoveTask(taskId uint64) bool {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
m.queue.Remove(context.Background(), taskId)
|
|
if taskId == m.curTask.GetId() {
|
|
m.cancelCurTask()
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (m *Scheduler[T]) cancelCurTask() {
|
|
if m.curTaskCancel != nil {
|
|
m.curTaskCancel()
|
|
m.curTaskCancel = nil
|
|
}
|
|
}
|