orca.ai/pkg/session/memory_test.go
2026-05-12 00:09:01 +08:00

578 lines
15 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package session
import (
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/orca/orca/pkg/embedding"
)
func setupMemoryManager(t *testing.T) (*MemoryManager, func()) {
t.Helper()
dir, err := os.MkdirTemp("", "orca-memory-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
cfg := MemoryConfig{
DBPath: filepath.Join(dir, "memory.db"),
ModelWindow: 8192,
EmbedConfig: embedding.Config{
APIKey: os.Getenv("SILICONFLOW_API_KEY"),
BaseURL: "https://api.siliconflow.cn/v1",
Model: "Pro/BAAI/bge-m3",
Timeout: 5000,
},
}
mm, err := NewMemoryManager(cfg)
if err != nil {
os.RemoveAll(dir)
t.Fatalf("NewMemoryManager failed: %v", err)
}
cleanup := func() {
mm.Close()
os.RemoveAll(dir)
}
return mm, cleanup
}
func TestSQLiteStore_SaveAndLoad(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-session-1"
msg := SessionMessage{
Role: RoleUser,
Content: "你好,请介绍一下自己",
Timestamp: time.Now(),
}
err := mm.SaveMessage(sessionID, msg)
if err != nil {
t.Fatalf("SaveMessage failed: %v", err)
}
messages, err := mm.GetWorkingMemory(sessionID)
if err != nil {
t.Fatalf("GetWorkingMemory failed: %v", err)
}
if len(messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(messages))
}
if messages[0].Content != msg.Content {
t.Errorf("expected content %q, got %q", msg.Content, messages[0].Content)
}
}
func TestSQLiteStore_MultipleMessages(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-session-multi"
baseTime := time.Now()
messages := []SessionMessage{
{Role: RoleUser, Content: "什么是机器学习?", Timestamp: baseTime},
{Role: RoleAssistant, Content: "机器学习是人工智能的一个分支...", Timestamp: baseTime.Add(time.Second)},
{Role: RoleUser, Content: "能举个例子吗?", Timestamp: baseTime.Add(2 * time.Second)},
{Role: RoleAssistant, Content: "比如垃圾邮件过滤器...", Timestamp: baseTime.Add(3 * time.Second)},
}
for _, msg := range messages {
if err := mm.SaveMessage(sessionID, msg); err != nil {
t.Fatalf("SaveMessage failed: %v", err)
}
}
loaded, err := mm.GetWorkingMemory(sessionID)
if err != nil {
t.Fatalf("GetWorkingMemory failed: %v", err)
}
if len(loaded) != len(messages) {
t.Fatalf("expected %d messages, got %d", len(messages), len(loaded))
}
for i, msg := range loaded {
if msg.Content != messages[i].Content {
t.Errorf("message %d: expected %q, got %q", i, messages[i].Content, msg.Content)
}
}
}
func TestCalculateBudget(t *testing.T) {
tests := []struct {
modelWindow int
wantTotal int
wantWorking int
}{
}
for _, tt := range tests {
budget := calculateBudget(tt.modelWindow)
if budget.Total != tt.wantTotal {
t.Errorf("modelWindow=%d: expected total %d, got %d",
tt.modelWindow, tt.wantTotal, budget.Total)
}
if budget.Working != tt.wantWorking {
t.Errorf("modelWindow=%d: expected working %d, got %d",
tt.modelWindow, tt.wantWorking, budget.Working)
}
if budget.ShortTerm != int(float64(budget.Total)*0.3) {
t.Errorf("ShortTerm budget incorrect: %d", budget.ShortTerm)
}
if budget.LongTerm != int(float64(budget.Total)*0.2) {
t.Errorf("LongTerm budget incorrect: %d", budget.LongTerm)
}
}
}
func TestGetWorkingMemory_TokenBudget(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-budget"
longContent := "这是一个很长的消息。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。" +
"重复多次以消耗token预算。"
messages := []SessionMessage{
{Role: RoleUser, Content: "第一条消息", Timestamp: time.Now()},
{Role: RoleAssistant, Content: longContent, Timestamp: time.Now()},
{Role: RoleUser, Content: "第三条消息", Timestamp: time.Now()},
}
for _, msg := range messages {
if err := mm.SaveMessage(sessionID, msg); err != nil {
t.Fatalf("SaveMessage failed: %v", err)
}
}
// Get working memory with budget
loaded, err := mm.GetWorkingMemory(sessionID)
if err != nil {
t.Fatalf("GetWorkingMemory failed: %v", err)
}
// Should return messages within budget
// The long message should cause earlier messages to be excluded
t.Logf("Loaded %d messages within budget", len(loaded))
for i, msg := range loaded {
t.Logf("Message %d: role=%s, len=%d", i, msg.Role, len(msg.Content))
}
}
func TestVectorStore_SaveAndSearch(t *testing.T) {
apiKey := os.Getenv("SILICONFLOW_API_KEY")
if apiKey == "" {
t.Skip("Skipping vector test: SILICONFLOW_API_KEY not set")
}
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-vectors"
messages := []SessionMessage{
{Role: RoleUser, Content: "Python 是什么编程语言?", Timestamp: time.Now()},
{Role: RoleAssistant, Content: "Python 是一种高级编程语言,以其简洁的语法而闻名", Timestamp: time.Now()},
{Role: RoleUser, Content: "Go 语言的特点是什么?", Timestamp: time.Now()},
{Role: RoleAssistant, Content: "Go 语言由 Google 开发,强调并发和性能", Timestamp: time.Now()},
{Role: RoleUser, Content: "今天天气怎么样?", Timestamp: time.Now()},
}
for _, msg := range messages {
if err := mm.SaveMessage(sessionID, msg); err != nil {
t.Fatalf("SaveMessage failed: %v", err)
}
}
time.Sleep(2 * time.Second)
results, err := mm.GetShortTermMemory(sessionID, "编程语言")
if err != nil {
t.Fatalf("GetShortTermMemory with vector search failed: %v", err)
}
t.Logf("Vector search returned %d results", len(results))
for i, r := range results {
t.Logf("Result %d: %s", i, r)
}
}
func TestVectorStore_CrossSessionSearch(t *testing.T) {
apiKey := os.Getenv("SILICONFLOW_API_KEY")
if apiKey == "" {
t.Skip("Skipping vector test: SILICONFLOW_API_KEY not set")
}
mm, cleanup := setupMemoryManager(t)
defer cleanup()
session1 := "session-ai"
for _, msg := range []SessionMessage{
{Role: RoleUser, Content: "什么是深度学习?", Timestamp: time.Now()},
{Role: RoleAssistant, Content: "深度学习是机器学习的一个子集,使用神经网络", Timestamp: time.Now()},
} {
if err := mm.SaveMessage(session1, msg); err != nil {
t.Fatalf("SaveMessage failed: %v", err)
}
}
session2 := "session-cooking"
for _, msg := range []SessionMessage{
{Role: RoleUser, Content: "如何做红烧肉?", Timestamp: time.Now()},
{Role: RoleAssistant, Content: "红烧肉需要五花肉、酱油、糖等材料", Timestamp: time.Now()},
} {
if err := mm.SaveMessage(session2, msg); err != nil {
t.Fatalf("SaveMessage failed: %v", err)
}
}
time.Sleep(2 * time.Second)
results, err := mm.GetLongTermMemory("神经网络")
if err != nil {
t.Fatalf("GetLongTermMemory failed: %v", err)
}
t.Logf("Cross-session search returned %d results", len(results))
for i, r := range results {
t.Logf("Result %d: %s (weight=%.2f)", i, r.Content, r.Weight)
}
}
func TestMaintainSessionMemory(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-maintenance"
userQuery := "什么是REST API"
assistantResponse := "REST APIRepresentational State Transfer是一种软件架构风格用于设计网络应用程序。它使用HTTP方法GET、POST、PUT、DELETE来操作资源。"
mm.MaintainSessionMemory(sessionID, userQuery, assistantResponse)
memories, err := mm.GetShortTermMemory(sessionID, "")
if err != nil {
t.Fatalf("GetShortTermMemory failed: %v", err)
}
if len(memories) == 0 {
t.Fatal("Expected short-term memory to be created")
}
t.Logf("Created %d short-term memories", len(memories))
for i, m := range memories {
t.Logf("Memory %d: %s", i, m)
}
}
func TestAddLongTermMemory(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
memories := []struct {
content string
mType string
}{
{"用户喜欢使用Python进行数据分析", "preference"},
{"项目使用Go语言开发", "project"},
{"用户偏好简洁的代码风格", "preference"},
}
for _, m := range memories {
if err := mm.AddLongTermMemory(m.content, m.mType); err != nil {
t.Fatalf("AddLongTermMemory failed: %v", err)
}
}
results, err := mm.GetLongTermMemory("")
if err != nil {
t.Fatalf("GetLongTermMemory failed: %v", err)
}
if len(results) == 0 {
t.Fatal("Expected long-term memories to be retrieved")
}
t.Logf("Retrieved %d long-term memories", len(results))
for i, r := range results {
t.Logf("Memory %d: %s (weight=%.2f)", i, r.Content, r.Weight)
}
}
func TestBuildMemoryContext(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-context"
mm.AddShortTermMemory(sessionID, "用户正在学习Go语言")
mm.AddShortTermMemory(sessionID, "用户之前问过关于goroutine的问题")
mm.AddLongTermMemory("用户是后端开发工程师", "fact")
mm.AddLongTermMemory("用户偏好技术文档", "preference")
context := mm.BuildMemoryContext(sessionID, "并发编程")
if context == "" {
t.Fatal("Expected memory context to be built")
}
t.Logf("Memory context:\n%s", context)
}
func TestMemoryManager_WithoutAPIKey(t *testing.T) {
dir, err := os.MkdirTemp("", "orca-memory-test-nokey-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(dir)
cfg := MemoryConfig{
DBPath: filepath.Join(dir, "memory.db"),
ModelWindow: 8192,
EmbedConfig: embedding.Config{
},
}
mm, err := NewMemoryManager(cfg)
if err != nil {
t.Fatalf("NewMemoryManager should work without API key: %v", err)
}
defer mm.Close()
sessionID := "test-nokey"
msg := SessionMessage{
Role: RoleUser,
Content: "测试无API Key模式",
Timestamp: time.Now(),
}
if err := mm.SaveMessage(sessionID, msg); err != nil {
t.Fatalf("SaveMessage should work without API key: %v", err)
}
messages, err := mm.GetWorkingMemory(sessionID)
if err != nil {
t.Fatalf("GetWorkingMemory should work without API key: %v", err)
}
if len(messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(messages))
}
memories, err := mm.GetShortTermMemory(sessionID, "")
if err != nil {
t.Fatalf("GetShortTermMemory should fallback to SQL: %v", err)
}
t.Logf("Short-term memories (no API key): %d", len(memories))
}
func TestCleanup(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-cleanup"
for i := 0; i < 15; i++ {
content := fmt.Sprintf("Short-term memory %d", i)
if err := mm.AddShortTermMemory(sessionID, content); err != nil {
t.Fatalf("AddShortTermMemory failed: %v", err)
}
}
if err := mm.Cleanup(); err != nil {
t.Fatalf("Cleanup failed: %v", err)
}
memories, err := mm.GetShortTermMemory(sessionID, "")
if err != nil {
t.Fatalf("GetShortTermMemory failed: %v", err)
}
if len(memories) > 10 {
t.Errorf("Expected at most 10 memories after cleanup, got %d", len(memories))
}
}
func TestEstimateTokens(t *testing.T) {
tests := []struct {
input string
expected int
}{
{"", 0},
}
for _, tt := range tests {
got := estimateTokens(tt.input)
if got != tt.expected {
t.Errorf("estimateTokens(%q) = %d, want %d", tt.input, got, tt.expected)
}
}
}
func TestReverseMessages(t *testing.T) {
msgs := []SessionMessage{
{Content: "first"},
{Content: "second"},
{Content: "third"},
}
reverseMessages(msgs)
expected := []string{"third", "second", "first"}
for i, msg := range msgs {
if msg.Content != expected[i] {
t.Errorf("reverseMessages: position %d expected %q, got %q",
i, expected[i], msg.Content)
}
}
}
func TestTruncateString(t *testing.T) {
tests := []struct {
input string
maxLen int
expected string
}{
{"hello", 10, "hello"},
{"hello world", 5, "hello..."},
{"", 5, ""},
{"short", 5, "short"},
}
for _, tt := range tests {
got := truncateString(tt.input, tt.maxLen)
if got != tt.expected {
t.Errorf("truncateString(%q, %d) = %q, want %q",
tt.input, tt.maxLen, got, tt.expected)
}
}
}
func TestFullConversationFlow(t *testing.T) {
mm, cleanup := setupMemoryManager(t)
defer cleanup()
sessionID := "test-conversation"
conversation := []struct {
role MessageRole
content string
}{
{RoleUser, "你好我想学习Go语言"},
{RoleAssistant, "Go语言是一种由Google开发的开源编程语言以其简洁、高效和强大的并发支持而闻名。它特别适合构建网络服务和分布式系统。"},
{RoleUser, "Go语言的并发是怎么实现的"},
{RoleAssistant, "Go语言使用goroutine和channel实现并发。Goroutine是轻量级线程由Go运行时管理。Channel用于goroutine之间的通信和同步。"},
{RoleUser, "能推荐一些学习资源吗?"},
{RoleAssistant, "推荐以下学习资源1. Go官方文档 2. 《Go程序设计语言》 3. Go by Example 网站"},
}
for _, msg := range conversation {
sessionMsg := SessionMessage{
Role: msg.role,
Content: msg.content,
Timestamp: time.Now(),
}
if err := mm.SaveMessage(sessionID, sessionMsg); err != nil {
t.Fatalf("SaveMessage failed: %v", err)
}
}
workingMem, err := mm.GetWorkingMemory(sessionID)
if err != nil {
t.Fatalf("GetWorkingMemory failed: %v", err)
}
t.Logf("Working memory: %d messages", len(workingMem))
for i := 1; i < len(conversation); i += 2 {
mm.MaintainSessionMemory(sessionID, conversation[i-1].content, conversation[i].content)
}
shortTerm, err := mm.GetShortTermMemory(sessionID, "")
if err != nil {
t.Fatalf("GetShortTermMemory failed: %v", err)
}
t.Logf("Short-term memories: %d", len(shortTerm))
for i, mem := range shortTerm {
t.Logf(" Memory %d: %s", i, mem)
}
context := mm.BuildMemoryContext(sessionID, "学习资源")
if context != "" {
t.Logf("Memory context for '学习资源':\n%s", context)
}
}
func BenchmarkSaveMessage(b *testing.B) {
mm, cleanup := setupMemoryManager(&testing.T{})
defer cleanup()
sessionID := "bench-session"
msg := SessionMessage{
Role: RoleUser,
Content: "这是一条测试消息",
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := mm.SaveMessage(sessionID, msg); err != nil {
b.Fatalf("SaveMessage failed: %v", err)
}
}
}
func BenchmarkGetWorkingMemory(b *testing.B) {
mm, cleanup := setupMemoryManager(&testing.T{})
defer cleanup()
sessionID := "bench-session"
for i := 0; i < 100; i++ {
msg := SessionMessage{
Role: RoleUser,
Content: fmt.Sprintf("Message %d", i),
Timestamp: time.Now(),
}
mm.SaveMessage(sessionID, msg)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := mm.GetWorkingMemory(sessionID); err != nil {
b.Fatalf("GetWorkingMemory failed: %v", err)
}
}
}
func BenchmarkBuildMemoryContext(b *testing.B) {
mm, cleanup := setupMemoryManager(&testing.T{})
defer cleanup()
sessionID := "bench-context"
for i := 0; i < 10; i++ {
mm.AddShortTermMemory(sessionID, fmt.Sprintf("Memory %d", i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = mm.BuildMemoryContext(sessionID, "test query")
}
}