diff --git a/cmd/orca/main.go b/cmd/orca/main.go index cc9feb2..61aeb5c 100644 --- a/cmd/orca/main.go +++ b/cmd/orca/main.go @@ -43,6 +43,8 @@ func main() { log.Fatalf("Failed to start kernel: %v", err) } + k.SetStreamWriter(os.Stdout) + fmt.Println("Orca Agent Framework") fmt.Println("Kernel started successfully") fmt.Printf(" LLM Model: %s\n", cfg.Ollama.Model) @@ -76,15 +78,13 @@ func main() { continue } - // Send message to LLM agent via kernel - response, err := k.SendMessage("user", "llm", input) - if err != nil { - fmt.Printf("Error: %v\n", err) - continue - } - - fmt.Println(response) - fmt.Println() + // Send message to LLM agent via kernel + _, err := k.SendMessage("user", "llm", input) + if err != nil { + fmt.Printf("Error: %v\n", err) + continue + } + fmt.Println() } if err := scanner.Err(); err != nil { diff --git a/pkg/actor/llm_agent.go b/pkg/actor/llm_agent.go index b9916e8..e1790bc 100644 --- a/pkg/actor/llm_agent.go +++ b/pkg/actor/llm_agent.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "strings" "github.com/orca/orca/pkg/bus" @@ -20,12 +21,13 @@ import ( // back to the LLM for final response generation. type LLMAgent struct { *BaseAgent - llm llm.LLM - sessionMgr *session.Manager - sessionID string - toolManager *tool.Manager - toolWorker *ToolWorker - windowSize int + llm llm.LLM + sessionMgr *session.Manager + sessionID string + toolManager *tool.Manager + toolWorker *ToolWorker + windowSize int + streamWriter io.Writer } // LLMAgentOption is a functional option for configuring the LLMAgent. @@ -66,6 +68,13 @@ func WithWindowSize(size int) LLMAgentOption { } } +// WithStreamWriter sets the writer for streaming LLM output. +func WithStreamWriter(w io.Writer) LLMAgentOption { + return func(a *LLMAgent) { + a.streamWriter = w + } +} + // NewLLMAgent creates a new LLMAgent with the given LLM backend and options. // The agent is started automatically upon creation. func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent { @@ -218,23 +227,26 @@ func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) maxRounds := 10 for round := 0; round < maxRounds; round++ { - response, err := a.llm.Chat(ctx, messages) + var content string + var err error + + if a.streamWriter != nil { + content, err = a.streamChat(ctx, messages) + } else { + content, err = a.syncChat(ctx, messages) + } if err != nil { return "", fmt.Errorf("chat round %d failed: %w", round, err) } - toolCalls := response.ToolCalls + toolCalls := a.parseToolCallsFromContent(content) if len(toolCalls) == 0 { - toolCalls = a.parseToolCallsFromContent(response.Content) - } - - if len(toolCalls) == 0 { - return response.Content, nil + return content, nil } messages = append(messages, llm.Message{ Role: "assistant", - Content: response.Content, + Content: content, }) for _, tc := range toolCalls { @@ -249,6 +261,42 @@ func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds) } +func (a *LLMAgent) syncChat(ctx context.Context, messages []llm.Message) (string, error) { + response, err := a.llm.Chat(ctx, messages) + if err != nil { + return "", err + } + + if len(response.ToolCalls) > 0 { + return response.Content, nil + } + return response.Content, nil +} + +func (a *LLMAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) { + var content strings.Builder + if a.streamWriter != nil { + fmt.Fprint(a.streamWriter, "\n") + } + + err := a.llm.Stream(ctx, messages, func(chunk string) error { + content.WriteString(chunk) + if a.streamWriter != nil { + fmt.Fprint(a.streamWriter, chunk) + } + return nil + }) + + if err != nil { + return "", err + } + + if a.streamWriter != nil { + fmt.Fprintln(a.streamWriter) + } + return content.String(), nil +} + func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall { var parsed struct { Tool string `json:"tool"` @@ -333,6 +381,9 @@ func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Messa }, nil } -// Compile-time interface checks. +func (a *LLMAgent) SetStreamWriter(w io.Writer) { + a.streamWriter = w +} + var _ Agent = (*LLMAgent)(nil) var _ Agent = (*ToolWorker)(nil) diff --git a/pkg/kernel/kernel.go b/pkg/kernel/kernel.go index 794639c..8c169c1 100644 --- a/pkg/kernel/kernel.go +++ b/pkg/kernel/kernel.go @@ -7,6 +7,7 @@ package kernel import ( "context" "fmt" + "io" "log" "os" "sync" @@ -228,6 +229,13 @@ func (k *Kernel) LLMAgent() *actor.LLMAgent { return k.llmAgent } +// SetStreamWriter sets the writer for streaming LLM output. +func (k *Kernel) SetStreamWriter(w io.Writer) { + if k.llmAgent != nil { + k.llmAgent.SetStreamWriter(w) + } +} + // SendMessage sends a message from a source to the LLM agent. // // This is the primary public API for interacting with the Orca system.