578 lines
15 KiB
Go
578 lines
15 KiB
Go
package session
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/orca/orca/pkg/embedding"
|
||
)
|
||
|
||
func setupMemoryManager(t *testing.T) (*MemoryManager, func()) {
|
||
t.Helper()
|
||
dir, err := os.MkdirTemp("", "orca-memory-test-*")
|
||
if err != nil {
|
||
t.Fatalf("failed to create temp dir: %v", err)
|
||
}
|
||
|
||
cfg := MemoryConfig{
|
||
DBPath: filepath.Join(dir, "memory.db"),
|
||
ModelWindow: 8192,
|
||
EmbedConfig: embedding.Config{
|
||
APIKey: os.Getenv("SILICONFLOW_API_KEY"),
|
||
BaseURL: "https://api.siliconflow.cn/v1",
|
||
Model: "Pro/BAAI/bge-m3",
|
||
Timeout: 5000,
|
||
},
|
||
}
|
||
|
||
mm, err := NewMemoryManager(cfg)
|
||
if err != nil {
|
||
os.RemoveAll(dir)
|
||
t.Fatalf("NewMemoryManager failed: %v", err)
|
||
}
|
||
|
||
cleanup := func() {
|
||
mm.Close()
|
||
os.RemoveAll(dir)
|
||
}
|
||
return mm, cleanup
|
||
}
|
||
|
||
|
||
func TestSQLiteStore_SaveAndLoad(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-session-1"
|
||
msg := SessionMessage{
|
||
Role: RoleUser,
|
||
Content: "你好,请介绍一下自己",
|
||
Timestamp: time.Now(),
|
||
}
|
||
|
||
err := mm.SaveMessage(sessionID, msg)
|
||
if err != nil {
|
||
t.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
|
||
messages, err := mm.GetWorkingMemory(sessionID)
|
||
if err != nil {
|
||
t.Fatalf("GetWorkingMemory failed: %v", err)
|
||
}
|
||
if len(messages) != 1 {
|
||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||
}
|
||
if messages[0].Content != msg.Content {
|
||
t.Errorf("expected content %q, got %q", msg.Content, messages[0].Content)
|
||
}
|
||
}
|
||
|
||
func TestSQLiteStore_MultipleMessages(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-session-multi"
|
||
baseTime := time.Now()
|
||
messages := []SessionMessage{
|
||
{Role: RoleUser, Content: "什么是机器学习?", Timestamp: baseTime},
|
||
{Role: RoleAssistant, Content: "机器学习是人工智能的一个分支...", Timestamp: baseTime.Add(time.Second)},
|
||
{Role: RoleUser, Content: "能举个例子吗?", Timestamp: baseTime.Add(2 * time.Second)},
|
||
{Role: RoleAssistant, Content: "比如垃圾邮件过滤器...", Timestamp: baseTime.Add(3 * time.Second)},
|
||
}
|
||
|
||
for _, msg := range messages {
|
||
if err := mm.SaveMessage(sessionID, msg); err != nil {
|
||
t.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
}
|
||
|
||
loaded, err := mm.GetWorkingMemory(sessionID)
|
||
if err != nil {
|
||
t.Fatalf("GetWorkingMemory failed: %v", err)
|
||
}
|
||
if len(loaded) != len(messages) {
|
||
t.Fatalf("expected %d messages, got %d", len(messages), len(loaded))
|
||
}
|
||
|
||
for i, msg := range loaded {
|
||
if msg.Content != messages[i].Content {
|
||
t.Errorf("message %d: expected %q, got %q", i, messages[i].Content, msg.Content)
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
func TestCalculateBudget(t *testing.T) {
|
||
tests := []struct {
|
||
modelWindow int
|
||
wantTotal int
|
||
wantWorking int
|
||
}{
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
budget := calculateBudget(tt.modelWindow)
|
||
if budget.Total != tt.wantTotal {
|
||
t.Errorf("modelWindow=%d: expected total %d, got %d",
|
||
tt.modelWindow, tt.wantTotal, budget.Total)
|
||
}
|
||
if budget.Working != tt.wantWorking {
|
||
t.Errorf("modelWindow=%d: expected working %d, got %d",
|
||
tt.modelWindow, tt.wantWorking, budget.Working)
|
||
}
|
||
if budget.ShortTerm != int(float64(budget.Total)*0.3) {
|
||
t.Errorf("ShortTerm budget incorrect: %d", budget.ShortTerm)
|
||
}
|
||
if budget.LongTerm != int(float64(budget.Total)*0.2) {
|
||
t.Errorf("LongTerm budget incorrect: %d", budget.LongTerm)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestGetWorkingMemory_TokenBudget(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-budget"
|
||
|
||
longContent := "这是一个很长的消息。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。" +
|
||
"重复多次以消耗token预算。"
|
||
|
||
messages := []SessionMessage{
|
||
{Role: RoleUser, Content: "第一条消息", Timestamp: time.Now()},
|
||
{Role: RoleAssistant, Content: longContent, Timestamp: time.Now()},
|
||
{Role: RoleUser, Content: "第三条消息", Timestamp: time.Now()},
|
||
}
|
||
|
||
for _, msg := range messages {
|
||
if err := mm.SaveMessage(sessionID, msg); err != nil {
|
||
t.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
}
|
||
|
||
// Get working memory with budget
|
||
loaded, err := mm.GetWorkingMemory(sessionID)
|
||
if err != nil {
|
||
t.Fatalf("GetWorkingMemory failed: %v", err)
|
||
}
|
||
|
||
// Should return messages within budget
|
||
// The long message should cause earlier messages to be excluded
|
||
t.Logf("Loaded %d messages within budget", len(loaded))
|
||
for i, msg := range loaded {
|
||
t.Logf("Message %d: role=%s, len=%d", i, msg.Role, len(msg.Content))
|
||
}
|
||
}
|
||
|
||
|
||
func TestVectorStore_SaveAndSearch(t *testing.T) {
|
||
apiKey := os.Getenv("SILICONFLOW_API_KEY")
|
||
if apiKey == "" {
|
||
t.Skip("Skipping vector test: SILICONFLOW_API_KEY not set")
|
||
}
|
||
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-vectors"
|
||
|
||
messages := []SessionMessage{
|
||
{Role: RoleUser, Content: "Python 是什么编程语言?", Timestamp: time.Now()},
|
||
{Role: RoleAssistant, Content: "Python 是一种高级编程语言,以其简洁的语法而闻名", Timestamp: time.Now()},
|
||
{Role: RoleUser, Content: "Go 语言的特点是什么?", Timestamp: time.Now()},
|
||
{Role: RoleAssistant, Content: "Go 语言由 Google 开发,强调并发和性能", Timestamp: time.Now()},
|
||
{Role: RoleUser, Content: "今天天气怎么样?", Timestamp: time.Now()},
|
||
}
|
||
|
||
for _, msg := range messages {
|
||
if err := mm.SaveMessage(sessionID, msg); err != nil {
|
||
t.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
}
|
||
|
||
time.Sleep(2 * time.Second)
|
||
|
||
results, err := mm.GetShortTermMemory(sessionID, "编程语言")
|
||
if err != nil {
|
||
t.Fatalf("GetShortTermMemory with vector search failed: %v", err)
|
||
}
|
||
|
||
t.Logf("Vector search returned %d results", len(results))
|
||
for i, r := range results {
|
||
t.Logf("Result %d: %s", i, r)
|
||
}
|
||
}
|
||
|
||
func TestVectorStore_CrossSessionSearch(t *testing.T) {
|
||
apiKey := os.Getenv("SILICONFLOW_API_KEY")
|
||
if apiKey == "" {
|
||
t.Skip("Skipping vector test: SILICONFLOW_API_KEY not set")
|
||
}
|
||
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
session1 := "session-ai"
|
||
for _, msg := range []SessionMessage{
|
||
{Role: RoleUser, Content: "什么是深度学习?", Timestamp: time.Now()},
|
||
{Role: RoleAssistant, Content: "深度学习是机器学习的一个子集,使用神经网络", Timestamp: time.Now()},
|
||
} {
|
||
if err := mm.SaveMessage(session1, msg); err != nil {
|
||
t.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
}
|
||
|
||
session2 := "session-cooking"
|
||
for _, msg := range []SessionMessage{
|
||
{Role: RoleUser, Content: "如何做红烧肉?", Timestamp: time.Now()},
|
||
{Role: RoleAssistant, Content: "红烧肉需要五花肉、酱油、糖等材料", Timestamp: time.Now()},
|
||
} {
|
||
if err := mm.SaveMessage(session2, msg); err != nil {
|
||
t.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
}
|
||
|
||
time.Sleep(2 * time.Second)
|
||
|
||
results, err := mm.GetLongTermMemory("神经网络")
|
||
if err != nil {
|
||
t.Fatalf("GetLongTermMemory failed: %v", err)
|
||
}
|
||
|
||
t.Logf("Cross-session search returned %d results", len(results))
|
||
for i, r := range results {
|
||
t.Logf("Result %d: %s (weight=%.2f)", i, r.Content, r.Weight)
|
||
}
|
||
}
|
||
|
||
|
||
func TestMaintainSessionMemory(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-maintenance"
|
||
userQuery := "什么是REST API?"
|
||
assistantResponse := "REST API(Representational State Transfer)是一种软件架构风格,用于设计网络应用程序。它使用HTTP方法(GET、POST、PUT、DELETE)来操作资源。"
|
||
|
||
mm.MaintainSessionMemory(sessionID, userQuery, assistantResponse)
|
||
|
||
memories, err := mm.GetShortTermMemory(sessionID, "")
|
||
if err != nil {
|
||
t.Fatalf("GetShortTermMemory failed: %v", err)
|
||
}
|
||
if len(memories) == 0 {
|
||
t.Fatal("Expected short-term memory to be created")
|
||
}
|
||
|
||
t.Logf("Created %d short-term memories", len(memories))
|
||
for i, m := range memories {
|
||
t.Logf("Memory %d: %s", i, m)
|
||
}
|
||
}
|
||
|
||
func TestAddLongTermMemory(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
memories := []struct {
|
||
content string
|
||
mType string
|
||
}{
|
||
{"用户喜欢使用Python进行数据分析", "preference"},
|
||
{"项目使用Go语言开发", "project"},
|
||
{"用户偏好简洁的代码风格", "preference"},
|
||
}
|
||
|
||
for _, m := range memories {
|
||
if err := mm.AddLongTermMemory(m.content, m.mType); err != nil {
|
||
t.Fatalf("AddLongTermMemory failed: %v", err)
|
||
}
|
||
}
|
||
|
||
results, err := mm.GetLongTermMemory("")
|
||
if err != nil {
|
||
t.Fatalf("GetLongTermMemory failed: %v", err)
|
||
}
|
||
if len(results) == 0 {
|
||
t.Fatal("Expected long-term memories to be retrieved")
|
||
}
|
||
|
||
t.Logf("Retrieved %d long-term memories", len(results))
|
||
for i, r := range results {
|
||
t.Logf("Memory %d: %s (weight=%.2f)", i, r.Content, r.Weight)
|
||
}
|
||
}
|
||
|
||
func TestBuildMemoryContext(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-context"
|
||
|
||
mm.AddShortTermMemory(sessionID, "用户正在学习Go语言")
|
||
mm.AddShortTermMemory(sessionID, "用户之前问过关于goroutine的问题")
|
||
|
||
mm.AddLongTermMemory("用户是后端开发工程师", "fact")
|
||
mm.AddLongTermMemory("用户偏好技术文档", "preference")
|
||
|
||
context := mm.BuildMemoryContext(sessionID, "并发编程")
|
||
if context == "" {
|
||
t.Fatal("Expected memory context to be built")
|
||
}
|
||
|
||
t.Logf("Memory context:\n%s", context)
|
||
}
|
||
|
||
|
||
func TestMemoryManager_WithoutAPIKey(t *testing.T) {
|
||
dir, err := os.MkdirTemp("", "orca-memory-test-nokey-*")
|
||
if err != nil {
|
||
t.Fatalf("failed to create temp dir: %v", err)
|
||
}
|
||
defer os.RemoveAll(dir)
|
||
|
||
cfg := MemoryConfig{
|
||
DBPath: filepath.Join(dir, "memory.db"),
|
||
ModelWindow: 8192,
|
||
EmbedConfig: embedding.Config{
|
||
},
|
||
}
|
||
|
||
mm, err := NewMemoryManager(cfg)
|
||
if err != nil {
|
||
t.Fatalf("NewMemoryManager should work without API key: %v", err)
|
||
}
|
||
defer mm.Close()
|
||
|
||
sessionID := "test-nokey"
|
||
msg := SessionMessage{
|
||
Role: RoleUser,
|
||
Content: "测试无API Key模式",
|
||
Timestamp: time.Now(),
|
||
}
|
||
|
||
if err := mm.SaveMessage(sessionID, msg); err != nil {
|
||
t.Fatalf("SaveMessage should work without API key: %v", err)
|
||
}
|
||
|
||
messages, err := mm.GetWorkingMemory(sessionID)
|
||
if err != nil {
|
||
t.Fatalf("GetWorkingMemory should work without API key: %v", err)
|
||
}
|
||
if len(messages) != 1 {
|
||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||
}
|
||
|
||
memories, err := mm.GetShortTermMemory(sessionID, "")
|
||
if err != nil {
|
||
t.Fatalf("GetShortTermMemory should fallback to SQL: %v", err)
|
||
}
|
||
t.Logf("Short-term memories (no API key): %d", len(memories))
|
||
}
|
||
|
||
|
||
func TestCleanup(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-cleanup"
|
||
|
||
for i := 0; i < 15; i++ {
|
||
content := fmt.Sprintf("Short-term memory %d", i)
|
||
if err := mm.AddShortTermMemory(sessionID, content); err != nil {
|
||
t.Fatalf("AddShortTermMemory failed: %v", err)
|
||
}
|
||
}
|
||
|
||
if err := mm.Cleanup(); err != nil {
|
||
t.Fatalf("Cleanup failed: %v", err)
|
||
}
|
||
|
||
memories, err := mm.GetShortTermMemory(sessionID, "")
|
||
if err != nil {
|
||
t.Fatalf("GetShortTermMemory failed: %v", err)
|
||
}
|
||
if len(memories) > 10 {
|
||
t.Errorf("Expected at most 10 memories after cleanup, got %d", len(memories))
|
||
}
|
||
}
|
||
|
||
|
||
func TestEstimateTokens(t *testing.T) {
|
||
tests := []struct {
|
||
input string
|
||
expected int
|
||
}{
|
||
{"", 0},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
got := estimateTokens(tt.input)
|
||
if got != tt.expected {
|
||
t.Errorf("estimateTokens(%q) = %d, want %d", tt.input, got, tt.expected)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestReverseMessages(t *testing.T) {
|
||
msgs := []SessionMessage{
|
||
{Content: "first"},
|
||
{Content: "second"},
|
||
{Content: "third"},
|
||
}
|
||
|
||
reverseMessages(msgs)
|
||
|
||
expected := []string{"third", "second", "first"}
|
||
for i, msg := range msgs {
|
||
if msg.Content != expected[i] {
|
||
t.Errorf("reverseMessages: position %d expected %q, got %q",
|
||
i, expected[i], msg.Content)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTruncateString(t *testing.T) {
|
||
tests := []struct {
|
||
input string
|
||
maxLen int
|
||
expected string
|
||
}{
|
||
{"hello", 10, "hello"},
|
||
{"hello world", 5, "hello..."},
|
||
{"", 5, ""},
|
||
{"short", 5, "short"},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
got := truncateString(tt.input, tt.maxLen)
|
||
if got != tt.expected {
|
||
t.Errorf("truncateString(%q, %d) = %q, want %q",
|
||
tt.input, tt.maxLen, got, tt.expected)
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
func TestFullConversationFlow(t *testing.T) {
|
||
mm, cleanup := setupMemoryManager(t)
|
||
defer cleanup()
|
||
|
||
sessionID := "test-conversation"
|
||
|
||
conversation := []struct {
|
||
role MessageRole
|
||
content string
|
||
}{
|
||
{RoleUser, "你好,我想学习Go语言"},
|
||
{RoleAssistant, "Go语言是一种由Google开发的开源编程语言,以其简洁、高效和强大的并发支持而闻名。它特别适合构建网络服务和分布式系统。"},
|
||
{RoleUser, "Go语言的并发是怎么实现的?"},
|
||
{RoleAssistant, "Go语言使用goroutine和channel实现并发。Goroutine是轻量级线程,由Go运行时管理。Channel用于goroutine之间的通信和同步。"},
|
||
{RoleUser, "能推荐一些学习资源吗?"},
|
||
{RoleAssistant, "推荐以下学习资源:1. Go官方文档 2. 《Go程序设计语言》 3. Go by Example 网站"},
|
||
}
|
||
|
||
for _, msg := range conversation {
|
||
sessionMsg := SessionMessage{
|
||
Role: msg.role,
|
||
Content: msg.content,
|
||
Timestamp: time.Now(),
|
||
}
|
||
if err := mm.SaveMessage(sessionID, sessionMsg); err != nil {
|
||
t.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
}
|
||
|
||
workingMem, err := mm.GetWorkingMemory(sessionID)
|
||
if err != nil {
|
||
t.Fatalf("GetWorkingMemory failed: %v", err)
|
||
}
|
||
t.Logf("Working memory: %d messages", len(workingMem))
|
||
|
||
for i := 1; i < len(conversation); i += 2 {
|
||
mm.MaintainSessionMemory(sessionID, conversation[i-1].content, conversation[i].content)
|
||
}
|
||
|
||
shortTerm, err := mm.GetShortTermMemory(sessionID, "")
|
||
if err != nil {
|
||
t.Fatalf("GetShortTermMemory failed: %v", err)
|
||
}
|
||
t.Logf("Short-term memories: %d", len(shortTerm))
|
||
for i, mem := range shortTerm {
|
||
t.Logf(" Memory %d: %s", i, mem)
|
||
}
|
||
|
||
context := mm.BuildMemoryContext(sessionID, "学习资源")
|
||
if context != "" {
|
||
t.Logf("Memory context for '学习资源':\n%s", context)
|
||
}
|
||
}
|
||
|
||
|
||
func BenchmarkSaveMessage(b *testing.B) {
|
||
mm, cleanup := setupMemoryManager(&testing.T{})
|
||
defer cleanup()
|
||
|
||
sessionID := "bench-session"
|
||
msg := SessionMessage{
|
||
Role: RoleUser,
|
||
Content: "这是一条测试消息",
|
||
Timestamp: time.Now(),
|
||
}
|
||
|
||
b.ResetTimer()
|
||
for i := 0; i < b.N; i++ {
|
||
if err := mm.SaveMessage(sessionID, msg); err != nil {
|
||
b.Fatalf("SaveMessage failed: %v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func BenchmarkGetWorkingMemory(b *testing.B) {
|
||
mm, cleanup := setupMemoryManager(&testing.T{})
|
||
defer cleanup()
|
||
|
||
sessionID := "bench-session"
|
||
for i := 0; i < 100; i++ {
|
||
msg := SessionMessage{
|
||
Role: RoleUser,
|
||
Content: fmt.Sprintf("Message %d", i),
|
||
Timestamp: time.Now(),
|
||
}
|
||
mm.SaveMessage(sessionID, msg)
|
||
}
|
||
|
||
b.ResetTimer()
|
||
for i := 0; i < b.N; i++ {
|
||
if _, err := mm.GetWorkingMemory(sessionID); err != nil {
|
||
b.Fatalf("GetWorkingMemory failed: %v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func BenchmarkBuildMemoryContext(b *testing.B) {
|
||
mm, cleanup := setupMemoryManager(&testing.T{})
|
||
defer cleanup()
|
||
|
||
sessionID := "bench-context"
|
||
for i := 0; i < 10; i++ {
|
||
mm.AddShortTermMemory(sessionID, fmt.Sprintf("Memory %d", i))
|
||
}
|
||
|
||
b.ResetTimer()
|
||
for i := 0; i < b.N; i++ {
|
||
_ = mm.BuildMemoryContext(sessionID, "test query")
|
||
}
|
||
}
|