228 lines
6.7 KiB
Go
228 lines
6.7 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"os"
|
||
"time"
|
||
|
||
"github.com/orca/orca/pkg/bus"
|
||
"github.com/orca/orca/pkg/kernel"
|
||
"github.com/orca/orca/pkg/actor"
|
||
"github.com/orca/orca/pkg/session"
|
||
_ "modernc.org/sqlite"
|
||
_ "modernc.org/sqlite/vec"
|
||
)
|
||
|
||
func main() {
|
||
fmt.Println("=== 记忆系统综合测试 ===")
|
||
fmt.Println("测试内容:记忆检索、命中率、token节省、日常使用")
|
||
|
||
k := kernel.New()
|
||
if err := k.Start(); err != nil {
|
||
fmt.Printf("启动失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
defer k.Stop()
|
||
|
||
orch := k.Orchestrator()
|
||
mm := k.MemoryManager()
|
||
|
||
// 阶段1: 建立用户画像(触发长期记忆提取)
|
||
fmt.Println("\n--- 阶段1: 建立用户画像 ---")
|
||
sendMessage(k, "你好,我叫李四,我在金融科技公司做架构师")
|
||
sendMessage(k, "我主要用Java和Kotlin,最近在研究微服务拆分")
|
||
sendMessage(k, "我喜欢详细的技术解释,带架构图最好")
|
||
sendMessage(k, "现在负责支付系统的重构,从单体迁移到微服务")
|
||
sendMessage(k, "团队有10个人,前端用React,后端用Spring Boot")
|
||
|
||
// 等待长期记忆提取
|
||
fmt.Println("\n等待长期记忆提取 (15秒)...")
|
||
time.Sleep(15 * time.Second)
|
||
|
||
// 阶段2: 子Agent调用(测试隔离)
|
||
fmt.Println("\n--- 阶段2: 子Agent调用测试 ---")
|
||
callSubAgent(orch, "coder", "写一个JWT认证的工具类,Java实现")
|
||
callSubAgent(orch, "reviewer", "审查代码:public class Auth { public String token; }")
|
||
|
||
// 阶段3: 查询记忆(测试检索命中率)
|
||
fmt.Println("\n--- 阶段3: 记忆检索测试 ---")
|
||
fmt.Println("\n查询1: 询问技术偏好(应命中长期记忆)")
|
||
sendMessage(k, "你觉得Java和Go哪个更适合做支付系统?")
|
||
|
||
fmt.Println("\n查询2: 询问团队信息(应命中长期记忆)")
|
||
sendMessage(k, "我们团队前端用什么框架比较好?")
|
||
|
||
fmt.Println("\n查询3: 询问个人背景(应命中长期记忆)")
|
||
sendMessage(k, "你能根据我的背景给些微服务拆分的建议吗?")
|
||
|
||
fmt.Println("\n查询4: 无关查询(测试未命中情况)")
|
||
sendMessage(k, "今天天气怎么样?")
|
||
|
||
// 阶段4: 等待并检查统计
|
||
fmt.Println("\n--- 阶段4: 等待并收集统计 (5秒) ---")
|
||
time.Sleep(5 * time.Second)
|
||
|
||
// 阶段5: 详细统计
|
||
fmt.Println("\n--- 阶段5: 详细统计 ---")
|
||
printDetailedStats(mm)
|
||
checkDatabase()
|
||
}
|
||
|
||
func sendMessage(k *kernel.Kernel, content string) string {
|
||
resp, err := k.SendMessage("user", "llm", content)
|
||
if err != nil {
|
||
fmt.Printf(" 发送失败: %v\n", err)
|
||
return ""
|
||
}
|
||
fmt.Printf(" 回复: %s\n", truncate(resp, 150))
|
||
return resp
|
||
}
|
||
|
||
func callSubAgent(orch *actor.Orchestrator, agentName, task string) {
|
||
msg := bus.Message{
|
||
Type: bus.MsgTypeTaskRequest,
|
||
From: "user",
|
||
To: agentName,
|
||
Content: task,
|
||
}
|
||
resp, err := orch.Process(context.Background(), msg)
|
||
if err != nil {
|
||
fmt.Printf(" %s 调用失败: %v\n", agentName, err)
|
||
} else {
|
||
fmt.Printf(" %s 回复: %s\n", agentName, truncate(fmt.Sprintf("%v", resp.Content), 150))
|
||
}
|
||
}
|
||
|
||
func printDetailedStats(mm *session.MemoryManager) {
|
||
if mm == nil {
|
||
fmt.Println("MemoryManager 未初始化")
|
||
return
|
||
}
|
||
|
||
// Embedding缓存统计
|
||
cacheSize, cacheHits, cacheMisses := mm.CacheStats()
|
||
total := cacheHits + cacheMisses
|
||
hitRate := float64(0)
|
||
if total > 0 {
|
||
hitRate = float64(cacheHits) * 100 / float64(total)
|
||
}
|
||
fmt.Printf("\n[Embedding缓存统计]\n")
|
||
fmt.Printf(" 缓存大小: %d\n", cacheSize)
|
||
fmt.Printf(" 命中次数: %d\n", cacheHits)
|
||
fmt.Printf(" 未命中次数: %d\n", cacheMisses)
|
||
fmt.Printf(" 命中率: %.1f%%\n", hitRate)
|
||
|
||
// 记忆上下文统计(模拟查询)
|
||
queries := []string{
|
||
"Java技术栈",
|
||
"团队规模",
|
||
"微服务拆分",
|
||
"前端框架",
|
||
"天气",
|
||
}
|
||
|
||
fmt.Printf("\n[记忆检索测试 - %d个查询]\n", len(queries))
|
||
totalMemories := 0
|
||
referencedMemories := 0
|
||
for _, q := range queries {
|
||
ctx, stats := mm.BuildMemoryContextWithStats("default", q)
|
||
hasMemory := ctx != ""
|
||
fmt.Printf(" 查询 '%s': 短期=%d, 长期=%d, tokens=%d, 有记忆=%v\n",
|
||
q, stats.ShortTermCount, stats.LongTermCount, stats.TotalTokens, hasMemory)
|
||
if hasMemory {
|
||
totalMemories++
|
||
if stats.LongTermCount > 0 {
|
||
referencedMemories++
|
||
}
|
||
}
|
||
}
|
||
fmt.Printf(" 记忆命中率: %d/%d (%.0f%%)\n", referencedMemories, len(queries), float64(referencedMemories)*100/float64(len(queries)))
|
||
}
|
||
|
||
func checkDatabase() {
|
||
dbPath := os.ExpandEnv("$HOME/.orca/sessions/orcasession.db")
|
||
db, err := sql.Open("sqlite", dbPath)
|
||
if err != nil {
|
||
fmt.Printf("打开数据库失败: %v\n", err)
|
||
return
|
||
}
|
||
defer db.Close()
|
||
|
||
fmt.Println("\n=== 数据库统计 ===")
|
||
|
||
// 表统计
|
||
fmt.Println("\n--- 各表记录数 ---")
|
||
tables := []string{"main_messages", "short_term_memories", "long_term_memories", "dialogue_buffer", "subagent_messages", "memory_usage_log"}
|
||
for _, table := range tables {
|
||
var count int
|
||
db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count)
|
||
fmt.Printf(" %s: %d\n", table, count)
|
||
}
|
||
|
||
// 长期记忆详情
|
||
fmt.Println("\n--- 长期记忆详情 ---")
|
||
rows, _ := db.Query(`
|
||
SELECT id, substr(content, 1, 50), memory_type, weight, confidence, access_count, created_at
|
||
FROM long_term_memories
|
||
ORDER BY id DESC
|
||
`)
|
||
defer rows.Close()
|
||
for rows.Next() {
|
||
var id int
|
||
var content, mtype string
|
||
var weight, confidence float64
|
||
var accessCount int
|
||
var createdAt string
|
||
rows.Scan(&id, &content, &mtype, &weight, &confidence, &accessCount, &createdAt)
|
||
fmt.Printf(" [#%d %s] weight=%.2f conf=%.2f access=%d | %s\n", id, mtype, weight, confidence, accessCount, content)
|
||
}
|
||
|
||
// memory_usage_log 详情
|
||
fmt.Println("\n--- 记忆使用日志 ---")
|
||
rows2, _ := db.Query(`
|
||
SELECT memory_id, substr(query, 1, 30), was_referenced, used_at
|
||
FROM memory_usage_log
|
||
ORDER BY id DESC
|
||
`)
|
||
defer rows2.Close()
|
||
count := 0
|
||
for rows2.Next() {
|
||
var memID int
|
||
var query string
|
||
var referenced int
|
||
var usedAt string
|
||
rows2.Scan(&memID, &query, &referenced, &usedAt)
|
||
fmt.Printf(" memory_id=%d query='%s' referenced=%v at=%s\n", memID, query, referenced == 1, usedAt)
|
||
count++
|
||
}
|
||
if count == 0 {
|
||
fmt.Println(" (空)")
|
||
}
|
||
|
||
// 子Agent统计
|
||
fmt.Println("\n--- 子Agent消息统计 ---")
|
||
rows3, _ := db.Query(`SELECT agent_name, COUNT(*) FROM subagent_messages GROUP BY agent_name`)
|
||
defer rows3.Close()
|
||
for rows3.Next() {
|
||
var agent string
|
||
var cnt int
|
||
rows3.Scan(&agent, &cnt)
|
||
fmt.Printf(" %s: %d\n", agent, cnt)
|
||
}
|
||
|
||
// 对话缓冲
|
||
fmt.Println("\n--- 对话缓冲 ---")
|
||
var bufCount int
|
||
db.QueryRow("SELECT COUNT(*) FROM dialogue_buffer").Scan(&bufCount)
|
||
fmt.Printf(" 当前缓冲条数: %d\n", bufCount)
|
||
}
|
||
|
||
func truncate(s string, max int) string {
|
||
if len(s) <= max {
|
||
return s
|
||
}
|
||
return s[:max] + "..."
|
||
}
|