205 lines
4.6 KiB
Go
205 lines
4.6 KiB
Go
package actor
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/orca/orca/pkg/bus"
|
|
"github.com/orca/orca/pkg/llm"
|
|
"github.com/orca/orca/pkg/session"
|
|
)
|
|
|
|
type SubAgentStore interface {
|
|
SaveSubAgentMessage(parentSessionID, sessionID, agentName string, msg session.SessionMessage) error
|
|
LoadSubAgentMessages(sessionID string) ([]session.SessionMessage, error)
|
|
}
|
|
|
|
type SubAgent struct {
|
|
*BaseAgent
|
|
llmBackend llm.LLM
|
|
systemPrompt string
|
|
role string
|
|
streamWriter io.Writer
|
|
store SubAgentStore
|
|
parentSessionID string
|
|
}
|
|
|
|
type SubAgentOption func(*SubAgent)
|
|
|
|
func WithSubAgentSystemPrompt(prompt string) SubAgentOption {
|
|
return func(a *SubAgent) {
|
|
a.systemPrompt = prompt
|
|
}
|
|
}
|
|
|
|
func WithSubAgentRole(role string) SubAgentOption {
|
|
return func(a *SubAgent) {
|
|
a.role = role
|
|
}
|
|
}
|
|
|
|
func WithSubAgentStreamWriter(w io.Writer) SubAgentOption {
|
|
return func(a *SubAgent) {
|
|
a.streamWriter = w
|
|
}
|
|
}
|
|
|
|
func WithSubAgentStore(store SubAgentStore) SubAgentOption {
|
|
return func(a *SubAgent) {
|
|
a.store = store
|
|
}
|
|
}
|
|
|
|
func WithSubAgentParentSessionID(parentSessionID string) SubAgentOption {
|
|
return func(a *SubAgent) {
|
|
a.parentSessionID = parentSessionID
|
|
}
|
|
}
|
|
|
|
func NewSubAgent(id string, llmBackend llm.LLM, opts ...SubAgentOption) *SubAgent {
|
|
sa := &SubAgent{
|
|
BaseAgent: NewBaseAgent(id, "subagent"),
|
|
llmBackend: llmBackend,
|
|
systemPrompt: "你是一个专业的AI助手。",
|
|
role: "assistant",
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(sa)
|
|
}
|
|
|
|
sa.SetHandler(sa.handleMessage)
|
|
|
|
if err := sa.Start(); err != nil {
|
|
panic(fmt.Sprintf("subagent: failed to start %s: %v", id, err))
|
|
}
|
|
|
|
return sa
|
|
}
|
|
|
|
func (sa *SubAgent) Role() string {
|
|
return sa.role
|
|
}
|
|
|
|
func (sa *SubAgent) SystemPrompt() string {
|
|
return sa.systemPrompt
|
|
}
|
|
|
|
func (sa *SubAgent) SetStreamWriter(w io.Writer) {
|
|
sa.streamWriter = w
|
|
}
|
|
|
|
func (sa *SubAgent) SetParentSessionID(parentSessionID string) {
|
|
sa.parentSessionID = parentSessionID
|
|
}
|
|
|
|
func (sa *SubAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
|
switch msg.Type {
|
|
case bus.MsgTypeTaskRequest:
|
|
return sa.handleTask(ctx, msg)
|
|
case bus.MsgTypeSystem:
|
|
return sa.handleSystem(ctx, msg)
|
|
default:
|
|
return bus.Message{}, fmt.Errorf("subagent %s: unsupported message type %s", sa.ID(), msg.Type)
|
|
}
|
|
}
|
|
|
|
func (sa *SubAgent) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
|
sessionID := uuid.New().String()
|
|
parentSessionID := sa.parentSessionID
|
|
if parentSessionID == "" {
|
|
parentSessionID = "unknown"
|
|
}
|
|
|
|
if sa.store != nil {
|
|
sa.store.SaveSubAgentMessage(parentSessionID, sessionID, sa.ID(), session.SessionMessage{
|
|
Role: session.RoleSystem,
|
|
Content: sa.systemPrompt,
|
|
Timestamp: time.Now(),
|
|
})
|
|
sa.store.SaveSubAgentMessage(parentSessionID, sessionID, sa.ID(), session.SessionMessage{
|
|
Role: session.RoleUser,
|
|
Content: fmt.Sprintf("%v", msg.Content),
|
|
Timestamp: time.Now(),
|
|
})
|
|
}
|
|
|
|
messages := []llm.Message{
|
|
{
|
|
Role: "system",
|
|
Content: sa.systemPrompt,
|
|
},
|
|
{
|
|
Role: "user",
|
|
Content: fmt.Sprintf("%v", msg.Content),
|
|
},
|
|
}
|
|
|
|
content, err := sa.streamChat(ctx, messages)
|
|
if err != nil {
|
|
return bus.Message{}, fmt.Errorf("subagent %s: LLM call failed: %w", sa.ID(), err)
|
|
}
|
|
|
|
if sa.store != nil {
|
|
sa.store.SaveSubAgentMessage(parentSessionID, sessionID, sa.ID(), session.SessionMessage{
|
|
Role: session.RoleAssistant,
|
|
Content: content,
|
|
Timestamp: time.Now(),
|
|
})
|
|
}
|
|
|
|
return bus.Message{
|
|
ID: msg.ID + "-response",
|
|
Type: bus.MsgTypeTaskResponse,
|
|
From: sa.ID(),
|
|
To: msg.From,
|
|
Content: content,
|
|
Metadata: map[string]string{
|
|
"processed_by": sa.ID(),
|
|
"agent_role": sa.role,
|
|
"session_id": sessionID,
|
|
"parent_session_id": parentSessionID,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (sa *SubAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) {
|
|
var content strings.Builder
|
|
|
|
if sa.streamWriter != nil {
|
|
fmt.Fprintf(sa.streamWriter, "\n[%s] ", sa.ID())
|
|
}
|
|
|
|
err := sa.llmBackend.Stream(ctx, messages, func(chunk string) error {
|
|
content.WriteString(chunk)
|
|
if sa.streamWriter != nil {
|
|
fmt.Fprint(sa.streamWriter, chunk)
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if sa.streamWriter != nil {
|
|
fmt.Fprintln(sa.streamWriter)
|
|
}
|
|
|
|
return content.String(), nil
|
|
}
|
|
|
|
func (sa *SubAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
|
return bus.Message{
|
|
ID: msg.ID + "-ack",
|
|
Type: bus.MsgTypeSystem,
|
|
From: sa.ID(),
|
|
To: msg.From,
|
|
Content: fmt.Sprintf("subagent %s acknowledged", sa.ID()),
|
|
}, nil
|
|
}
|