orca.ai/pkg/actor/memory_extractor.go
2026-05-12 00:09:01 +08:00

165 lines
4.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package actor
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/orca/orca/pkg/bus"
"github.com/orca/orca/pkg/llm"
)
type MemoryExtractorAgent struct {
*SubAgent
config ExtractConfig
}
type ExtractConfig struct {
BatchSize int
MaxFacts int
MinConfidence float64
AutoTag bool
}
type Dialogue struct {
UserQuery string
AssistantResponse string
}
type Fact struct {
Content string `json:"content"`
Type string `json:"type"`
Tags []string `json:"tags"`
Confidence float64 `json:"confidence"`
Replace *string `json:"replace"`
}
type ExtractResult struct {
Facts []Fact `json:"facts"`
}
const DefaultMemoryExtractorPrompt = `# Memory Extractor Agent
你是一个专门从对话中提取用户信息的 Agent。你的工作是将非结构化的对话转化为结构化的长期记忆。
## 任务
分析给定的对话记录,提取以下类型的信息:
1. **事实 (fact)**:客观信息
- 工作:公司、职位、技术栈、行业
- 技术:擅长语言、框架偏好、架构经验
- 个人:教育背景、所在城市(仅用户明确提及)
2. **偏好 (preference)**:主观倾向
- 回答风格:简洁/详细/代码示例/架构图
- 技术偏好:语言、数据库、部署方式
- 沟通偏好:正式/ casual
3. **项目 (project)**:当前工作
- 项目名称、技术方案、当前阶段、遇到的挑战
## 输出格式
只输出 JSON不要任何解释
` + "```json" + `
{
"facts": [
{
"content": "用户在电商公司担任后端工程师",
"type": "fact",
"tags": ["工作", "后端", "电商"],
"confidence": 0.95,
"replace": null
},
{
"content": "用户偏好简洁的技术回答,不要过多解释",
"type": "preference",
"tags": ["沟通风格", "偏好"],
"confidence": 0.85,
"replace": "用户喜欢详细的回答"
}
]
}
` + "```" + `
## 规则
- confidence < 0.6 的事实不输出
- 如果新事实与旧事实冲突:
- 在 replace 字段填入被替换的旧事实 content
- 只替换同一 type + 同一 tags 的事实
- 不猜测、不推断,只提取用户明确表达的信息
- 标签从预设列表选择:工作、技术、偏好、项目、沟通风格、行业`
func loadAgentPrompt(agentName string) string {
path := filepath.Join(os.Getenv("HOME"), ".orca", "agents", "_builtin", agentName+".md")
data, err := os.ReadFile(path)
if err != nil {
return DefaultMemoryExtractorPrompt
}
return string(data)
}
func NewMemoryExtractorAgent(id string, llmBackend llm.LLM, cfg ExtractConfig) *MemoryExtractorAgent {
prompt := loadAgentPrompt("memory_extractor")
sa := NewSubAgent(id, llmBackend,
WithSubAgentRole("memory_extractor"),
WithSubAgentSystemPrompt(prompt),
)
return &MemoryExtractorAgent{
SubAgent: sa,
config: cfg,
}
}
func (mea *MemoryExtractorAgent) ExtractFacts(dialogues []Dialogue) ([]Fact, error) {
var sb strings.Builder
sb.WriteString("请分析以下对话记录,提取用户的关键信息:\n\n")
for i, d := range dialogues {
sb.WriteString(fmt.Sprintf("--- 对话 %d ---\n", i+1))
sb.WriteString(fmt.Sprintf("用户:%s\n", d.UserQuery))
sb.WriteString(fmt.Sprintf("助手:%s\n\n", d.AssistantResponse))
}
msg := bus.Message{Type: bus.MsgTypeTaskRequest, Content: sb.String()}
resp, err := mea.Process(context.Background(), msg)
if err != nil {
return nil, fmt.Errorf("extract facts failed: %w", err)
}
return parseFactJSON(resp.Content.(string))
}
func parseFactJSON(content string) ([]Fact, error) {
content = extractJSONFromMarkdown(content)
var result ExtractResult
if err := json.Unmarshal([]byte(content), &result); err != nil {
return nil, fmt.Errorf("parse fact json failed: %w", err)
}
return result.Facts, nil
}
func extractJSONFromMarkdown(content string) string {
if idx := strings.Index(content, "```json"); idx != -1 {
start := idx + 7
if end := strings.Index(content[start:], "```"); end != -1 {
return strings.TrimSpace(content[start : start+end])
}
}
if idx := strings.Index(content, "```"); idx != -1 {
start := idx + 3
if end := strings.Index(content[start:], "```"); end != -1 {
return strings.TrimSpace(content[start : start+end])
}
}
return strings.TrimSpace(content)
}