Files
mayfly-go/server/internal/ai/memory/extractor.go
2026-04-21 17:22:21 +08:00

250 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package memory
import (
"context"
"fmt"
"mayfly-go/internal/ai/pkg/utils"
"mayfly-go/pkg/gox"
"mayfly-go/pkg/logx"
"strings"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
// Extractor 记忆提取器接口
type Extractor interface {
// ExtractFromMessages 从消息历史中提取记忆
ExtractFromMessages(ctx context.Context, userID string, messages []adk.Message) ([]*MemoryItem, error)
}
// LLMExtractorConfig LLM提取器配置
type LLMExtractorConfig struct {
Enabled bool // 是否启用
Temperature float64 // 温度参数 (0-1)
MaxTokens int // 最大token数
MinConfidence float64 // 最小置信度阈值
MaxItemsPerCall int // 单次提取最大记忆数量
ChatModel model.ToolCallingChatModel // ChatModel 实例
}
// DefaultLLMExtractorConfig 返回默认配置
func DefaultLLMExtractorConfig() *LLMExtractorConfig {
return &LLMExtractorConfig{
Enabled: true,
Temperature: 0.3, // 较低温度,保证提取稳定性
MaxTokens: 500,
MinConfidence: 0.7,
MaxItemsPerCall: 10,
ChatModel: nil, // 需要外部设置
}
}
// LLMExtractor 基于LLM的记忆提取器
type LLMExtractor struct {
config *LLMExtractorConfig
}
// MemoryExtractionResult LLM提取结果结构
type MemoryExtractionResult struct {
Type string `json:"type"` // 记忆类型: preference/fact/skill/experience
Content string `json:"content"` // 记忆内容(自然语言描述)
Tags []string `json:"tags"` // 标签
Confidence float64 `json:"confidence"` // 置信度(仅用于过滤,不存储)
Reason string `json:"reason,omitempty"` // 提取原因
}
// NewLLMExtractor 创建LLM提取器
func NewLLMExtractor() *LLMExtractor {
return &LLMExtractor{
config: DefaultLLMExtractorConfig(),
}
}
// WithConfig 设置配置
func (e *LLMExtractor) WithConfig(config *LLMExtractorConfig) *LLMExtractor {
if config != nil {
e.config = config
}
return e
}
// ExtractFromMessages 使用LLM从消息中提取记忆
func (e *LLMExtractor) ExtractFromMessages(ctx context.Context, userID string, messages []adk.Message) ([]*MemoryItem, error) {
if !e.config.Enabled || len(messages) == 0 {
return []*MemoryItem{}, nil
}
// 检查是否配置了 ChatModel
if e.config.ChatModel == nil {
logx.WarnfContext(ctx, "LLM extractor ChatModel not configured, skipping extraction")
return []*MemoryItem{}, nil
}
// 只处理最近的消息最多10条
recentMessages := messages
if len(recentMessages) > 10 {
recentMessages = recentMessages[len(recentMessages)-10:]
}
// 构建提示词
prompt := e.buildExtractionPrompt(recentMessages)
// 尝试使用 LLM 生成提取结果,如果失败则降级返回空列表
memories, err := e.extractWithLLM(ctx, prompt, userID)
if err != nil {
logx.WarnfContext(ctx, "LLM memory extraction failed: %v, skipping", err)
return []*MemoryItem{}, nil // 降级:返回空列表,不阻塞主流程
}
if len(memories) > 0 {
logx.InfofContext(ctx, "extracted %d memories using LLM", len(memories))
for _, m := range memories {
logx.DebugfContext(ctx, " - [%s] %s (tags: %v)", m.Type, m.Content, m.Tags)
}
}
return memories, nil
}
// extractWithLLM 使用 LLM 提取记忆(带 panic 保护)
func (e *LLMExtractor) extractWithLLM(ctx context.Context, prompt string, userID string) (memories []*MemoryItem, err error) {
defer gox.Recover()
// 调用 LLM 生成提取结果
response, err := e.config.ChatModel.Generate(ctx, []*schema.Message{
{
Role: schema.System,
Content: prompt,
},
})
if err != nil {
return nil, fmt.Errorf("LLM generate: %w", err)
}
// 解析LLM返回结果
memories, err = e.parseExtractionResult(response.Content, userID)
if err != nil {
return nil, fmt.Errorf("parse LLM response: %w", err)
}
// 过滤低置信度的记忆
filteredMemories := e.filterByConfidence(memories)
return filteredMemories, nil
}
// buildExtractionPrompt 构建提取提示词
func (e *LLMExtractor) buildExtractionPrompt(messages []adk.Message) string {
var sb strings.Builder
sb.WriteString("你是一个专业的用户信息提取助手。请从以下对话中提取用户的重要信息。\n\n")
sb.WriteString("## 提取目标\n")
sb.WriteString("识别并提取以下类型的信息:\n")
sb.WriteString("1. **用户偏好**: 编辑器、编程语言、工具、工作习惯等\n")
sb.WriteString("2. **事实信息**: 服务器IP、数据库配置、API端点、文件路径等\n")
sb.WriteString("3. **工作环境**: 当前目录、项目结构、技术栈等\n")
sb.WriteString("4. **任务状态**: 正在进行的操作、待办事项等\n\n")
sb.WriteString("## 提取规则\n")
sb.WriteString("- 只提取明确陈述的信息,不要推测\n")
sb.WriteString("- 忽略寒暄、问候等无关内容\n")
sb.WriteString("- 如果同一信息多次出现,选择最新或最详细的版本\n")
sb.WriteString("- 置信度评分标准:\n")
sb.WriteString(" * 0.9-1.0: 用户明确陈述的事实\n")
sb.WriteString(" * 0.7-0.9: 用户暗示或间接提到的信息\n")
sb.WriteString(" * <0.7: 不确定或模糊的信息(不应提取)\n\n")
sb.WriteString("## 输出格式\n")
sb.WriteString("必须以JSON数组格式返回每个元素包含\n")
sb.WriteString("- `type`: 记忆类型preference/fact/skill/experience\n")
sb.WriteString("- `content`: 记忆内容(使用自然语言完整描述,便于后续语义检索)\n")
sb.WriteString("- `tags`: 标签数组3-5个关键词用于分类和快速过滤\n")
sb.WriteString("- `confidence`: 置信度0-1之间的小数仅用于内部过滤\n")
sb.WriteString("- `reason`: 提取原因(可选,简要说明为什么提取这条记忆)\n\n")
sb.WriteString("## 示例\n")
sb.WriteString("用户说:\"我喜欢用 vim 编辑配置文件,服务器是 192.168.1.100\"\n")
sb.WriteString("返回:\n")
sb.WriteString(`[
{"type": "preference", "content": "用户偏好使用 vim 作为配置文件编辑器", "tags": ["editor", "vim", "preference"], "confidence": 0.95, "reason": "用户明确表达偏好"},
{"type": "fact", "content": "用户的服务器IP地址为 192.168.1.100", "tags": ["server", "ip", "infrastructure"], "confidence": 0.9, "reason": "用户提供具体服务器地址"}
]` + "\n\n")
sb.WriteString("## 注意事项\n")
sb.WriteString("- 不要提取敏感信息密码、密钥、token等\n")
sb.WriteString("- content 字段应使用完整的自然语言描述,而非简化的键值对\n")
sb.WriteString("- tags 应使用英文小写,便于统一检索\n")
sb.WriteString(fmt.Sprintf("- 最多提取 %d 条最重要的记忆\n", e.config.MaxItemsPerCall))
sb.WriteString("- 如果没有值得提取的信息,返回空数组 []\n\n")
sb.WriteString("## 对话内容\n")
for i, msg := range messages {
if i > 0 {
sb.WriteString("\n")
}
// 根据消息角色格式化
roleStr := strings.ToUpper(string(msg.Role))
content := msg.Content
// 如果有工具调用,也记录
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
content += fmt.Sprintf(" [调用工具: %s]", tc.Function.Name)
}
}
sb.WriteString(fmt.Sprintf("%s: %s", roleStr, content))
}
return sb.String()
}
// parseExtractionResult 解析LLM返回的结果
func (e *LLMExtractor) parseExtractionResult(response string, userID string) ([]*MemoryItem, error) {
results, err := utils.ParseLLMJSON[[]MemoryExtractionResult](response)
if err != nil {
return nil, fmt.Errorf("parse LLM JSON response: %w", err)
}
// 如果解析成功但结果为空,返回空切片而不是错误
if results == nil || len(*results) == 0 {
return []*MemoryItem{}, nil
}
// 转换为 MemoryItem
var memories []*MemoryItem
for _, result := range *results {
// 验证必要字段
if result.Type == "" || result.Content == "" {
continue
}
// 验证置信度范围
if result.Confidence < 0 || result.Confidence > 1 {
logx.Warnf("invalid confidence value: %.2f, skipping", result.Confidence)
continue
}
item := CreateMemory(userID, result.Type, result.Content, result.Tags)
// 添加元数据
if result.Reason != "" {
item.Metadata["extraction_reason"] = result.Reason
}
item.Metadata["extracted_by"] = "llm"
memories = append(memories, item)
}
return memories, nil
}
// filterByConfidence 根据置信度过滤记忆
func (e *LLMExtractor) filterByConfidence(memories []*MemoryItem) []*MemoryItem {
// 注意:当前 MemoryItem 不再存储 Confidence此方法保留用于未来扩展
// 如果需要基于置信度过滤,应在提取阶段完成
return memories
}