Files
mayfly-go/server/internal/db/infrastructure/service/scheduler.go
2023-12-29 08:30:10 +08:00

161 lines
3.5 KiB
Go

package service
import (
"context"
"errors"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/queue"
"sync"
"time"
)
type Scheduler[T entity.DbTask] struct {
mutex sync.Mutex
wg sync.WaitGroup
queue *queue.DelayQueue[T]
closed bool
curTask T
curTaskContext context.Context
curTaskCancel context.CancelFunc
UpdateTaskStatus func(ctx context.Context, status entity.TaskStatus, lastErr error, task T) error
RunTask func(ctx context.Context, task T) error
}
type SchedulerOption[T entity.DbTask] func(*Scheduler[T])
func NewScheduler[T entity.DbTask](opts ...SchedulerOption[T]) (*Scheduler[T], error) {
scheduler := &Scheduler[T]{
queue: queue.NewDelayQueue[T](0),
}
for _, opt := range opts {
opt(scheduler)
}
if scheduler.RunTask == nil || scheduler.UpdateTaskStatus == nil {
return nil, errors.New("调度器没有设置 RunTask 或 UpdateTaskStatus")
}
scheduler.wg.Add(1)
go scheduler.run()
return scheduler, nil
}
func (m *Scheduler[T]) PushTask(ctx context.Context, task T) bool {
if !task.Schedule() {
return false
}
m.mutex.Lock()
defer m.mutex.Unlock()
return m.queue.Enqueue(ctx, task)
}
func (m *Scheduler[T]) UpdateTask(ctx context.Context, task T) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
if task.GetId() == m.curTask.GetId() {
return m.curTask.Update(task)
}
oldTask, ok := m.queue.Remove(ctx, task.GetId())
if ok {
if !oldTask.Update(task) {
return false
}
} else {
oldTask = task
}
if !oldTask.Schedule() {
return false
}
return m.queue.Enqueue(ctx, oldTask)
}
func (m *Scheduler[T]) updateCurTask(status entity.TaskStatus, lastErr error, task T) bool {
seconds := []time.Duration{time.Second * 1, time.Second * 8, time.Second * 64}
for _, second := range seconds {
if m.closed {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), second)
err := m.UpdateTaskStatus(ctx, status, lastErr, task)
cancel()
if err != nil {
logx.Errorf("保存任务失败: %v", err)
time.Sleep(second)
continue
}
return true
}
return false
}
func (m *Scheduler[T]) run() {
defer m.wg.Done()
var ctx context.Context
var cancel context.CancelFunc
for !m.closed {
m.mutex.Lock()
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
task, ok := m.queue.Dequeue(ctx)
cancel()
if !ok {
m.mutex.Unlock()
time.Sleep(time.Second)
continue
}
m.curTask = task
m.updateCurTask(entity.TaskReserved, nil, task)
m.curTaskContext, m.curTaskCancel = context.WithCancel(context.Background())
m.mutex.Unlock()
err := m.RunTask(m.curTaskContext, task)
m.mutex.Lock()
taskStatus := entity.TaskSuccess
if err != nil {
taskStatus = entity.TaskFailed
}
m.updateCurTask(taskStatus, err, task)
m.cancelCurTask()
task.Schedule()
if !task.IsFinished() {
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
m.queue.Enqueue(ctx, task)
cancel()
}
m.mutex.Unlock()
}
}
func (m *Scheduler[T]) Close() {
if m.closed {
return
}
m.mutex.Lock()
m.cancelCurTask()
m.closed = true
m.mutex.Unlock()
m.wg.Wait()
}
func (m *Scheduler[T]) RemoveTask(taskId uint64) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
m.queue.Remove(context.Background(), taskId)
if taskId == m.curTask.GetId() {
m.cancelCurTask()
}
return true
}
func (m *Scheduler[T]) cancelCurTask() {
if m.curTaskCancel != nil {
m.curTaskCancel()
m.curTaskCancel = nil
}
}