778 lines
22 KiB
Go
778 lines
22 KiB
Go
package actor
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/orca/orca/pkg/bus"
|
||
"github.com/orca/orca/pkg/llm"
|
||
"github.com/orca/orca/pkg/session"
|
||
"github.com/orca/orca/pkg/skill"
|
||
"github.com/orca/orca/pkg/tool"
|
||
)
|
||
|
||
// LLMAgent implements the Agent interface by integrating an LLM backend
|
||
// with the actor system and tool framework.
|
||
//
|
||
// It receives user messages, retrieves session context, calls the LLM,
|
||
// and handles tool call responses by executing tools and feeding results
|
||
// back to the LLM for final response generation.
|
||
type LLMAgent struct {
|
||
*BaseAgent
|
||
llm llm.LLM
|
||
sessionMgr *session.Manager
|
||
sessionID string
|
||
toolManager *tool.Manager
|
||
skillManager *skill.Manager
|
||
toolWorker *ToolWorker
|
||
windowSize int
|
||
streamWriter io.Writer
|
||
systemPrompt string
|
||
subAgents map[string]string
|
||
memoryManager *session.MemoryManager
|
||
}
|
||
|
||
// LLMAgentOption is a functional option for configuring the LLMAgent.
|
||
type LLMAgentOption func(*LLMAgent)
|
||
|
||
// WithSessionManager sets the session manager for conversation history.
|
||
func WithSessionManager(mgr *session.Manager) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.sessionMgr = mgr
|
||
}
|
||
}
|
||
|
||
// WithSessionID sets the session ID for conversation persistence.
|
||
func WithSessionID(id string) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.sessionID = id
|
||
}
|
||
}
|
||
|
||
// WithToolManager sets the tool manager for executing tools.
|
||
func WithToolManager(mgr *tool.Manager) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.toolManager = mgr
|
||
}
|
||
}
|
||
|
||
// WithToolWorker sets the tool worker for delegated tool execution.
|
||
func WithToolWorker(w *ToolWorker) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.toolWorker = w
|
||
}
|
||
}
|
||
|
||
// WithWindowSize sets the context window size for session history.
|
||
func WithWindowSize(size int) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.windowSize = size
|
||
}
|
||
}
|
||
|
||
func WithSkillManager(mgr *skill.Manager) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.skillManager = mgr
|
||
}
|
||
}
|
||
|
||
// WithStreamWriter sets the writer for streaming LLM output.
|
||
func WithStreamWriter(w io.Writer) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.streamWriter = w
|
||
}
|
||
}
|
||
|
||
func WithSystemPrompt(prompt string) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.systemPrompt = prompt
|
||
}
|
||
}
|
||
|
||
func WithSubAgents(agents map[string]string) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.subAgents = agents
|
||
}
|
||
}
|
||
|
||
func WithMemoryManager(mm *session.MemoryManager) LLMAgentOption {
|
||
return func(a *LLMAgent) {
|
||
a.memoryManager = mm
|
||
}
|
||
}
|
||
|
||
// NewLLMAgent creates a new LLMAgent with the given LLM backend and options.
|
||
// The agent is started automatically upon creation.
|
||
func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent {
|
||
a := &LLMAgent{
|
||
BaseAgent: NewBaseAgent(id, "llm_agent"),
|
||
llm: backend,
|
||
windowSize: 20, // Default context window
|
||
}
|
||
|
||
for _, opt := range opts {
|
||
opt(a)
|
||
}
|
||
|
||
a.SetHandler(a.handleMessage)
|
||
if err := a.Start(); err != nil {
|
||
panic(fmt.Sprintf("llm_agent: failed to start: %v", err))
|
||
}
|
||
return a
|
||
}
|
||
|
||
// handleMessage routes incoming messages to the appropriate handler.
|
||
func (a *LLMAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||
switch msg.Type {
|
||
case bus.MsgTypeTaskRequest:
|
||
return a.handleUserMessage(ctx, msg)
|
||
case bus.MsgTypeSystem:
|
||
return a.handleSystem(ctx, msg)
|
||
default:
|
||
return bus.Message{}, fmt.Errorf("llm_agent %s: unsupported message type %s", a.ID(), msg.Type)
|
||
}
|
||
}
|
||
|
||
// handleUserMessage processes a user message through the LLM.
|
||
//
|
||
// Flow:
|
||
// 1. Persist the user message to session history
|
||
// 2. Retrieve recent conversation context
|
||
// 3. Convert to LLM message format
|
||
// 4. Call LLM.Chat
|
||
// 5. If response has tool calls:
|
||
// a. Execute each tool (directly or via ToolWorker)
|
||
// b. Add tool results to conversation
|
||
// c. Call LLM.Chat again with results
|
||
// 6. Persist the assistant response
|
||
// 7. Return the final response
|
||
func (a *LLMAgent) handleUserMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||
content, ok := msg.Content.(string)
|
||
if !ok {
|
||
return bus.Message{}, fmt.Errorf("llm_agent: expected string content, got %T", msg.Content)
|
||
}
|
||
|
||
// 处理特殊命令
|
||
if content == "/context" || content == "/debug" {
|
||
a.LogContextDetails(content)
|
||
return bus.Message{
|
||
ID: msg.ID + "-response",
|
||
Type: bus.MsgTypeTaskResponse,
|
||
From: a.ID(),
|
||
To: msg.From,
|
||
Content: "[上下文详情已输出到日志]",
|
||
}, nil
|
||
}
|
||
|
||
// Ensure session exists
|
||
if a.sessionMgr != nil && a.sessionID != "" {
|
||
// Check if session exists; create if not
|
||
if _, err := a.sessionMgr.GetSession(a.sessionID); err != nil {
|
||
a.sessionMgr.CreateSession(a.sessionID, map[string]string{
|
||
"source": "llm_agent",
|
||
})
|
||
}
|
||
|
||
// Persist user message
|
||
a.sessionMgr.AddMessage(a.sessionID, session.RoleUser, content, nil)
|
||
}
|
||
|
||
if a.memoryManager != nil && a.sessionID != "" {
|
||
a.memoryManager.SaveMessage(a.sessionID, session.SessionMessage{
|
||
Role: session.RoleUser,
|
||
Content: content,
|
||
Timestamp: time.Now(),
|
||
})
|
||
}
|
||
|
||
llmMessages := a.buildLLMMessages(content)
|
||
|
||
if a.skillManager != nil {
|
||
matchedSkills := a.skillManager.FindSkill(content)
|
||
for _, s := range matchedSkills {
|
||
if s.Body != "" {
|
||
llmMessages = append(llmMessages, llm.Message{
|
||
Role: "system",
|
||
Content: fmt.Sprintf("以下是你需要遵循的 %s 技能指南:\n\n%s", s.Name, s.Body),
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
// Call LLM (potentially multiple rounds for tool calls)
|
||
finalResponse, err := a.chatWithToolLoop(ctx, llmMessages)
|
||
if err != nil {
|
||
return bus.Message{}, fmt.Errorf("llm_agent: LLM chat failed: %w", err)
|
||
}
|
||
|
||
// Persist assistant response
|
||
if a.sessionMgr != nil && a.sessionID != "" {
|
||
a.sessionMgr.AddMessage(a.sessionID, session.RoleAssistant, finalResponse, nil)
|
||
}
|
||
|
||
if a.memoryManager != nil && a.sessionID != "" {
|
||
a.memoryManager.SaveMessage(a.sessionID, session.SessionMessage{
|
||
Role: session.RoleAssistant,
|
||
Content: finalResponse,
|
||
Timestamp: time.Now(),
|
||
})
|
||
go a.memoryManager.MaintainSessionMemory(a.sessionID, content, finalResponse)
|
||
}
|
||
|
||
return bus.Message{
|
||
ID: msg.ID + "-response",
|
||
Type: bus.MsgTypeTaskResponse,
|
||
From: a.ID(),
|
||
To: msg.From,
|
||
Content: finalResponse,
|
||
}, nil
|
||
}
|
||
|
||
type contextStats struct {
|
||
systemPromptTokens int
|
||
toolPromptTokens int
|
||
memoryTokens int
|
||
historyTokens int
|
||
memoryShortTerm int
|
||
memoryLongTerm int
|
||
historyCount int
|
||
}
|
||
|
||
func (a *LLMAgent) buildLLMMessages(query string) []llm.Message {
|
||
return a.buildLLMMessagesWithStats(query, nil)
|
||
}
|
||
|
||
func (a *LLMAgent) buildLLMMessagesWithStats(query string, stats *contextStats) []llm.Message {
|
||
messages := make([]llm.Message, 0)
|
||
if stats == nil {
|
||
stats = &contextStats{}
|
||
}
|
||
|
||
// 1. 用户自定义 system prompt(配置式身份描述)
|
||
if a.systemPrompt != "" {
|
||
messages = append(messages, llm.Message{
|
||
Role: "system",
|
||
Content: a.systemPrompt,
|
||
})
|
||
stats.systemPromptTokens = estimateTokens(a.systemPrompt)
|
||
}
|
||
|
||
// 2. 运行时工具说明(动态生成)
|
||
toolPrompt := a.buildToolPrompt()
|
||
if toolPrompt != "" {
|
||
messages = append(messages, llm.Message{
|
||
Role: "system",
|
||
Content: toolPrompt,
|
||
})
|
||
stats.toolPromptTokens = estimateTokens(toolPrompt)
|
||
}
|
||
|
||
if a.memoryManager != nil && a.sessionID != "" {
|
||
if a.memoryManager.ShouldInjectMemory(a.sessionID, query) {
|
||
memoryCtx, memStats := a.memoryManager.BuildMemoryContextWithStats(a.sessionID, query)
|
||
if memoryCtx != "" {
|
||
messages = append(messages, llm.Message{
|
||
Role: "system",
|
||
Content: memoryCtx,
|
||
})
|
||
stats.memoryTokens = memStats.TotalTokens
|
||
stats.memoryShortTerm = memStats.ShortTermCount
|
||
stats.memoryLongTerm = memStats.LongTermCount
|
||
}
|
||
}
|
||
}
|
||
|
||
if a.sessionMgr != nil && a.sessionID != "" {
|
||
sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize)
|
||
if err == nil {
|
||
stats.historyCount = len(sessionMsgs)
|
||
for _, sm := range sessionMsgs {
|
||
msg := llm.Message{
|
||
Role: string(sm.Role),
|
||
Content: sm.Content,
|
||
}
|
||
if sm.Role == session.RoleTool && sm.Metadata != nil {
|
||
msg.ToolCallID = sm.Metadata["tool_call_id"]
|
||
}
|
||
messages = append(messages, msg)
|
||
stats.historyTokens += estimateTokens(sm.Content)
|
||
}
|
||
}
|
||
}
|
||
|
||
return messages
|
||
}
|
||
|
||
func (a *LLMAgent) LogContextDetails(query string) {
|
||
fmt.Println("\n========== 上下文详情 ==========")
|
||
|
||
if a.memoryManager != nil && a.sessionID != "" {
|
||
fmt.Println("\n[记忆内容]")
|
||
shortTerm, _ := a.memoryManager.GetShortTermMemory(a.sessionID, query)
|
||
if len(shortTerm) > 0 {
|
||
fmt.Println(" 短期记忆:")
|
||
for i, m := range shortTerm {
|
||
fmt.Printf(" [%d] %s\n", i+1, truncateForDisplay(m, 80))
|
||
}
|
||
}
|
||
longTerm, _ := a.memoryManager.GetLongTermMemory(query)
|
||
if len(longTerm) > 0 {
|
||
fmt.Println(" 长期记忆:")
|
||
for i, m := range longTerm {
|
||
fmt.Printf(" [%d] %s\n", i+1, truncateForDisplay(m.Content, 80))
|
||
}
|
||
}
|
||
if len(shortTerm) == 0 && len(longTerm) == 0 {
|
||
fmt.Println(" (无记忆)")
|
||
}
|
||
|
||
cacheSize, cacheHits, cacheMisses := a.memoryManager.CacheStats()
|
||
fmt.Printf("\n[Embedding缓存] 大小=%d, 命中=%d, 未命中=%d, 命中率=%.1f%%\n",
|
||
cacheSize, cacheHits, cacheMisses,
|
||
float64(cacheHits)*100/float64(cacheHits+cacheMisses+1))
|
||
}
|
||
|
||
if a.sessionMgr != nil && a.sessionID != "" {
|
||
fmt.Println("\n[历史对话]")
|
||
sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize)
|
||
if err == nil && len(sessionMsgs) > 0 {
|
||
start := 0
|
||
if len(sessionMsgs) > 10 {
|
||
start = len(sessionMsgs) - 10
|
||
fmt.Printf(" (显示最近 10/%d 条)\n", len(sessionMsgs))
|
||
}
|
||
for i := start; i < len(sessionMsgs); i++ {
|
||
sm := sessionMsgs[i]
|
||
role := string(sm.Role)
|
||
if role == "" {
|
||
role = "unknown"
|
||
}
|
||
fmt.Printf(" [%s] %s\n", role, truncateForDisplay(sm.Content, 80))
|
||
}
|
||
} else {
|
||
fmt.Println(" (无历史)")
|
||
}
|
||
}
|
||
|
||
fmt.Println("================================\n")
|
||
}
|
||
|
||
func truncateForDisplay(s string, maxLen int) string {
|
||
if len(s) <= maxLen {
|
||
return s
|
||
}
|
||
return s[:maxLen] + "..."
|
||
}
|
||
|
||
func estimateTokens(text string) int {
|
||
return len([]rune(text)) / 4
|
||
}
|
||
|
||
// buildToolPrompt 生成工具说明提示词(不包含身份描述)。
|
||
// 将可用工具和调用规则注入给 LLM,支持基于提示词的工具调用。
|
||
func (a *LLMAgent) buildToolPrompt() string {
|
||
var b strings.Builder
|
||
|
||
if a.toolManager != nil {
|
||
tools := a.toolManager.List()
|
||
|
||
b.WriteString("你可以使用以下工具来完成用户的请求。\n\n")
|
||
b.WriteString("可用工具列表:\n")
|
||
|
||
for _, t := range tools {
|
||
b.WriteString(fmt.Sprintf("\n工具名: %s\n", t.Name()))
|
||
b.WriteString(fmt.Sprintf("描述: %s\n", t.Description()))
|
||
paramsJSON, _ := json.Marshal(t.Parameters())
|
||
b.WriteString(fmt.Sprintf("参数: %s\n", string(paramsJSON)))
|
||
}
|
||
|
||
b.WriteString("\n规则:\n")
|
||
b.WriteString("1. 当你需要调用工具时,请在回复中**只输出**以下 JSON 格式(不要添加其他文字):\n")
|
||
b.WriteString(` {"tool": "工具名", "arguments": {"参数名": "参数值"}}` + "\n")
|
||
b.WriteString("2. 如果需要同时调用多个工具(并行执行),请输出 JSON 数组格式:\n")
|
||
b.WriteString(` [{"tool": "工具名1", "arguments": {...}}, {"tool": "工具名2", "arguments": {...}}]` + "\n")
|
||
b.WriteString("3. 如果你已经看到了工具返回的结果,请直接根据结果回答用户,不要再次调用工具。\n")
|
||
b.WriteString("4. 当你不需要调用工具时,请直接回复用户。\n")
|
||
b.WriteString("5. 当用户的请求涉及代码、架构、数学计算等专业领域时,你必须调用相应的子Agent,不要自己直接回答。\n")
|
||
}
|
||
|
||
if len(a.subAgents) > 0 {
|
||
b.WriteString("\n\n你可以调用以下专业Agent来协助完成特定任务:\n")
|
||
for name, description := range a.subAgents {
|
||
b.WriteString(fmt.Sprintf("- %s: %s\n", name, description))
|
||
}
|
||
b.WriteString("\n调用方式:使用 agent_call 工具,指定 agent 名称和任务描述。\n")
|
||
b.WriteString("示例:{\"tool\": \"agent_call\", \"arguments\": {\"agent\": \"coder\", \"task\": \"写个快速排序\"}}\n")
|
||
b.WriteString("如果用户有多个独立任务,请同时调用多个 agent_call(JSON数组格式),让它们并行执行。\n")
|
||
b.WriteString("\n【强制规则】当用户的请求涉及代码编程、系统架构、数学计算、代码审查等专业领域时,你必须调用相应的子Agent。\n")
|
||
b.WriteString("你绝对不能自己直接回答编程或架构问题,必须通过 agent_call 工具委托给专业Agent处理。\n")
|
||
}
|
||
|
||
if a.skillManager != nil {
|
||
skills := a.skillManager.ListSkills()
|
||
if len(skills) > 0 {
|
||
b.WriteString("\n\n你还可以使用以下技能来更好地帮助用户:\n")
|
||
for _, s := range skills {
|
||
b.WriteString(fmt.Sprintf("\n=== 技能: %s ===\n", s.Name))
|
||
b.WriteString(fmt.Sprintf("描述: %s\n", s.Description))
|
||
if len(s.Triggers) > 0 {
|
||
b.WriteString(fmt.Sprintf("触发词: %s\n", strings.Join(s.Triggers, ", ")))
|
||
}
|
||
if s.Body != "" {
|
||
body := s.Body
|
||
if len(body) > 4000 {
|
||
body = body[:4000] + "\n...[内容已截断]"
|
||
}
|
||
b.WriteString(fmt.Sprintf("\n详细指南:\n%s\n", body))
|
||
}
|
||
b.WriteString(fmt.Sprintf("=== 结束: %s ===\n", s.Name))
|
||
}
|
||
b.WriteString("\n当用户的请求匹配某个技能的触发词时,请根据该技能的详细指南提供更专业的帮助。\n")
|
||
b.WriteString("你应该主动使用相关技能的专业知识来回答用户问题,而不需要询问用户是否使用技能。\n")
|
||
}
|
||
}
|
||
|
||
if b.Len() == 0 {
|
||
return ""
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func (a *LLMAgent) buildToolDefs() []llm.ToolDef {
|
||
var tools []llm.ToolDef
|
||
|
||
if a.toolManager != nil {
|
||
for _, t := range a.toolManager.List() {
|
||
params := t.Parameters()
|
||
properties := make(map[string]llm.ToolProperty)
|
||
required := []string{}
|
||
|
||
for name, param := range params {
|
||
prop := llm.ToolProperty{
|
||
Type: param.Type,
|
||
Description: param.Description,
|
||
}
|
||
if len(param.Enum) > 0 {
|
||
prop.Enum = param.Enum
|
||
}
|
||
properties[name] = prop
|
||
if param.Required {
|
||
required = append(required, name)
|
||
}
|
||
}
|
||
|
||
tools = append(tools, llm.ToolDef{
|
||
Type: "function",
|
||
Function: llm.ToolFunction{
|
||
Name: t.Name(),
|
||
Description: t.Description(),
|
||
Parameters: llm.ToolFunctionParameters{
|
||
Type: "object",
|
||
Required: required,
|
||
Properties: properties,
|
||
},
|
||
},
|
||
})
|
||
}
|
||
}
|
||
|
||
return tools
|
||
}
|
||
|
||
func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) (string, error) {
|
||
maxRounds := 10
|
||
|
||
for round := 0; round < maxRounds; round++ {
|
||
var content string
|
||
var err error
|
||
|
||
if a.streamWriter != nil {
|
||
content, err = a.streamChat(ctx, messages)
|
||
} else {
|
||
content, err = a.syncChat(ctx, messages)
|
||
}
|
||
if err != nil {
|
||
return "", fmt.Errorf("chat round %d failed: %w", round, err)
|
||
}
|
||
|
||
toolCalls := a.parseToolCallsFromContent(content)
|
||
if len(toolCalls) == 0 {
|
||
return content, nil
|
||
}
|
||
|
||
messages = append(messages, llm.Message{
|
||
Role: "assistant",
|
||
Content: content,
|
||
})
|
||
|
||
results := a.executeToolCallsParallel(ctx, toolCalls)
|
||
|
||
for _, result := range results {
|
||
messages = append(messages, llm.Message{
|
||
Role: "user",
|
||
Content: result,
|
||
})
|
||
}
|
||
}
|
||
|
||
return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds)
|
||
}
|
||
|
||
func (a *LLMAgent) syncChat(ctx context.Context, messages []llm.Message) (string, error) {
|
||
response, err := a.llm.Chat(ctx, messages)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(response.ToolCalls) > 0 {
|
||
return response.Content, nil
|
||
}
|
||
return response.Content, nil
|
||
}
|
||
|
||
func (a *LLMAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) {
|
||
var content strings.Builder
|
||
if a.streamWriter != nil {
|
||
fmt.Fprint(a.streamWriter, "\n")
|
||
}
|
||
|
||
err := a.llm.Stream(ctx, messages, func(chunk string) error {
|
||
content.WriteString(chunk)
|
||
if a.streamWriter != nil {
|
||
fmt.Fprint(a.streamWriter, chunk)
|
||
}
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if a.streamWriter != nil {
|
||
fmt.Fprintln(a.streamWriter)
|
||
}
|
||
return content.String(), nil
|
||
}
|
||
|
||
func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall {
|
||
cleanContent := a.extractJSONFromMarkdown(content)
|
||
|
||
var toolCalls []llm.ToolCall
|
||
var callIndex int
|
||
|
||
var parsedList []struct {
|
||
Tool string `json:"tool"`
|
||
Arguments map[string]interface{} `json:"arguments"`
|
||
}
|
||
if err := json.Unmarshal([]byte(cleanContent), &parsedList); err == nil && len(parsedList) > 0 {
|
||
for _, parsed := range parsedList {
|
||
if parsed.Tool == "" {
|
||
continue
|
||
}
|
||
argsJSON, _ := json.Marshal(parsed.Arguments)
|
||
toolCalls = append(toolCalls, llm.ToolCall{
|
||
ID: fmt.Sprintf("call_%d", callIndex),
|
||
Type: "function",
|
||
Function: llm.FunctionCall{
|
||
Name: parsed.Tool,
|
||
Arguments: string(argsJSON),
|
||
},
|
||
})
|
||
callIndex++
|
||
}
|
||
return toolCalls
|
||
}
|
||
|
||
var parsed struct {
|
||
Tool string `json:"tool"`
|
||
Arguments map[string]interface{} `json:"arguments"`
|
||
}
|
||
if err := json.Unmarshal([]byte(cleanContent), &parsed); err == nil && parsed.Tool != "" {
|
||
argsJSON, _ := json.Marshal(parsed.Arguments)
|
||
return []llm.ToolCall{{
|
||
ID: "call_0",
|
||
Type: "function",
|
||
Function: llm.FunctionCall{
|
||
Name: parsed.Tool,
|
||
Arguments: string(argsJSON),
|
||
},
|
||
}}
|
||
}
|
||
|
||
var commaSeparated []struct {
|
||
Tool string `json:"tool"`
|
||
Arguments map[string]interface{} `json:"arguments"`
|
||
}
|
||
wrapped := "[" + cleanContent + "]"
|
||
if err := json.Unmarshal([]byte(wrapped), &commaSeparated); err == nil && len(commaSeparated) > 0 {
|
||
for _, parsed := range commaSeparated {
|
||
if parsed.Tool == "" {
|
||
continue
|
||
}
|
||
argsJSON, _ := json.Marshal(parsed.Arguments)
|
||
toolCalls = append(toolCalls, llm.ToolCall{
|
||
ID: fmt.Sprintf("call_%d", callIndex),
|
||
Type: "function",
|
||
Function: llm.FunctionCall{
|
||
Name: parsed.Tool,
|
||
Arguments: string(argsJSON),
|
||
},
|
||
})
|
||
callIndex++
|
||
}
|
||
return toolCalls
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (a *LLMAgent) extractJSONFromMarkdown(content string) string {
|
||
start := strings.Index(content, "`"+"``json")
|
||
if start == -1 {
|
||
start = strings.Index(content, "`"+"``")
|
||
if start == -1 {
|
||
return content
|
||
}
|
||
} else {
|
||
start += 7
|
||
}
|
||
|
||
if !strings.HasPrefix(content[start:], "`"+"``") {
|
||
newline := strings.Index(content[start:], "\n")
|
||
if newline != -1 {
|
||
start = start + newline + 1
|
||
}
|
||
} else {
|
||
start += 3
|
||
newline := strings.Index(content[start:], "\n")
|
||
if newline != -1 {
|
||
start = start + newline + 1
|
||
}
|
||
}
|
||
|
||
end := strings.Index(content[start:], "\n`"+"``")
|
||
if end == -1 {
|
||
end = strings.Index(content[start:], "`"+"``")
|
||
}
|
||
if end == -1 {
|
||
return strings.TrimSpace(content[start:])
|
||
}
|
||
|
||
return strings.TrimSpace(content[start : start+end])
|
||
}
|
||
|
||
func (a *LLMAgent) executeToolCall(ctx context.Context, tc llm.ToolCall) string {
|
||
toolName := tc.Function.Name
|
||
|
||
if a.streamWriter != nil {
|
||
fmt.Fprintf(a.streamWriter, "\n[正在执行工具: %s...]\n", toolName)
|
||
}
|
||
|
||
var args map[string]interface{}
|
||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||
args = map[string]interface{}{
|
||
"_raw": tc.Function.Arguments,
|
||
}
|
||
}
|
||
|
||
if a.toolWorker != nil {
|
||
toolCallMsg := bus.Message{
|
||
ID: tc.ID,
|
||
Type: bus.MsgTypeToolCall,
|
||
From: a.ID(),
|
||
To: a.toolWorker.ID(),
|
||
Content: map[string]interface{}{"name": toolName, "arguments": args},
|
||
}
|
||
|
||
resultMsg, err := a.toolWorker.Process(ctx, toolCallMsg)
|
||
if a.streamWriter != nil {
|
||
fmt.Fprintf(a.streamWriter, "[工具 %s 执行完成]\n", toolName)
|
||
}
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
|
||
}
|
||
|
||
resultJSON, err := json.Marshal(resultMsg.Content)
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
|
||
}
|
||
return string(resultJSON)
|
||
}
|
||
|
||
if a.toolManager != nil {
|
||
result, err := a.toolManager.Execute(toolName, ctx, args)
|
||
if a.streamWriter != nil {
|
||
fmt.Fprintf(a.streamWriter, "[工具 %s 执行完成]\n", toolName)
|
||
}
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
|
||
}
|
||
|
||
resultJSON, err := json.Marshal(result)
|
||
if err != nil {
|
||
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
|
||
}
|
||
return string(resultJSON)
|
||
}
|
||
|
||
return fmt.Sprintf(`{"error": "no tool worker or tool manager available for %q"}`, toolName)
|
||
}
|
||
|
||
func (a *LLMAgent) executeToolCallsParallel(ctx context.Context, toolCalls []llm.ToolCall) []string {
|
||
if len(toolCalls) == 1 {
|
||
result := a.executeToolCall(ctx, toolCalls[0])
|
||
return []string{fmt.Sprintf("工具 %s 的执行结果:%s", toolCalls[0].Function.Name, result)}
|
||
}
|
||
|
||
type result struct {
|
||
index int
|
||
content string
|
||
}
|
||
|
||
results := make([]result, len(toolCalls))
|
||
var wg sync.WaitGroup
|
||
|
||
for i, tc := range toolCalls {
|
||
wg.Add(1)
|
||
go func(idx int, toolCall llm.ToolCall) {
|
||
defer wg.Done()
|
||
res := a.executeToolCall(ctx, toolCall)
|
||
results[idx] = result{
|
||
index: idx,
|
||
content: fmt.Sprintf("工具 %s 的执行结果:%s", toolCall.Function.Name, res),
|
||
}
|
||
}(i, tc)
|
||
}
|
||
|
||
wg.Wait()
|
||
|
||
strings := make([]string, len(toolCalls))
|
||
for i, r := range results {
|
||
strings[i] = r.content
|
||
}
|
||
return strings
|
||
}
|
||
|
||
// handleSystem processes internal system messages.
|
||
func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||
return bus.Message{
|
||
ID: msg.ID + "-ack",
|
||
Type: bus.MsgTypeSystem,
|
||
From: a.ID(),
|
||
To: msg.From,
|
||
Content: "llm_agent acknowledged",
|
||
}, nil
|
||
}
|
||
|
||
func (a *LLMAgent) SetStreamWriter(w io.Writer) {
|
||
a.streamWriter = w
|
||
}
|
||
|
||
var _ Agent = (*LLMAgent)(nil)
|
||
var _ Agent = (*ToolWorker)(nil)
|