165 lines
4.3 KiB
Go
165 lines
4.3 KiB
Go
package actor
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
|
||
"github.com/orca/orca/pkg/bus"
|
||
"github.com/orca/orca/pkg/llm"
|
||
)
|
||
|
||
type MemoryExtractorAgent struct {
|
||
*SubAgent
|
||
config ExtractConfig
|
||
}
|
||
|
||
type ExtractConfig struct {
|
||
BatchSize int
|
||
MaxFacts int
|
||
MinConfidence float64
|
||
AutoTag bool
|
||
}
|
||
|
||
type Dialogue struct {
|
||
UserQuery string
|
||
AssistantResponse string
|
||
}
|
||
|
||
type Fact struct {
|
||
Content string `json:"content"`
|
||
Type string `json:"type"`
|
||
Tags []string `json:"tags"`
|
||
Confidence float64 `json:"confidence"`
|
||
Replace *string `json:"replace"`
|
||
}
|
||
|
||
type ExtractResult struct {
|
||
Facts []Fact `json:"facts"`
|
||
}
|
||
|
||
const DefaultMemoryExtractorPrompt = `# Memory Extractor Agent
|
||
|
||
你是一个专门从对话中提取用户信息的 Agent。你的工作是将非结构化的对话转化为结构化的长期记忆。
|
||
|
||
## 任务
|
||
|
||
分析给定的对话记录,提取以下类型的信息:
|
||
|
||
1. **事实 (fact)**:客观信息
|
||
- 工作:公司、职位、技术栈、行业
|
||
- 技术:擅长语言、框架偏好、架构经验
|
||
- 个人:教育背景、所在城市(仅用户明确提及)
|
||
|
||
2. **偏好 (preference)**:主观倾向
|
||
- 回答风格:简洁/详细/代码示例/架构图
|
||
- 技术偏好:语言、数据库、部署方式
|
||
- 沟通偏好:正式/ casual
|
||
|
||
3. **项目 (project)**:当前工作
|
||
- 项目名称、技术方案、当前阶段、遇到的挑战
|
||
|
||
## 输出格式
|
||
|
||
只输出 JSON,不要任何解释:
|
||
|
||
` + "```json" + `
|
||
{
|
||
"facts": [
|
||
{
|
||
"content": "用户在电商公司担任后端工程师",
|
||
"type": "fact",
|
||
"tags": ["工作", "后端", "电商"],
|
||
"confidence": 0.95,
|
||
"replace": null
|
||
},
|
||
{
|
||
"content": "用户偏好简洁的技术回答,不要过多解释",
|
||
"type": "preference",
|
||
"tags": ["沟通风格", "偏好"],
|
||
"confidence": 0.85,
|
||
"replace": "用户喜欢详细的回答"
|
||
}
|
||
]
|
||
}
|
||
` + "```" + `
|
||
|
||
## 规则
|
||
|
||
- confidence < 0.6 的事实不输出
|
||
- 如果新事实与旧事实冲突:
|
||
- 在 replace 字段填入被替换的旧事实 content
|
||
- 只替换同一 type + 同一 tags 的事实
|
||
- 不猜测、不推断,只提取用户明确表达的信息
|
||
- 标签从预设列表选择:工作、技术、偏好、项目、沟通风格、行业`
|
||
|
||
func loadAgentPrompt(agentName string) string {
|
||
path := filepath.Join(os.Getenv("HOME"), ".orca", "agents", "_builtin", agentName+".md")
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return DefaultMemoryExtractorPrompt
|
||
}
|
||
return string(data)
|
||
}
|
||
|
||
func NewMemoryExtractorAgent(id string, llmBackend llm.LLM, cfg ExtractConfig) *MemoryExtractorAgent {
|
||
prompt := loadAgentPrompt("memory_extractor")
|
||
|
||
sa := NewSubAgent(id, llmBackend,
|
||
WithSubAgentRole("memory_extractor"),
|
||
WithSubAgentSystemPrompt(prompt),
|
||
)
|
||
|
||
return &MemoryExtractorAgent{
|
||
SubAgent: sa,
|
||
config: cfg,
|
||
}
|
||
}
|
||
|
||
func (mea *MemoryExtractorAgent) ExtractFacts(dialogues []Dialogue) ([]Fact, error) {
|
||
var sb strings.Builder
|
||
sb.WriteString("请分析以下对话记录,提取用户的关键信息:\n\n")
|
||
for i, d := range dialogues {
|
||
sb.WriteString(fmt.Sprintf("--- 对话 %d ---\n", i+1))
|
||
sb.WriteString(fmt.Sprintf("用户:%s\n", d.UserQuery))
|
||
sb.WriteString(fmt.Sprintf("助手:%s\n\n", d.AssistantResponse))
|
||
}
|
||
|
||
msg := bus.Message{Type: bus.MsgTypeTaskRequest, Content: sb.String()}
|
||
resp, err := mea.Process(context.Background(), msg)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract facts failed: %w", err)
|
||
}
|
||
|
||
return parseFactJSON(resp.Content.(string))
|
||
}
|
||
|
||
func parseFactJSON(content string) ([]Fact, error) {
|
||
content = extractJSONFromMarkdown(content)
|
||
|
||
var result ExtractResult
|
||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||
return nil, fmt.Errorf("parse fact json failed: %w", err)
|
||
}
|
||
|
||
return result.Facts, nil
|
||
}
|
||
|
||
func extractJSONFromMarkdown(content string) string {
|
||
if idx := strings.Index(content, "```json"); idx != -1 {
|
||
start := idx + 7
|
||
if end := strings.Index(content[start:], "```"); end != -1 {
|
||
return strings.TrimSpace(content[start : start+end])
|
||
}
|
||
}
|
||
if idx := strings.Index(content, "```"); idx != -1 {
|
||
start := idx + 3
|
||
if end := strings.Index(content[start:], "```"); end != -1 {
|
||
return strings.TrimSpace(content[start : start+end])
|
||
}
|
||
}
|
||
return strings.TrimSpace(content)
|
||
} |