orca.ai/pkg/actor/llm_agent.go
2026-05-12 00:09:01 +08:00

778 lines
22 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package actor
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/orca/orca/pkg/bus"
"github.com/orca/orca/pkg/llm"
"github.com/orca/orca/pkg/session"
"github.com/orca/orca/pkg/skill"
"github.com/orca/orca/pkg/tool"
)
// LLMAgent implements the Agent interface by integrating an LLM backend
// with the actor system and tool framework.
//
// It receives user messages, retrieves session context, calls the LLM,
// and handles tool call responses by executing tools and feeding results
// back to the LLM for final response generation.
type LLMAgent struct {
*BaseAgent
llm llm.LLM
sessionMgr *session.Manager
sessionID string
toolManager *tool.Manager
skillManager *skill.Manager
toolWorker *ToolWorker
windowSize int
streamWriter io.Writer
systemPrompt string
subAgents map[string]string
memoryManager *session.MemoryManager
}
// LLMAgentOption is a functional option for configuring the LLMAgent.
type LLMAgentOption func(*LLMAgent)
// WithSessionManager sets the session manager for conversation history.
func WithSessionManager(mgr *session.Manager) LLMAgentOption {
return func(a *LLMAgent) {
a.sessionMgr = mgr
}
}
// WithSessionID sets the session ID for conversation persistence.
func WithSessionID(id string) LLMAgentOption {
return func(a *LLMAgent) {
a.sessionID = id
}
}
// WithToolManager sets the tool manager for executing tools.
func WithToolManager(mgr *tool.Manager) LLMAgentOption {
return func(a *LLMAgent) {
a.toolManager = mgr
}
}
// WithToolWorker sets the tool worker for delegated tool execution.
func WithToolWorker(w *ToolWorker) LLMAgentOption {
return func(a *LLMAgent) {
a.toolWorker = w
}
}
// WithWindowSize sets the context window size for session history.
func WithWindowSize(size int) LLMAgentOption {
return func(a *LLMAgent) {
a.windowSize = size
}
}
func WithSkillManager(mgr *skill.Manager) LLMAgentOption {
return func(a *LLMAgent) {
a.skillManager = mgr
}
}
// WithStreamWriter sets the writer for streaming LLM output.
func WithStreamWriter(w io.Writer) LLMAgentOption {
return func(a *LLMAgent) {
a.streamWriter = w
}
}
func WithSystemPrompt(prompt string) LLMAgentOption {
return func(a *LLMAgent) {
a.systemPrompt = prompt
}
}
func WithSubAgents(agents map[string]string) LLMAgentOption {
return func(a *LLMAgent) {
a.subAgents = agents
}
}
func WithMemoryManager(mm *session.MemoryManager) LLMAgentOption {
return func(a *LLMAgent) {
a.memoryManager = mm
}
}
// NewLLMAgent creates a new LLMAgent with the given LLM backend and options.
// The agent is started automatically upon creation.
func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent {
a := &LLMAgent{
BaseAgent: NewBaseAgent(id, "llm_agent"),
llm: backend,
windowSize: 20, // Default context window
}
for _, opt := range opts {
opt(a)
}
a.SetHandler(a.handleMessage)
if err := a.Start(); err != nil {
panic(fmt.Sprintf("llm_agent: failed to start: %v", err))
}
return a
}
// handleMessage routes incoming messages to the appropriate handler.
func (a *LLMAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
switch msg.Type {
case bus.MsgTypeTaskRequest:
return a.handleUserMessage(ctx, msg)
case bus.MsgTypeSystem:
return a.handleSystem(ctx, msg)
default:
return bus.Message{}, fmt.Errorf("llm_agent %s: unsupported message type %s", a.ID(), msg.Type)
}
}
// handleUserMessage processes a user message through the LLM.
//
// Flow:
// 1. Persist the user message to session history
// 2. Retrieve recent conversation context
// 3. Convert to LLM message format
// 4. Call LLM.Chat
// 5. If response has tool calls:
// a. Execute each tool (directly or via ToolWorker)
// b. Add tool results to conversation
// c. Call LLM.Chat again with results
// 6. Persist the assistant response
// 7. Return the final response
func (a *LLMAgent) handleUserMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
content, ok := msg.Content.(string)
if !ok {
return bus.Message{}, fmt.Errorf("llm_agent: expected string content, got %T", msg.Content)
}
// 处理特殊命令
if content == "/context" || content == "/debug" {
a.LogContextDetails(content)
return bus.Message{
ID: msg.ID + "-response",
Type: bus.MsgTypeTaskResponse,
From: a.ID(),
To: msg.From,
Content: "[上下文详情已输出到日志]",
}, nil
}
// Ensure session exists
if a.sessionMgr != nil && a.sessionID != "" {
// Check if session exists; create if not
if _, err := a.sessionMgr.GetSession(a.sessionID); err != nil {
a.sessionMgr.CreateSession(a.sessionID, map[string]string{
"source": "llm_agent",
})
}
// Persist user message
a.sessionMgr.AddMessage(a.sessionID, session.RoleUser, content, nil)
}
if a.memoryManager != nil && a.sessionID != "" {
a.memoryManager.SaveMessage(a.sessionID, session.SessionMessage{
Role: session.RoleUser,
Content: content,
Timestamp: time.Now(),
})
}
llmMessages := a.buildLLMMessages(content)
if a.skillManager != nil {
matchedSkills := a.skillManager.FindSkill(content)
for _, s := range matchedSkills {
if s.Body != "" {
llmMessages = append(llmMessages, llm.Message{
Role: "system",
Content: fmt.Sprintf("以下是你需要遵循的 %s 技能指南:\n\n%s", s.Name, s.Body),
})
}
}
}
// Call LLM (potentially multiple rounds for tool calls)
finalResponse, err := a.chatWithToolLoop(ctx, llmMessages)
if err != nil {
return bus.Message{}, fmt.Errorf("llm_agent: LLM chat failed: %w", err)
}
// Persist assistant response
if a.sessionMgr != nil && a.sessionID != "" {
a.sessionMgr.AddMessage(a.sessionID, session.RoleAssistant, finalResponse, nil)
}
if a.memoryManager != nil && a.sessionID != "" {
a.memoryManager.SaveMessage(a.sessionID, session.SessionMessage{
Role: session.RoleAssistant,
Content: finalResponse,
Timestamp: time.Now(),
})
go a.memoryManager.MaintainSessionMemory(a.sessionID, content, finalResponse)
}
return bus.Message{
ID: msg.ID + "-response",
Type: bus.MsgTypeTaskResponse,
From: a.ID(),
To: msg.From,
Content: finalResponse,
}, nil
}
type contextStats struct {
systemPromptTokens int
toolPromptTokens int
memoryTokens int
historyTokens int
memoryShortTerm int
memoryLongTerm int
historyCount int
}
func (a *LLMAgent) buildLLMMessages(query string) []llm.Message {
return a.buildLLMMessagesWithStats(query, nil)
}
func (a *LLMAgent) buildLLMMessagesWithStats(query string, stats *contextStats) []llm.Message {
messages := make([]llm.Message, 0)
if stats == nil {
stats = &contextStats{}
}
// 1. 用户自定义 system prompt配置式身份描述
if a.systemPrompt != "" {
messages = append(messages, llm.Message{
Role: "system",
Content: a.systemPrompt,
})
stats.systemPromptTokens = estimateTokens(a.systemPrompt)
}
// 2. 运行时工具说明(动态生成)
toolPrompt := a.buildToolPrompt()
if toolPrompt != "" {
messages = append(messages, llm.Message{
Role: "system",
Content: toolPrompt,
})
stats.toolPromptTokens = estimateTokens(toolPrompt)
}
if a.memoryManager != nil && a.sessionID != "" {
if a.memoryManager.ShouldInjectMemory(a.sessionID, query) {
memoryCtx, memStats := a.memoryManager.BuildMemoryContextWithStats(a.sessionID, query)
if memoryCtx != "" {
messages = append(messages, llm.Message{
Role: "system",
Content: memoryCtx,
})
stats.memoryTokens = memStats.TotalTokens
stats.memoryShortTerm = memStats.ShortTermCount
stats.memoryLongTerm = memStats.LongTermCount
}
}
}
if a.sessionMgr != nil && a.sessionID != "" {
sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize)
if err == nil {
stats.historyCount = len(sessionMsgs)
for _, sm := range sessionMsgs {
msg := llm.Message{
Role: string(sm.Role),
Content: sm.Content,
}
if sm.Role == session.RoleTool && sm.Metadata != nil {
msg.ToolCallID = sm.Metadata["tool_call_id"]
}
messages = append(messages, msg)
stats.historyTokens += estimateTokens(sm.Content)
}
}
}
return messages
}
func (a *LLMAgent) LogContextDetails(query string) {
fmt.Println("\n========== 上下文详情 ==========")
if a.memoryManager != nil && a.sessionID != "" {
fmt.Println("\n[记忆内容]")
shortTerm, _ := a.memoryManager.GetShortTermMemory(a.sessionID, query)
if len(shortTerm) > 0 {
fmt.Println(" 短期记忆:")
for i, m := range shortTerm {
fmt.Printf(" [%d] %s\n", i+1, truncateForDisplay(m, 80))
}
}
longTerm, _ := a.memoryManager.GetLongTermMemory(query)
if len(longTerm) > 0 {
fmt.Println(" 长期记忆:")
for i, m := range longTerm {
fmt.Printf(" [%d] %s\n", i+1, truncateForDisplay(m.Content, 80))
}
}
if len(shortTerm) == 0 && len(longTerm) == 0 {
fmt.Println(" (无记忆)")
}
cacheSize, cacheHits, cacheMisses := a.memoryManager.CacheStats()
fmt.Printf("\n[Embedding缓存] 大小=%d, 命中=%d, 未命中=%d, 命中率=%.1f%%\n",
cacheSize, cacheHits, cacheMisses,
float64(cacheHits)*100/float64(cacheHits+cacheMisses+1))
}
if a.sessionMgr != nil && a.sessionID != "" {
fmt.Println("\n[历史对话]")
sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize)
if err == nil && len(sessionMsgs) > 0 {
start := 0
if len(sessionMsgs) > 10 {
start = len(sessionMsgs) - 10
fmt.Printf(" (显示最近 10/%d 条)\n", len(sessionMsgs))
}
for i := start; i < len(sessionMsgs); i++ {
sm := sessionMsgs[i]
role := string(sm.Role)
if role == "" {
role = "unknown"
}
fmt.Printf(" [%s] %s\n", role, truncateForDisplay(sm.Content, 80))
}
} else {
fmt.Println(" (无历史)")
}
}
fmt.Println("================================\n")
}
func truncateForDisplay(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func estimateTokens(text string) int {
return len([]rune(text)) / 4
}
// buildToolPrompt 生成工具说明提示词(不包含身份描述)。
// 将可用工具和调用规则注入给 LLM支持基于提示词的工具调用。
func (a *LLMAgent) buildToolPrompt() string {
var b strings.Builder
if a.toolManager != nil {
tools := a.toolManager.List()
b.WriteString("你可以使用以下工具来完成用户的请求。\n\n")
b.WriteString("可用工具列表:\n")
for _, t := range tools {
b.WriteString(fmt.Sprintf("\n工具名: %s\n", t.Name()))
b.WriteString(fmt.Sprintf("描述: %s\n", t.Description()))
paramsJSON, _ := json.Marshal(t.Parameters())
b.WriteString(fmt.Sprintf("参数: %s\n", string(paramsJSON)))
}
b.WriteString("\n规则\n")
b.WriteString("1. 当你需要调用工具时,请在回复中**只输出**以下 JSON 格式(不要添加其他文字):\n")
b.WriteString(` {"tool": "工具名", "arguments": {"参数名": "参数值"}}` + "\n")
b.WriteString("2. 如果需要同时调用多个工具(并行执行),请输出 JSON 数组格式:\n")
b.WriteString(` [{"tool": "工具名1", "arguments": {...}}, {"tool": "工具名2", "arguments": {...}}]` + "\n")
b.WriteString("3. 如果你已经看到了工具返回的结果,请直接根据结果回答用户,不要再次调用工具。\n")
b.WriteString("4. 当你不需要调用工具时,请直接回复用户。\n")
b.WriteString("5. 当用户的请求涉及代码、架构、数学计算等专业领域时你必须调用相应的子Agent不要自己直接回答。\n")
}
if len(a.subAgents) > 0 {
b.WriteString("\n\n你可以调用以下专业Agent来协助完成特定任务\n")
for name, description := range a.subAgents {
b.WriteString(fmt.Sprintf("- %s: %s\n", name, description))
}
b.WriteString("\n调用方式使用 agent_call 工具,指定 agent 名称和任务描述。\n")
b.WriteString("示例:{\"tool\": \"agent_call\", \"arguments\": {\"agent\": \"coder\", \"task\": \"写个快速排序\"}}\n")
b.WriteString("如果用户有多个独立任务,请同时调用多个 agent_callJSON数组格式让它们并行执行。\n")
b.WriteString("\n【强制规则】当用户的请求涉及代码编程、系统架构、数学计算、代码审查等专业领域时你必须调用相应的子Agent。\n")
b.WriteString("你绝对不能自己直接回答编程或架构问题,必须通过 agent_call 工具委托给专业Agent处理。\n")
}
if a.skillManager != nil {
skills := a.skillManager.ListSkills()
if len(skills) > 0 {
b.WriteString("\n\n你还可以使用以下技能来更好地帮助用户\n")
for _, s := range skills {
b.WriteString(fmt.Sprintf("\n=== 技能: %s ===\n", s.Name))
b.WriteString(fmt.Sprintf("描述: %s\n", s.Description))
if len(s.Triggers) > 0 {
b.WriteString(fmt.Sprintf("触发词: %s\n", strings.Join(s.Triggers, ", ")))
}
if s.Body != "" {
body := s.Body
if len(body) > 4000 {
body = body[:4000] + "\n...[内容已截断]"
}
b.WriteString(fmt.Sprintf("\n详细指南:\n%s\n", body))
}
b.WriteString(fmt.Sprintf("=== 结束: %s ===\n", s.Name))
}
b.WriteString("\n当用户的请求匹配某个技能的触发词时请根据该技能的详细指南提供更专业的帮助。\n")
b.WriteString("你应该主动使用相关技能的专业知识来回答用户问题,而不需要询问用户是否使用技能。\n")
}
}
if b.Len() == 0 {
return ""
}
return b.String()
}
func (a *LLMAgent) buildToolDefs() []llm.ToolDef {
var tools []llm.ToolDef
if a.toolManager != nil {
for _, t := range a.toolManager.List() {
params := t.Parameters()
properties := make(map[string]llm.ToolProperty)
required := []string{}
for name, param := range params {
prop := llm.ToolProperty{
Type: param.Type,
Description: param.Description,
}
if len(param.Enum) > 0 {
prop.Enum = param.Enum
}
properties[name] = prop
if param.Required {
required = append(required, name)
}
}
tools = append(tools, llm.ToolDef{
Type: "function",
Function: llm.ToolFunction{
Name: t.Name(),
Description: t.Description(),
Parameters: llm.ToolFunctionParameters{
Type: "object",
Required: required,
Properties: properties,
},
},
})
}
}
return tools
}
func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) (string, error) {
maxRounds := 10
for round := 0; round < maxRounds; round++ {
var content string
var err error
if a.streamWriter != nil {
content, err = a.streamChat(ctx, messages)
} else {
content, err = a.syncChat(ctx, messages)
}
if err != nil {
return "", fmt.Errorf("chat round %d failed: %w", round, err)
}
toolCalls := a.parseToolCallsFromContent(content)
if len(toolCalls) == 0 {
return content, nil
}
messages = append(messages, llm.Message{
Role: "assistant",
Content: content,
})
results := a.executeToolCallsParallel(ctx, toolCalls)
for _, result := range results {
messages = append(messages, llm.Message{
Role: "user",
Content: result,
})
}
}
return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds)
}
func (a *LLMAgent) syncChat(ctx context.Context, messages []llm.Message) (string, error) {
response, err := a.llm.Chat(ctx, messages)
if err != nil {
return "", err
}
if len(response.ToolCalls) > 0 {
return response.Content, nil
}
return response.Content, nil
}
func (a *LLMAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) {
var content strings.Builder
if a.streamWriter != nil {
fmt.Fprint(a.streamWriter, "\n")
}
err := a.llm.Stream(ctx, messages, func(chunk string) error {
content.WriteString(chunk)
if a.streamWriter != nil {
fmt.Fprint(a.streamWriter, chunk)
}
return nil
})
if err != nil {
return "", err
}
if a.streamWriter != nil {
fmt.Fprintln(a.streamWriter)
}
return content.String(), nil
}
func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall {
cleanContent := a.extractJSONFromMarkdown(content)
var toolCalls []llm.ToolCall
var callIndex int
var parsedList []struct {
Tool string `json:"tool"`
Arguments map[string]interface{} `json:"arguments"`
}
if err := json.Unmarshal([]byte(cleanContent), &parsedList); err == nil && len(parsedList) > 0 {
for _, parsed := range parsedList {
if parsed.Tool == "" {
continue
}
argsJSON, _ := json.Marshal(parsed.Arguments)
toolCalls = append(toolCalls, llm.ToolCall{
ID: fmt.Sprintf("call_%d", callIndex),
Type: "function",
Function: llm.FunctionCall{
Name: parsed.Tool,
Arguments: string(argsJSON),
},
})
callIndex++
}
return toolCalls
}
var parsed struct {
Tool string `json:"tool"`
Arguments map[string]interface{} `json:"arguments"`
}
if err := json.Unmarshal([]byte(cleanContent), &parsed); err == nil && parsed.Tool != "" {
argsJSON, _ := json.Marshal(parsed.Arguments)
return []llm.ToolCall{{
ID: "call_0",
Type: "function",
Function: llm.FunctionCall{
Name: parsed.Tool,
Arguments: string(argsJSON),
},
}}
}
var commaSeparated []struct {
Tool string `json:"tool"`
Arguments map[string]interface{} `json:"arguments"`
}
wrapped := "[" + cleanContent + "]"
if err := json.Unmarshal([]byte(wrapped), &commaSeparated); err == nil && len(commaSeparated) > 0 {
for _, parsed := range commaSeparated {
if parsed.Tool == "" {
continue
}
argsJSON, _ := json.Marshal(parsed.Arguments)
toolCalls = append(toolCalls, llm.ToolCall{
ID: fmt.Sprintf("call_%d", callIndex),
Type: "function",
Function: llm.FunctionCall{
Name: parsed.Tool,
Arguments: string(argsJSON),
},
})
callIndex++
}
return toolCalls
}
return nil
}
func (a *LLMAgent) extractJSONFromMarkdown(content string) string {
start := strings.Index(content, "`"+"``json")
if start == -1 {
start = strings.Index(content, "`"+"``")
if start == -1 {
return content
}
} else {
start += 7
}
if !strings.HasPrefix(content[start:], "`"+"``") {
newline := strings.Index(content[start:], "\n")
if newline != -1 {
start = start + newline + 1
}
} else {
start += 3
newline := strings.Index(content[start:], "\n")
if newline != -1 {
start = start + newline + 1
}
}
end := strings.Index(content[start:], "\n`"+"``")
if end == -1 {
end = strings.Index(content[start:], "`"+"``")
}
if end == -1 {
return strings.TrimSpace(content[start:])
}
return strings.TrimSpace(content[start : start+end])
}
func (a *LLMAgent) executeToolCall(ctx context.Context, tc llm.ToolCall) string {
toolName := tc.Function.Name
if a.streamWriter != nil {
fmt.Fprintf(a.streamWriter, "\n[正在执行工具: %s...]\n", toolName)
}
var args map[string]interface{}
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
args = map[string]interface{}{
"_raw": tc.Function.Arguments,
}
}
if a.toolWorker != nil {
toolCallMsg := bus.Message{
ID: tc.ID,
Type: bus.MsgTypeToolCall,
From: a.ID(),
To: a.toolWorker.ID(),
Content: map[string]interface{}{"name": toolName, "arguments": args},
}
resultMsg, err := a.toolWorker.Process(ctx, toolCallMsg)
if a.streamWriter != nil {
fmt.Fprintf(a.streamWriter, "[工具 %s 执行完成]\n", toolName)
}
if err != nil {
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
}
resultJSON, err := json.Marshal(resultMsg.Content)
if err != nil {
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
}
return string(resultJSON)
}
if a.toolManager != nil {
result, err := a.toolManager.Execute(toolName, ctx, args)
if a.streamWriter != nil {
fmt.Fprintf(a.streamWriter, "[工具 %s 执行完成]\n", toolName)
}
if err != nil {
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
}
resultJSON, err := json.Marshal(result)
if err != nil {
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
}
return string(resultJSON)
}
return fmt.Sprintf(`{"error": "no tool worker or tool manager available for %q"}`, toolName)
}
func (a *LLMAgent) executeToolCallsParallel(ctx context.Context, toolCalls []llm.ToolCall) []string {
if len(toolCalls) == 1 {
result := a.executeToolCall(ctx, toolCalls[0])
return []string{fmt.Sprintf("工具 %s 的执行结果:%s", toolCalls[0].Function.Name, result)}
}
type result struct {
index int
content string
}
results := make([]result, len(toolCalls))
var wg sync.WaitGroup
for i, tc := range toolCalls {
wg.Add(1)
go func(idx int, toolCall llm.ToolCall) {
defer wg.Done()
res := a.executeToolCall(ctx, toolCall)
results[idx] = result{
index: idx,
content: fmt.Sprintf("工具 %s 的执行结果:%s", toolCall.Function.Name, res),
}
}(i, tc)
}
wg.Wait()
strings := make([]string, len(toolCalls))
for i, r := range results {
strings[i] = r.content
}
return strings
}
// handleSystem processes internal system messages.
func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
return bus.Message{
ID: msg.ID + "-ack",
Type: bus.MsgTypeSystem,
From: a.ID(),
To: msg.From,
Content: "llm_agent acknowledged",
}, nil
}
func (a *LLMAgent) SetStreamWriter(w io.Writer) {
a.streamWriter = w
}
var _ Agent = (*LLMAgent)(nil)
var _ Agent = (*ToolWorker)(nil)