package main import ( "context" "database/sql" "fmt" "os" "time" "github.com/orca/orca/pkg/bus" "github.com/orca/orca/pkg/kernel" "github.com/orca/orca/pkg/actor" "github.com/orca/orca/pkg/session" _ "modernc.org/sqlite" _ "modernc.org/sqlite/vec" ) func main() { fmt.Println("=== 记忆系统综合测试 ===") fmt.Println("测试内容:记忆检索、命中率、token节省、日常使用") k := kernel.New() if err := k.Start(); err != nil { fmt.Printf("启动失败: %v\n", err) os.Exit(1) } defer k.Stop() orch := k.Orchestrator() mm := k.MemoryManager() // 阶段1: 建立用户画像(触发长期记忆提取) fmt.Println("\n--- 阶段1: 建立用户画像 ---") sendMessage(k, "你好,我叫李四,我在金融科技公司做架构师") sendMessage(k, "我主要用Java和Kotlin,最近在研究微服务拆分") sendMessage(k, "我喜欢详细的技术解释,带架构图最好") sendMessage(k, "现在负责支付系统的重构,从单体迁移到微服务") sendMessage(k, "团队有10个人,前端用React,后端用Spring Boot") // 等待长期记忆提取 fmt.Println("\n等待长期记忆提取 (15秒)...") time.Sleep(15 * time.Second) // 阶段2: 子Agent调用(测试隔离) fmt.Println("\n--- 阶段2: 子Agent调用测试 ---") callSubAgent(orch, "coder", "写一个JWT认证的工具类,Java实现") callSubAgent(orch, "reviewer", "审查代码:public class Auth { public String token; }") // 阶段3: 查询记忆(测试检索命中率) fmt.Println("\n--- 阶段3: 记忆检索测试 ---") fmt.Println("\n查询1: 询问技术偏好(应命中长期记忆)") sendMessage(k, "你觉得Java和Go哪个更适合做支付系统?") fmt.Println("\n查询2: 询问团队信息(应命中长期记忆)") sendMessage(k, "我们团队前端用什么框架比较好?") fmt.Println("\n查询3: 询问个人背景(应命中长期记忆)") sendMessage(k, "你能根据我的背景给些微服务拆分的建议吗?") fmt.Println("\n查询4: 无关查询(测试未命中情况)") sendMessage(k, "今天天气怎么样?") // 阶段4: 等待并检查统计 fmt.Println("\n--- 阶段4: 等待并收集统计 (5秒) ---") time.Sleep(5 * time.Second) // 阶段5: 详细统计 fmt.Println("\n--- 阶段5: 详细统计 ---") printDetailedStats(mm) checkDatabase() } func sendMessage(k *kernel.Kernel, content string) string { resp, err := k.SendMessage("user", "llm", content) if err != nil { fmt.Printf(" 发送失败: %v\n", err) return "" } fmt.Printf(" 回复: %s\n", truncate(resp, 150)) return resp } func callSubAgent(orch *actor.Orchestrator, agentName, task string) { msg := bus.Message{ Type: bus.MsgTypeTaskRequest, From: "user", To: agentName, Content: task, } resp, err := orch.Process(context.Background(), msg) if err != nil { fmt.Printf(" %s 调用失败: %v\n", agentName, err) } else { fmt.Printf(" %s 回复: %s\n", agentName, truncate(fmt.Sprintf("%v", resp.Content), 150)) } } func printDetailedStats(mm *session.MemoryManager) { if mm == nil { fmt.Println("MemoryManager 未初始化") return } // Embedding缓存统计 cacheSize, cacheHits, cacheMisses := mm.CacheStats() total := cacheHits + cacheMisses hitRate := float64(0) if total > 0 { hitRate = float64(cacheHits) * 100 / float64(total) } fmt.Printf("\n[Embedding缓存统计]\n") fmt.Printf(" 缓存大小: %d\n", cacheSize) fmt.Printf(" 命中次数: %d\n", cacheHits) fmt.Printf(" 未命中次数: %d\n", cacheMisses) fmt.Printf(" 命中率: %.1f%%\n", hitRate) // 记忆上下文统计(模拟查询) queries := []string{ "Java技术栈", "团队规模", "微服务拆分", "前端框架", "天气", } fmt.Printf("\n[记忆检索测试 - %d个查询]\n", len(queries)) totalMemories := 0 referencedMemories := 0 for _, q := range queries { ctx, stats := mm.BuildMemoryContextWithStats("default", q) hasMemory := ctx != "" fmt.Printf(" 查询 '%s': 短期=%d, 长期=%d, tokens=%d, 有记忆=%v\n", q, stats.ShortTermCount, stats.LongTermCount, stats.TotalTokens, hasMemory) if hasMemory { totalMemories++ if stats.LongTermCount > 0 { referencedMemories++ } } } fmt.Printf(" 记忆命中率: %d/%d (%.0f%%)\n", referencedMemories, len(queries), float64(referencedMemories)*100/float64(len(queries))) } func checkDatabase() { dbPath := os.ExpandEnv("$HOME/.orca/sessions/orcasession.db") db, err := sql.Open("sqlite", dbPath) if err != nil { fmt.Printf("打开数据库失败: %v\n", err) return } defer db.Close() fmt.Println("\n=== 数据库统计 ===") // 表统计 fmt.Println("\n--- 各表记录数 ---") tables := []string{"main_messages", "short_term_memories", "long_term_memories", "dialogue_buffer", "subagent_messages", "memory_usage_log"} for _, table := range tables { var count int db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count) fmt.Printf(" %s: %d\n", table, count) } // 长期记忆详情 fmt.Println("\n--- 长期记忆详情 ---") rows, _ := db.Query(` SELECT id, substr(content, 1, 50), memory_type, weight, confidence, access_count, created_at FROM long_term_memories ORDER BY id DESC `) defer rows.Close() for rows.Next() { var id int var content, mtype string var weight, confidence float64 var accessCount int var createdAt string rows.Scan(&id, &content, &mtype, &weight, &confidence, &accessCount, &createdAt) fmt.Printf(" [#%d %s] weight=%.2f conf=%.2f access=%d | %s\n", id, mtype, weight, confidence, accessCount, content) } // memory_usage_log 详情 fmt.Println("\n--- 记忆使用日志 ---") rows2, _ := db.Query(` SELECT memory_id, substr(query, 1, 30), was_referenced, used_at FROM memory_usage_log ORDER BY id DESC `) defer rows2.Close() count := 0 for rows2.Next() { var memID int var query string var referenced int var usedAt string rows2.Scan(&memID, &query, &referenced, &usedAt) fmt.Printf(" memory_id=%d query='%s' referenced=%v at=%s\n", memID, query, referenced == 1, usedAt) count++ } if count == 0 { fmt.Println(" (空)") } // 子Agent统计 fmt.Println("\n--- 子Agent消息统计 ---") rows3, _ := db.Query(`SELECT agent_name, COUNT(*) FROM subagent_messages GROUP BY agent_name`) defer rows3.Close() for rows3.Next() { var agent string var cnt int rows3.Scan(&agent, &cnt) fmt.Printf(" %s: %d\n", agent, cnt) } // 对话缓冲 fmt.Println("\n--- 对话缓冲 ---") var bufCount int db.QueryRow("SELECT COUNT(*) FROM dialogue_buffer").Scan(&bufCount) fmt.Printf(" 当前缓冲条数: %d\n", bufCount) } func truncate(s string, max int) string { if len(s) <= max { return s } return s[:max] + "..." }