orca.ai/test_memory_retrieval.go
2026-05-12 00:09:01 +08:00

228 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"database/sql"
"fmt"
"os"
"time"
"github.com/orca/orca/pkg/bus"
"github.com/orca/orca/pkg/kernel"
"github.com/orca/orca/pkg/actor"
"github.com/orca/orca/pkg/session"
_ "modernc.org/sqlite"
_ "modernc.org/sqlite/vec"
)
func main() {
fmt.Println("=== 记忆系统综合测试 ===")
fmt.Println("测试内容记忆检索、命中率、token节省、日常使用")
k := kernel.New()
if err := k.Start(); err != nil {
fmt.Printf("启动失败: %v\n", err)
os.Exit(1)
}
defer k.Stop()
orch := k.Orchestrator()
mm := k.MemoryManager()
// 阶段1: 建立用户画像(触发长期记忆提取)
fmt.Println("\n--- 阶段1: 建立用户画像 ---")
sendMessage(k, "你好,我叫李四,我在金融科技公司做架构师")
sendMessage(k, "我主要用Java和Kotlin最近在研究微服务拆分")
sendMessage(k, "我喜欢详细的技术解释,带架构图最好")
sendMessage(k, "现在负责支付系统的重构,从单体迁移到微服务")
sendMessage(k, "团队有10个人前端用React后端用Spring Boot")
// 等待长期记忆提取
fmt.Println("\n等待长期记忆提取 (15秒)...")
time.Sleep(15 * time.Second)
// 阶段2: 子Agent调用测试隔离
fmt.Println("\n--- 阶段2: 子Agent调用测试 ---")
callSubAgent(orch, "coder", "写一个JWT认证的工具类Java实现")
callSubAgent(orch, "reviewer", "审查代码public class Auth { public String token; }")
// 阶段3: 查询记忆(测试检索命中率)
fmt.Println("\n--- 阶段3: 记忆检索测试 ---")
fmt.Println("\n查询1: 询问技术偏好(应命中长期记忆)")
sendMessage(k, "你觉得Java和Go哪个更适合做支付系统")
fmt.Println("\n查询2: 询问团队信息(应命中长期记忆)")
sendMessage(k, "我们团队前端用什么框架比较好?")
fmt.Println("\n查询3: 询问个人背景(应命中长期记忆)")
sendMessage(k, "你能根据我的背景给些微服务拆分的建议吗?")
fmt.Println("\n查询4: 无关查询(测试未命中情况)")
sendMessage(k, "今天天气怎么样?")
// 阶段4: 等待并检查统计
fmt.Println("\n--- 阶段4: 等待并收集统计 (5秒) ---")
time.Sleep(5 * time.Second)
// 阶段5: 详细统计
fmt.Println("\n--- 阶段5: 详细统计 ---")
printDetailedStats(mm)
checkDatabase()
}
func sendMessage(k *kernel.Kernel, content string) string {
resp, err := k.SendMessage("user", "llm", content)
if err != nil {
fmt.Printf(" 发送失败: %v\n", err)
return ""
}
fmt.Printf(" 回复: %s\n", truncate(resp, 150))
return resp
}
func callSubAgent(orch *actor.Orchestrator, agentName, task string) {
msg := bus.Message{
Type: bus.MsgTypeTaskRequest,
From: "user",
To: agentName,
Content: task,
}
resp, err := orch.Process(context.Background(), msg)
if err != nil {
fmt.Printf(" %s 调用失败: %v\n", agentName, err)
} else {
fmt.Printf(" %s 回复: %s\n", agentName, truncate(fmt.Sprintf("%v", resp.Content), 150))
}
}
func printDetailedStats(mm *session.MemoryManager) {
if mm == nil {
fmt.Println("MemoryManager 未初始化")
return
}
// Embedding缓存统计
cacheSize, cacheHits, cacheMisses := mm.CacheStats()
total := cacheHits + cacheMisses
hitRate := float64(0)
if total > 0 {
hitRate = float64(cacheHits) * 100 / float64(total)
}
fmt.Printf("\n[Embedding缓存统计]\n")
fmt.Printf(" 缓存大小: %d\n", cacheSize)
fmt.Printf(" 命中次数: %d\n", cacheHits)
fmt.Printf(" 未命中次数: %d\n", cacheMisses)
fmt.Printf(" 命中率: %.1f%%\n", hitRate)
// 记忆上下文统计(模拟查询)
queries := []string{
"Java技术栈",
"团队规模",
"微服务拆分",
"前端框架",
"天气",
}
fmt.Printf("\n[记忆检索测试 - %d个查询]\n", len(queries))
totalMemories := 0
referencedMemories := 0
for _, q := range queries {
ctx, stats := mm.BuildMemoryContextWithStats("default", q)
hasMemory := ctx != ""
fmt.Printf(" 查询 '%s': 短期=%d, 长期=%d, tokens=%d, 有记忆=%v\n",
q, stats.ShortTermCount, stats.LongTermCount, stats.TotalTokens, hasMemory)
if hasMemory {
totalMemories++
if stats.LongTermCount > 0 {
referencedMemories++
}
}
}
fmt.Printf(" 记忆命中率: %d/%d (%.0f%%)\n", referencedMemories, len(queries), float64(referencedMemories)*100/float64(len(queries)))
}
func checkDatabase() {
dbPath := os.ExpandEnv("$HOME/.orca/sessions/orcasession.db")
db, err := sql.Open("sqlite", dbPath)
if err != nil {
fmt.Printf("打开数据库失败: %v\n", err)
return
}
defer db.Close()
fmt.Println("\n=== 数据库统计 ===")
// 表统计
fmt.Println("\n--- 各表记录数 ---")
tables := []string{"main_messages", "short_term_memories", "long_term_memories", "dialogue_buffer", "subagent_messages", "memory_usage_log"}
for _, table := range tables {
var count int
db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count)
fmt.Printf(" %s: %d\n", table, count)
}
// 长期记忆详情
fmt.Println("\n--- 长期记忆详情 ---")
rows, _ := db.Query(`
SELECT id, substr(content, 1, 50), memory_type, weight, confidence, access_count, created_at
FROM long_term_memories
ORDER BY id DESC
`)
defer rows.Close()
for rows.Next() {
var id int
var content, mtype string
var weight, confidence float64
var accessCount int
var createdAt string
rows.Scan(&id, &content, &mtype, &weight, &confidence, &accessCount, &createdAt)
fmt.Printf(" [#%d %s] weight=%.2f conf=%.2f access=%d | %s\n", id, mtype, weight, confidence, accessCount, content)
}
// memory_usage_log 详情
fmt.Println("\n--- 记忆使用日志 ---")
rows2, _ := db.Query(`
SELECT memory_id, substr(query, 1, 30), was_referenced, used_at
FROM memory_usage_log
ORDER BY id DESC
`)
defer rows2.Close()
count := 0
for rows2.Next() {
var memID int
var query string
var referenced int
var usedAt string
rows2.Scan(&memID, &query, &referenced, &usedAt)
fmt.Printf(" memory_id=%d query='%s' referenced=%v at=%s\n", memID, query, referenced == 1, usedAt)
count++
}
if count == 0 {
fmt.Println(" (空)")
}
// 子Agent统计
fmt.Println("\n--- 子Agent消息统计 ---")
rows3, _ := db.Query(`SELECT agent_name, COUNT(*) FROM subagent_messages GROUP BY agent_name`)
defer rows3.Close()
for rows3.Next() {
var agent string
var cnt int
rows3.Scan(&agent, &cnt)
fmt.Printf(" %s: %d\n", agent, cnt)
}
// 对话缓冲
fmt.Println("\n--- 对话缓冲 ---")
var bufCount int
db.QueryRow("SELECT COUNT(*) FROM dialogue_buffer").Scan(&bufCount)
fmt.Printf(" 当前缓冲条数: %d\n", bufCount)
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max] + "..."
}