package session import ( "fmt" "os" "path/filepath" "testing" "time" "github.com/orca/orca/pkg/embedding" ) func setupMemoryManager(t *testing.T) (*MemoryManager, func()) { t.Helper() dir, err := os.MkdirTemp("", "orca-memory-test-*") if err != nil { t.Fatalf("failed to create temp dir: %v", err) } cfg := MemoryConfig{ DBPath: filepath.Join(dir, "memory.db"), ModelWindow: 8192, EmbedConfig: embedding.Config{ APIKey: os.Getenv("SILICONFLOW_API_KEY"), BaseURL: "https://api.siliconflow.cn/v1", Model: "Pro/BAAI/bge-m3", Timeout: 5000, }, } mm, err := NewMemoryManager(cfg) if err != nil { os.RemoveAll(dir) t.Fatalf("NewMemoryManager failed: %v", err) } cleanup := func() { mm.Close() os.RemoveAll(dir) } return mm, cleanup } func TestSQLiteStore_SaveAndLoad(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-session-1" msg := SessionMessage{ Role: RoleUser, Content: "你好,请介绍一下自己", Timestamp: time.Now(), } err := mm.SaveMessage(sessionID, msg) if err != nil { t.Fatalf("SaveMessage failed: %v", err) } messages, err := mm.GetWorkingMemory(sessionID) if err != nil { t.Fatalf("GetWorkingMemory failed: %v", err) } if len(messages) != 1 { t.Fatalf("expected 1 message, got %d", len(messages)) } if messages[0].Content != msg.Content { t.Errorf("expected content %q, got %q", msg.Content, messages[0].Content) } } func TestSQLiteStore_MultipleMessages(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-session-multi" baseTime := time.Now() messages := []SessionMessage{ {Role: RoleUser, Content: "什么是机器学习?", Timestamp: baseTime}, {Role: RoleAssistant, Content: "机器学习是人工智能的一个分支...", Timestamp: baseTime.Add(time.Second)}, {Role: RoleUser, Content: "能举个例子吗?", Timestamp: baseTime.Add(2 * time.Second)}, {Role: RoleAssistant, Content: "比如垃圾邮件过滤器...", Timestamp: baseTime.Add(3 * time.Second)}, } for _, msg := range messages { if err := mm.SaveMessage(sessionID, msg); err != nil { t.Fatalf("SaveMessage failed: %v", err) } } loaded, err := mm.GetWorkingMemory(sessionID) if err != nil { t.Fatalf("GetWorkingMemory failed: %v", err) } if len(loaded) != len(messages) { t.Fatalf("expected %d messages, got %d", len(messages), len(loaded)) } for i, msg := range loaded { if msg.Content != messages[i].Content { t.Errorf("message %d: expected %q, got %q", i, messages[i].Content, msg.Content) } } } func TestCalculateBudget(t *testing.T) { tests := []struct { modelWindow int wantTotal int wantWorking int }{ } for _, tt := range tests { budget := calculateBudget(tt.modelWindow) if budget.Total != tt.wantTotal { t.Errorf("modelWindow=%d: expected total %d, got %d", tt.modelWindow, tt.wantTotal, budget.Total) } if budget.Working != tt.wantWorking { t.Errorf("modelWindow=%d: expected working %d, got %d", tt.modelWindow, tt.wantWorking, budget.Working) } if budget.ShortTerm != int(float64(budget.Total)*0.3) { t.Errorf("ShortTerm budget incorrect: %d", budget.ShortTerm) } if budget.LongTerm != int(float64(budget.Total)*0.2) { t.Errorf("LongTerm budget incorrect: %d", budget.LongTerm) } } } func TestGetWorkingMemory_TokenBudget(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-budget" longContent := "这是一个很长的消息。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" + "重复多次以消耗token预算。" messages := []SessionMessage{ {Role: RoleUser, Content: "第一条消息", Timestamp: time.Now()}, {Role: RoleAssistant, Content: longContent, Timestamp: time.Now()}, {Role: RoleUser, Content: "第三条消息", Timestamp: time.Now()}, } for _, msg := range messages { if err := mm.SaveMessage(sessionID, msg); err != nil { t.Fatalf("SaveMessage failed: %v", err) } } // Get working memory with budget loaded, err := mm.GetWorkingMemory(sessionID) if err != nil { t.Fatalf("GetWorkingMemory failed: %v", err) } // Should return messages within budget // The long message should cause earlier messages to be excluded t.Logf("Loaded %d messages within budget", len(loaded)) for i, msg := range loaded { t.Logf("Message %d: role=%s, len=%d", i, msg.Role, len(msg.Content)) } } func TestVectorStore_SaveAndSearch(t *testing.T) { apiKey := os.Getenv("SILICONFLOW_API_KEY") if apiKey == "" { t.Skip("Skipping vector test: SILICONFLOW_API_KEY not set") } mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-vectors" messages := []SessionMessage{ {Role: RoleUser, Content: "Python 是什么编程语言?", Timestamp: time.Now()}, {Role: RoleAssistant, Content: "Python 是一种高级编程语言,以其简洁的语法而闻名", Timestamp: time.Now()}, {Role: RoleUser, Content: "Go 语言的特点是什么?", Timestamp: time.Now()}, {Role: RoleAssistant, Content: "Go 语言由 Google 开发,强调并发和性能", Timestamp: time.Now()}, {Role: RoleUser, Content: "今天天气怎么样?", Timestamp: time.Now()}, } for _, msg := range messages { if err := mm.SaveMessage(sessionID, msg); err != nil { t.Fatalf("SaveMessage failed: %v", err) } } time.Sleep(2 * time.Second) results, err := mm.GetShortTermMemory(sessionID, "编程语言") if err != nil { t.Fatalf("GetShortTermMemory with vector search failed: %v", err) } t.Logf("Vector search returned %d results", len(results)) for i, r := range results { t.Logf("Result %d: %s", i, r) } } func TestVectorStore_CrossSessionSearch(t *testing.T) { apiKey := os.Getenv("SILICONFLOW_API_KEY") if apiKey == "" { t.Skip("Skipping vector test: SILICONFLOW_API_KEY not set") } mm, cleanup := setupMemoryManager(t) defer cleanup() session1 := "session-ai" for _, msg := range []SessionMessage{ {Role: RoleUser, Content: "什么是深度学习?", Timestamp: time.Now()}, {Role: RoleAssistant, Content: "深度学习是机器学习的一个子集,使用神经网络", Timestamp: time.Now()}, } { if err := mm.SaveMessage(session1, msg); err != nil { t.Fatalf("SaveMessage failed: %v", err) } } session2 := "session-cooking" for _, msg := range []SessionMessage{ {Role: RoleUser, Content: "如何做红烧肉?", Timestamp: time.Now()}, {Role: RoleAssistant, Content: "红烧肉需要五花肉、酱油、糖等材料", Timestamp: time.Now()}, } { if err := mm.SaveMessage(session2, msg); err != nil { t.Fatalf("SaveMessage failed: %v", err) } } time.Sleep(2 * time.Second) results, err := mm.GetLongTermMemory("神经网络") if err != nil { t.Fatalf("GetLongTermMemory failed: %v", err) } t.Logf("Cross-session search returned %d results", len(results)) for i, r := range results { t.Logf("Result %d: %s (weight=%.2f)", i, r.Content, r.Weight) } } func TestMaintainSessionMemory(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-maintenance" userQuery := "什么是REST API?" assistantResponse := "REST API(Representational State Transfer)是一种软件架构风格,用于设计网络应用程序。它使用HTTP方法(GET、POST、PUT、DELETE)来操作资源。" mm.MaintainSessionMemory(sessionID, userQuery, assistantResponse) memories, err := mm.GetShortTermMemory(sessionID, "") if err != nil { t.Fatalf("GetShortTermMemory failed: %v", err) } if len(memories) == 0 { t.Fatal("Expected short-term memory to be created") } t.Logf("Created %d short-term memories", len(memories)) for i, m := range memories { t.Logf("Memory %d: %s", i, m) } } func TestAddLongTermMemory(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() memories := []struct { content string mType string }{ {"用户喜欢使用Python进行数据分析", "preference"}, {"项目使用Go语言开发", "project"}, {"用户偏好简洁的代码风格", "preference"}, } for _, m := range memories { if err := mm.AddLongTermMemory(m.content, m.mType); err != nil { t.Fatalf("AddLongTermMemory failed: %v", err) } } results, err := mm.GetLongTermMemory("") if err != nil { t.Fatalf("GetLongTermMemory failed: %v", err) } if len(results) == 0 { t.Fatal("Expected long-term memories to be retrieved") } t.Logf("Retrieved %d long-term memories", len(results)) for i, r := range results { t.Logf("Memory %d: %s (weight=%.2f)", i, r.Content, r.Weight) } } func TestBuildMemoryContext(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-context" mm.AddShortTermMemory(sessionID, "用户正在学习Go语言") mm.AddShortTermMemory(sessionID, "用户之前问过关于goroutine的问题") mm.AddLongTermMemory("用户是后端开发工程师", "fact") mm.AddLongTermMemory("用户偏好技术文档", "preference") context := mm.BuildMemoryContext(sessionID, "并发编程") if context == "" { t.Fatal("Expected memory context to be built") } t.Logf("Memory context:\n%s", context) } func TestMemoryManager_WithoutAPIKey(t *testing.T) { dir, err := os.MkdirTemp("", "orca-memory-test-nokey-*") if err != nil { t.Fatalf("failed to create temp dir: %v", err) } defer os.RemoveAll(dir) cfg := MemoryConfig{ DBPath: filepath.Join(dir, "memory.db"), ModelWindow: 8192, EmbedConfig: embedding.Config{ }, } mm, err := NewMemoryManager(cfg) if err != nil { t.Fatalf("NewMemoryManager should work without API key: %v", err) } defer mm.Close() sessionID := "test-nokey" msg := SessionMessage{ Role: RoleUser, Content: "测试无API Key模式", Timestamp: time.Now(), } if err := mm.SaveMessage(sessionID, msg); err != nil { t.Fatalf("SaveMessage should work without API key: %v", err) } messages, err := mm.GetWorkingMemory(sessionID) if err != nil { t.Fatalf("GetWorkingMemory should work without API key: %v", err) } if len(messages) != 1 { t.Fatalf("expected 1 message, got %d", len(messages)) } memories, err := mm.GetShortTermMemory(sessionID, "") if err != nil { t.Fatalf("GetShortTermMemory should fallback to SQL: %v", err) } t.Logf("Short-term memories (no API key): %d", len(memories)) } func TestCleanup(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-cleanup" for i := 0; i < 15; i++ { content := fmt.Sprintf("Short-term memory %d", i) if err := mm.AddShortTermMemory(sessionID, content); err != nil { t.Fatalf("AddShortTermMemory failed: %v", err) } } if err := mm.Cleanup(); err != nil { t.Fatalf("Cleanup failed: %v", err) } memories, err := mm.GetShortTermMemory(sessionID, "") if err != nil { t.Fatalf("GetShortTermMemory failed: %v", err) } if len(memories) > 10 { t.Errorf("Expected at most 10 memories after cleanup, got %d", len(memories)) } } func TestEstimateTokens(t *testing.T) { tests := []struct { input string expected int }{ {"", 0}, } for _, tt := range tests { got := estimateTokens(tt.input) if got != tt.expected { t.Errorf("estimateTokens(%q) = %d, want %d", tt.input, got, tt.expected) } } } func TestReverseMessages(t *testing.T) { msgs := []SessionMessage{ {Content: "first"}, {Content: "second"}, {Content: "third"}, } reverseMessages(msgs) expected := []string{"third", "second", "first"} for i, msg := range msgs { if msg.Content != expected[i] { t.Errorf("reverseMessages: position %d expected %q, got %q", i, expected[i], msg.Content) } } } func TestTruncateString(t *testing.T) { tests := []struct { input string maxLen int expected string }{ {"hello", 10, "hello"}, {"hello world", 5, "hello..."}, {"", 5, ""}, {"short", 5, "short"}, } for _, tt := range tests { got := truncateString(tt.input, tt.maxLen) if got != tt.expected { t.Errorf("truncateString(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.expected) } } } func TestFullConversationFlow(t *testing.T) { mm, cleanup := setupMemoryManager(t) defer cleanup() sessionID := "test-conversation" conversation := []struct { role MessageRole content string }{ {RoleUser, "你好,我想学习Go语言"}, {RoleAssistant, "Go语言是一种由Google开发的开源编程语言,以其简洁、高效和强大的并发支持而闻名。它特别适合构建网络服务和分布式系统。"}, {RoleUser, "Go语言的并发是怎么实现的?"}, {RoleAssistant, "Go语言使用goroutine和channel实现并发。Goroutine是轻量级线程,由Go运行时管理。Channel用于goroutine之间的通信和同步。"}, {RoleUser, "能推荐一些学习资源吗?"}, {RoleAssistant, "推荐以下学习资源:1. Go官方文档 2. 《Go程序设计语言》 3. Go by Example 网站"}, } for _, msg := range conversation { sessionMsg := SessionMessage{ Role: msg.role, Content: msg.content, Timestamp: time.Now(), } if err := mm.SaveMessage(sessionID, sessionMsg); err != nil { t.Fatalf("SaveMessage failed: %v", err) } } workingMem, err := mm.GetWorkingMemory(sessionID) if err != nil { t.Fatalf("GetWorkingMemory failed: %v", err) } t.Logf("Working memory: %d messages", len(workingMem)) for i := 1; i < len(conversation); i += 2 { mm.MaintainSessionMemory(sessionID, conversation[i-1].content, conversation[i].content) } shortTerm, err := mm.GetShortTermMemory(sessionID, "") if err != nil { t.Fatalf("GetShortTermMemory failed: %v", err) } t.Logf("Short-term memories: %d", len(shortTerm)) for i, mem := range shortTerm { t.Logf(" Memory %d: %s", i, mem) } context := mm.BuildMemoryContext(sessionID, "学习资源") if context != "" { t.Logf("Memory context for '学习资源':\n%s", context) } } func BenchmarkSaveMessage(b *testing.B) { mm, cleanup := setupMemoryManager(&testing.T{}) defer cleanup() sessionID := "bench-session" msg := SessionMessage{ Role: RoleUser, Content: "这是一条测试消息", Timestamp: time.Now(), } b.ResetTimer() for i := 0; i < b.N; i++ { if err := mm.SaveMessage(sessionID, msg); err != nil { b.Fatalf("SaveMessage failed: %v", err) } } } func BenchmarkGetWorkingMemory(b *testing.B) { mm, cleanup := setupMemoryManager(&testing.T{}) defer cleanup() sessionID := "bench-session" for i := 0; i < 100; i++ { msg := SessionMessage{ Role: RoleUser, Content: fmt.Sprintf("Message %d", i), Timestamp: time.Now(), } mm.SaveMessage(sessionID, msg) } b.ResetTimer() for i := 0; i < b.N; i++ { if _, err := mm.GetWorkingMemory(sessionID); err != nil { b.Fatalf("GetWorkingMemory failed: %v", err) } } } func BenchmarkBuildMemoryContext(b *testing.B) { mm, cleanup := setupMemoryManager(&testing.T{}) defer cleanup() sessionID := "bench-context" for i := 0; i < 10; i++ { mm.AddShortTermMemory(sessionID, fmt.Sprintf("Memory %d", i)) } b.ResetTimer() for i := 0; i < b.N; i++ { _ = mm.BuildMemoryContext(sessionID, "test query") } }