248 lines
6.9 KiB
Go
248 lines
6.9 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"os"
|
||
"time"
|
||
|
||
"github.com/orca/orca/pkg/bus"
|
||
"github.com/orca/orca/pkg/kernel"
|
||
_ "modernc.org/sqlite"
|
||
_ "modernc.org/sqlite/vec"
|
||
)
|
||
|
||
func main() {
|
||
fmt.Println("=== 子Agent调用 + 记忆系统测试 ===")
|
||
|
||
k := kernel.New()
|
||
if err := k.Start(); err != nil {
|
||
fmt.Printf("启动失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
defer k.Stop()
|
||
|
||
orch := k.Orchestrator()
|
||
mm := k.MemoryManager()
|
||
|
||
fmt.Println("\n1. 发送普通消息给 LLM...")
|
||
resp1, err := k.SendMessage("user", "llm", "你好,我叫张三,我喜欢用Go语言写后端服务")
|
||
if err != nil {
|
||
fmt.Printf("发送失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Printf("回复: %s\n", truncate(resp1, 200))
|
||
|
||
fmt.Println("\n2. 发送第二条消息给 LLM...")
|
||
resp2, err := k.SendMessage("user", "llm", "你觉得Go和Python哪个更适合做Web后端?")
|
||
if err != nil {
|
||
fmt.Printf("发送失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Printf("回复: %s\n", truncate(resp2, 200))
|
||
|
||
fmt.Println("\n3. 直接调用 coder 子Agent...")
|
||
msg := bus.Message{
|
||
Type: bus.MsgTypeTaskRequest,
|
||
From: "user",
|
||
To: "coder",
|
||
Content: "请写一个快速排序算法,用Go语言实现",
|
||
}
|
||
resp3, err := orch.Process(context.Background(), msg)
|
||
if err != nil {
|
||
fmt.Printf("coder 调用失败: %v\n", err)
|
||
} else {
|
||
fmt.Printf("coder 回复: %s\n", truncate(fmt.Sprintf("%v", resp3.Content), 300))
|
||
}
|
||
|
||
fmt.Println("\n4. 直接调用 reviewer 子Agent...")
|
||
msg2 := bus.Message{
|
||
Type: bus.MsgTypeTaskRequest,
|
||
From: "user",
|
||
To: "reviewer",
|
||
Content: "请审查以下代码的质量:func main() { fmt.Println(\"hello\") }",
|
||
}
|
||
resp4, err := orch.Process(context.Background(), msg2)
|
||
if err != nil {
|
||
fmt.Printf("reviewer 调用失败: %v\n", err)
|
||
} else {
|
||
fmt.Printf("reviewer 回复: %s\n", truncate(fmt.Sprintf("%v", resp4.Content), 300))
|
||
}
|
||
|
||
fmt.Println("\n5. 发送第三条消息(继续对话,积累dialogue_buffer)...")
|
||
resp5, err := k.SendMessage("user", "llm", "我在一家电商公司做后端开发,平时用Go处理高并发订单系统")
|
||
if err != nil {
|
||
fmt.Printf("发送失败: %v\n", err)
|
||
} else {
|
||
fmt.Printf("回复: %s\n", truncate(resp5, 200))
|
||
}
|
||
|
||
fmt.Println("\n6. 发送第四条消息(继续对话,触发长期记忆提取阈值)...")
|
||
resp6, err := k.SendMessage("user", "llm", "我希望回答简洁一些,直接给代码示例,不要太多解释")
|
||
if err != nil {
|
||
fmt.Printf("发送失败: %v\n", err)
|
||
} else {
|
||
fmt.Printf("回复: %s\n", truncate(resp6, 200))
|
||
}
|
||
|
||
fmt.Println("\n7. 发送第五条消息(超过5条阈值,应触发自动提取)...")
|
||
resp7, err := k.SendMessage("user", "llm", "最近在做一个库存管理模块,用Redis做缓存,MySQL做持久化")
|
||
if err != nil {
|
||
fmt.Printf("发送失败: %v\n", err)
|
||
} else {
|
||
fmt.Printf("回复: %s\n", truncate(resp7, 200))
|
||
}
|
||
|
||
fmt.Println("\n8. 手动维护短期记忆...")
|
||
if mm != nil {
|
||
mm.AddShortTermMemory("default", "用户张三喜欢Go语言")
|
||
mm.AddShortTermMemory("default", "用户询问过Web后端技术选型")
|
||
fmt.Println("短期记忆已添加")
|
||
}
|
||
|
||
fmt.Println("\n9. 手动添加长期记忆...")
|
||
if mm != nil {
|
||
mm.AddLongTermMemory("用户偏好Go语言", "preference")
|
||
mm.AddLongTermMemory("用户关注后端开发", "fact")
|
||
fmt.Println("长期记忆已添加")
|
||
}
|
||
|
||
fmt.Println("\n10. 等待异步长期记忆提取 + embedding 处理 (10秒)...")
|
||
time.Sleep(10 * time.Second)
|
||
|
||
fmt.Println("\n11. 查询数据库...")
|
||
checkDatabase()
|
||
}
|
||
|
||
func checkDatabase() {
|
||
dbPath := os.ExpandEnv("$HOME/.orca/sessions/orcasession.db")
|
||
db, err := sql.Open("sqlite", dbPath)
|
||
if err != nil {
|
||
fmt.Printf("打开数据库失败: %v\n", err)
|
||
return
|
||
}
|
||
defer db.Close()
|
||
|
||
fmt.Println("\n--- main_messages 统计 ---")
|
||
var count int
|
||
db.QueryRow("SELECT COUNT(*) FROM main_messages").Scan(&count)
|
||
fmt.Printf("总消息数: %d\n", count)
|
||
|
||
if count > 0 {
|
||
fmt.Println(" 最近5条消息:")
|
||
rows, _ := db.Query(`
|
||
SELECT role, substr(content, 1, 60), timestamp
|
||
FROM main_messages
|
||
ORDER BY id DESC LIMIT 5
|
||
`)
|
||
defer rows.Close()
|
||
for rows.Next() {
|
||
var role, content, ts string
|
||
rows.Scan(&role, &content, &ts)
|
||
fmt.Printf(" [%s] %s... (%s)\n", role, content, ts)
|
||
}
|
||
}
|
||
|
||
fmt.Println("\n--- subagent_messages 统计 ---")
|
||
var subCount int
|
||
db.QueryRow("SELECT COUNT(*) FROM subagent_messages").Scan(&subCount)
|
||
fmt.Printf("子Agent消息数: %d\n", subCount)
|
||
|
||
if subCount > 0 {
|
||
fmt.Println(" 子Agent消息明细:")
|
||
rows, _ := db.Query(`
|
||
SELECT agent_name, role, substr(content, 1, 50), parent_session_id, session_id
|
||
FROM subagent_messages
|
||
ORDER BY id DESC LIMIT 10
|
||
`)
|
||
defer rows.Close()
|
||
for rows.Next() {
|
||
var agent, role, content, parentSID, sid string
|
||
rows.Scan(&agent, &role, &content, &parentSID, &sid)
|
||
fmt.Printf(" [%s/%s] %s... (parent=%s, sid=%s)\n", agent, role, content, parentSID, sid)
|
||
}
|
||
}
|
||
|
||
fmt.Println("\n--- 向量表统计 ---")
|
||
var vecCount int
|
||
err = db.QueryRow("SELECT COUNT(*) FROM vec_main_messages").Scan(&vecCount)
|
||
if err != nil {
|
||
fmt.Printf("向量表查询失败: %v\n", err)
|
||
} else {
|
||
fmt.Printf("向量数量: %d\n", vecCount)
|
||
}
|
||
|
||
if vecCount > 0 {
|
||
fmt.Println(" 最近5条向量:")
|
||
rows, _ := db.Query(`
|
||
SELECT v.msg_id, m.role, substr(m.content, 1, 40)
|
||
FROM vec_main_messages v
|
||
JOIN main_messages m ON v.msg_id = m.id
|
||
LIMIT 5
|
||
`)
|
||
defer rows.Close()
|
||
for rows.Next() {
|
||
var msgID int
|
||
var role, content string
|
||
rows.Scan(&msgID, &role, &content)
|
||
fmt.Printf(" msg_id=%d role=%s content=%s\n", msgID, role, content)
|
||
}
|
||
}
|
||
|
||
fmt.Println("\n--- 短期记忆 ---")
|
||
rows, _ := db.Query("SELECT substr(content, 1, 60) FROM short_term_memories LIMIT 5")
|
||
defer rows.Close()
|
||
stmCount := 0
|
||
for rows.Next() {
|
||
var content string
|
||
rows.Scan(&content)
|
||
fmt.Printf(" %s\n", content)
|
||
stmCount++
|
||
}
|
||
if stmCount == 0 {
|
||
fmt.Println(" (空)")
|
||
}
|
||
|
||
fmt.Println("\n--- 长期记忆 ---")
|
||
rows2, _ := db.Query("SELECT substr(content, 1, 60), memory_type, weight FROM long_term_memories LIMIT 5")
|
||
defer rows2.Close()
|
||
ltmCount := 0
|
||
for rows2.Next() {
|
||
var content, mtype string
|
||
var weight float64
|
||
rows2.Scan(&content, &mtype, &weight)
|
||
fmt.Printf(" [%s|weight=%.2f] %s\n", mtype, weight, content)
|
||
ltmCount++
|
||
}
|
||
if ltmCount == 0 {
|
||
fmt.Println(" (空)")
|
||
}
|
||
|
||
fmt.Println("\n--- dialogue_buffer ---")
|
||
var bufCount int
|
||
db.QueryRow("SELECT COUNT(*) FROM dialogue_buffer").Scan(&bufCount)
|
||
fmt.Printf("对话缓冲条数: %d\n", bufCount)
|
||
|
||
fmt.Println("\n--- 按 agent_name 统计子Agent消息 ---")
|
||
rows3, _ := db.Query(`
|
||
SELECT agent_name, COUNT(*) as cnt
|
||
FROM subagent_messages
|
||
GROUP BY agent_name
|
||
`)
|
||
defer rows3.Close()
|
||
for rows3.Next() {
|
||
var agent string
|
||
var cnt int
|
||
rows3.Scan(&agent, &cnt)
|
||
fmt.Printf(" %s: %d 条消息\n", agent, cnt)
|
||
}
|
||
}
|
||
|
||
func truncate(s string, max int) string {
|
||
if len(s) <= max {
|
||
return s
|
||
}
|
||
return s[:max] + "..."
|
||
}
|