- Add streamWriter to LLMAgent for real-time output - Support streaming mode in chatWithToolLoop - Add SetStreamWriter to Kernel and LLMAgent - CLI displays streaming responses immediately - Tool calling still works with streaming
390 lines
11 KiB
Go
390 lines
11 KiB
Go
package actor
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"strings"
|
||
|
||
"github.com/orca/orca/pkg/bus"
|
||
"github.com/orca/orca/pkg/llm"
|
||
"github.com/orca/orca/pkg/session"
|
||
"github.com/orca/orca/pkg/tool"
|
||
)
|
||
|
||
// LLMAgent implements the Agent interface by integrating an LLM backend
|
||
// with the actor system and tool framework.
|
||
//
|
||
// It receives user messages, retrieves session context, calls the LLM,
|
||
// and handles tool call responses by executing tools and feeding results
|
||
// back to the LLM for final response generation.
|
||
type LLMAgent struct {
|
||
*BaseAgent
|
||
llm llm.LLM
|
||
sessionMgr *session.Manager
|
||
sessionID string
|
||
toolManager *tool.Manager
|
||
toolWorker *ToolWorker
|
||
windowSize int
|
||
streamWriter io.Writer
|
||
}
|
||
|
||
// LLMAgentOption is a functional option for configuring the LLMAgent.
|
||
type LLMAgentOption func(*LLMAgent)
|
||
|
||
// WithSessionManager sets the session manager for conversation history.
|
||
func WithSessionManager(mgr *session.Manager) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.sessionMgr = mgr
|
||
}
|
||
}
|
||
|
||
// WithSessionID sets the session ID for conversation persistence.
|
||
func WithSessionID(id string) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.sessionID = id
|
||
}
|
||
}
|
||
|
||
// WithToolManager sets the tool manager for executing tools.
|
||
func WithToolManager(mgr *tool.Manager) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.toolManager = mgr
|
||
}
|
||
}
|
||
|
||
// WithToolWorker sets the tool worker for delegated tool execution.
|
||
func WithToolWorker(w *ToolWorker) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.toolWorker = w
|
||
}
|
||
}
|
||
|
||
// WithWindowSize sets the context window size for session history.
|
||
func WithWindowSize(size int) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.windowSize = size
|
||
}
|
||
}
|
||
|
||
// WithStreamWriter sets the writer for streaming LLM output.
|
||
func WithStreamWriter(w io.Writer) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.streamWriter = w
|
||
}
|
||
}
|
||
|
||
// NewLLMAgent creates a new LLMAgent with the given LLM backend and options.
|
||
// The agent is started automatically upon creation.
|
||
func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent {
|
||
a := &LLMAgent{
|
||
BaseAgent: NewBaseAgent(id, "llm_agent"),
|
||
llm: backend,
|
||
windowSize: 20, // Default context window
|
||
}
|
||
|
||
for _, opt := range opts {
|
||
opt(a)
|
||
}
|
||
|
||
a.SetHandler(a.handleMessage)
|
||
if err := a.Start(); err != nil {
|
||
panic(fmt.Sprintf("llm_agent: failed to start: %v", err))
|
||
}
|
||
return a
|
||
}
|
||
|
||
// handleMessage routes incoming messages to the appropriate handler.
|
||
func (a *LLMAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||
switch msg.Type {
|
||
case bus.MsgTypeTaskRequest:
|
||
return a.handleUserMessage(ctx, msg)
|
||
case bus.MsgTypeSystem:
|
||
return a.handleSystem(ctx, msg)
|
||
default:
|
||
return bus.Message{}, fmt.Errorf("llm_agent %s: unsupported message type %s", a.ID(), msg.Type)
|
||
}
|
||
}
|
||
|
||
// handleUserMessage processes a user message through the LLM.
|
||
//
|
||
// Flow:
|
||
// 1. Persist the user message to session history
|
||
// 2. Retrieve recent conversation context
|
||
// 3. Convert to LLM message format
|
||
// 4. Call LLM.Chat
|
||
// 5. If response has tool calls:
|
||
// a. Execute each tool (directly or via ToolWorker)
|
||
// b. Add tool results to conversation
|
||
// c. Call LLM.Chat again with results
|
||
// 6. Persist the assistant response
|
||
// 7. Return the final response
|
||
func (a *LLMAgent) handleUserMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||
content, ok := msg.Content.(string)
|
||
if !ok {
|
||
return bus.Message{}, fmt.Errorf("llm_agent: expected string content, got %T", msg.Content)
|
||
}
|
||
|
||
// Ensure session exists
|
||
if a.sessionMgr != nil && a.sessionID != "" {
|
||
// Check if session exists; create if not
|
||
if _, err := a.sessionMgr.GetSession(a.sessionID); err != nil {
|
||
a.sessionMgr.CreateSession(a.sessionID, map[string]string{
|
||
"source": "llm_agent",
|
||
})
|
||
}
|
||
|
||
// Persist user message
|
||
a.sessionMgr.AddMessage(a.sessionID, session.RoleUser, content, nil)
|
||
}
|
||
|
||
// Build LLM messages from session context
|
||
llmMessages := a.buildLLMMessages()
|
||
|
||
// Call LLM (potentially multiple rounds for tool calls)
|
||
finalResponse, err := a.chatWithToolLoop(ctx, llmMessages)
|
||
if err != nil {
|
||
return bus.Message{}, fmt.Errorf("llm_agent: LLM chat failed: %w", err)
|
||
}
|
||
|
||
// Persist assistant response
|
||
if a.sessionMgr != nil && a.sessionID != "" {
|
||
a.sessionMgr.AddMessage(a.sessionID, session.RoleAssistant, finalResponse, nil)
|
||
}
|
||
|
||
return bus.Message{
|
||
ID: msg.ID + "-response",
|
||
Type: bus.MsgTypeTaskResponse,
|
||
From: a.ID(),
|
||
To: msg.From,
|
||
Content: finalResponse,
|
||
}, nil
|
||
}
|
||
|
||
func (a *LLMAgent) buildLLMMessages() []llm.Message {
|
||
messages := make([]llm.Message, 0)
|
||
|
||
if a.toolManager != nil {
|
||
messages = append(messages, llm.Message{
|
||
Role: "system",
|
||
Content: a.buildToolSystemPrompt(),
|
||
})
|
||
}
|
||
|
||
if a.sessionMgr == nil || a.sessionID == "" {
|
||
return messages
|
||
}
|
||
|
||
sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize)
|
||
if err != nil {
|
||
return messages
|
||
}
|
||
|
||
for _, sm := range sessionMsgs {
|
||
msg := llm.Message{
|
||
Role: string(sm.Role),
|
||
Content: sm.Content,
|
||
}
|
||
if sm.Role == session.RoleTool && sm.Metadata != nil {
|
||
msg.ToolCallID = sm.Metadata["tool_call_id"]
|
||
}
|
||
messages = append(messages, msg)
|
||
}
|
||
|
||
return messages
|
||
}
|
||
|
||
// buildToolSystemPrompt creates a system prompt describing all available tools.
|
||
// This enables prompt-based tool calling for models without native function
|
||
// calling support.
|
||
func (a *LLMAgent) buildToolSystemPrompt() string {
|
||
if a.toolManager == nil {
|
||
return ""
|
||
}
|
||
|
||
var b strings.Builder
|
||
b.WriteString("你是一个 AI 助手,可以使用以下工具来完成用户的请求。\n\n")
|
||
b.WriteString("可用工具列表:\n")
|
||
|
||
for _, t := range a.toolManager.List() {
|
||
b.WriteString(fmt.Sprintf("\n工具名: %s\n", t.Name()))
|
||
b.WriteString(fmt.Sprintf("描述: %s\n", t.Description()))
|
||
paramsJSON, _ := json.Marshal(t.Parameters())
|
||
b.WriteString(fmt.Sprintf("参数: %s\n", string(paramsJSON)))
|
||
}
|
||
|
||
b.WriteString("\n规则:\n")
|
||
b.WriteString("1. 当你需要调用工具时,请在回复中**只输出**以下 JSON 格式(不要添加其他文字):\n")
|
||
b.WriteString(` {"tool": "工具名", "arguments": {"参数名": "参数值"}}` + "\n")
|
||
b.WriteString("2. 如果你已经看到了工具返回的结果,请直接根据结果回答用户,不要再次调用工具。\n")
|
||
b.WriteString("3. 如果你不需要调用工具,请直接回复用户。\n")
|
||
|
||
return b.String()
|
||
}
|
||
|
||
func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) (string, error) {
|
||
maxRounds := 10
|
||
|
||
for round := 0; round < maxRounds; round++ {
|
||
var content string
|
||
var err error
|
||
|
||
if a.streamWriter != nil {
|
||
content, err = a.streamChat(ctx, messages)
|
||
} else {
|
||
content, err = a.syncChat(ctx, messages)
|
||
}
|
||
if err != nil {
|
||
return "", fmt.Errorf("chat round %d failed: %w", round, err)
|
||
}
|
||
|
||
toolCalls := a.parseToolCallsFromContent(content)
|
||
if len(toolCalls) == 0 {
|
||
return content, nil
|
||
}
|
||
|
||
messages = append(messages, llm.Message{
|
||
Role: "assistant",
|
||
Content: content,
|
||
})
|
||
|
||
for _, tc := range toolCalls {
|
||
resultContent := a.executeToolCall(ctx, tc)
|
||
messages = append(messages, llm.Message{
|
||
Role: "user",
|
||
Content: fmt.Sprintf("工具 %s 的执行结果:%s", tc.Function.Name, resultContent),
|
||
})
|
||
}
|
||
}
|
||
|
||
return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds)
|
||
}
|
||
|
||
func (a *LLMAgent) syncChat(ctx context.Context, messages []llm.Message) (string, error) {
|
||
response, err := a.llm.Chat(ctx, messages)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(response.ToolCalls) > 0 {
|
||
return response.Content, nil
|
||
}
|
||
return response.Content, nil
|
||
}
|
||
|
||
func (a *LLMAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) {
|
||
var content strings.Builder
|
||
if a.streamWriter != nil {
|
||
fmt.Fprint(a.streamWriter, "\n")
|
||
}
|
||
|
||
err := a.llm.Stream(ctx, messages, func(chunk string) error {
|
||
content.WriteString(chunk)
|
||
if a.streamWriter != nil {
|
||
fmt.Fprint(a.streamWriter, chunk)
|
||
}
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if a.streamWriter != nil {
|
||
fmt.Fprintln(a.streamWriter)
|
||
}
|
||
return content.String(), nil
|
||
}
|
||
|
||
func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall {
|
||
var parsed struct {
|
||
Tool string `json:"tool"`
|
||
Arguments map[string]interface{} `json:"arguments"`
|
||
}
|
||
if err := json.Unmarshal([]byte(content), &parsed); err != nil || parsed.Tool == "" {
|
||
return nil
|
||
}
|
||
|
||
argsJSON, _ := json.Marshal(parsed.Arguments)
|
||
return []llm.ToolCall{{
|
||
ID: "call_0",
|
||
Type: "function",
|
||
Function: llm.FunctionCall{
|
||
Name: parsed.Tool,
|
||
Arguments: string(argsJSON),
|
||
},
|
||
}}
|
||
}
|
||
|
||
// executeToolCall runs a single tool call and returns the result as a JSON string.
|
||
func (a *LLMAgent) executeToolCall(ctx context.Context, tc llm.ToolCall) string {
|
||
toolName := tc.Function.Name
|
||
|
||
// Parse arguments
|
||
var args map[string]interface{}
|
||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||
args = map[string]interface{}{
|
||
"_raw": tc.Function.Arguments,
|
||
}
|
||
}
|
||
|
||
// Execute via ToolWorker (preferred) or directly via tool.Manager
|
||
if a.toolWorker != nil {
|
||
// Create a tool call message for the ToolWorker
|
||
toolCallMsg := bus.Message{
|
||
ID: tc.ID,
|
||
Type: bus.MsgTypeToolCall,
|
||
From: a.ID(),
|
||
To: a.toolWorker.ID(),
|
||
Content: map[string]interface{}{"name": toolName, "arguments": args},
|
||
}
|
||
|
||
resultMsg, err := a.toolWorker.Process(ctx, toolCallMsg)
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
|
||
}
|
||
|
||
// Serialize the result
|
||
resultJSON, err := json.Marshal(resultMsg.Content)
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
|
||
}
|
||
return string(resultJSON)
|
||
}
|
||
|
||
// Fallback: execute directly via tool.Manager
|
||
if a.toolManager != nil {
|
||
result, err := a.toolManager.Execute(toolName, ctx, args)
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
|
||
}
|
||
|
||
resultJSON, err := json.Marshal(result)
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
|
||
}
|
||
return string(resultJSON)
|
||
}
|
||
|
||
return fmt.Sprintf(`{"error": "no tool worker or tool manager available for %q"}`, toolName)
|
||
}
|
||
|
||
// handleSystem processes internal system messages.
|
||
func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||
return bus.Message{
|
||
ID: msg.ID + "-ack",
|
||
Type: bus.MsgTypeSystem,
|
||
From: a.ID(),
|
||
To: msg.From,
|
||
Content: "llm_agent acknowledged",
|
||
}, nil
|
||
}
|
||
|
||
func (a *LLMAgent) SetStreamWriter(w io.Writer) {
|
||
a.streamWriter = w
|
||
}
|
||
|
||
var _ Agent = (*LLMAgent)(nil)
|
||
var _ Agent = (*ToolWorker)(nil)
|