Files
mayfly-go/server/internal/ai/tools/interrupt.go
meilin.huang 3768cef62d feat: ai助手优化等
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 22:37:10 +08:00

150 lines
3.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tools
import (
"context"
"mayfly-go/internal/ai/session"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/collx"
)
// InterruptType 定义中断类型
type InterruptType string
const (
InterruptTypeApproval InterruptType = "interrupt_approval" // 人工审批
InterruptTypeParamCompletion InterruptType = "interrupt_param_completion" // 参数补全
)
type ToolInfo struct {
Name string `json:"name"`
Desc string `json:"desc"`
JsonSchema string `json:"jsonSchema"`
}
type InterruptState struct {
Name string `json:"name"`
ToolCallId string `json:"toolCallId"`
Args string `json:"args"`
}
// InterruptMetadata 定义中断元数据接口
// 任何需要触发人机交互中断的信息都应实现此接口
type InterruptMetadata interface {
// GetType 返回中断类型,前端根据此类型渲染不同组件
GetType() InterruptType
// GetTitle 返回标题
GetTitle() string
// GetDescription 返回详细描述
GetDescription() string
// GetPayload 返回中断业务负载数据
GetPayload() any
// GetToolInfo 获取工具信息
GetToolInfo() *ToolInfo
// GetToolCallId 获取工具调用ID
GetToolCallId() string
// GetArgument 获取工具参数
GetArgument() string
}
// BaseInterruptInfo 基础中断信息结构体,实现 InterruptMetadata 接口
// 具体中断类型应嵌入此结构体
type BaseInterruptInfo struct {
Type InterruptType `json:"type"` // 中断类型
Title string `json:"title"` // 中断标题
Description string `json:"description"` // 中断描述
Payload any `json:"payload,omitempty"` // 中断负载数据
ToolCallId string `json:"toolCallId"`
ToolInfo *ToolInfo `json:"toolInfo"`
Arguments string `json:"arguments"` //原始参数字符串
}
var _ InterruptMetadata = (*BaseInterruptInfo)(nil)
// --- 实现 InterruptMetadata 接口 ---
func (b *BaseInterruptInfo) GetType() InterruptType {
return b.Type
}
func (b *BaseInterruptInfo) GetTitle() string {
return b.Title
}
func (b *BaseInterruptInfo) GetDescription() string {
return b.Description
}
func (b *BaseInterruptInfo) GetPayload() any {
return b.Payload
}
func (b *BaseInterruptInfo) GetToolInfo() *ToolInfo {
return b.ToolInfo
}
func (b *BaseInterruptInfo) GetArgument() string {
return b.Arguments
}
func (b *BaseInterruptInfo) GetToolCallId() string {
return b.ToolCallId
}
// ApprovalResume 审批恢复参数
type ApprovalResume struct {
*InterruptResume
}
// ParamCompletionResume 参数补全恢复参数
type ParamCompletionResume struct {
*InterruptResume
}
// InterruptResume 中断恢复的信息
type InterruptResume struct {
TurnId string `json:"turnId" binding:"required"`
InterruptId string `json:"interruptId" binding:"required"` // 中断id
InterruptType InterruptType `json:"interruptType" binding:"required"` // 中断类型
Action string `json:"action" binding:"required"` // 操作
Payload collx.M `json:"payload"` // 操作参数
}
// ToTarget 将 InterruptResume 转换为具体的恢复参数结构体(如 ApprovalResume 或 ParamCompletionResume
func (i *InterruptResume) ToTarget() any {
switch i.InterruptType {
case InterruptTypeApproval:
return &ApprovalResume{
InterruptResume: i,
}
case InterruptTypeParamCompletion:
return &ParamCompletionResume{
InterruptResume: i,
}
default:
return i
}
}
func AppendResumeInfo(ctx context.Context, interruptId string, resumeInfo any) *session.Message {
msgQuery := &session.MessageQuery{
ActionId: interruptId,
}
msgs, err := session.DefaultSessionStore.GetMessage(ctx, msgQuery)
if err != nil || len(msgs) == 0 {
logx.InfofContext(ctx, "not found interrupt message")
return nil
}
msg := msgs[0]
msg.Extra.Set("resumeInfo", resumeInfo)
session.DefaultSessionStore.UpdateMessage(ctx, msg)
return msg
}