orca.ai/pkg/actor/llm_agent.go
大森 04c7ea5e39 feat: add streaming output for CLI
- Add streamWriter to LLMAgent for real-time output
- Support streaming mode in chatWithToolLoop
- Add SetStreamWriter to Kernel and LLMAgent
- CLI displays streaming responses immediately
- Tool calling still works with streaming
2026-05-08 01:19:07 +08:00

390 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package actor
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/orca/orca/pkg/bus"
"github.com/orca/orca/pkg/llm"
"github.com/orca/orca/pkg/session"
"github.com/orca/orca/pkg/tool"
)
// LLMAgent implements the Agent interface by integrating an LLM backend
// with the actor system and tool framework.
//
// It receives user messages, retrieves session context, calls the LLM,
// and handles tool call responses by executing tools and feeding results
// back to the LLM for final response generation.
type LLMAgent struct {
*BaseAgent
llm llm.LLM
sessionMgr *session.Manager
sessionID string
toolManager *tool.Manager
toolWorker *ToolWorker
windowSize int
streamWriter io.Writer
}
// LLMAgentOption is a functional option for configuring the LLMAgent.
type LLMAgentOption func(*LLMAgent)
// WithSessionManager sets the session manager for conversation history.
func WithSessionManager(mgr *session.Manager) LLMAgentOption {
return func(a *LLMAgent) {
a.sessionMgr = mgr
}
}
// WithSessionID sets the session ID for conversation persistence.
func WithSessionID(id string) LLMAgentOption {
return func(a *LLMAgent) {
a.sessionID = id
}
}
// WithToolManager sets the tool manager for executing tools.
func WithToolManager(mgr *tool.Manager) LLMAgentOption {
return func(a *LLMAgent) {
a.toolManager = mgr
}
}
// WithToolWorker sets the tool worker for delegated tool execution.
func WithToolWorker(w *ToolWorker) LLMAgentOption {
return func(a *LLMAgent) {
a.toolWorker = w
}
}
// WithWindowSize sets the context window size for session history.
func WithWindowSize(size int) LLMAgentOption {
return func(a *LLMAgent) {
a.windowSize = size
}
}
// WithStreamWriter sets the writer for streaming LLM output.
func WithStreamWriter(w io.Writer) LLMAgentOption {
return func(a *LLMAgent) {
a.streamWriter = w
}
}
// NewLLMAgent creates a new LLMAgent with the given LLM backend and options.
// The agent is started automatically upon creation.
func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent {
a := &LLMAgent{
BaseAgent: NewBaseAgent(id, "llm_agent"),
llm: backend,
windowSize: 20, // Default context window
}
for _, opt := range opts {
opt(a)
}
a.SetHandler(a.handleMessage)
if err := a.Start(); err != nil {
panic(fmt.Sprintf("llm_agent: failed to start: %v", err))
}
return a
}
// handleMessage routes incoming messages to the appropriate handler.
func (a *LLMAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
switch msg.Type {
case bus.MsgTypeTaskRequest:
return a.handleUserMessage(ctx, msg)
case bus.MsgTypeSystem:
return a.handleSystem(ctx, msg)
default:
return bus.Message{}, fmt.Errorf("llm_agent %s: unsupported message type %s", a.ID(), msg.Type)
}
}
// handleUserMessage processes a user message through the LLM.
//
// Flow:
// 1. Persist the user message to session history
// 2. Retrieve recent conversation context
// 3. Convert to LLM message format
// 4. Call LLM.Chat
// 5. If response has tool calls:
// a. Execute each tool (directly or via ToolWorker)
// b. Add tool results to conversation
// c. Call LLM.Chat again with results
// 6. Persist the assistant response
// 7. Return the final response
func (a *LLMAgent) handleUserMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
content, ok := msg.Content.(string)
if !ok {
return bus.Message{}, fmt.Errorf("llm_agent: expected string content, got %T", msg.Content)
}
// Ensure session exists
if a.sessionMgr != nil && a.sessionID != "" {
// Check if session exists; create if not
if _, err := a.sessionMgr.GetSession(a.sessionID); err != nil {
a.sessionMgr.CreateSession(a.sessionID, map[string]string{
"source": "llm_agent",
})
}
// Persist user message
a.sessionMgr.AddMessage(a.sessionID, session.RoleUser, content, nil)
}
// Build LLM messages from session context
llmMessages := a.buildLLMMessages()
// Call LLM (potentially multiple rounds for tool calls)
finalResponse, err := a.chatWithToolLoop(ctx, llmMessages)
if err != nil {
return bus.Message{}, fmt.Errorf("llm_agent: LLM chat failed: %w", err)
}
// Persist assistant response
if a.sessionMgr != nil && a.sessionID != "" {
a.sessionMgr.AddMessage(a.sessionID, session.RoleAssistant, finalResponse, nil)
}
return bus.Message{
ID: msg.ID + "-response",
Type: bus.MsgTypeTaskResponse,
From: a.ID(),
To: msg.From,
Content: finalResponse,
}, nil
}
func (a *LLMAgent) buildLLMMessages() []llm.Message {
messages := make([]llm.Message, 0)
if a.toolManager != nil {
messages = append(messages, llm.Message{
Role: "system",
Content: a.buildToolSystemPrompt(),
})
}
if a.sessionMgr == nil || a.sessionID == "" {
return messages
}
sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize)
if err != nil {
return messages
}
for _, sm := range sessionMsgs {
msg := llm.Message{
Role: string(sm.Role),
Content: sm.Content,
}
if sm.Role == session.RoleTool && sm.Metadata != nil {
msg.ToolCallID = sm.Metadata["tool_call_id"]
}
messages = append(messages, msg)
}
return messages
}
// buildToolSystemPrompt creates a system prompt describing all available tools.
// This enables prompt-based tool calling for models without native function
// calling support.
func (a *LLMAgent) buildToolSystemPrompt() string {
if a.toolManager == nil {
return ""
}
var b strings.Builder
b.WriteString("你是一个 AI 助手,可以使用以下工具来完成用户的请求。\n\n")
b.WriteString("可用工具列表:\n")
for _, t := range a.toolManager.List() {
b.WriteString(fmt.Sprintf("\n工具名: %s\n", t.Name()))
b.WriteString(fmt.Sprintf("描述: %s\n", t.Description()))
paramsJSON, _ := json.Marshal(t.Parameters())
b.WriteString(fmt.Sprintf("参数: %s\n", string(paramsJSON)))
}
b.WriteString("\n规则\n")
b.WriteString("1. 当你需要调用工具时,请在回复中**只输出**以下 JSON 格式(不要添加其他文字):\n")
b.WriteString(` {"tool": "工具名", "arguments": {"参数名": "参数值"}}` + "\n")
b.WriteString("2. 如果你已经看到了工具返回的结果,请直接根据结果回答用户,不要再次调用工具。\n")
b.WriteString("3. 如果你不需要调用工具,请直接回复用户。\n")
return b.String()
}
func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) (string, error) {
maxRounds := 10
for round := 0; round < maxRounds; round++ {
var content string
var err error
if a.streamWriter != nil {
content, err = a.streamChat(ctx, messages)
} else {
content, err = a.syncChat(ctx, messages)
}
if err != nil {
return "", fmt.Errorf("chat round %d failed: %w", round, err)
}
toolCalls := a.parseToolCallsFromContent(content)
if len(toolCalls) == 0 {
return content, nil
}
messages = append(messages, llm.Message{
Role: "assistant",
Content: content,
})
for _, tc := range toolCalls {
resultContent := a.executeToolCall(ctx, tc)
messages = append(messages, llm.Message{
Role: "user",
Content: fmt.Sprintf("工具 %s 的执行结果:%s", tc.Function.Name, resultContent),
})
}
}
return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds)
}
func (a *LLMAgent) syncChat(ctx context.Context, messages []llm.Message) (string, error) {
response, err := a.llm.Chat(ctx, messages)
if err != nil {
return "", err
}
if len(response.ToolCalls) > 0 {
return response.Content, nil
}
return response.Content, nil
}
func (a *LLMAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) {
var content strings.Builder
if a.streamWriter != nil {
fmt.Fprint(a.streamWriter, "\n")
}
err := a.llm.Stream(ctx, messages, func(chunk string) error {
content.WriteString(chunk)
if a.streamWriter != nil {
fmt.Fprint(a.streamWriter, chunk)
}
return nil
})
if err != nil {
return "", err
}
if a.streamWriter != nil {
fmt.Fprintln(a.streamWriter)
}
return content.String(), nil
}
func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall {
var parsed struct {
Tool string `json:"tool"`
Arguments map[string]interface{} `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &parsed); err != nil || parsed.Tool == "" {
return nil
}
argsJSON, _ := json.Marshal(parsed.Arguments)
return []llm.ToolCall{{
ID: "call_0",
Type: "function",
Function: llm.FunctionCall{
Name: parsed.Tool,
Arguments: string(argsJSON),
},
}}
}
// executeToolCall runs a single tool call and returns the result as a JSON string.
func (a *LLMAgent) executeToolCall(ctx context.Context, tc llm.ToolCall) string {
toolName := tc.Function.Name
// Parse arguments
var args map[string]interface{}
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
args = map[string]interface{}{
"_raw": tc.Function.Arguments,
}
}
// Execute via ToolWorker (preferred) or directly via tool.Manager
if a.toolWorker != nil {
// Create a tool call message for the ToolWorker
toolCallMsg := bus.Message{
ID: tc.ID,
Type: bus.MsgTypeToolCall,
From: a.ID(),
To: a.toolWorker.ID(),
Content: map[string]interface{}{"name": toolName, "arguments": args},
}
resultMsg, err := a.toolWorker.Process(ctx, toolCallMsg)
if err != nil {
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
}
// Serialize the result
resultJSON, err := json.Marshal(resultMsg.Content)
if err != nil {
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
}
return string(resultJSON)
}
// Fallback: execute directly via tool.Manager
if a.toolManager != nil {
result, err := a.toolManager.Execute(toolName, ctx, args)
if err != nil {
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
}
resultJSON, err := json.Marshal(result)
if err != nil {
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
}
return string(resultJSON)
}
return fmt.Sprintf(`{"error": "no tool worker or tool manager available for %q"}`, toolName)
}
// handleSystem processes internal system messages.
func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
return bus.Message{
ID: msg.ID + "-ack",
Type: bus.MsgTypeSystem,
From: a.ID(),
To: msg.From,
Content: "llm_agent acknowledged",
}, nil
}
func (a *LLMAgent) SetStreamWriter(w io.Writer) {
a.streamWriter = w
}
var _ Agent = (*LLMAgent)(nil)
var _ Agent = (*ToolWorker)(nil)