orca.ai/pkg/actor/subagent.go
2026-05-12 00:09:01 +08:00

205 lines
4.6 KiB
Go

package actor
import (
"context"
"fmt"
"io"
"strings"
"time"
"github.com/google/uuid"
"github.com/orca/orca/pkg/bus"
"github.com/orca/orca/pkg/llm"
"github.com/orca/orca/pkg/session"
)
type SubAgentStore interface {
SaveSubAgentMessage(parentSessionID, sessionID, agentName string, msg session.SessionMessage) error
LoadSubAgentMessages(sessionID string) ([]session.SessionMessage, error)
}
type SubAgent struct {
*BaseAgent
llmBackend llm.LLM
systemPrompt string
role string
streamWriter io.Writer
store SubAgentStore
parentSessionID string
}
type SubAgentOption func(*SubAgent)
func WithSubAgentSystemPrompt(prompt string) SubAgentOption {
return func(a *SubAgent) {
a.systemPrompt = prompt
}
}
func WithSubAgentRole(role string) SubAgentOption {
return func(a *SubAgent) {
a.role = role
}
}
func WithSubAgentStreamWriter(w io.Writer) SubAgentOption {
return func(a *SubAgent) {
a.streamWriter = w
}
}
func WithSubAgentStore(store SubAgentStore) SubAgentOption {
return func(a *SubAgent) {
a.store = store
}
}
func WithSubAgentParentSessionID(parentSessionID string) SubAgentOption {
return func(a *SubAgent) {
a.parentSessionID = parentSessionID
}
}
func NewSubAgent(id string, llmBackend llm.LLM, opts ...SubAgentOption) *SubAgent {
sa := &SubAgent{
BaseAgent: NewBaseAgent(id, "subagent"),
llmBackend: llmBackend,
systemPrompt: "你是一个专业的AI助手。",
role: "assistant",
}
for _, opt := range opts {
opt(sa)
}
sa.SetHandler(sa.handleMessage)
if err := sa.Start(); err != nil {
panic(fmt.Sprintf("subagent: failed to start %s: %v", id, err))
}
return sa
}
func (sa *SubAgent) Role() string {
return sa.role
}
func (sa *SubAgent) SystemPrompt() string {
return sa.systemPrompt
}
func (sa *SubAgent) SetStreamWriter(w io.Writer) {
sa.streamWriter = w
}
func (sa *SubAgent) SetParentSessionID(parentSessionID string) {
sa.parentSessionID = parentSessionID
}
func (sa *SubAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
switch msg.Type {
case bus.MsgTypeTaskRequest:
return sa.handleTask(ctx, msg)
case bus.MsgTypeSystem:
return sa.handleSystem(ctx, msg)
default:
return bus.Message{}, fmt.Errorf("subagent %s: unsupported message type %s", sa.ID(), msg.Type)
}
}
func (sa *SubAgent) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) {
sessionID := uuid.New().String()
parentSessionID := sa.parentSessionID
if parentSessionID == "" {
parentSessionID = "unknown"
}
if sa.store != nil {
sa.store.SaveSubAgentMessage(parentSessionID, sessionID, sa.ID(), session.SessionMessage{
Role: session.RoleSystem,
Content: sa.systemPrompt,
Timestamp: time.Now(),
})
sa.store.SaveSubAgentMessage(parentSessionID, sessionID, sa.ID(), session.SessionMessage{
Role: session.RoleUser,
Content: fmt.Sprintf("%v", msg.Content),
Timestamp: time.Now(),
})
}
messages := []llm.Message{
{
Role: "system",
Content: sa.systemPrompt,
},
{
Role: "user",
Content: fmt.Sprintf("%v", msg.Content),
},
}
content, err := sa.streamChat(ctx, messages)
if err != nil {
return bus.Message{}, fmt.Errorf("subagent %s: LLM call failed: %w", sa.ID(), err)
}
if sa.store != nil {
sa.store.SaveSubAgentMessage(parentSessionID, sessionID, sa.ID(), session.SessionMessage{
Role: session.RoleAssistant,
Content: content,
Timestamp: time.Now(),
})
}
return bus.Message{
ID: msg.ID + "-response",
Type: bus.MsgTypeTaskResponse,
From: sa.ID(),
To: msg.From,
Content: content,
Metadata: map[string]string{
"processed_by": sa.ID(),
"agent_role": sa.role,
"session_id": sessionID,
"parent_session_id": parentSessionID,
},
}, nil
}
func (sa *SubAgent) streamChat(ctx context.Context, messages []llm.Message) (string, error) {
var content strings.Builder
if sa.streamWriter != nil {
fmt.Fprintf(sa.streamWriter, "\n[%s] ", sa.ID())
}
err := sa.llmBackend.Stream(ctx, messages, func(chunk string) error {
content.WriteString(chunk)
if sa.streamWriter != nil {
fmt.Fprint(sa.streamWriter, chunk)
}
return nil
})
if err != nil {
return "", err
}
if sa.streamWriter != nil {
fmt.Fprintln(sa.streamWriter)
}
return content.String(), nil
}
func (sa *SubAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
return bus.Message{
ID: msg.ID + "-ack",
Type: bus.MsgTypeSystem,
From: sa.ID(),
To: msg.From,
Content: fmt.Sprintf("subagent %s acknowledged", sa.ID()),
}, nil
}