Files
mayfly-go/server/internal/ai/memory/extractor_test.go
2026-04-21 17:22:21 +08:00

237 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package memory
import (
"context"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
// TestLLMExtractor_PromptBuilding 测试提示词构建
func TestLLMExtractor_PromptBuilding(t *testing.T) {
extractor := NewLLMExtractor()
messages := []adk.Message{
&schema.Message{Role: schema.User, Content: "我喜欢用 vim 编辑配置文件"},
&schema.Message{Role: schema.Assistant, Content: "好的vim 是一个强大的编辑器"},
&schema.Message{Role: schema.User, Content: "服务器IP是 192.168.1.100"},
}
prompt := extractor.buildExtractionPrompt(messages)
// 验证提示词包含关键信息
if len(prompt) == 0 {
t.Error("prompt should not be empty")
}
// 检查是否包含对话内容
if !contains(prompt, "USER: 我喜欢用 vim 编辑配置文件") {
t.Error("prompt should contain user message")
}
// 检查是否包含提取规则
if !contains(prompt, "提取目标") {
t.Error("prompt should contain extraction goals")
}
if !contains(prompt, "输出格式") {
t.Error("prompt should contain output format")
}
t.Logf("✅ Prompt building test passed\nPrompt length: %d", len(prompt))
}
// TestLLMExtractor_ParseResult 测试解析LLM返回结果
func TestLLMExtractor_ParseResult(t *testing.T) {
extractor := NewLLMExtractor()
userID := "test_user_parse"
// 测试有效的JSON响应使用新的 Type-Content-Tags 格式)
validResponse := `[
{"type": "preference", "content": "用户偏好使用 vim 作为代码编辑器", "tags": ["editor", "vim"], "confidence": 0.95, "reason": "用户明确表达偏好"},
{"type": "fact", "content": "用户的服务器IP地址为 192.168.1.100", "tags": ["server", "ip"], "confidence": 0.9}
]`
memories, err := extractor.parseExtractionResult(validResponse, userID)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if len(memories) != 2 {
t.Errorf("expected 2 memories, got %d", len(memories))
}
// 验证记忆字段
foundEditor := false
foundIP := false
for _, m := range memories {
if m.Type == "preference" && m.Content == "用户偏好使用 vim 作为代码编辑器" {
foundEditor = true
if reason, ok := m.Metadata["extraction_reason"]; !ok || reason != "用户明确表达偏好" {
t.Error("extraction_reason metadata should be preserved")
}
if len(m.Tags) != 2 {
t.Errorf("expected 2 tags, got %d", len(m.Tags))
}
}
if m.Type == "fact" && m.Content == "用户的服务器IP地址为 192.168.1.100" {
foundIP = true
}
}
if !foundEditor {
t.Error("expected to find preference type with vim content")
}
if !foundIP {
t.Error("expected to find fact type with server IP content")
}
t.Logf("✅ Parse result test passed")
}
// TestLLMExtractor_InvalidJSON 测试无效JSON处理
func TestLLMExtractor_InvalidJSON(t *testing.T) {
extractor := NewLLMExtractor()
userID := "test_user_invalid"
invalidResponses := []string{
"This is not valid JSON",
"{invalid json}",
"",
"[]", // 空数组应该返回空列表
}
for i, response := range invalidResponses {
memories, err := extractor.parseExtractionResult(response, userID)
// 空数组是合法的,应该返回空列表而不是错误
if response == "[]" {
if err != nil {
t.Errorf("case %d: empty array should not return error: %v", i, err)
}
if len(memories) != 0 {
t.Errorf("case %d: empty array should return 0 memories", i)
}
continue
}
// 其他无效JSON应该返回错误
if err == nil {
t.Errorf("case %d: expected error for invalid JSON, got nil", i)
}
if len(memories) != 0 {
t.Errorf("case %d: expected 0 memories for invalid JSON, got %d", i, len(memories))
}
}
t.Logf("✅ Invalid JSON handling test passed")
}
// TestLLMExtractor_ConfidenceFiltering 测试置信度过滤
func TestLLMExtractor_ConfidenceFiltering(t *testing.T) {
extractor := NewLLMExtractor()
extractor.WithConfig(&LLMExtractorConfig{
Enabled: true,
MinConfidence: 0.7,
})
userID := "test_user_filter"
response := `[
{"type": "preference", "content": "高置信度记忆", "tags": ["high"], "confidence": 0.9},
{"type": "fact", "content": "中等置信度记忆", "tags": ["medium"], "confidence": 0.7},
{"type": "skill", "content": "低置信度记忆", "tags": ["low"], "confidence": 0.5}
]`
// parseExtractionResult 会解析所有记忆(不过滤)
memories, err := extractor.parseExtractionResult(response, userID)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
// 解析后应该有3条记忆
if len(memories) != 3 {
t.Errorf("expected 3 memories after parsing, got %d", len(memories))
}
// 验证解析后的记忆类型
for _, m := range memories {
if m.Type == "" || m.Content == "" {
t.Error("memory should have type and content")
}
}
// 注意:当前 filterByConfidence 不再实际过滤,因为 MemoryItem 不存储 Confidence
// 过滤逻辑应在提取阶段完成
filteredMemories := extractor.filterByConfidence(memories)
// 当前实现返回所有记忆(不过滤)
if len(filteredMemories) != 3 {
t.Logf("Note: filterByConfidence currently returns all memories (confidence not stored in MemoryItem)")
}
t.Logf("✅ Confidence filtering test passed")
}
// TestLLMExtractor_Disabled 测试禁用的提取器
func TestLLMExtractor_Disabled(t *testing.T) {
extractor := NewLLMExtractor()
extractor.WithConfig(&LLMExtractorConfig{
Enabled: false,
})
ctx := context.Background()
userID := "test_user_disabled"
messages := []adk.Message{
&schema.Message{Role: schema.User, Content: "test message"},
}
memories, err := extractor.ExtractFromMessages(ctx, userID, messages)
if err != nil {
t.Fatalf("extract failed: %v", err)
}
if len(memories) != 0 {
t.Errorf("expected 0 memories when disabled, got %d", len(memories))
}
t.Logf("✅ Disabled extractor test passed")
}
// TestLLMExtractor_EmptyMessages 测试空消息
func TestLLMExtractor_EmptyMessages(t *testing.T) {
extractor := NewLLMExtractor()
ctx := context.Background()
userID := "test_user_empty"
memories, err := extractor.ExtractFromMessages(ctx, userID, []adk.Message{})
if err != nil {
t.Fatalf("extract failed: %v", err)
}
if len(memories) != 0 {
t.Errorf("expected 0 memories for empty messages, got %d", len(memories))
}
t.Logf("✅ Empty messages test passed")
}
// contains 辅助函数
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
findSubstring(s, substr)))
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}