orca.ai/test_memory_system.go
2026-05-12 00:09:01 +08:00

248 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"database/sql"
"fmt"
"os"
"time"
"github.com/orca/orca/pkg/bus"
"github.com/orca/orca/pkg/kernel"
_ "modernc.org/sqlite"
_ "modernc.org/sqlite/vec"
)
func main() {
fmt.Println("=== 子Agent调用 + 记忆系统测试 ===")
k := kernel.New()
if err := k.Start(); err != nil {
fmt.Printf("启动失败: %v\n", err)
os.Exit(1)
}
defer k.Stop()
orch := k.Orchestrator()
mm := k.MemoryManager()
fmt.Println("\n1. 发送普通消息给 LLM...")
resp1, err := k.SendMessage("user", "llm", "你好我叫张三我喜欢用Go语言写后端服务")
if err != nil {
fmt.Printf("发送失败: %v\n", err)
os.Exit(1)
}
fmt.Printf("回复: %s\n", truncate(resp1, 200))
fmt.Println("\n2. 发送第二条消息给 LLM...")
resp2, err := k.SendMessage("user", "llm", "你觉得Go和Python哪个更适合做Web后端")
if err != nil {
fmt.Printf("发送失败: %v\n", err)
os.Exit(1)
}
fmt.Printf("回复: %s\n", truncate(resp2, 200))
fmt.Println("\n3. 直接调用 coder 子Agent...")
msg := bus.Message{
Type: bus.MsgTypeTaskRequest,
From: "user",
To: "coder",
Content: "请写一个快速排序算法用Go语言实现",
}
resp3, err := orch.Process(context.Background(), msg)
if err != nil {
fmt.Printf("coder 调用失败: %v\n", err)
} else {
fmt.Printf("coder 回复: %s\n", truncate(fmt.Sprintf("%v", resp3.Content), 300))
}
fmt.Println("\n4. 直接调用 reviewer 子Agent...")
msg2 := bus.Message{
Type: bus.MsgTypeTaskRequest,
From: "user",
To: "reviewer",
Content: "请审查以下代码的质量func main() { fmt.Println(\"hello\") }",
}
resp4, err := orch.Process(context.Background(), msg2)
if err != nil {
fmt.Printf("reviewer 调用失败: %v\n", err)
} else {
fmt.Printf("reviewer 回复: %s\n", truncate(fmt.Sprintf("%v", resp4.Content), 300))
}
fmt.Println("\n5. 发送第三条消息继续对话积累dialogue_buffer...")
resp5, err := k.SendMessage("user", "llm", "我在一家电商公司做后端开发平时用Go处理高并发订单系统")
if err != nil {
fmt.Printf("发送失败: %v\n", err)
} else {
fmt.Printf("回复: %s\n", truncate(resp5, 200))
}
fmt.Println("\n6. 发送第四条消息(继续对话,触发长期记忆提取阈值)...")
resp6, err := k.SendMessage("user", "llm", "我希望回答简洁一些,直接给代码示例,不要太多解释")
if err != nil {
fmt.Printf("发送失败: %v\n", err)
} else {
fmt.Printf("回复: %s\n", truncate(resp6, 200))
}
fmt.Println("\n7. 发送第五条消息超过5条阈值应触发自动提取...")
resp7, err := k.SendMessage("user", "llm", "最近在做一个库存管理模块用Redis做缓存MySQL做持久化")
if err != nil {
fmt.Printf("发送失败: %v\n", err)
} else {
fmt.Printf("回复: %s\n", truncate(resp7, 200))
}
fmt.Println("\n8. 手动维护短期记忆...")
if mm != nil {
mm.AddShortTermMemory("default", "用户张三喜欢Go语言")
mm.AddShortTermMemory("default", "用户询问过Web后端技术选型")
fmt.Println("短期记忆已添加")
}
fmt.Println("\n9. 手动添加长期记忆...")
if mm != nil {
mm.AddLongTermMemory("用户偏好Go语言", "preference")
mm.AddLongTermMemory("用户关注后端开发", "fact")
fmt.Println("长期记忆已添加")
}
fmt.Println("\n10. 等待异步长期记忆提取 + embedding 处理 (10秒)...")
time.Sleep(10 * time.Second)
fmt.Println("\n11. 查询数据库...")
checkDatabase()
}
func checkDatabase() {
dbPath := os.ExpandEnv("$HOME/.orca/sessions/orcasession.db")
db, err := sql.Open("sqlite", dbPath)
if err != nil {
fmt.Printf("打开数据库失败: %v\n", err)
return
}
defer db.Close()
fmt.Println("\n--- main_messages 统计 ---")
var count int
db.QueryRow("SELECT COUNT(*) FROM main_messages").Scan(&count)
fmt.Printf("总消息数: %d\n", count)
if count > 0 {
fmt.Println(" 最近5条消息:")
rows, _ := db.Query(`
SELECT role, substr(content, 1, 60), timestamp
FROM main_messages
ORDER BY id DESC LIMIT 5
`)
defer rows.Close()
for rows.Next() {
var role, content, ts string
rows.Scan(&role, &content, &ts)
fmt.Printf(" [%s] %s... (%s)\n", role, content, ts)
}
}
fmt.Println("\n--- subagent_messages 统计 ---")
var subCount int
db.QueryRow("SELECT COUNT(*) FROM subagent_messages").Scan(&subCount)
fmt.Printf("子Agent消息数: %d\n", subCount)
if subCount > 0 {
fmt.Println(" 子Agent消息明细:")
rows, _ := db.Query(`
SELECT agent_name, role, substr(content, 1, 50), parent_session_id, session_id
FROM subagent_messages
ORDER BY id DESC LIMIT 10
`)
defer rows.Close()
for rows.Next() {
var agent, role, content, parentSID, sid string
rows.Scan(&agent, &role, &content, &parentSID, &sid)
fmt.Printf(" [%s/%s] %s... (parent=%s, sid=%s)\n", agent, role, content, parentSID, sid)
}
}
fmt.Println("\n--- 向量表统计 ---")
var vecCount int
err = db.QueryRow("SELECT COUNT(*) FROM vec_main_messages").Scan(&vecCount)
if err != nil {
fmt.Printf("向量表查询失败: %v\n", err)
} else {
fmt.Printf("向量数量: %d\n", vecCount)
}
if vecCount > 0 {
fmt.Println(" 最近5条向量:")
rows, _ := db.Query(`
SELECT v.msg_id, m.role, substr(m.content, 1, 40)
FROM vec_main_messages v
JOIN main_messages m ON v.msg_id = m.id
LIMIT 5
`)
defer rows.Close()
for rows.Next() {
var msgID int
var role, content string
rows.Scan(&msgID, &role, &content)
fmt.Printf(" msg_id=%d role=%s content=%s\n", msgID, role, content)
}
}
fmt.Println("\n--- 短期记忆 ---")
rows, _ := db.Query("SELECT substr(content, 1, 60) FROM short_term_memories LIMIT 5")
defer rows.Close()
stmCount := 0
for rows.Next() {
var content string
rows.Scan(&content)
fmt.Printf(" %s\n", content)
stmCount++
}
if stmCount == 0 {
fmt.Println(" (空)")
}
fmt.Println("\n--- 长期记忆 ---")
rows2, _ := db.Query("SELECT substr(content, 1, 60), memory_type, weight FROM long_term_memories LIMIT 5")
defer rows2.Close()
ltmCount := 0
for rows2.Next() {
var content, mtype string
var weight float64
rows2.Scan(&content, &mtype, &weight)
fmt.Printf(" [%s|weight=%.2f] %s\n", mtype, weight, content)
ltmCount++
}
if ltmCount == 0 {
fmt.Println(" (空)")
}
fmt.Println("\n--- dialogue_buffer ---")
var bufCount int
db.QueryRow("SELECT COUNT(*) FROM dialogue_buffer").Scan(&bufCount)
fmt.Printf("对话缓冲条数: %d\n", bufCount)
fmt.Println("\n--- 按 agent_name 统计子Agent消息 ---")
rows3, _ := db.Query(`
SELECT agent_name, COUNT(*) as cnt
FROM subagent_messages
GROUP BY agent_name
`)
defer rows3.Close()
for rows3.Next() {
var agent string
var cnt int
rows3.Scan(&agent, &cnt)
fmt.Printf(" %s: %d 条消息\n", agent, cnt)
}
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max] + "..."
}