feat: add streaming output for CLI
- Add streamWriter to LLMAgent for real-time output - Support streaming mode in chatWithToolLoop - Add SetStreamWriter to Kernel and LLMAgent - CLI displays streaming responses immediately - Tool calling still works with streaming
This commit is contained in:
parent
6b94476347
commit
04c7ea5e39
@ -43,6 +43,8 @@ func main() {
|
|||||||
log.Fatalf("Failed to start kernel: %v", err)
|
log.Fatalf("Failed to start kernel: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
k.SetStreamWriter(os.Stdout)
|
||||||
|
|
||||||
fmt.Println("Orca Agent Framework")
|
fmt.Println("Orca Agent Framework")
|
||||||
fmt.Println("Kernel started successfully")
|
fmt.Println("Kernel started successfully")
|
||||||
fmt.Printf(" LLM Model: %s\n", cfg.Ollama.Model)
|
fmt.Printf(" LLM Model: %s\n", cfg.Ollama.Model)
|
||||||
@ -76,15 +78,13 @@ func main() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send message to LLM agent via kernel
|
// Send message to LLM agent via kernel
|
||||||
response, err := k.SendMessage("user", "llm", input)
|
_, err := k.SendMessage("user", "llm", input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error: %v\n", err)
|
fmt.Printf("Error: %v\n", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
fmt.Println()
|
||||||
fmt.Println(response)
|
|
||||||
fmt.Println()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/orca/orca/pkg/bus"
|
"github.com/orca/orca/pkg/bus"
|
||||||
@ -20,12 +21,13 @@ import (
|
|||||||
// back to the LLM for final response generation.
|
// back to the LLM for final response generation.
|
||||||
type LLMAgent struct {
|
type LLMAgent struct {
|
||||||
*BaseAgent
|
*BaseAgent
|
||||||
llm llm.LLM
|
llm llm.LLM
|
||||||
sessionMgr *session.Manager
|
sessionMgr *session.Manager
|
||||||
sessionID string
|
sessionID string
|
||||||
toolManager *tool.Manager
|
toolManager *tool.Manager
|
||||||
toolWorker *ToolWorker
|
toolWorker *ToolWorker
|
||||||
windowSize int
|
windowSize int
|
||||||
|
streamWriter io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMAgentOption is a functional option for configuring the LLMAgent.
|
// LLMAgentOption is a functional option for configuring the LLMAgent.
|
||||||
@ -66,6 +68,13 @@ func WithWindowSize(size int) LLMAgentOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithStreamWriter sets the writer for streaming LLM output.
|
||||||
|
func WithStreamWriter(w io.Writer) LLMAgentOption {
|
||||||
|
return func(a *LLMAgent) {
|
||||||
|
a.streamWriter = w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewLLMAgent creates a new LLMAgent with the given LLM backend and options.
|
// NewLLMAgent creates a new LLMAgent with the given LLM backend and options.
|
||||||
// The agent is started automatically upon creation.
|
// The agent is started automatically upon creation.
|
||||||
func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent {
|
func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent {
|
||||||
@ -218,23 +227,26 @@ func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message)
|
|||||||
maxRounds := 10
|
maxRounds := 10
|
||||||
|
|
||||||
for round := 0; round < maxRounds; round++ {
|
for round := 0; round < maxRounds; round++ {
|
||||||
response, err := a.llm.Chat(ctx, messages)
|
var content string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if a.streamWriter != nil {
|
||||||
|
content, err = a.streamChat(ctx, messages)
|
||||||
|
} else {
|
||||||
|
content, err = a.syncChat(ctx, messages)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("chat round %d failed: %w", round, err)
|
return "", fmt.Errorf("chat round %d failed: %w", round, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCalls := response.ToolCalls
|
toolCalls := a.parseToolCallsFromContent(content)
|
||||||
if len(toolCalls) == 0 {
|
if len(toolCalls) == 0 {
|
||||||
toolCalls = a.parseToolCallsFromContent(response.Content)
|
return content, nil
|
||||||
}
|
|
||||||
|
|
||||||
if len(toolCalls) == 0 {
|
|
||||||
return response.Content, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, llm.Message{
|
messages = append(messages, llm.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: response.Content,
|
Content: content,
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, tc := range toolCalls {
|
for _, tc := range toolCalls {
|
||||||
@ -249,6 +261,42 @@ func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message)
|
|||||||
return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds)
|
return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *LLMAgent) syncChat(ctx context.Context, messages []llm.Message) (string, error) {
|
||||||
|
response, err := a.llm.Chat(ctx, messages)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response.ToolCalls) > 0 {
|
||||||
|
return response.Content, nil
|
||||||
|
}
|
||||||
|
return response.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *LLMAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) {
|
||||||
|
var content strings.Builder
|
||||||
|
if a.streamWriter != nil {
|
||||||
|
fmt.Fprint(a.streamWriter, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := a.llm.Stream(ctx, messages, func(chunk string) error {
|
||||||
|
content.WriteString(chunk)
|
||||||
|
if a.streamWriter != nil {
|
||||||
|
fmt.Fprint(a.streamWriter, chunk)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.streamWriter != nil {
|
||||||
|
fmt.Fprintln(a.streamWriter)
|
||||||
|
}
|
||||||
|
return content.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall {
|
func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall {
|
||||||
var parsed struct {
|
var parsed struct {
|
||||||
Tool string `json:"tool"`
|
Tool string `json:"tool"`
|
||||||
@ -333,6 +381,9 @@ func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Messa
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile-time interface checks.
|
func (a *LLMAgent) SetStreamWriter(w io.Writer) {
|
||||||
|
a.streamWriter = w
|
||||||
|
}
|
||||||
|
|
||||||
var _ Agent = (*LLMAgent)(nil)
|
var _ Agent = (*LLMAgent)(nil)
|
||||||
var _ Agent = (*ToolWorker)(nil)
|
var _ Agent = (*ToolWorker)(nil)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ package kernel
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@ -228,6 +229,13 @@ func (k *Kernel) LLMAgent() *actor.LLMAgent {
|
|||||||
return k.llmAgent
|
return k.llmAgent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetStreamWriter sets the writer for streaming LLM output.
|
||||||
|
func (k *Kernel) SetStreamWriter(w io.Writer) {
|
||||||
|
if k.llmAgent != nil {
|
||||||
|
k.llmAgent.SetStreamWriter(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SendMessage sends a message from a source to the LLM agent.
|
// SendMessage sends a message from a source to the LLM agent.
|
||||||
//
|
//
|
||||||
// This is the primary public API for interacting with the Orca system.
|
// This is the primary public API for interacting with the Orca system.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user