Files
mayfly-go/server/internal/ai/tools/interrupt_param_completion.go

140 lines
4.5 KiB
Go
Raw Normal View History

package tools
import (
"context"
"encoding/json"
"errors"
"fmt"
"mayfly-go/internal/ai/imsg"
"mayfly-go/internal/ai/session"
"mayfly-go/pkg/i18n"
"mayfly-go/pkg/utils/jsonx"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
)
// CompletionParamInfo 参数补全信息
type CompletionParamInfo struct {
Param string `json:"param"` // 参数名
Name string `json:"name"` // 参数描述
}
// ParamCompletionInterruptInfo 参数完善中断信息
type ParamCompletionInterruptInfo struct {
BaseInterruptInfo
// 参数类型,如"db"、"machine"、"table"等
ParamType string `json:"paramType"`
MissingParams []CompletionParamInfo `json:"missingParams"` // 缺失参数列表
}
// InterruptOrResumeParamCompletion 中断或恢复参数完善
func InterruptOrResumeParamCompletion(ctx context.Context, toolDesc string, args any, reason string, paramType string, missingParams []CompletionParamInfo) error {
isResume, err := ResumeParamCompletion(ctx, args)
if !isResume {
return InterruptParamCompletion(ctx, toolDesc, args, reason, paramType, missingParams)
}
if err == nil {
return nil
}
return NewToolError(err, RecoverNone)
}
// InterruptParamCompletion 中断参数完善
func InterruptParamCompletion(ctx context.Context, toolDesc string, args any, reason string, paramType string, missingParams []CompletionParamInfo) error {
argsInJSON := jsonx.ToStr(args)
// 创建中断信息包含完整的MissingParams
interruptInfo := &ParamCompletionInterruptInfo{
BaseInterruptInfo: BaseInterruptInfo{
Type: InterruptTypeParamCompletion,
Title: i18n.T(imsg.InfoIncomplete),
Description: reason,
Payload: missingParams,
ToolCallId: compose.GetToolCallID(ctx),
ToolInfo: &ToolInfo{Name: toolDesc},
Arguments: argsInJSON,
},
ParamType: paramType,
MissingParams: missingParams,
}
return tool.StatefulInterrupt(ctx, interruptInfo, argsInJSON)
}
// ResumeParamCompletion 恢复参数完善
func ResumeParamCompletion(ctx context.Context, args any) (bool, error) {
// 首先检查是否有参数补全过的中断消息,并从中提取参数值进行恢复
messages, _ := session.DefaultSessionStore.GetMessage(ctx, &session.MessageQuery{MessageType: string(InterruptTypeParamCompletion), ToolCallId: compose.GetToolCallID(ctx)})
if len(messages) > 0 {
for _, msg := range messages {
var resumeInfo ParamCompletionResume
if err := msg.Extra.Unmarshal("resumeInfo", &resumeInfo); err != nil {
continue
}
return true, handleParamCompletion(ctx, &resumeInfo, args)
}
}
// 检查是否是从中断恢复
wasInterrupted, _, _ := tool.GetInterruptState[string](ctx)
if !wasInterrupted {
return false, nil
}
// 直接使用 GetResumeContext 检查参数补全恢复
isTarget, hasData, data := tool.GetResumeContext[*ParamCompletionResume](ctx)
if !isTarget || !hasData {
// 不是参数补全目标,继续执行
return false, nil
}
// 修改参数调用的消息体,更新参数后
msg := AppendResumeInfo(ctx, data.InterruptId, data)
if msg == nil {
return true, nil
}
if err := handleParamCompletion(ctx, data, args); err != nil {
return true, err
}
// 对同一 TurnId 的 RMW 操作加锁,防止并发覆盖
session.WithTurnLock(ctx, func() {
toolCallMsgs, err := session.DefaultSessionStore.GetMessage(ctx, &session.MessageQuery{TurnId: data.TurnId, MessageType: "tool_call"})
if err != nil || len(toolCallMsgs) == 0 {
return
}
for _, toolCallMsg := range toolCallMsgs {
for i := range toolCallMsg.ToolCalls {
if toolCallMsg.ToolCalls[i].ID == msg.ToolCallId {
toolCallMsg.ToolCalls[i].Function.Arguments = jsonx.ToStr(args)
session.DefaultSessionStore.UpdateMessage(ctx, toolCallMsg)
break
}
}
}
})
return true, nil
}
func handleParamCompletion(ctx context.Context, data *ParamCompletionResume, args any) error {
if data.Action != "complete" {
return errors.New("[PARAM_COMPLETION_CANCELLED] The user has cancelled the parameter completion for this tool.\nPlease do not retry parameter completion automatically. Ask the user for further instructions if needed.")
}
// 从 Payload 中获取 params 和 caches
payload := data.Payload
paramValues, ok := payload["params"].(map[string]any)
if !ok {
return fmt.Errorf("missing params in payload")
}
if err := json.Unmarshal([]byte(jsonx.ToStr(paramValues)), args); err != nil {
return fmt.Errorf("unmarshal stored args failed: %w", err)
}
return nil
}