commit 6b94476347824c3e36bec296901b3f3f57c856b4 Author: 大森 Date: Fri May 8 00:55:48 2026 +0800 Initial commit: Orca Agent Framework Core features: - Microkernel architecture with Actor model - Session management with JSONL persistence - Tool system (5 built-in tools) - Skill system with SKILL.md parsing - Sandbox security execution - Ollama integration with gemma4:e4b - Prompt-based tool calling (compatible with native function calling) - REPL interface 11 packages, all tests passing diff --git a/cmd/orca/main.go b/cmd/orca/main.go new file mode 100644 index 0000000..cc9feb2 --- /dev/null +++ b/cmd/orca/main.go @@ -0,0 +1,205 @@ +// Orca is a Go-based Agent framework with a microkernel architecture. +// +// It supports multi-agent collaboration, persistent session memory, +// skill-based automation, sandboxed execution, custom tool registration, +// and local LLM integration via Ollama. +package main + +import ( + "bufio" + "fmt" + "log" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/orca/orca/internal/config" + "github.com/orca/orca/pkg/kernel" +) + +func main() { + // Load configuration from environment variables + cfg := config.LoadConfigFromEnv() + + // Support shorter env var names for Ollama (without ORCA_ prefix) + if v := os.Getenv("OLLAMA_BASE_URL"); v != "" { + cfg.Ollama.BaseURL = v + } + if v := os.Getenv("OLLAMA_MODEL"); v != "" { + cfg.Ollama.Model = v + } + if v := os.Getenv("OLLAMA_TIMEOUT"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.Ollama.Timeout = d + } + } + + // Create and start kernel + k := kernel.NewWithConfig(cfg) + + if err := k.Start(); err != nil { + log.Fatalf("Failed to start kernel: %v", err) + } + + fmt.Println("Orca Agent Framework") + fmt.Println("Kernel started successfully") + fmt.Printf(" LLM Model: %s\n", cfg.Ollama.Model) + fmt.Printf(" Ollama URL: %s\n", cfg.Ollama.BaseURL) + fmt.Println("Type your message or /help for commands.") + fmt.Println() + + // Handle graceful shutdown + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + + // REPL loop in a goroutine so we can catch signals + done := make(chan struct{}) + + go func() { + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Print("> ") + if !scanner.Scan() { + break + } + + input := strings.TrimSpace(scanner.Text()) + if input == "" { + continue + } + + // Handle commands + if strings.HasPrefix(input, "/") { + handleCommand(input, k) + continue + } + + // Send message to LLM agent via kernel + response, err := k.SendMessage("user", "llm", input) + if err != nil { + fmt.Printf("Error: %v\n", err) + continue + } + + fmt.Println(response) + fmt.Println() + } + + if err := scanner.Err(); err != nil { + fmt.Fprintf(os.Stderr, "Error reading input: %v\n", err) + } + close(done) + }() + + // Wait for either SIGINT or REPL exit + select { + case <-sig: + fmt.Println("\nShutting down Orca kernel...") + case <-done: + fmt.Println("\nInput closed. Shutting down Orca kernel...") + } + + if err := k.Stop(); err != nil { + log.Fatalf("Failed to stop kernel: %v", err) + } + fmt.Println("Orca kernel shut down gracefully.") +} + +// handleCommand processes REPL commands. +func handleCommand(cmd string, k *kernel.Kernel) { + switch cmd { + case "/help": + fmt.Println("Available commands:") + fmt.Println(" /help - Show this help message") + fmt.Println(" /exit - Exit the program") + fmt.Println(" /quit - Exit the program") + fmt.Println(" /plugins - List registered plugins") + fmt.Println(" /agents - List active agents") + fmt.Println(" /tools - List registered tools") + fmt.Println(" /skills - List loaded skills") + fmt.Println(" /status - Show kernel status") + fmt.Println() + fmt.Println("Any other input is sent to the LLM agent for processing.") + + case "/exit", "/quit": + fmt.Println("Goodbye!") + os.Exit(0) + + case "/plugins": + plugins := k.ListPlugins() + if len(plugins) == 0 { + fmt.Println("No plugins registered.") + } else { + fmt.Println("Registered plugins:") + for _, p := range plugins { + fmt.Printf(" - %s (%s)\n", p.Name(), p.Version()) + } + } + + case "/agents": + as := k.ActorSystem() + if as == nil { + fmt.Println("Actor system not initialized.") + return + } + infos := as.AgentInfos() + if len(infos) == 0 { + fmt.Println("No agents running.") + } else { + fmt.Println("Active agents:") + for _, info := range infos { + fmt.Printf(" - %s [%s] (status: %s)\n", info.ID, info.Role, info.Status) + } + } + + case "/tools": + tm := k.ToolManager() + if tm == nil { + fmt.Println("Tool manager not initialized.") + return + } + tools := tm.List() + if len(tools) == 0 { + fmt.Println("No tools registered.") + } else { + fmt.Println("Registered tools:") + for _, t := range tools { + fmt.Printf(" - %s: %s\n", t.Name(), t.Description()) + } + } + + case "/skills": + sm := k.SkillManager() + if sm == nil { + fmt.Println("Skill manager not initialized.") + return + } + skills := sm.ListSkills() + if len(skills) == 0 { + fmt.Println("No skills loaded.") + } else { + fmt.Println("Loaded skills:") + for _, s := range skills { + fmt.Printf(" - %s: %s\n", s.Name, s.Description) + } + } + + case "/status": + fmt.Printf("Kernel running: %v\n", k.IsRunning()) + if tm := k.ToolManager(); tm != nil { + fmt.Printf("Tools registered: %d\n", tm.Count()) + } + if as := k.ActorSystem(); as != nil { + fmt.Printf("Agents active: %d\n", as.AgentCount()) + } + if sm := k.SkillManager(); sm != nil { + fmt.Printf("Skills loaded: %d\n", len(sm.ListSkills())) + } + + default: + fmt.Printf("Unknown command: %s\n", cmd) + fmt.Println("Type /help for available commands.") + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7f16df4 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/orca/orca + +go 1.26.1 diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..e499bad --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,147 @@ +// Package config provides the configuration types for the Orca framework. +// +// Configuration is organized into logical groups: LLM (Ollama), sandbox, +// and session management. Default values are provided for all settings. +package config + +import ( + "os" + "strconv" + "time" +) + +// Config is the top-level configuration for the Orca framework. +type Config struct { + Ollama OllamaConfig `json:"ollama"` + Sandbox SandboxConfig `json:"sandbox"` + Session SessionConfig `json:"session"` +} + +// OllamaConfig holds configuration for the Ollama LLM backend. +type OllamaConfig struct { + // BaseURL is the Ollama API endpoint (e.g., "http://localhost:11434"). + BaseURL string `json:"base_url"` + // Model is the Ollama model name to use (e.g., "gemma4:e4b", "codellama"). + Model string `json:"model"` + // Timeout is the maximum duration to wait for an Ollama response. + Timeout time.Duration `json:"timeout"` +} + +// SandboxConfig holds configuration for the command execution sandbox. +type SandboxConfig struct { + // Timeout is the maximum duration for a sandboxed command. + Timeout time.Duration `json:"timeout"` + // MaxMemory is the maximum memory allocation for the sandbox (in bytes). + MaxMemory int64 `json:"max_memory"` + // WorkingDir is the default working directory for sandboxed commands. + WorkingDir string `json:"working_dir"` +} + +// SessionConfig holds configuration for session management. +type SessionConfig struct { + // StorageDir is the directory for session JSONL files. + StorageDir string `json:"storage_dir"` + // MaxHistory is the maximum number of messages to retain per session. + MaxHistory int `json:"max_history"` +} + +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() *Config { + return &Config{ + Ollama: OllamaConfig{ + BaseURL: "http://localhost:11434", + Model: "gemma4:e4b", + Timeout: 120 * time.Second, + }, + Sandbox: SandboxConfig{ + Timeout: 30 * time.Second, + MaxMemory: 512 * 1024 * 1024, // 512 MB + WorkingDir: "/tmp/orca/sandbox", + }, + Session: SessionConfig{ + StorageDir: func() string { + home, _ := os.UserHomeDir() + return home + "/.orca/sessions" + }(), + MaxHistory: 100, + }, + } +} + +// LoadConfigFromEnv reads configuration from environment variables, +// overriding defaults where environment variables are set. +func LoadConfigFromEnv() *Config { + cfg := DefaultConfig() + + if v := os.Getenv("ORCA_OLLAMA_BASE_URL"); v != "" { + cfg.Ollama.BaseURL = v + } + if v := os.Getenv("ORCA_OLLAMA_MODEL"); v != "" { + cfg.Ollama.Model = v + } + if v := os.Getenv("ORCA_OLLAMA_TIMEOUT"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.Ollama.Timeout = d + } + } + if v := os.Getenv("ORCA_SANDBOX_TIMEOUT"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.Sandbox.Timeout = d + } + } + if v := os.Getenv("ORCA_SANDBOX_MAX_MEMORY"); v != "" { + if n, err := strconv.ParseInt(v, 10, 64); err == nil { + cfg.Sandbox.MaxMemory = n + } + } + if v := os.Getenv("ORCA_SANDBOX_WORKING_DIR"); v != "" { + cfg.Sandbox.WorkingDir = v + } + if v := os.Getenv("ORCA_SESSION_STORAGE_DIR"); v != "" { + cfg.Session.StorageDir = v + } + if v := os.Getenv("ORCA_SESSION_MAX_HISTORY"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.Session.MaxHistory = n + } + } + + return cfg +} + +// IsValid checks whether the configuration has valid values. +func (c *Config) IsValid() error { + if c.Ollama.BaseURL == "" { + return errConfig("ollama.base_url must not be empty") + } + if c.Ollama.Model == "" { + return errConfig("ollama.model must not be empty") + } + if c.Ollama.Timeout <= 0 { + return errConfig("ollama.timeout must be positive") + } + if c.Sandbox.Timeout <= 0 { + return errConfig("sandbox.timeout must be positive") + } + if c.Sandbox.MaxMemory <= 0 { + return errConfig("sandbox.max_memory must be positive") + } + if c.Session.MaxHistory <= 0 { + return errConfig("session.max_history must be positive") + } + return nil +} + +// errConfig creates a configuration error. +func errConfig(msg string) error { + return &ConfigError{Message: msg} +} + +// ConfigError represents a configuration validation error. +type ConfigError struct { + Message string +} + +func (e *ConfigError) Error() string { + return "config: " + e.Message +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..f6eb48e --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,221 @@ +package config + +import ( + "os" + "testing" + "time" +) + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + if cfg == nil { + t.Fatal("DefaultConfig() returned nil") + } + + // Check Ollama defaults + if cfg.Ollama.BaseURL != "http://localhost:11434" { + t.Errorf("expected default Ollama BaseURL 'http://localhost:11434', got %q", cfg.Ollama.BaseURL) + } + if cfg.Ollama.Model != "gemma4:e4b" { + t.Errorf("expected default Ollama Model 'gemma4:e4b', got %q", cfg.Ollama.Model) + } + if cfg.Ollama.Timeout != 120*time.Second { + t.Errorf("expected default Ollama Timeout 120s, got %v", cfg.Ollama.Timeout) + } + + // Check Sandbox defaults + if cfg.Sandbox.Timeout != 30*time.Second { + t.Errorf("expected default Sandbox Timeout 30s, got %v", cfg.Sandbox.Timeout) + } + if cfg.Sandbox.MaxMemory != 512*1024*1024 { + t.Errorf("expected default Sandbox MaxMemory 512MB, got %d", cfg.Sandbox.MaxMemory) + } + if cfg.Sandbox.WorkingDir != "/tmp/orca/sandbox" { + t.Errorf("expected default Sandbox WorkingDir '/tmp/orca/sandbox', got %q", cfg.Sandbox.WorkingDir) + } + + // Check Session defaults + if cfg.Session.MaxHistory != 100 { + t.Errorf("expected default Session MaxHistory 100, got %d", cfg.Session.MaxHistory) + } + if cfg.Session.StorageDir == "" { + t.Error("expected non-empty Session StorageDir") + } +} + +func TestDefaultConfigStorageDir(t *testing.T) { + cfg := DefaultConfig() + home, _ := os.UserHomeDir() + expected := home + "/.orca/sessions" + if cfg.Session.StorageDir != expected { + t.Errorf("expected StorageDir %q, got %q", expected, cfg.Session.StorageDir) + } +} + +func TestLoadConfigFromEnv(t *testing.T) { + // Set environment variables + os.Setenv("ORCA_OLLAMA_BASE_URL", "http://custom:11434") + os.Setenv("ORCA_OLLAMA_MODEL", "codellama") + os.Setenv("ORCA_OLLAMA_TIMEOUT", "60s") + os.Setenv("ORCA_SANDBOX_TIMEOUT", "120s") + os.Setenv("ORCA_SANDBOX_MAX_MEMORY", "1073741824") + os.Setenv("ORCA_SANDBOX_WORKING_DIR", "/custom/sandbox") + os.Setenv("ORCA_SESSION_STORAGE_DIR", "/custom/sessions") + os.Setenv("ORCA_SESSION_MAX_HISTORY", "200") + + defer func() { + os.Unsetenv("ORCA_OLLAMA_BASE_URL") + os.Unsetenv("ORCA_OLLAMA_MODEL") + os.Unsetenv("ORCA_OLLAMA_TIMEOUT") + os.Unsetenv("ORCA_SANDBOX_TIMEOUT") + os.Unsetenv("ORCA_SANDBOX_MAX_MEMORY") + os.Unsetenv("ORCA_SANDBOX_WORKING_DIR") + os.Unsetenv("ORCA_SESSION_STORAGE_DIR") + os.Unsetenv("ORCA_SESSION_MAX_HISTORY") + }() + + cfg := LoadConfigFromEnv() + + if cfg.Ollama.BaseURL != "http://custom:11434" { + t.Errorf("expected Ollama BaseURL 'http://custom:11434', got %q", cfg.Ollama.BaseURL) + } + if cfg.Ollama.Model != "codellama" { + t.Errorf("expected Ollama Model 'codellama', got %q", cfg.Ollama.Model) + } + if cfg.Ollama.Timeout != 60*time.Second { + t.Errorf("expected Ollama Timeout 60s, got %v", cfg.Ollama.Timeout) + } + if cfg.Sandbox.Timeout != 120*time.Second { + t.Errorf("expected Sandbox Timeout 120s, got %v", cfg.Sandbox.Timeout) + } + if cfg.Sandbox.MaxMemory != 1073741824 { + t.Errorf("expected Sandbox MaxMemory 1073741824, got %d", cfg.Sandbox.MaxMemory) + } + if cfg.Sandbox.WorkingDir != "/custom/sandbox" { + t.Errorf("expected Sandbox WorkingDir '/custom/sandbox', got %q", cfg.Sandbox.WorkingDir) + } + if cfg.Session.StorageDir != "/custom/sessions" { + t.Errorf("expected Session StorageDir '/custom/sessions', got %q", cfg.Session.StorageDir) + } + if cfg.Session.MaxHistory != 200 { + t.Errorf("expected Session MaxHistory 200, got %d", cfg.Session.MaxHistory) + } +} + +func TestLoadConfigFromEnvPartial(t *testing.T) { + os.Setenv("ORCA_OLLAMA_MODEL", "mistral") + defer os.Unsetenv("ORCA_OLLAMA_MODEL") + + cfg := LoadConfigFromEnv() + + // Should use env override + if cfg.Ollama.Model != "mistral" { + t.Errorf("expected Model 'mistral', got %q", cfg.Ollama.Model) + } + // Should keep defaults for unset values + if cfg.Ollama.BaseURL != "http://localhost:11434" { + t.Errorf("expected default BaseURL, got %q", cfg.Ollama.BaseURL) + } +} + +func TestConfigIsValid(t *testing.T) { + cfg := DefaultConfig() + if err := cfg.IsValid(); err != nil { + t.Errorf("default config should be valid: %v", err) + } +} + +func TestConfigInvalidBaseURL(t *testing.T) { + cfg := DefaultConfig() + cfg.Ollama.BaseURL = "" + if err := cfg.IsValid(); err == nil { + t.Error("expected error for empty BaseURL") + } +} + +func TestConfigInvalidModel(t *testing.T) { + cfg := DefaultConfig() + cfg.Ollama.Model = "" + if err := cfg.IsValid(); err == nil { + t.Error("expected error for empty Model") + } +} + +func TestConfigInvalidOllamaTimeout(t *testing.T) { + cfg := DefaultConfig() + cfg.Ollama.Timeout = 0 + if err := cfg.IsValid(); err == nil { + t.Error("expected error for zero Ollama Timeout") + } +} + +func TestConfigInvalidSandboxTimeout(t *testing.T) { + cfg := DefaultConfig() + cfg.Sandbox.Timeout = -1 + if err := cfg.IsValid(); err == nil { + t.Error("expected error for negative Sandbox Timeout") + } +} + +func TestConfigInvalidMaxMemory(t *testing.T) { + cfg := DefaultConfig() + cfg.Sandbox.MaxMemory = 0 + if err := cfg.IsValid(); err == nil { + t.Error("expected error for zero MaxMemory") + } +} + +func TestConfigInvalidMaxHistory(t *testing.T) { + cfg := DefaultConfig() + cfg.Session.MaxHistory = 0 + if err := cfg.IsValid(); err == nil { + t.Error("expected error for zero MaxHistory") + } +} + +func TestConfigError(t *testing.T) { + err := errConfig("test error") + if err.Error() != "config: test error" { + t.Errorf("unexpected error message: %s", err.Error()) + } + + ce, ok := err.(*ConfigError) + if !ok { + t.Fatal("expected ConfigError type") + } + if ce.Message != "test error" { + t.Errorf("expected Message 'test error', got %q", ce.Message) + } +} + +func TestLoadConfigFromEnvInvalidTimeout(t *testing.T) { + os.Setenv("ORCA_OLLAMA_TIMEOUT", "not-a-duration") + defer os.Unsetenv("ORCA_OLLAMA_TIMEOUT") + + cfg := LoadConfigFromEnv() + // Should keep default when env var is unparseable + if cfg.Ollama.Timeout != 120*time.Second { + t.Errorf("expected default 120s when env is invalid, got %v", cfg.Ollama.Timeout) + } +} + +func TestLoadConfigFromEnvInvalidMaxMemory(t *testing.T) { + os.Setenv("ORCA_SANDBOX_MAX_MEMORY", "not-a-number") + defer os.Unsetenv("ORCA_SANDBOX_MAX_MEMORY") + + cfg := LoadConfigFromEnv() + // Should keep default when env var is unparseable + if cfg.Sandbox.MaxMemory != 512*1024*1024 { + t.Errorf("expected default MaxMemory when env is invalid, got %d", cfg.Sandbox.MaxMemory) + } +} + +func TestLoadConfigFromEnvInvalidMaxHistory(t *testing.T) { + os.Setenv("ORCA_SESSION_MAX_HISTORY", "not-a-number") + defer os.Unsetenv("ORCA_SESSION_MAX_HISTORY") + + cfg := LoadConfigFromEnv() + if cfg.Session.MaxHistory != 100 { + t.Errorf("expected default MaxHistory when env is invalid, got %d", cfg.Session.MaxHistory) + } +} diff --git a/pkg/actor/actor.go b/pkg/actor/actor.go new file mode 100644 index 0000000..bc9b041 --- /dev/null +++ b/pkg/actor/actor.go @@ -0,0 +1,220 @@ +// Package actor implements the Actor model for the Orca framework. +// +// An Agent is an independent goroutine that communicates via channels. +// Each agent has a state machine: Idle -> Processing -> [ToolCall] -> +// WaitingForTool -> Processing -> Completed. +package actor + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "github.com/orca/orca/pkg/bus" +) + +// ActorStatus represents the current state of an agent in its state machine. +type ActorStatus int + +const ( + // StatusIdle indicates the agent is ready to accept messages. + StatusIdle ActorStatus = iota + // StatusProcessing indicates the agent is actively handling a message. + StatusProcessing + // StatusWaitingForTool indicates the agent has called a tool and is awaiting its result. + StatusWaitingForTool + // StatusCompleted indicates the agent has finished processing the last message. + StatusCompleted + // StatusStopped indicates the agent has been shut down. + StatusStopped +) + +// String returns the human-readable name of the actor status. +func (s ActorStatus) String() string { + switch s { + case StatusIdle: + return "idle" + case StatusProcessing: + return "processing" + case StatusWaitingForTool: + return "waiting_for_tool" + case StatusCompleted: + return "completed" + case StatusStopped: + return "stopped" + default: + return "unknown" + } +} + +// Agent is the interface that all actors in the Orca framework must implement. +// +// Each Agent runs as an independent goroutine processing messages +// through an internal channel. The Process method provides a synchronous +// API to submit messages and await responses. +type Agent interface { + // ID returns the unique identifier for this agent. + ID() string + // Role returns the role/type of this agent (e.g., "orchestrator", "worker"). + Role() string + // Process sends a message to this agent and waits for a response. + // This is a synchronous call; the agent's goroutine handles the message. + Process(ctx context.Context, msg bus.Message) (bus.Message, error) + // Stop gracefully shuts down this agent, waiting for in-flight processing to complete. + Stop() error +} + +// agentRequest wraps a message and provides a response channel. +type agentRequest struct { + ctx context.Context + msg bus.Message + resp chan agentResponse +} + +// agentResponse wraps the result of processing a message. +type agentResponse struct { + msg bus.Message + err error +} + +// BaseAgent provides shared infrastructure for all agent implementations. +// +// It manages the message channel, goroutine lifecycle, and status tracking. +// Concrete agents should embed BaseAgent and set a handler via SetHandler. +type BaseAgent struct { + id string + role string + msgCh chan agentRequest + stopCh chan struct{} + status atomic.Value + wg sync.WaitGroup + mu sync.Mutex + started bool + handler func(context.Context, bus.Message) (bus.Message, error) +} + +// NewBaseAgent creates a new BaseAgent with the given id and role. +// The agent is not started until Start() is called and a handler is set. +func NewBaseAgent(id, role string) *BaseAgent { + a := &BaseAgent{ + id: id, + role: role, + msgCh: make(chan agentRequest, 64), + stopCh: make(chan struct{}), + } + a.status.Store(StatusIdle) + return a +} + +// ID returns the agent's unique identifier. +func (a *BaseAgent) ID() string { return a.id } + +// Role returns the agent's role. +func (a *BaseAgent) Role() string { return a.role } + +// Status returns the current ActorStatus of this agent. +func (a *BaseAgent) Status() ActorStatus { + s, _ := a.status.Load().(ActorStatus) + return s +} + +// setStatus atomically updates the agent's status. +func (a *BaseAgent) setStatus(s ActorStatus) { + a.status.Store(s) +} + +// SetHandler sets the message handler function for this agent. +// Must be called before Start(). +func (a *BaseAgent) SetHandler(handler func(context.Context, bus.Message) (bus.Message, error)) { + a.mu.Lock() + defer a.mu.Unlock() + a.handler = handler +} + +// Start launches the agent's message processing goroutine. +// The handler must be set before calling Start. +func (a *BaseAgent) Start() error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.started { + return fmt.Errorf("agent %s is already started", a.id) + } + if a.handler == nil { + return fmt.Errorf("agent %s has no handler set", a.id) + } + + a.started = true + a.status.Store(StatusIdle) + a.wg.Add(1) + go a.loop() + return nil +} + +// loop is the main goroutine that reads messages from msgCh and processes them. +func (a *BaseAgent) loop() { + defer a.wg.Done() + + for { + select { + case req := <-a.msgCh: + a.setStatus(StatusProcessing) + resp, err := a.handler(req.ctx, req.msg) + if err != nil { + a.setStatus(StatusIdle) + } else { + a.setStatus(StatusCompleted) + } + req.resp <- agentResponse{msg: resp, err: err} + case <-a.stopCh: + a.setStatus(StatusStopped) + return + } + } +} + +// Process sends a message to the agent's processing loop and waits for a response. +// It respects context cancellation and the agent's stop signal. +func (a *BaseAgent) Process(ctx context.Context, msg bus.Message) (bus.Message, error) { + respCh := make(chan agentResponse, 1) + + select { + case a.msgCh <- agentRequest{ctx: ctx, msg: msg, resp: respCh}: + case <-ctx.Done(): + return bus.Message{}, ctx.Err() + case <-a.stopCh: + return bus.Message{}, fmt.Errorf("agent %s is stopped", a.id) + } + + select { + case r := <-respCh: + return r.msg, r.err + case <-ctx.Done(): + return bus.Message{}, ctx.Err() + } +} + +// Stop gracefully shuts down the agent. +// It signals the processing loop to exit and waits for it to finish. +func (a *BaseAgent) Stop() error { + a.mu.Lock() + started := a.started + a.started = false + a.mu.Unlock() + + if !started { + return nil + } + + close(a.stopCh) + a.wg.Wait() + return nil +} + +// IsStarted returns whether the agent's processing loop is running. +func (a *BaseAgent) IsStarted() bool { + a.mu.Lock() + defer a.mu.Unlock() + return a.started +} diff --git a/pkg/actor/actor_test.go b/pkg/actor/actor_test.go new file mode 100644 index 0000000..4ff82f1 --- /dev/null +++ b/pkg/actor/actor_test.go @@ -0,0 +1,697 @@ +package actor + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/orca/orca/pkg/bus" +) + +// ============================================================ +// BaseAgent Tests +// ============================================================ + +func TestNewBaseAgent(t *testing.T) { + a := NewBaseAgent("test-1", "worker") + if a == nil { + t.Fatal("NewBaseAgent() returned nil") + } + if a.ID() != "test-1" { + t.Errorf("expected id 'test-1', got %q", a.ID()) + } + if a.Role() != "worker" { + t.Errorf("expected role 'worker', got %q", a.Role()) + } + if s := a.Status(); s != StatusIdle { + t.Errorf("expected initial StatusIdle, got %s", s) + } +} + +func TestBaseAgentStartAndStop(t *testing.T) { + a := NewBaseAgent("test-2", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ID: "response"}, nil + }) + + if err := a.Start(); err != nil { + t.Fatalf("Start failed: %v", err) + } + if !a.IsStarted() { + t.Error("expected agent to be started") + } + + if err := a.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } + if a.IsStarted() { + t.Error("expected agent to be stopped after Stop()") + } +} + +func TestBaseAgentDoubleStart(t *testing.T) { + a := NewBaseAgent("test-3", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ID: "response"}, nil + }) + + if err := a.Start(); err != nil { + t.Fatalf("first Start failed: %v", err) + } + + err := a.Start() + if err == nil { + t.Error("expected error on double start") + } + a.Stop() +} + +func TestBaseAgentStartWithoutHandler(t *testing.T) { + a := NewBaseAgent("test-4", "worker") + err := a.Start() + if err == nil { + t.Error("expected error starting agent without handler") + } +} + +func TestBaseAgentProcessAndResponse(t *testing.T) { + a := NewBaseAgent("test-5", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ + ID: msg.ID + "-resp", + Type: bus.MsgTypeTaskResponse, + From: a.ID(), + To: msg.From, + Content: "processed: " + msg.ID, + }, nil + }) + a.Start() + defer a.Stop() + + ctx := context.Background() + resp, err := a.Process(ctx, bus.Message{ + ID: "task-1", + Type: bus.MsgTypeTaskRequest, + From: "caller", + }) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + if resp.ID != "task-1-resp" { + t.Errorf("expected response ID 'task-1-resp', got %q", resp.ID) + } + if resp.Content != "processed: task-1" { + t.Errorf("expected content 'processed: task-1', got %v", resp.Content) + } + if resp.From != "test-5" { + t.Errorf("expected From 'test-5', got %q", resp.From) + } +} + +func TestBaseAgentProcessReturnsError(t *testing.T) { + expectedErr := errors.New("processing failed") + a := NewBaseAgent("test-6", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{}, expectedErr + }) + a.Start() + defer a.Stop() + + _, err := a.Process(context.Background(), bus.Message{ID: "task-1"}) + if err == nil { + t.Fatal("expected error from Process") + } + if !errors.Is(err, expectedErr) { + t.Errorf("expected error %v, got %v", expectedErr, err) + } +} + +func TestBaseAgentProcessOnStoppedAgent(t *testing.T) { + a := NewBaseAgent("test-7", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ID: "response"}, nil + }) + a.Start() + a.Stop() + + _, err := a.Process(context.Background(), bus.Message{ID: "task-1"}) + if err == nil { + t.Error("expected error processing on stopped agent") + } +} + +func TestBaseAgentContextCancellation(t *testing.T) { + a := NewBaseAgent("test-8", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + // Simulate long processing + select { + case <-time.After(5 * time.Second): + return bus.Message{ID: "response"}, nil + case <-ctx.Done(): + return bus.Message{}, ctx.Err() + } + }) + a.Start() + defer a.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + _, err := a.Process(ctx, bus.Message{ID: "task-1"}) + if err == nil { + t.Error("expected error from context cancellation") + } +} + +func TestBaseAgentStatusTransitions(t *testing.T) { + a := NewBaseAgent("test-9", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + if a.Status() != StatusProcessing { + t.Errorf("expected StatusProcessing inside handler, got %s", a.Status()) + } + return bus.Message{ID: "response"}, nil + }) + a.Start() + defer a.Stop() + + // Should be idle before processing + if s := a.Status(); s != StatusIdle { + t.Errorf("expected StatusIdle before Process, got %s", s) + } + + _, err := a.Process(context.Background(), bus.Message{ID: "task-1"}) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + + // Should be completed after processing + // Give the goroutine a moment to update the status + time.Sleep(5 * time.Millisecond) + if s := a.Status(); s != StatusCompleted { + t.Errorf("expected StatusCompleted after Process, got %s", s) + } +} + +func TestBaseAgentConcurrentProcess(t *testing.T) { + a := NewBaseAgent("test-10", "worker") + var counter int32 + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + atomic.AddInt32(&counter, 1) + return bus.Message{ID: "response"}, nil + }) + a.Start() + defer a.Stop() + + var completed int32 + for i := 0; i < 10; i++ { + go func(i int) { + _, err := a.Process(context.Background(), bus.Message{ID: "task"}) + if err == nil { + atomic.AddInt32(&completed, 1) + } + }(i) + } + + time.Sleep(200 * time.Millisecond) + if n := atomic.LoadInt32(&completed); n != 10 { + t.Errorf("expected 10 completed tasks, got %d", n) + } + if n := atomic.LoadInt32(&counter); n != 10 { + t.Errorf("expected 10 handler calls, got %d", n) + } +} + +func TestBaseAgentStopIdempotent(t *testing.T) { + a := NewBaseAgent("test-11", "worker") + a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ID: "response"}, nil + }) + a.Start() + + if err := a.Stop(); err != nil { + t.Fatalf("first Stop failed: %v", err) + } + if err := a.Stop(); err != nil { + t.Fatalf("second Stop should be idempotent: %v", err) + } +} + +func TestActorStatusString(t *testing.T) { + tests := []struct { + status ActorStatus + want string + }{ + {StatusIdle, "idle"}, + {StatusProcessing, "processing"}, + {StatusWaitingForTool, "waiting_for_tool"}, + {StatusCompleted, "completed"}, + {StatusStopped, "stopped"}, + {ActorStatus(99), "unknown"}, + } + for _, tt := range tests { + if got := tt.status.String(); got != tt.want { + t.Errorf("ActorStatus(%d).String() = %q, want %q", tt.status, got, tt.want) + } + } +} + +// ============================================================ +// Worker Tests +// ============================================================ + +func TestNewWorker(t *testing.T) { + w := NewWorker("worker-1") + if w == nil { + t.Fatal("NewWorker() returned nil") + } + if w.ID() != "worker-1" { + t.Errorf("expected id 'worker-1', got %q", w.ID()) + } + if w.Role() != "worker" { + t.Errorf("expected role 'worker', got %q", w.Role()) + } + if !w.IsStarted() { + t.Error("expected worker to be started automatically") + } + w.Stop() +} + +func TestWorkerProcessTask(t *testing.T) { + w := NewWorker("worker-2") + defer w.Stop() + + resp, err := w.Process(context.Background(), bus.Message{ + ID: "task-1", + Type: bus.MsgTypeTaskRequest, + From: "caller", + Content: "do something", + }) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + if resp.Type != bus.MsgTypeTaskResponse { + t.Errorf("expected MsgTypeTaskResponse, got %s", resp.Type) + } + if resp.From != "worker-2" { + t.Errorf("expected From 'worker-2', got %q", resp.From) + } + if resp.Metadata["processed_by"] != "worker-2" { + t.Errorf("expected processed_by 'worker-2', got %q", resp.Metadata["processed_by"]) + } +} + +func TestWorkerProcessToolCall(t *testing.T) { + w := NewWorker("worker-3") + defer w.Stop() + + resp, err := w.Process(context.Background(), bus.Message{ + ID: "tool-1", + Type: bus.MsgTypeToolCall, + From: "caller", + Content: "execute command", + }) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + if resp.Type != bus.MsgTypeToolResult { + t.Errorf("expected MsgTypeToolResult, got %s", resp.Type) + } +} + +func TestWorkerStatusDuringToolCall(t *testing.T) { + w := NewWorker("worker-4") + + // Perform a tool call + _, err := w.Process(context.Background(), bus.Message{ + ID: "tool-1", + Type: bus.MsgTypeToolCall, + From: "caller", + }) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + + // After tool call, status should be Processing (set back by defer) + time.Sleep(5 * time.Millisecond) + // Status could be Processing or Completed depending on timing + s := w.Status() + if s != StatusProcessing && s != StatusIdle && s != StatusCompleted { + t.Errorf("expected Processing/Idle/Completed after tool call, got %s", s) + } + w.Stop() +} + +func TestWorkerUnsupportedMessage(t *testing.T) { + w := NewWorker("worker-5") + defer w.Stop() + + _, err := w.Process(context.Background(), bus.Message{ + ID: "unknown-1", + Type: bus.MsgTypeObservation, + From: "caller", + }) + if err == nil { + t.Error("expected error for unsupported message type") + } +} + +// ============================================================ +// Orchestrator Tests +// ============================================================ + +func TestNewOrchestrator(t *testing.T) { + o := NewOrchestrator("orch-1", nil) + if o == nil { + t.Fatal("NewOrchestrator() returned nil") + } + if o.ID() != "orch-1" { + t.Errorf("expected id 'orch-1', got %q", o.ID()) + } + if o.Role() != "orchestrator" { + t.Errorf("expected role 'orchestrator', got %q", o.Role()) + } + if !o.IsStarted() { + t.Error("expected orchestrator to be started automatically") + } + o.Stop() +} + +func TestOrchestratorAddWorker(t *testing.T) { + o := NewOrchestrator("orch-2", nil) + defer o.Stop() + + w := NewWorker("worker-10") + defer w.Stop() + + o.AddWorker(w) + if n := o.WorkerCount(); n != 1 { + t.Errorf("expected 1 worker, got %d", n) + } + + got, ok := o.GetWorker("worker-10") + if !ok { + t.Fatal("expected to find worker-10") + } + if got.ID() != "worker-10" { + t.Errorf("expected worker ID 'worker-10', got %q", got.ID()) + } +} + +func TestOrchestratorRemoveWorker(t *testing.T) { + o := NewOrchestrator("orch-3", nil) + defer o.Stop() + + w := NewWorker("worker-11") + defer w.Stop() + o.AddWorker(w) + o.RemoveWorker("worker-11") + + if n := o.WorkerCount(); n != 0 { + t.Errorf("expected 0 workers after removal, got %d", n) + } +} + +func TestOrchestratorListWorkers(t *testing.T) { + o := NewOrchestrator("orch-4", nil) + defer o.Stop() + + workers := []string{"w-1", "w-2", "w-3"} + for _, name := range workers { + w := NewWorker(name) + defer w.Stop() + o.AddWorker(w) + } + + list := o.ListWorkers() + if len(list) != len(workers) { + t.Errorf("expected %d workers, got %d", len(workers), len(list)) + } + + ids := make(map[string]bool) + for _, w := range list { + ids[w.ID()] = true + } + for _, name := range workers { + if !ids[name] { + t.Errorf("missing worker %q in list", name) + } + } +} + +func TestOrchestratorDelegatesToWorker(t *testing.T) { + o := NewOrchestrator("orch-5", nil) + defer o.Stop() + + w := NewWorker("worker-20") + defer w.Stop() + o.AddWorker(w) + + resp, err := o.Process(context.Background(), bus.Message{ + ID: "task-1", + Type: bus.MsgTypeTaskRequest, + From: "caller", + Content: "do work", + }) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + if resp.From != "worker-20" { + t.Errorf("expected response from 'worker-20', got %q", resp.From) + } + if resp.Metadata["processed_by"] != "worker-20" { + t.Errorf("expected processed_by 'worker-20', got %q", resp.Metadata["processed_by"]) + } +} + +func TestOrchestratorNoWorkers(t *testing.T) { + o := NewOrchestrator("orch-6", nil) + defer o.Stop() + + _, err := o.Process(context.Background(), bus.Message{ + ID: "task-1", + Type: bus.MsgTypeTaskRequest, + From: "caller", + }) + if err == nil { + t.Error("expected error when no workers available") + } +} + +func TestOrchestratorSystemMessage(t *testing.T) { + o := NewOrchestrator("orch-7", nil) + defer o.Stop() + + resp, err := o.Process(context.Background(), bus.Message{ + ID: "sys-1", + Type: bus.MsgTypeSystem, + From: "caller", + }) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + if resp.Content != "orchestrator acknowledged" { + t.Errorf("expected acknowledged message, got %v", resp.Content) + } +} + +// ============================================================ +// System Tests +// ============================================================ + +func TestNewSystem(t *testing.T) { + s := NewSystem() + if s == nil { + t.Fatal("NewSystem() returned nil") + } + if n := s.AgentCount(); n != 0 { + t.Errorf("expected 0 agents, got %d", n) + } +} + +func TestSystemCreateWorker(t *testing.T) { + s := NewSystem() + w, err := s.CreateWorker() + if err != nil { + t.Fatalf("CreateWorker failed: %v", err) + } + if w == nil { + t.Fatal("CreateWorker returned nil") + } + if w.Role() != "worker" { + t.Errorf("expected role 'worker', got %q", w.Role()) + } + if n := s.AgentCount(); n != 1 { + t.Errorf("expected 1 agent, got %d", n) + } + s.StopAll() +} + +func TestSystemStopAgent(t *testing.T) { + s := NewSystem() + w, _ := s.CreateWorker() + + if err := s.StopAgent(w.ID()); err != nil { + t.Fatalf("StopAgent failed: %v", err) + } + if n := s.AgentCount(); n != 0 { + t.Errorf("expected 0 agents, got %d", n) + } + + _, ok := s.GetAgent(w.ID()) + if ok { + t.Error("expected agent to be removed after StopAgent") + } +} + +func TestSystemStopAgentNotFound(t *testing.T) { + s := NewSystem() + err := s.StopAgent("nonexistent") + if err == nil { + t.Error("expected error stopping nonexistent agent") + } +} + +func TestSystemListAgents(t *testing.T) { + s := NewSystem() + s.CreateWorker() + s.CreateWorker() + + agents := s.ListAgents() + if len(agents) != 2 { + t.Errorf("expected 2 agents, got %d", len(agents)) + } + s.StopAll() +} + +func TestSystemAgentInfos(t *testing.T) { + s := NewSystem() + w, _ := s.CreateWorker() + + infos := s.AgentInfos() + if len(infos) != 1 { + t.Fatalf("expected 1 agent info, got %d", len(infos)) + } + if infos[0].ID != w.ID() { + t.Errorf("expected ID %q, got %q", w.ID(), infos[0].ID) + } + if infos[0].Role != "worker" { + t.Errorf("expected Role 'worker', got %q", infos[0].Role) + } + if infos[0].Status != StatusIdle { + t.Errorf("expected Status StatusIdle, got %s", infos[0].Status) + } + s.StopAll() +} + +func TestSystemStopAll(t *testing.T) { + s := NewSystem() + s.CreateWorker() + s.CreateWorker() + s.CreateWorker() + + if err := s.StopAll(); err != nil { + t.Fatalf("StopAll failed: %v", err) + } + if n := s.AgentCount(); n != 0 { + t.Errorf("expected 0 agents after StopAll, got %d", n) + } +} + +// ============================================================ +// ToolWorker Tests +// ============================================================ + +func TestNewToolWorker(t *testing.T) { + tw := NewToolWorker("tool-1", nil) + if tw == nil { + t.Fatal("NewToolWorker() returned nil") + } + if tw.ID() != "tool-1" { + t.Errorf("expected id 'tool-1', got %q", tw.ID()) + } + if tw.Role() != "tool_worker" { + t.Errorf("expected role 'tool_worker', got %q", tw.Role()) + } + if !tw.IsStarted() { + t.Error("expected tool worker to be started automatically") + } + tw.Stop() +} + +func TestToolWorkerProcessSystemMessage(t *testing.T) { + tw := NewToolWorker("tool-2", nil) + defer tw.Stop() + + resp, err := tw.Process(context.Background(), bus.Message{ + ID: "sys-1", + Type: bus.MsgTypeSystem, + From: "caller", + }) + if err != nil { + t.Fatalf("Process failed: %v", err) + } + if resp.Content != "tool_worker acknowledged" { + t.Errorf("expected 'tool_worker acknowledged', got %v", resp.Content) + } +} + +func TestToolWorkerUnsupportedMessage(t *testing.T) { + tw := NewToolWorker("tool-3", nil) + defer tw.Stop() + + _, err := tw.Process(context.Background(), bus.Message{ + ID: "obs-1", + Type: bus.MsgTypeObservation, + From: "caller", + }) + if err == nil { + t.Error("expected error for unsupported message type") + } +} + +func TestParseToolCallContentMap(t *testing.T) { + name, args, err := parseToolCallContent(map[string]interface{}{ + "name": "exec", + "arguments": map[string]interface{}{"command": "ls"}, + }) + if err != nil { + t.Fatalf("parseToolCallContent failed: %v", err) + } + if name != "exec" { + t.Errorf("expected name 'exec', got %q", name) + } + if args["command"] != "ls" { + t.Errorf("expected args['command'] = 'ls', got %v", args["command"]) + } +} + +func TestParseToolCallContentMissingName(t *testing.T) { + _, _, err := parseToolCallContent(map[string]interface{}{"foo": "bar"}) + if err == nil { + t.Error("expected error for missing name") + } +} + +// ============================================================ +// System ToolWorker Tests +// ============================================================ + +func TestSystemCreateToolWorker(t *testing.T) { + s := NewSystem() + tw, err := s.CreateToolWorker(nil) + if err != nil { + t.Fatalf("CreateToolWorker failed: %v", err) + } + if tw == nil { + t.Fatal("CreateToolWorker returned nil") + } + if tw.Role() != "tool_worker" { + t.Errorf("expected role 'tool_worker', got %q", tw.Role()) + } + if n := s.AgentCount(); n != 1 { + t.Errorf("expected 1 agent, got %d", n) + } + s.StopAll() +} diff --git a/pkg/actor/agent.go b/pkg/actor/agent.go new file mode 100644 index 0000000..0371abd --- /dev/null +++ b/pkg/actor/agent.go @@ -0,0 +1,18 @@ +// Package actor implements the Actor model for the Orca framework. +// +// This file provides additional agent types that integrate the LLM +// and Tool systems with the actor framework. See actor.go for the +// base Agent interface and BaseAgent implementation. +package actor + +// This file exists alongside actor.go to provide the LLMAgent and +// ToolWorker types, completing the integration between the actor +// system and the LLM / Tool subsystems. +// +// The key types in this package are: +// - Agent (interface, in actor.go) +// - BaseAgent (struct, in actor.go) +// - Orchestrator (in orchestrator.go) +// - Worker (in worker.go) +// - LLMAgent (in llm_agent.go) +// - ToolWorker (in tool_worker.go) diff --git a/pkg/actor/llm_agent.go b/pkg/actor/llm_agent.go new file mode 100644 index 0000000..b9916e8 --- /dev/null +++ b/pkg/actor/llm_agent.go @@ -0,0 +1,338 @@ +package actor + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/orca/orca/pkg/bus" + "github.com/orca/orca/pkg/llm" + "github.com/orca/orca/pkg/session" + "github.com/orca/orca/pkg/tool" +) + +// LLMAgent implements the Agent interface by integrating an LLM backend +// with the actor system and tool framework. +// +// It receives user messages, retrieves session context, calls the LLM, +// and handles tool call responses by executing tools and feeding results +// back to the LLM for final response generation. +type LLMAgent struct { + *BaseAgent + llm llm.LLM + sessionMgr *session.Manager + sessionID string + toolManager *tool.Manager + toolWorker *ToolWorker + windowSize int +} + +// LLMAgentOption is a functional option for configuring the LLMAgent. +type LLMAgentOption func(*LLMAgent) + +// WithSessionManager sets the session manager for conversation history. +func WithSessionManager(mgr *session.Manager) LLMAgentOption { + return func(a *LLMAgent) { + a.sessionMgr = mgr + } +} + +// WithSessionID sets the session ID for conversation persistence. +func WithSessionID(id string) LLMAgentOption { + return func(a *LLMAgent) { + a.sessionID = id + } +} + +// WithToolManager sets the tool manager for executing tools. +func WithToolManager(mgr *tool.Manager) LLMAgentOption { + return func(a *LLMAgent) { + a.toolManager = mgr + } +} + +// WithToolWorker sets the tool worker for delegated tool execution. +func WithToolWorker(w *ToolWorker) LLMAgentOption { + return func(a *LLMAgent) { + a.toolWorker = w + } +} + +// WithWindowSize sets the context window size for session history. +func WithWindowSize(size int) LLMAgentOption { + return func(a *LLMAgent) { + a.windowSize = size + } +} + +// NewLLMAgent creates a new LLMAgent with the given LLM backend and options. +// The agent is started automatically upon creation. +func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent { + a := &LLMAgent{ + BaseAgent: NewBaseAgent(id, "llm_agent"), + llm: backend, + windowSize: 20, // Default context window + } + + for _, opt := range opts { + opt(a) + } + + a.SetHandler(a.handleMessage) + if err := a.Start(); err != nil { + panic(fmt.Sprintf("llm_agent: failed to start: %v", err)) + } + return a +} + +// handleMessage routes incoming messages to the appropriate handler. +func (a *LLMAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) { + switch msg.Type { + case bus.MsgTypeTaskRequest: + return a.handleUserMessage(ctx, msg) + case bus.MsgTypeSystem: + return a.handleSystem(ctx, msg) + default: + return bus.Message{}, fmt.Errorf("llm_agent %s: unsupported message type %s", a.ID(), msg.Type) + } +} + +// handleUserMessage processes a user message through the LLM. +// +// Flow: +// 1. Persist the user message to session history +// 2. Retrieve recent conversation context +// 3. Convert to LLM message format +// 4. Call LLM.Chat +// 5. If response has tool calls: +// a. Execute each tool (directly or via ToolWorker) +// b. Add tool results to conversation +// c. Call LLM.Chat again with results +// 6. Persist the assistant response +// 7. Return the final response +func (a *LLMAgent) handleUserMessage(ctx context.Context, msg bus.Message) (bus.Message, error) { + content, ok := msg.Content.(string) + if !ok { + return bus.Message{}, fmt.Errorf("llm_agent: expected string content, got %T", msg.Content) + } + + // Ensure session exists + if a.sessionMgr != nil && a.sessionID != "" { + // Check if session exists; create if not + if _, err := a.sessionMgr.GetSession(a.sessionID); err != nil { + a.sessionMgr.CreateSession(a.sessionID, map[string]string{ + "source": "llm_agent", + }) + } + + // Persist user message + a.sessionMgr.AddMessage(a.sessionID, session.RoleUser, content, nil) + } + + // Build LLM messages from session context + llmMessages := a.buildLLMMessages() + + // Call LLM (potentially multiple rounds for tool calls) + finalResponse, err := a.chatWithToolLoop(ctx, llmMessages) + if err != nil { + return bus.Message{}, fmt.Errorf("llm_agent: LLM chat failed: %w", err) + } + + // Persist assistant response + if a.sessionMgr != nil && a.sessionID != "" { + a.sessionMgr.AddMessage(a.sessionID, session.RoleAssistant, finalResponse, nil) + } + + return bus.Message{ + ID: msg.ID + "-response", + Type: bus.MsgTypeTaskResponse, + From: a.ID(), + To: msg.From, + Content: finalResponse, + }, nil +} + +func (a *LLMAgent) buildLLMMessages() []llm.Message { + messages := make([]llm.Message, 0) + + if a.toolManager != nil { + messages = append(messages, llm.Message{ + Role: "system", + Content: a.buildToolSystemPrompt(), + }) + } + + if a.sessionMgr == nil || a.sessionID == "" { + return messages + } + + sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize) + if err != nil { + return messages + } + + for _, sm := range sessionMsgs { + msg := llm.Message{ + Role: string(sm.Role), + Content: sm.Content, + } + if sm.Role == session.RoleTool && sm.Metadata != nil { + msg.ToolCallID = sm.Metadata["tool_call_id"] + } + messages = append(messages, msg) + } + + return messages +} + +// buildToolSystemPrompt creates a system prompt describing all available tools. +// This enables prompt-based tool calling for models without native function +// calling support. +func (a *LLMAgent) buildToolSystemPrompt() string { + if a.toolManager == nil { + return "" + } + + var b strings.Builder + b.WriteString("你是一个 AI 助手,可以使用以下工具来完成用户的请求。\n\n") + b.WriteString("可用工具列表:\n") + + for _, t := range a.toolManager.List() { + b.WriteString(fmt.Sprintf("\n工具名: %s\n", t.Name())) + b.WriteString(fmt.Sprintf("描述: %s\n", t.Description())) + paramsJSON, _ := json.Marshal(t.Parameters()) + b.WriteString(fmt.Sprintf("参数: %s\n", string(paramsJSON))) + } + + b.WriteString("\n规则:\n") + b.WriteString("1. 当你需要调用工具时,请在回复中**只输出**以下 JSON 格式(不要添加其他文字):\n") + b.WriteString(` {"tool": "工具名", "arguments": {"参数名": "参数值"}}` + "\n") + b.WriteString("2. 如果你已经看到了工具返回的结果,请直接根据结果回答用户,不要再次调用工具。\n") + b.WriteString("3. 如果你不需要调用工具,请直接回复用户。\n") + + return b.String() +} + +func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) (string, error) { + maxRounds := 10 + + for round := 0; round < maxRounds; round++ { + response, err := a.llm.Chat(ctx, messages) + if err != nil { + return "", fmt.Errorf("chat round %d failed: %w", round, err) + } + + toolCalls := response.ToolCalls + if len(toolCalls) == 0 { + toolCalls = a.parseToolCallsFromContent(response.Content) + } + + if len(toolCalls) == 0 { + return response.Content, nil + } + + messages = append(messages, llm.Message{ + Role: "assistant", + Content: response.Content, + }) + + for _, tc := range toolCalls { + resultContent := a.executeToolCall(ctx, tc) + messages = append(messages, llm.Message{ + Role: "user", + Content: fmt.Sprintf("工具 %s 的执行结果:%s", tc.Function.Name, resultContent), + }) + } + } + + return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds) +} + +func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall { + var parsed struct { + Tool string `json:"tool"` + Arguments map[string]interface{} `json:"arguments"` + } + if err := json.Unmarshal([]byte(content), &parsed); err != nil || parsed.Tool == "" { + return nil + } + + argsJSON, _ := json.Marshal(parsed.Arguments) + return []llm.ToolCall{{ + ID: "call_0", + Type: "function", + Function: llm.FunctionCall{ + Name: parsed.Tool, + Arguments: string(argsJSON), + }, + }} +} + +// executeToolCall runs a single tool call and returns the result as a JSON string. +func (a *LLMAgent) executeToolCall(ctx context.Context, tc llm.ToolCall) string { + toolName := tc.Function.Name + + // Parse arguments + var args map[string]interface{} + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + args = map[string]interface{}{ + "_raw": tc.Function.Arguments, + } + } + + // Execute via ToolWorker (preferred) or directly via tool.Manager + if a.toolWorker != nil { + // Create a tool call message for the ToolWorker + toolCallMsg := bus.Message{ + ID: tc.ID, + Type: bus.MsgTypeToolCall, + From: a.ID(), + To: a.toolWorker.ID(), + Content: map[string]interface{}{"name": toolName, "arguments": args}, + } + + resultMsg, err := a.toolWorker.Process(ctx, toolCallMsg) + if err != nil { + return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err) + } + + // Serialize the result + resultJSON, err := json.Marshal(resultMsg.Content) + if err != nil { + return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err) + } + return string(resultJSON) + } + + // Fallback: execute directly via tool.Manager + if a.toolManager != nil { + result, err := a.toolManager.Execute(toolName, ctx, args) + if err != nil { + return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err) + } + + resultJSON, err := json.Marshal(result) + if err != nil { + return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err) + } + return string(resultJSON) + } + + return fmt.Sprintf(`{"error": "no tool worker or tool manager available for %q"}`, toolName) +} + +// handleSystem processes internal system messages. +func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ + ID: msg.ID + "-ack", + Type: bus.MsgTypeSystem, + From: a.ID(), + To: msg.From, + Content: "llm_agent acknowledged", + }, nil +} + +// Compile-time interface checks. +var _ Agent = (*LLMAgent)(nil) +var _ Agent = (*ToolWorker)(nil) diff --git a/pkg/actor/orchestrator.go b/pkg/actor/orchestrator.go new file mode 100644 index 0000000..5f98c15 --- /dev/null +++ b/pkg/actor/orchestrator.go @@ -0,0 +1,123 @@ +package actor + +import ( + "context" + "fmt" + "sync" + + "github.com/orca/orca/pkg/bus" +) + +// Orchestrator is an agent that coordinates task execution across a pool of workers. +// +// It receives task requests, delegates them to available workers, and +// collects responses. The orchestrator maintains a registry of worker +// agents and can dynamically add or remove them. +type Orchestrator struct { + *BaseAgent + workers map[string]Agent + bus bus.MessageBus + mu sync.RWMutex +} + +// NewOrchestrator creates a new Orchestrator agent with the given id and message bus. +// The agent is started automatically upon creation. +func NewOrchestrator(id string, mb bus.MessageBus) *Orchestrator { + o := &Orchestrator{ + BaseAgent: NewBaseAgent(id, "orchestrator"), + workers: make(map[string]Agent), + bus: mb, + } + o.SetHandler(o.handleMessage) + // Start the agent's processing loop + if err := o.Start(); err != nil { + // This should not happen since handler is set above + panic(fmt.Sprintf("orchestrator: failed to start: %v", err)) + } + return o +} + +// handleMessage routes incoming messages to the appropriate handler. +func (o *Orchestrator) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) { + switch msg.Type { + case bus.MsgTypeTaskRequest: + return o.handleTask(ctx, msg) + case bus.MsgTypeSystem: + return o.handleSystem(ctx, msg) + default: + return bus.Message{}, fmt.Errorf("orchestrator %s: unsupported message type %s", o.ID(), msg.Type) + } +} + +// handleTask processes a task request by delegating to an available worker. +func (o *Orchestrator) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) { + o.mu.RLock() + defer o.mu.RUnlock() + + if len(o.workers) == 0 { + return bus.Message{}, fmt.Errorf("orchestrator %s: no workers available", o.ID()) + } + + // Simple round-robin: pick the first available worker + for _, w := range o.workers { + return w.Process(ctx, msg) + } + + return bus.Message{}, fmt.Errorf("orchestrator %s: no workers available", o.ID()) +} + +// handleSystem processes internal system messages. +func (o *Orchestrator) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ + ID: msg.ID + "-ack", + Type: bus.MsgTypeSystem, + From: o.ID(), + To: msg.From, + Content: "orchestrator acknowledged", + }, nil +} + +// AddWorker registers a worker agent with this orchestrator. +func (o *Orchestrator) AddWorker(w Agent) { + o.mu.Lock() + defer o.mu.Unlock() + o.workers[w.ID()] = w +} + +// RemoveWorker unregisters a worker agent from this orchestrator. +func (o *Orchestrator) RemoveWorker(id string) { + o.mu.Lock() + defer o.mu.Unlock() + delete(o.workers, id) +} + +// WorkerCount returns the number of registered workers. +func (o *Orchestrator) WorkerCount() int { + o.mu.RLock() + defer o.mu.RUnlock() + return len(o.workers) +} + +// GetWorker retrieves a registered worker by ID. +func (o *Orchestrator) GetWorker(id string) (Agent, bool) { + o.mu.RLock() + defer o.mu.RUnlock() + w, ok := o.workers[id] + return w, ok +} + +// ListWorkers returns all registered workers. +func (o *Orchestrator) ListWorkers() []Agent { + o.mu.RLock() + defer o.mu.RUnlock() + workers := make([]Agent, 0, len(o.workers)) + for _, w := range o.workers { + workers = append(workers, w) + } + return workers +} + +// Bus returns the orchestrator's message bus reference. +func (o *Orchestrator) Bus() bus.MessageBus { + return o.bus +} diff --git a/pkg/actor/system.go b/pkg/actor/system.go new file mode 100644 index 0000000..b448638 --- /dev/null +++ b/pkg/actor/system.go @@ -0,0 +1,180 @@ +package actor + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/orca/orca/pkg/tool" +) + +// System manages the lifecycle of all agents in the Orca actor framework. +// +// It provides centralized agent creation, monitoring, and shutdown +// capabilities. Agents are identified by unique IDs and organized by role. +type System struct { + mu sync.RWMutex + agents map[string]Agent + nextID int64 +} + +// NewSystem creates a new empty actor System. +func NewSystem() *System { + return &System{ + agents: make(map[string]Agent), + } +} + +// AgentInfo holds summary information about a managed agent. +type AgentInfo struct { + ID string `json:"id"` + Role string `json:"role"` + Status ActorStatus `json:"status"` +} + +// CreateOrchestrator creates a new Orchestrator agent and registers it. +func (s *System) CreateOrchestrator(bus interface{}) (*Orchestrator, error) { + id := s.nextAgentID("orch") + return s.addOrchestrator(id, bus) +} + +// CreateWorker creates a new Worker agent and registers it. +func (s *System) CreateWorker() (*Worker, error) { + id := s.nextAgentID("worker") + return s.addWorker(id) +} + +// CreateToolWorker creates a new ToolWorker agent with the given tool manager and registers it. +func (s *System) CreateToolWorker(manager *tool.Manager) (*ToolWorker, error) { + id := s.nextAgentID("tool") + return s.addToolWorker(id, manager) +} + +// nextAgentID generates a unique agent ID with the given prefix. +func (s *System) nextAgentID(prefix string) string { + n := atomic.AddInt64(&s.nextID, 1) + return fmt.Sprintf("%s-%d", prefix, n) +} + +// addOrchestrator creates and registers an orchestrator. +func (s *System) addOrchestrator(id string, busInterface interface{}) (*Orchestrator, error) { + mb, ok := busInterface.(interface{ Bus() }) + var orch *Orchestrator + if ok { + // If busInterface has a Bus() method, we could extract it here + _ = mb + } + orch = NewOrchestrator(id, nil) + + s.mu.Lock() + s.agents[id] = orch + s.mu.Unlock() + + return orch, nil +} + +// addWorker creates and registers a worker. +func (s *System) addWorker(id string) (*Worker, error) { + w := NewWorker(id) + + s.mu.Lock() + s.agents[id] = w + s.mu.Unlock() + + return w, nil +} + +// addToolWorker creates and registers a tool worker with the given tool manager. +func (s *System) addToolWorker(id string, manager *tool.Manager) (*ToolWorker, error) { + w := NewToolWorker(id, manager) + + s.mu.Lock() + s.agents[id] = w + s.mu.Unlock() + + return w, nil +} + +// StopAgent stops and removes a single agent by ID. +func (s *System) StopAgent(id string) error { + s.mu.Lock() + agent, ok := s.agents[id] + if !ok { + s.mu.Unlock() + return fmt.Errorf("agent %s not found", id) + } + delete(s.agents, id) + s.mu.Unlock() + + return agent.Stop() +} + +// GetAgent retrieves a registered agent by ID. +func (s *System) GetAgent(id string) (Agent, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + agent, ok := s.agents[id] + return agent, ok +} + +// ListAgents returns all registered agents. +func (s *System) ListAgents() []Agent { + s.mu.RLock() + defer s.mu.RUnlock() + + agents := make([]Agent, 0, len(s.agents)) + for _, a := range s.agents { + agents = append(agents, a) + } + return agents +} + +// AgentInfos returns summary information for all registered agents. +func (s *System) AgentInfos() []AgentInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + infos := make([]AgentInfo, 0, len(s.agents)) + for _, a := range s.agents { + // Try to get status from BaseAgent + status := StatusIdle + if ba, ok := a.(*BaseAgent); ok { + status = ba.Status() + } else if orch, ok := a.(*Orchestrator); ok { + status = orch.Status() + } else if w, ok := a.(*Worker); ok { + status = w.Status() + } else if tw, ok := a.(*ToolWorker); ok { + status = tw.Status() + } + + infos = append(infos, AgentInfo{ + ID: a.ID(), + Role: a.Role(), + Status: status, + }) + } + return infos +} + +// StopAll gracefully stops all registered agents. +func (s *System) StopAll() error { + s.mu.Lock() + defer s.mu.Unlock() + + var lastErr error + for id, agent := range s.agents { + if err := agent.Stop(); err != nil { + lastErr = err + } + delete(s.agents, id) + } + return lastErr +} + +// AgentCount returns the number of registered agents. +func (s *System) AgentCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.agents) +} diff --git a/pkg/actor/tool_worker.go b/pkg/actor/tool_worker.go new file mode 100644 index 0000000..fad4e24 --- /dev/null +++ b/pkg/actor/tool_worker.go @@ -0,0 +1,153 @@ +package actor + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/orca/orca/pkg/bus" + "github.com/orca/orca/pkg/tool" +) + +// ToolWorker is an agent that processes tool call messages by executing +// tools through the tool.Manager. +// +// It implements the Agent interface and handles MsgTypeToolCall messages. +// When a tool call is received, it extracts the tool name and arguments +// from the message content, executes the tool via the Manager, and +// returns a MsgTypeToolResult with the execution result. +type ToolWorker struct { + *BaseAgent + manager *tool.Manager +} + +// NewToolWorker creates a new ToolWorker agent with the given id and tool manager. +// The agent is started automatically upon creation. +func NewToolWorker(id string, manager *tool.Manager) *ToolWorker { + w := &ToolWorker{ + BaseAgent: NewBaseAgent(id, "tool_worker"), + manager: manager, + } + w.SetHandler(w.handleMessage) + if err := w.Start(); err != nil { + panic(fmt.Sprintf("tool_worker: failed to start: %v", err)) + } + return w +} + +// handleMessage routes incoming messages to the appropriate handler. +func (w *ToolWorker) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) { + switch msg.Type { + case bus.MsgTypeToolCall: + return w.handleToolCall(ctx, msg) + case bus.MsgTypeTaskRequest: + return w.handleTask(ctx, msg) + case bus.MsgTypeSystem: + return w.handleSystem(ctx, msg) + default: + return bus.Message{}, fmt.Errorf("tool_worker %s: unsupported message type %s", w.ID(), msg.Type) + } +} + +// handleToolCall processes a tool call by executing the named tool +// with the provided arguments. +// +// The msg.Content is expected to contain a JSON object with: +// - "name": the tool name (string) +// - "arguments": the tool arguments (object) +// +// Or alternatively, msg.Content can be a string in the format: +// tool_name(arg1=val1, arg2=val2) +func (w *ToolWorker) handleToolCall(ctx context.Context, msg bus.Message) (bus.Message, error) { + w.setStatus(StatusWaitingForTool) + defer w.setStatus(StatusProcessing) + + toolName, args, err := parseToolCallContent(msg.Content) + if err != nil { + return bus.Message{ + ID: msg.ID + "-result", + Type: bus.MsgTypeToolResult, + From: w.ID(), + To: msg.From, + Content: map[string]interface{}{"error": err.Error()}, + }, nil + } + + // Execute the tool + result, err := w.manager.Execute(toolName, ctx, args) + if err != nil { + return bus.Message{ + ID: msg.ID + "-result", + Type: bus.MsgTypeToolResult, + From: w.ID(), + To: msg.From, + Content: map[string]interface{}{"error": err.Error()}, + }, nil + } + + return bus.Message{ + ID: msg.ID + "-result", + Type: bus.MsgTypeToolResult, + From: w.ID(), + To: msg.From, + Content: result, + }, nil +} + +// parseToolCallContent extracts the tool name and arguments from various +// content formats. +func parseToolCallContent(content interface{}) (string, map[string]interface{}, error) { + switch v := content.(type) { + case map[string]interface{}: + // Format: {"name": "tool_name", "arguments": {...}} + name, ok := v["name"].(string) + if !ok || name == "" { + return "", nil, fmt.Errorf("tool call content missing 'name' field") + } + args, _ := v["arguments"].(map[string]interface{}) + if args == nil { + args = make(map[string]interface{}) + } + return name, args, nil + + case string: + // Try JSON format + var parsed map[string]interface{} + if err := json.Unmarshal([]byte(v), &parsed); err == nil { + name, ok := parsed["name"].(string) + if ok && name != "" { + args, _ := parsed["arguments"].(map[string]interface{}) + if args == nil { + args = make(map[string]interface{}) + } + return name, args, nil + } + } + return "", nil, fmt.Errorf("cannot parse tool call from string content: %s", v) + + default: + return "", nil, fmt.Errorf("unsupported tool call content type: %T", content) + } +} + +// handleTask processes a task request by returning a task response. +func (w *ToolWorker) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ + ID: msg.ID + "-response", + Type: bus.MsgTypeTaskResponse, + From: w.ID(), + To: msg.From, + Content: msg.Content, + }, nil +} + +// handleSystem processes internal system messages. +func (w *ToolWorker) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ + ID: msg.ID + "-ack", + Type: bus.MsgTypeSystem, + From: w.ID(), + To: msg.From, + Content: "tool_worker acknowledged", + }, nil +} diff --git a/pkg/actor/worker.go b/pkg/actor/worker.go new file mode 100644 index 0000000..4a6815b --- /dev/null +++ b/pkg/actor/worker.go @@ -0,0 +1,88 @@ +package actor + +import ( + "context" + "fmt" + + "github.com/orca/orca/pkg/bus" +) + +// Worker is an agent that processes tasks and makes tool calls. +// +// Workers are the execution units in the actor system. They receive +// task requests from the orchestrator, process them (potentially making +// tool calls), and return results. +type Worker struct { + *BaseAgent +} + +// NewWorker creates a new Worker agent with the given id. +// The agent is started automatically upon creation. +func NewWorker(id string) *Worker { + w := &Worker{ + BaseAgent: NewBaseAgent(id, "worker"), + } + w.SetHandler(w.handleMessage) + if err := w.Start(); err != nil { + panic(fmt.Sprintf("worker: failed to start: %v", err)) + } + return w +} + +// handleMessage routes incoming messages to the appropriate handler. +func (w *Worker) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) { + switch msg.Type { + case bus.MsgTypeTaskRequest: + return w.handleTask(ctx, msg) + case bus.MsgTypeToolCall: + return w.handleToolCall(ctx, msg) + case bus.MsgTypeSystem: + return w.handleSystem(ctx, msg) + default: + return bus.Message{}, fmt.Errorf("worker %s: unsupported message type %s", w.ID(), msg.Type) + } +} + +// handleTask processes a task request and returns a task response. +func (w *Worker) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) { + // Process the task - in a real implementation this would involve + // the LLM, tool calls, etc. + return bus.Message{ + ID: msg.ID + "-response", + Type: bus.MsgTypeTaskResponse, + From: w.ID(), + To: msg.From, + Content: msg.Content, + Metadata: map[string]string{ + "processed_by": w.ID(), + }, + }, nil +} + +// handleToolCall processes a tool call request, transitions to WaitingForTool +// state, and returns the result. +func (w *Worker) handleToolCall(ctx context.Context, msg bus.Message) (bus.Message, error) { + w.setStatus(StatusWaitingForTool) + defer w.setStatus(StatusProcessing) + + // In a real implementation, this would invoke the actual tool. + // For now, acknowledge the tool call. + return bus.Message{ + ID: msg.ID + "-result", + Type: bus.MsgTypeToolResult, + From: w.ID(), + To: msg.From, + Content: msg.Content, + }, nil +} + +// handleSystem processes internal system messages. +func (w *Worker) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) { + return bus.Message{ + ID: msg.ID + "-ack", + Type: bus.MsgTypeSystem, + From: w.ID(), + To: msg.From, + Content: "worker acknowledged", + }, nil +} diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go new file mode 100644 index 0000000..63cbf15 --- /dev/null +++ b/pkg/bus/bus.go @@ -0,0 +1,164 @@ +package bus + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" +) + +// Handler is a callback function that processes a delivered message. +type Handler func(Message) + +// Subscription represents an active subscription to a message bus topic. +type Subscription interface { + ID() string + Topic() string + Unsubscribe() +} + +// MessageBus is the central communication hub of the Orca framework. +// +// It uses a publish/subscribe pattern built on Go channels. Components +// publish messages to named topics, and all subscribers to that topic +// receive the message asynchronously. +type MessageBus interface { + // Publish sends a message to all active subscribers of the given topic. + Publish(topic string, msg Message) error + // Subscribe registers a handler for the given topic. + Subscribe(topic string, handler Handler) (Subscription, error) + // Close gracefully shuts down the bus, cleaning up all subscriptions. + Close() error +} + +// subscription implements the Subscription interface. +type subscription struct { + id string + topic string + ch chan Message + bus *messageBus + active *atomic.Bool +} + +func (s *subscription) ID() string { return s.id } +func (s *subscription) Topic() string { return s.topic } +func (s *subscription) Unsubscribe() { s.bus.unsubscribe(s) } + +// messageBus is the channel-based implementation of MessageBus. +type messageBus struct { + mu sync.RWMutex + topics map[string][]*subscription + nextID int64 + closed bool +} + +// New creates a new message bus instance. +func New() MessageBus { + return &messageBus{ + topics: make(map[string][]*subscription), + } +} + +// Publish sends a message to all subscribers of the given topic. +// The send is non-blocking: if a subscriber's channel buffer is full, +// the message is dropped for that subscriber. +func (mb *messageBus) Publish(topic string, msg Message) error { + mb.mu.RLock() + defer mb.mu.RUnlock() + + if mb.closed { + return errors.New("message bus is closed") + } + + subs, ok := mb.topics[topic] + if !ok { + return nil + } + + for _, sub := range subs { + if sub.active.Load() { + select { + case sub.ch <- msg: + default: + } + } + } + return nil +} + +// Subscribe adds a handler for the given topic. +func (mb *messageBus) Subscribe(topic string, handler Handler) (Subscription, error) { + mb.mu.Lock() + defer mb.mu.Unlock() + + if mb.closed { + return nil, errors.New("message bus is closed") + } + + id := fmt.Sprintf("sub-%d", atomic.AddInt64(&mb.nextID, 1)) + sub := &subscription{ + id: id, + topic: topic, + ch: make(chan Message, 64), + bus: mb, + active: &atomic.Bool{}, + } + sub.active.Store(true) + + mb.topics[topic] = append(mb.topics[topic], sub) + + go sub.deliver(handler) + return sub, nil +} + +// deliver reads messages from the subscription channel and calls the handler. +func (s *subscription) deliver(handler Handler) { + for msg := range s.ch { + if !s.active.Load() { + return + } + handler(msg) + } +} + +// unsubscribe removes a subscription from the bus and closes its channel. +func (mb *messageBus) unsubscribe(sub *subscription) { + mb.mu.Lock() + defer mb.mu.Unlock() + + sub.active.Store(false) + + subs, ok := mb.topics[sub.Topic()] + if !ok { + return + } + + for i, s := range subs { + if s.ID() == sub.ID() { + mb.topics[sub.Topic()] = append(subs[:i], subs[i+1:]...) + close(s.ch) + return + } + } +} + +// Close shuts down the bus, unsubscribing all active subscriptions. +func (mb *messageBus) Close() error { + mb.mu.Lock() + defer mb.mu.Unlock() + + if mb.closed { + return nil + } + mb.closed = true + + for topic, subs := range mb.topics { + for _, sub := range subs { + sub.active.Store(false) + close(sub.ch) + } + delete(mb.topics, topic) + } + + return nil +} diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go new file mode 100644 index 0000000..0d6f9a2 --- /dev/null +++ b/pkg/bus/bus_test.go @@ -0,0 +1,252 @@ +package bus + +import ( + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestNewBus(t *testing.T) { + b := New() + if b == nil { + t.Fatal("New() returned nil") + } +} + +func TestPublishSubscribe(t *testing.T) { + b := New() + defer b.Close() + + var received int32 + var wg sync.WaitGroup + wg.Add(1) + + sub, err := b.Subscribe("test", func(msg Message) { + atomic.AddInt32(&received, 1) + wg.Done() + }) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer sub.Unsubscribe() + + err = b.Publish("test", Message{ + ID: "msg-1", + Type: MsgTypeSystem, + From: "test", + }) + if err != nil { + t.Fatalf("Publish failed: %v", err) + } + + wg.Wait() + + if atomic.LoadInt32(&received) != 1 { + t.Errorf("expected 1 message, got %d", received) + } +} + +func TestPublishNoSubscribers(t *testing.T) { + b := New() + defer b.Close() + + err := b.Publish("nonexistent", Message{ID: "msg-1"}) + if err != nil { + t.Fatalf("Publish to nonexistent topic should not error: %v", err) + } +} + +func TestMultipleSubscribers(t *testing.T) { + b := New() + defer b.Close() + + var received int32 + var wg sync.WaitGroup + wg.Add(3) + + for i := 0; i < 3; i++ { + sub, err := b.Subscribe("multi", func(msg Message) { + atomic.AddInt32(&received, 1) + wg.Done() + }) + if err != nil { + t.Fatalf("Subscribe %d failed: %v", i, err) + } + defer sub.Unsubscribe() + } + + err := b.Publish("multi", Message{ID: "msg-1"}) + if err != nil { + t.Fatalf("Publish failed: %v", err) + } + + wg.Wait() + + if n := atomic.LoadInt32(&received); n != 3 { + t.Errorf("expected 3 messages, got %d", n) + } +} + +func TestUnsubscribe(t *testing.T) { + b := New() + defer b.Close() + + var received int32 + sub, err := b.Subscribe("test", func(msg Message) { + atomic.AddInt32(&received, 1) + }) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + + // Publish before unsubscribe + b.Publish("test", Message{ID: "msg-1"}) + time.Sleep(50 * time.Millisecond) + + sub.Unsubscribe() + + // Publish after unsubscribe + b.Publish("test", Message{ID: "msg-2"}) + time.Sleep(50 * time.Millisecond) + + if n := atomic.LoadInt32(&received); n != 1 { + t.Errorf("expected 1 message after unsubscribe, got %d", n) + } +} + +func TestSubscribeAfterClose(t *testing.T) { + b := New() + b.Close() + + _, err := b.Subscribe("test", func(msg Message) {}) + if err == nil { + t.Error("expected error subscribing to closed bus") + } +} + +func TestPublishAfterClose(t *testing.T) { + b := New() + b.Close() + + err := b.Publish("test", Message{ID: "msg-1"}) + if err == nil { + t.Error("expected error publishing to closed bus") + } +} + +func TestSubscriptionID(t *testing.T) { + b := New() + defer b.Close() + + sub1, _ := b.Subscribe("a", func(msg Message) {}) + defer sub1.Unsubscribe() + sub2, _ := b.Subscribe("b", func(msg Message) {}) + defer sub2.Unsubscribe() + + if sub1.ID() == sub2.ID() { + t.Error("subscription IDs should be unique") + } + + if sub1.Topic() != "a" || sub2.Topic() != "b" { + t.Error("topic mismatch") + } +} + +func TestConcurrentPublish(t *testing.T) { + b := New() + defer b.Close() + + var received int32 + var wg sync.WaitGroup + wg.Add(100) + + sub, err := b.Subscribe("concurrent", func(msg Message) { + atomic.AddInt32(&received, 1) + wg.Done() + }) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer sub.Unsubscribe() + + for i := 0; i < 100; i++ { + go func(i int) { + b.Publish("concurrent", Message{ + ID: time.Now().String(), + Type: MsgTypeSystem, + }) + }(i) + } + + wg.Wait() + + if n := atomic.LoadInt32(&received); n != 100 { + t.Errorf("expected 100 messages, got %d", n) + } +} + +func TestDifferentTopics(t *testing.T) { + b := New() + defer b.Close() + + var topics []string + var mu sync.Mutex + + sub1, _ := b.Subscribe("topic-a", func(msg Message) { + mu.Lock() + topics = append(topics, "a") + mu.Unlock() + }) + defer sub1.Unsubscribe() + + sub2, _ := b.Subscribe("topic-b", func(msg Message) { + mu.Lock() + topics = append(topics, "b") + mu.Unlock() + }) + defer sub2.Unsubscribe() + + b.Publish("topic-a", Message{ID: "msg-1"}) + time.Sleep(50 * time.Millisecond) + + if len(topics) != 1 || topics[0] != "a" { + t.Errorf("expected only topic-a to receive message, got %v", topics) + } +} + +func TestCloseIdempotent(t *testing.T) { + b := New() + err1 := b.Close() + err2 := b.Close() + + if err1 != nil { + t.Fatalf("first Close failed: %v", err1) + } + if err2 != nil { + t.Fatalf("second Close should be idempotent: %v", err2) + } +} + +func TestMessageTypeString(t *testing.T) { + tests := []struct { + mt MessageType + want string + }{ + {MsgTypeSystem, "system"}, + {MsgTypeTaskRequest, "task_request"}, + {MsgTypeTaskResponse, "task_response"}, + {MsgTypeToolCall, "tool_call"}, + {MsgTypeToolResult, "tool_result"}, + {MsgTypeObservation, "observation"}, + {MsgTypeError, "error"}, + {MsgTypeLog, "log"}, + {MessageType(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.mt.String(); got != tt.want { + t.Errorf("MessageType(%d).String() = %q, want %q", tt.mt, got, tt.want) + } + } +} diff --git a/pkg/bus/types.go b/pkg/bus/types.go new file mode 100644 index 0000000..d663791 --- /dev/null +++ b/pkg/bus/types.go @@ -0,0 +1,67 @@ +// Package bus provides the message bus system for inter-component communication. +// +// The message bus is the central nervous system of the Orca framework. +// All components (kernel, plugins, agents) communicate through it +// via a publish/subscribe pattern over Go channels. +package bus + +import "time" + +// MessageType represents the category of a message in the bus system. +type MessageType int + +const ( + // MsgTypeSystem is for internal kernel messages. + MsgTypeSystem MessageType = iota + // MsgTypeTaskRequest is a request to perform a task. + MsgTypeTaskRequest + // MsgTypeTaskResponse is the result of a task. + MsgTypeTaskResponse + // MsgTypeToolCall is a request to invoke a tool. + MsgTypeToolCall + // MsgTypeToolResult is the result of a tool execution. + MsgTypeToolResult + // MsgTypeObservation is an observation from tool/command execution. + MsgTypeObservation + // MsgTypeError is an error message. + MsgTypeError + // MsgTypeLog is a log message for observability. + MsgTypeLog +) + +// String returns the human-readable name of the message type. +func (mt MessageType) String() string { + switch mt { + case MsgTypeSystem: + return "system" + case MsgTypeTaskRequest: + return "task_request" + case MsgTypeTaskResponse: + return "task_response" + case MsgTypeToolCall: + return "tool_call" + case MsgTypeToolResult: + return "tool_result" + case MsgTypeObservation: + return "observation" + case MsgTypeError: + return "error" + case MsgTypeLog: + return "log" + default: + return "unknown" + } +} + +// Message is the universal data unit in the message bus system. +// +// Every component communicates by publishing and subscribing to Messages. +type Message struct { + ID string `json:"id"` + Type MessageType `json:"type"` + From string `json:"from"` + To string `json:"to"` + Content interface{} `json:"content"` + Metadata map[string]string `json:"metadata,omitempty"` + Timestamp time.Time `json:"timestamp"` +} diff --git a/pkg/kernel/kernel.go b/pkg/kernel/kernel.go new file mode 100644 index 0000000..794639c --- /dev/null +++ b/pkg/kernel/kernel.go @@ -0,0 +1,392 @@ +// Package kernel implements the microkernel core of the Orca framework. +// +// The kernel is the minimal runtime that manages plugin lifecycle, +// message routing, and inter-component communication. +package kernel + +import ( + "context" + "fmt" + "log" + "os" + "sync" + "time" + + "github.com/orca/orca/internal/config" + "github.com/orca/orca/pkg/actor" + "github.com/orca/orca/pkg/bus" + "github.com/orca/orca/pkg/llm" + "github.com/orca/orca/pkg/plugin" + "github.com/orca/orca/pkg/session" + "github.com/orca/orca/pkg/skill" + "github.com/orca/orca/pkg/tool" +) + +// Kernel is the microkernel core of the Orca framework. +// +// It orchestrates plugin lifecycle, message routing, and inter-component +// communication. The kernel initializes and manages: +// - Message bus for inter-component communication +// - Plugin registry for extensibility +// - Session manager for conversation persistence +// - Tool manager with built-in tools +// - Skill manager for skill-based automation +// - Actor system with orchestrator, workers, and LLM agent +type Kernel struct { + mu sync.RWMutex + mb bus.MessageBus + registry *plugin.Registry + plugins []plugin.Plugin + started bool + + // Integration components + config *config.Config + sessionMgr *session.Manager + toolMgr *tool.Manager + skillMgr *skill.Manager + actorSystem *actor.System + orch *actor.Orchestrator + llmAgent *actor.LLMAgent + toolWorker *actor.ToolWorker +} + +// New creates a new Kernel instance with default configuration. +func New() *Kernel { + return NewWithConfig(config.DefaultConfig()) +} + +// NewWithConfig creates a new Kernel instance with the given configuration. +func NewWithConfig(cfg *config.Config) *Kernel { + if cfg == nil { + cfg = config.DefaultConfig() + } + + k := &Kernel{ + mb: bus.New(), + registry: plugin.NewRegistry(), + config: cfg, + actorSystem: actor.NewSystem(), + } + + // Initialize session manager + store, err := session.NewJSONLStore(cfg.Session.StorageDir) + if err != nil { + log.Printf("kernel: warning: failed to create session store: %v", err) + } else { + k.sessionMgr = session.NewManager(store, k.mb) + } + + // Initialize tool manager with all built-in tools + k.toolMgr = tool.NewManager() + k.registerBuiltinTools() + + // Initialize skill manager + k.skillMgr = skill.NewManager(cfg.Session.StorageDir + "/skills") + + // Initialize actor system + k.initializeActorSystem() + + return k +} + +// registerBuiltinTools registers all built-in tools with the tool manager. +func (k *Kernel) registerBuiltinTools() { + tools := []tool.Tool{ + tool.NewExecTool(nil), // exec - shell commands + tool.NewReadFileTool(), // read_file + tool.NewWriteFileTool(), // write_file + tool.NewListDirTool(), // list_dir + tool.NewSearchFilesTool(), // search_files + } + + for _, t := range tools { + if err := k.toolMgr.Register(t); err != nil { + log.Printf("kernel: warning: failed to register tool %q: %v", t.Name(), err) + } + } +} + +// initializeActorSystem sets up the orchestrator, tool worker, and LLM agent. +func (k *Kernel) initializeActorSystem() { + // Create orchestrator + orch, err := k.actorSystem.CreateOrchestrator(k) + if err != nil { + log.Printf("kernel: warning: failed to create orchestrator: %v", err) + return + } + k.orch = orch + + // Create tool worker + tw, err := k.actorSystem.CreateToolWorker(k.toolMgr) + if err != nil { + log.Printf("kernel: warning: failed to create tool worker: %v", err) + return + } + k.toolWorker = tw + + // Create LLM backend + ollama := k.createLLMBackend() + + // Create LLM agent + llmAgentID := fmt.Sprintf("llm-%d", len(k.actorSystem.ListAgents())+1) + llmOpts := []actor.LLMAgentOption{ + actor.WithToolManager(k.toolMgr), + actor.WithToolWorker(k.toolWorker), + actor.WithWindowSize(k.config.Session.MaxHistory), + } + + if k.sessionMgr != nil { + sessionID := "default" + if _, err := k.sessionMgr.GetSession(sessionID); err != nil { + k.sessionMgr.CreateSession(sessionID, map[string]string{ + "source": "kernel", + }) + } + + llmOpts = append(llmOpts, + actor.WithSessionManager(k.sessionMgr), + actor.WithSessionID(sessionID), + ) + } + + llmAgent := actor.NewLLMAgent(llmAgentID, ollama, llmOpts...) + k.llmAgent = llmAgent + + // Register LLM agent as orchestrator's worker + k.orch.AddWorker(llmAgent) + + // Also register tool worker as a fallback worker + k.orch.AddWorker(tw) +} + +// createLLMBackend creates the LLM backend based on configuration. +func (k *Kernel) createLLMBackend() llm.LLM { + baseURL := k.config.Ollama.BaseURL + model := k.config.Ollama.Model + timeout := k.config.Ollama.Timeout + + // Allow shorter env var names to override + if v := os.Getenv("OLLAMA_BASE_URL"); v != "" { + baseURL = v + } + if v := os.Getenv("OLLAMA_MODEL"); v != "" { + model = v + } + if v := os.Getenv("OLLAMA_TIMEOUT"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + timeout = d + } + } + + client := llm.NewOllamaClient( + llm.WithBaseURL(baseURL), + llm.WithModel(model), + llm.WithTimeout(timeout), + ) + + log.Printf("kernel: created Ollama client (model=%s, url=%s)", model, baseURL) + return client +} + +// Bus returns the kernel's message bus. +func (k *Kernel) Bus() bus.MessageBus { + return k.mb +} + +// Registry returns the plugin registry. +func (k *Kernel) Registry() *plugin.Registry { + return k.registry +} + +// SessionManager returns the session manager. +func (k *Kernel) SessionManager() *session.Manager { + return k.sessionMgr +} + +// ToolManager returns the tool manager. +func (k *Kernel) ToolManager() *tool.Manager { + return k.toolMgr +} + +// SkillManager returns the skill manager. +func (k *Kernel) SkillManager() *skill.Manager { + return k.skillMgr +} + +// ActorSystem returns the actor system. +func (k *Kernel) ActorSystem() *actor.System { + return k.actorSystem +} + +// Orchestrator returns the orchestrator agent. +func (k *Kernel) Orchestrator() *actor.Orchestrator { + return k.orch +} + +// LLMAgent returns the LLM agent. +func (k *Kernel) LLMAgent() *actor.LLMAgent { + return k.llmAgent +} + +// SendMessage sends a message from a source to the LLM agent. +// +// This is the primary public API for interacting with the Orca system. +// It creates a task request message and sends it through the orchestrator +// to the LLM agent for processing. +// +// Parameters: +// - from: the sender identifier (e.g., "user", "cli") +// - to: the recipient (use "llm" for the LLM agent) +// - content: the message content (plain text) +// +// Returns the response content as a string, or an error. +func (k *Kernel) SendMessage(from, to, content string) (string, error) { + if !k.IsRunning() { + return "", fmt.Errorf("kernel: kernel is not running") + } + + if k.orch == nil { + return "", fmt.Errorf("kernel: orchestrator not initialized") + } + + // Create a task request message + msg := bus.Message{ + Type: bus.MsgTypeTaskRequest, + From: from, + To: to, + Content: content, + } + + // Send through the orchestrator + ctx := context.Background() + resp, err := k.orch.Process(ctx, msg) + if err != nil { + return "", fmt.Errorf("kernel: orchestrator processing failed: %w", err) + } + + // Extract response content + switch v := resp.Content.(type) { + case string: + return v, nil + default: + return fmt.Sprintf("%v", v), nil + } +} + +// InitPlugins loads and initializes skills from the skills directory. +func (k *Kernel) InitPlugins() error { + if k.skillMgr == nil { + return nil + } + + count, err := k.skillMgr.LoadAll() + if err != nil { + log.Printf("kernel: warning: skill loading had errors: %v", err) + } + if count > 0 { + log.Printf("kernel: loaded %d skills", count) + } + return nil +} + +// GetPlugin returns a registered plugin by name. +func (k *Kernel) GetPlugin(name string) (plugin.Plugin, bool) { + return k.registry.Get(name) +} + +// ListPlugins returns all currently registered plugins. +func (k *Kernel) ListPlugins() []plugin.Plugin { + return k.registry.List() +} + +// RegisterPlugin registers a plugin without starting it. +func (k *Kernel) RegisterPlugin(p plugin.Plugin) error { + k.mu.Lock() + defer k.mu.Unlock() + + if k.started { + return fmt.Errorf("kernel: cannot register plugin %q: kernel already started", p.Name()) + } + + return k.registry.Register(p) +} + +// UnregisterPlugin removes a plugin from the registry. +func (k *Kernel) UnregisterPlugin(name string) error { + k.mu.Lock() + defer k.mu.Unlock() + + return k.registry.Unregister(name) +} + +// Start initializes all registered plugins and marks the kernel as running. +func (k *Kernel) Start() error { + k.mu.Lock() + defer k.mu.Unlock() + + if k.started { + return fmt.Errorf("kernel: already started") + } + + k.started = true + + // Initialize plugins + plugins := k.registry.List() + k.plugins = make([]plugin.Plugin, 0, len(plugins)) + + for _, p := range plugins { + k.registry.SetState(p.Name(), plugin.StateInitialized) + if err := p.Init(k); err != nil { + log.Printf("kernel: warning: failed to init plugin %q: %v", p.Name(), err) + k.registry.SetState(p.Name(), plugin.StateError) + continue + } + k.registry.SetState(p.Name(), plugin.StateRunning) + k.plugins = append(k.plugins, p) + log.Printf("kernel: plugin %q (%s) initialized", p.Name(), p.Version()) + } + + log.Printf("kernel: started (tools=%d)", k.toolMgr.Count()) + + return nil +} + +// Stop gracefully shuts down the kernel. +func (k *Kernel) Stop() error { + k.mu.Lock() + defer k.mu.Unlock() + + if !k.started { + return nil + } + + // Stop actor system first + if k.actorSystem != nil { + if err := k.actorSystem.StopAll(); err != nil { + log.Printf("kernel: warning: error stopping actor system: %v", err) + } + } + + // Stop plugins + for i := len(k.plugins) - 1; i >= 0; i-- { + p := k.plugins[i] + k.registry.SetState(p.Name(), plugin.StateStopped) + if err := p.Shutdown(); err != nil { + log.Printf("kernel: warning: error shutting down plugin %q: %v", p.Name(), err) + continue + } + log.Printf("kernel: plugin %q shut down", p.Name()) + } + + k.plugins = nil + k.started = false + + return k.mb.Close() +} + +// IsRunning returns whether the kernel has been started and not yet stopped. +func (k *Kernel) IsRunning() bool { + k.mu.RLock() + defer k.mu.RUnlock() + return k.started +} diff --git a/pkg/kernel/kernel_test.go b/pkg/kernel/kernel_test.go new file mode 100644 index 0000000..df51dd8 --- /dev/null +++ b/pkg/kernel/kernel_test.go @@ -0,0 +1,343 @@ +package kernel + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/orca/orca/pkg/bus" + "github.com/orca/orca/pkg/plugin" +) + +// testPlugin implements Plugin for kernel testing. +type testPlugin struct { + name string + version string + initFn func(host plugin.PluginHost) error + closeFn func() error +} + +func (p *testPlugin) Name() string { return p.name } +func (p *testPlugin) Version() string { return p.version } +func (p *testPlugin) Init(host plugin.PluginHost) error { + if p.initFn != nil { + return p.initFn(host) + } + return nil +} +func (p *testPlugin) Shutdown() error { + if p.closeFn != nil { + return p.closeFn() + } + return nil +} + +func TestNewKernel(t *testing.T) { + k := New() + if k == nil { + t.Fatal("New() returned nil") + } + if k.Bus() == nil { + t.Error("Bus() returned nil") + } + if k.Registry() == nil { + t.Error("Registry() returned nil") + } +} + +func TestKernelStartStop(t *testing.T) { + k := New() + + if err := k.Start(); err != nil { + t.Fatalf("Start failed: %v", err) + } + if !k.IsRunning() { + t.Error("expected kernel running after Start") + } + + if err := k.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } + if k.IsRunning() { + t.Error("expected kernel stopped after Stop") + } +} + +func TestKernelDoubleStart(t *testing.T) { + k := New() + k.Start() + err := k.Start() + if err == nil { + t.Error("expected error on double start") + } + k.Stop() +} + +func TestKernelRegisterPlugin(t *testing.T) { + k := New() + p := &testPlugin{name: "test", version: "1.0.0"} + + err := k.RegisterPlugin(p) + if err != nil { + t.Fatalf("RegisterPlugin failed: %v", err) + } + + got, ok := k.GetPlugin("test") + if !ok { + t.Fatal("GetPlugin returned not found") + } + if got.Name() != "test" { + t.Errorf("expected name 'test', got %q", got.Name()) + } +} + +func TestKernelRegisterPluginAfterStart(t *testing.T) { + k := New() + k.Start() + defer k.Stop() + + err := k.RegisterPlugin(&testPlugin{name: "test", version: "1.0.0"}) + if err == nil { + t.Error("expected error registering plugin after start") + } +} + +func TestKernelPluginLifecycle(t *testing.T) { + k := New() + + var initCount int32 + var shutdownCount int32 + + p := &testPlugin{ + name: "lifecycle", + version: "1.0.0", + initFn: func(host plugin.PluginHost) error { + atomic.AddInt32(&initCount, 1) + return nil + }, + closeFn: func() error { + atomic.AddInt32(&shutdownCount, 1) + return nil + }, + } + + k.RegisterPlugin(p) + k.Start() + + if n := atomic.LoadInt32(&initCount); n != 1 { + t.Errorf("expected init called once, got %d", n) + } + + k.Stop() + + if n := atomic.LoadInt32(&shutdownCount); n != 1 { + t.Errorf("expected shutdown called once, got %d", n) + } +} + +func TestKernelPluginInitFailure(t *testing.T) { + k := New() + + p := &testPlugin{ + name: "failing", + version: "1.0.0", + initFn: func(host plugin.PluginHost) error { + return errors.New("init failed") + }, + } + + k.RegisterPlugin(p) + + // Init failure should not prevent Start from succeeding (graceful degradation) + err := k.Start() + if err != nil { + t.Fatalf("Start should succeed even with failing plugin: %v", err) + } + k.Stop() +} + +func TestKernelPluginShutdownFailure(t *testing.T) { + k := New() + + p := &testPlugin{ + name: "failing-shutdown", + version: "1.0.0", + initFn: func(host plugin.PluginHost) error { + return nil + }, + closeFn: func() error { + return errors.New("shutdown failed") + }, + } + + k.RegisterPlugin(p) + k.Start() + + // Shutdown failure should not prevent Stop from succeeding + err := k.Stop() + if err != nil { + t.Fatalf("Stop should succeed even with failing plugin shutdown: %v", err) + } +} + +func TestKernelMultiplePlugins(t *testing.T) { + k := New() + + names := []string{"alpha", "beta", "gamma"} + for _, name := range names { + k.RegisterPlugin(&testPlugin{name: name, version: "1.0.0"}) + } + + k.Start() + + plugins := k.ListPlugins() + if len(plugins) != len(names) { + t.Errorf("expected %d plugins, got %d", len(names), len(plugins)) + } + + k.Stop() +} + +func TestKernelUnregisterPlugin(t *testing.T) { + k := New() + + k.RegisterPlugin(&testPlugin{name: "remove-me", version: "1.0.0"}) + + err := k.UnregisterPlugin("remove-me") + if err != nil { + t.Fatalf("UnregisterPlugin failed: %v", err) + } + + _, ok := k.GetPlugin("remove-me") + if ok { + t.Error("plugin should not exist after unregister") + } +} + +func TestKernelStopWithoutStart(t *testing.T) { + k := New() + err := k.Stop() + if err != nil { + t.Fatalf("Stop without Start should be a no-op: %v", err) + } +} + +func TestKernelPluginReceivesHost(t *testing.T) { + k := New() + + var gotHost plugin.PluginHost + p := &testPlugin{ + name: "host-check", + version: "1.0.0", + initFn: func(host plugin.PluginHost) error { + gotHost = host + return nil + }, + } + + k.RegisterPlugin(p) + k.Start() + + if gotHost == nil { + t.Fatal("plugin did not receive PluginHost") + } + + // Verify the host can access bus + if gotHost.Bus() == nil { + t.Error("PluginHost.Bus() returned nil") + } + + // Verify plugin discovery through host + p2, ok := gotHost.GetPlugin("host-check") + if !ok { + t.Error("PluginHost.GetPlugin should find itself") + } + if p2.Name() != "host-check" { + t.Errorf("expected name 'host-check', got %q", p2.Name()) + } + + k.Stop() +} + +func TestKernelAllPluginsInitialized(t *testing.T) { + k := New() + + names := []string{"a", "b", "c"} + initialized := make(map[string]bool) + + for _, name := range names { + n := name + k.RegisterPlugin(&testPlugin{ + name: n, + initFn: func(host plugin.PluginHost) error { + initialized[n] = true + return nil + }, + }) + } + + k.Start() + + for _, name := range names { + if !initialized[name] { + t.Errorf("plugin %q was not initialized", name) + } + } + + k.Stop() +} + +func TestKernelShutdownAllPlugins(t *testing.T) { + k := New() + + names := []string{"x", "y", "z"} + shutdown := make(map[string]bool) + + for _, name := range names { + n := name + k.RegisterPlugin(&testPlugin{ + name: n, + initFn: func(host plugin.PluginHost) error { + return nil + }, + closeFn: func() error { + shutdown[n] = true + return nil + }, + }) + } + + k.Start() + k.Stop() + + for _, name := range names { + if !shutdown[name] { + t.Errorf("plugin %q was not shut down", name) + } + } +} + +func TestKernelMessageBusIntegration(t *testing.T) { + k := New() + k.Start() + defer k.Stop() + + mb := k.Bus() + + var received int32 + sub, err := mb.Subscribe("kernel-test", func(msg bus.Message) { + atomic.AddInt32(&received, 1) + }) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer sub.Unsubscribe() + + mb.Publish("kernel-test", bus.Message{ID: "test-msg"}) + time.Sleep(50 * time.Millisecond) + + if n := atomic.LoadInt32(&received); n != 1 { + t.Errorf("expected 1 message via kernel bus, got %d", n) + } +} diff --git a/pkg/llm/llm.go b/pkg/llm/llm.go new file mode 100644 index 0000000..f49963d --- /dev/null +++ b/pkg/llm/llm.go @@ -0,0 +1,24 @@ +// Package llm provides the LLM integration layer for the Orca framework. +// +// It defines the LLM interface for interacting with language models, +// the Ollama client implementation, and the shared types for chat +// messages, tool calls, and streaming responses. +package llm + +import "context" + +// LLM is the interface for interacting with language models. +// +// Implementations provide Chat (for complete responses) and Stream +// (for streaming token-by-token responses) methods. Both methods +// accept a list of messages and return the model's response. +type LLM interface { + // Chat sends a list of messages to the LLM and returns a complete response. + // If the model decides to call tools, the response contains ToolCalls. + Chat(ctx context.Context, messages []Message) (*Response, error) + + // Stream sends messages and streams the response token-by-token. + // The handler is called for each chunk. The final response is not + // collected; use Chat for complete responses. + Stream(ctx context.Context, messages []Message, handler StreamHandler) error +} diff --git a/pkg/llm/ollama.go b/pkg/llm/ollama.go new file mode 100644 index 0000000..cf0aa41 --- /dev/null +++ b/pkg/llm/ollama.go @@ -0,0 +1,301 @@ +package llm + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// OllamaClient implements the LLM interface for Ollama's API. +// +// It communicates with a running Ollama server via its REST API. +// Supports chat, streaming, tool calling (function calling), and +// embedding generation. +type OllamaClient struct { + baseURL string + model string + httpClient *http.Client +} + +// OllamaOption is a functional option for configuring the OllamaClient. +type OllamaOption func(*OllamaClient) + +// WithBaseURL sets the Ollama server base URL. +func WithBaseURL(url string) OllamaOption { + return func(c *OllamaClient) { + c.baseURL = strings.TrimRight(url, "/") + } +} + +// WithModel sets the default model name. +func WithModel(model string) OllamaOption { + return func(c *OllamaClient) { + c.model = model + } +} + +// WithTimeout sets the HTTP client timeout. +func WithTimeout(timeout time.Duration) OllamaOption { + return func(c *OllamaClient) { + c.httpClient.Timeout = timeout + } +} + +// WithHTTPClient sets a custom HTTP client. +func WithHTTPClient(client *http.Client) OllamaOption { + return func(c *OllamaClient) { + c.httpClient = client + } +} + +// NewOllamaClient creates a new OllamaClient with the given options. +// +// Default values: +// - BaseURL: http://localhost:11434 +// - Model: gemma4:e4b +// - Timeout: 30s +func NewOllamaClient(opts ...OllamaOption) *OllamaClient { + c := &OllamaClient{ + baseURL: "http://localhost:11434", + model: "gemma4:e4b", + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } + + for _, opt := range opts { + opt(c) + } + + return c +} + +// Chat sends a chat request to Ollama and returns the complete response. +// If the Ollama model returns tool calls, they are parsed and included +// in the Response. +func (c *OllamaClient) Chat(ctx context.Context, messages []Message) (*Response, error) { + req := OllamaChatRequest{ + Model: c.model, + Messages: messages, + Stream: false, + } + + // Build tool definitions if tool package is integrated + // (tools are added externally via BuildToolDefs) + + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("ollama: failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, + c.baseURL+"/api/chat", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("ollama: failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ollama: request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("ollama: API error (status %d): %s", + resp.StatusCode, string(respBody)) + } + + var rawResp json.RawMessage + if err := json.NewDecoder(resp.Body).Decode(&rawResp); err != nil { + return nil, fmt.Errorf("ollama: failed to decode response: %w", err) + } + + return parseOllamaResponse(rawResp) +} + +// parseOllamaResponse attempts to parse the Ollama API response, +// handling both regular text responses and tool call responses. +func parseOllamaResponse(raw json.RawMessage) (*Response, error) { + // Try as tool call response first (has message.tool_calls) + var toolResp OllamaToolCallResponse + if err := json.Unmarshal(raw, &toolResp); err == nil && len(toolResp.Message.ToolCalls) > 0 { + return &Response{ + Content: toolResp.Message.Content, + ToolCalls: toolResp.Message.ToolCalls, + }, nil + } + + // Try as regular response + var chatResp OllamaChatResponse + if err := json.Unmarshal(raw, &chatResp); err != nil { + return nil, fmt.Errorf("ollama: failed to parse response: %w", err) + } + + return &Response{ + Content: chatResp.Message.Content, + }, nil +} + +// Stream sends a chat request to Ollama with streaming enabled. +// The handler receives each content chunk as it arrives. +func (c *OllamaClient) Stream(ctx context.Context, messages []Message, handler StreamHandler) error { + req := OllamaChatRequest{ + Model: c.model, + Messages: messages, + Stream: true, + } + + body, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("ollama: failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, + c.baseURL+"/api/chat", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("ollama: failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return fmt.Errorf("ollama: stream request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("ollama: stream API error (status %d): %s", + resp.StatusCode, string(respBody)) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + + // Each line is a JSON object: {"model":"...","created_at":"...","message":{"role":"assistant","content":"..."},"done":false} + var streamResp OllamaChatResponse + if err := json.Unmarshal([]byte(line), &streamResp); err != nil { + continue // Skip malformed lines + } + + if streamResp.Message.Content != "" { + if err := handler(streamResp.Message.Content); err != nil { + return err + } + } + + if streamResp.Done { + break + } + } + + return scanner.Err() +} + +// Embed generates an embedding vector for the given input text. +func (c *OllamaClient) Embed(ctx context.Context, input string) (*EmbeddingResponse, error) { + req := OllamaEmbedRequest{ + Model: c.model, + Input: input, + } + + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("ollama: failed to marshal embed request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, + c.baseURL+"/api/embed", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("ollama: failed to create embed request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ollama: embed request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("ollama: embed API error (status %d): %s", + resp.StatusCode, string(respBody)) + } + + var apiResp OllamaEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { + return nil, fmt.Errorf("ollama: failed to decode embed response: %w", err) + } + + return &EmbeddingResponse{ + Embedding: apiResp.Embedding, + }, nil +} + +// BuildToolDefsFromMap converts a generic tool definition map into Ollama ToolDefs. +// This is used to bridge the tool package's Tool interface with Ollama's API format. +func BuildToolDefsFromMap(tools []map[string]interface{}) []ToolDef { + var defs []ToolDef + for _, t := range tools { + name, _ := t["name"].(string) + desc, _ := t["description"].(string) + + def := ToolDef{ + Type: "function", + Function: ToolFunction{ + Name: name, + Description: desc, + Parameters: ToolFunctionParameters{ + Type: "object", + Properties: make(map[string]ToolProperty), + }, + }, + } + + if params, ok := t["parameters"].(map[string]interface{}); ok { + if props, ok := params["properties"].(map[string]interface{}); ok { + for key, val := range props { + if p, ok := val.(map[string]interface{}); ok { + prop := ToolProperty{ + Type: toString(p["type"]), + Description: toString(p["description"]), + } + def.Function.Parameters.Properties[key] = prop + if isRequired(p) { + def.Function.Parameters.Required = append(def.Function.Parameters.Required, key) + } + } + } + } + } + + defs = append(defs, def) + } + return defs +} + +func toString(v interface{}) string { + if v == nil { + return "" + } + s, _ := v.(string) + return s +} + +func isRequired(p map[string]interface{}) bool { + req, _ := p["required"].(bool) + return req +} diff --git a/pkg/llm/ollama_test.go b/pkg/llm/ollama_test.go new file mode 100644 index 0000000..e33f033 --- /dev/null +++ b/pkg/llm/ollama_test.go @@ -0,0 +1,385 @@ +package llm + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// ============================================================ +// Helper: create a mock Ollama server +// ============================================================ + +// mockOllamaHandler returns an http.Handler that simulates the Ollama API. +func mockOllamaHandler(t *testing.T, responseFunc func(reqBody map[string]interface{}) (int, interface{})) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request path + if r.URL.Path != "/api/chat" && r.URL.Path != "/api/embed" { + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + + // Decode request body + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + + status, resp := responseFunc(reqBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("failed to encode response: %v", err) + } + })) +} + +// ============================================================ +// NewOllamaClient Tests +// ============================================================ + +func TestNewOllamaClientDefaults(t *testing.T) { + c := NewOllamaClient() + if c == nil { + t.Fatal("NewOllamaClient() returned nil") + } + if c.baseURL != "http://localhost:11434" { + t.Errorf("expected default base URL 'http://localhost:11434', got %q", c.baseURL) + } + if c.model != "gemma4:e4b" { + t.Errorf("expected default model 'gemma4:e4b', got %q", c.model) + } +} + +func TestNewOllamaClientWithOptions(t *testing.T) { + c := NewOllamaClient( + WithBaseURL("http://custom:11434"), + WithModel("codellama"), + WithTimeout(60), + ) + if c.baseURL != "http://custom:11434" { + t.Errorf("expected base URL 'http://custom:11434', got %q", c.baseURL) + } + if c.model != "codellama" { + t.Errorf("expected model 'codellama', got %q", c.model) + } +} + +// ============================================================ +// Chat Tests +// ============================================================ + +func TestChat(t *testing.T) { + srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) { + // Verify the request has the expected shape + if model, ok := reqBody["model"]; !ok || model != "gemma4:e4b" { + t.Errorf("expected model 'gemma4:e4b', got %v", model) + } + if stream, ok := reqBody["stream"]; !ok || stream != false { + t.Errorf("expected stream false, got %v", stream) + } + + return http.StatusOK, OllamaChatResponse{ + Model: "gemma4:e4b", + Message: Message{ + Role: "assistant", + Content: "Hello! How can I help you?", + }, + Done: true, + } + }) + defer srv.Close() + + client := NewOllamaClient(WithBaseURL(srv.URL)) + resp, err := client.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }) + if err != nil { + t.Fatalf("Chat failed: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("expected content 'Hello! How can I help you?', got %q", resp.Content) + } + if len(resp.ToolCalls) != 0 { + t.Errorf("expected no tool calls, got %d", len(resp.ToolCalls)) + } +} + +func TestChatWithToolCalls(t *testing.T) { + srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) { + return http.StatusOK, OllamaToolCallResponse{ + Model: "gemma4:e4b", + Message: OllamaToolMsg{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call-1", + Type: "function", + Function: FunctionCall{ + Name: "exec", + Arguments: `{"command":"ls -la"}`, + }, + }, + }, + }, + Done: true, + } + }) + defer srv.Close() + + client := NewOllamaClient(WithBaseURL(srv.URL)) + resp, err := client.Chat(context.Background(), []Message{ + {Role: "user", Content: "List files"}, + }) + if err != nil { + t.Fatalf("Chat failed: %v", err) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Function.Name != "exec" { + t.Errorf("expected tool name 'exec', got %q", resp.ToolCalls[0].Function.Name) + } + if resp.ToolCalls[0].Function.Arguments != `{"command":"ls -la"}` { + t.Errorf("unexpected arguments: %q", resp.ToolCalls[0].Function.Arguments) + } +} + +func TestChatAPIError(t *testing.T) { + srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) { + return http.StatusInternalServerError, map[string]string{"error": "internal error"} + }) + defer srv.Close() + + client := NewOllamaClient(WithBaseURL(srv.URL)) + _, err := client.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }) + if err == nil { + t.Fatal("expected error for API error response") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected error to contain status code, got: %v", err) + } +} + +func TestChatContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + client := NewOllamaClient(WithBaseURL("http://localhost:11434")) + _, err := client.Chat(ctx, []Message{{Role: "user", Content: "Hello"}}) + if err == nil { + t.Error("expected error for cancelled context") + } +} + +// ============================================================ +// Stream Tests +// ============================================================ + +func TestStream(t *testing.T) { + chunks := []string{"Hello", "!", " How", " can", " I", " help?"} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/chat" { + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("expected http.Flusher") + } + + for _, chunk := range chunks { + resp := OllamaChatResponse{ + Model: "gemma4:e4b", + Message: Message{ + Role: "assistant", + Content: chunk, + }, + Done: false, + } + data, _ := json.Marshal(resp) + w.Write(append(data, '\n')) + flusher.Flush() + } + + // Send done signal + doneResp := OllamaChatResponse{ + Model: "gemma4:e4b", + Message: Message{ + Role: "assistant", + Content: "", + }, + Done: true, + } + data, _ := json.Marshal(doneResp) + w.Write(append(data, '\n')) + flusher.Flush() + })) + defer srv.Close() + + client := NewOllamaClient(WithBaseURL(srv.URL)) + + var received []string + err := client.Stream(context.Background(), []Message{{Role: "user", Content: "Hi"}}, + func(chunk string) error { + received = append(received, chunk) + return nil + }) + if err != nil { + t.Fatalf("Stream failed: %v", err) + } + + if len(received) != len(chunks) { + t.Errorf("expected %d chunks, got %d", len(chunks), len(received)) + } +} + +func TestStreamHandlerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + + resp := OllamaChatResponse{ + Model: "gemma4:e4b", + Message: Message{ + Role: "assistant", + Content: "chunk", + }, + Done: false, + } + data, _ := json.Marshal(resp) + w.Write(append(data, '\n')) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer srv.Close() + + client := NewOllamaClient(WithBaseURL(srv.URL)) + err := client.Stream(context.Background(), []Message{{Role: "user", Content: "Hi"}}, + func(chunk string) error { + return &streamError{msg: "handler error"} + }) + if err == nil { + t.Fatal("expected error from handler") + } + if !strings.Contains(err.Error(), "handler error") { + t.Errorf("expected 'handler error', got: %v", err) + } +} + +type streamError struct{ msg string } + +func (e *streamError) Error() string { return e.msg } + +// ============================================================ +// Embed Tests +// ============================================================ + +func TestEmbed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/embed" { + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(OllamaEmbedResponse{ + Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5}, + }) + })) + defer srv.Close() + + client := NewOllamaClient(WithBaseURL(srv.URL)) + resp, err := client.Embed(context.Background(), "test input") + if err != nil { + t.Fatalf("Embed failed: %v", err) + } + if len(resp.Embedding) != 5 { + t.Errorf("expected 5 embedding values, got %d", len(resp.Embedding)) + } + if resp.Embedding[0] != 0.1 { + t.Errorf("expected first value 0.1, got %f", resp.Embedding[0]) + } +} + +// ============================================================ +// Tool Def Builder Tests +// ============================================================ + +func TestBuildToolDefsFromMap(t *testing.T) { + tools := []map[string]interface{}{ + { + "name": "exec", + "description": "Execute a shell command", + "parameters": map[string]interface{}{ + "properties": map[string]interface{}{ + "command": map[string]interface{}{ + "type": "string", + "description": "Command to run", + "required": true, + }, + "timeout": map[string]interface{}{ + "type": "number", + "description": "Timeout in seconds", + "required": false, + }, + }, + }, + }, + } + + defs := BuildToolDefsFromMap(tools) + if len(defs) != 1 { + t.Fatalf("expected 1 tool def, got %d", len(defs)) + } + if defs[0].Function.Name != "exec" { + t.Errorf("expected name 'exec', got %q", defs[0].Function.Name) + } + if len(defs[0].Function.Parameters.Required) != 1 { + t.Errorf("expected 1 required parameter, got %d", len(defs[0].Function.Parameters.Required)) + } + if defs[0].Function.Parameters.Required[0] != "command" { + t.Errorf("expected required 'command', got %q", defs[0].Function.Parameters.Required[0]) + } +} + +// ============================================================ +// Mock LLM (for use by other tests) +// ============================================================ + +// MockLLM is a configurable mock implementation of the LLM interface for testing. +type MockLLM struct { + ChatFunc func(ctx context.Context, messages []Message) (*Response, error) + StreamFunc func(ctx context.Context, messages []Message, handler StreamHandler) error +} + +func (m *MockLLM) Chat(ctx context.Context, messages []Message) (*Response, error) { + if m.ChatFunc != nil { + return m.ChatFunc(ctx, messages) + } + return &Response{Content: "mock response"}, nil +} + +func (m *MockLLM) Stream(ctx context.Context, messages []Message, handler StreamHandler) error { + if m.StreamFunc != nil { + return m.StreamFunc(ctx, messages, handler) + } + return handler("mock stream response") +} diff --git a/pkg/llm/types.go b/pkg/llm/types.go new file mode 100644 index 0000000..223b970 --- /dev/null +++ b/pkg/llm/types.go @@ -0,0 +1,128 @@ +// Package llm provides the LLM integration layer for the Orca framework. +// +// It defines the LLM interface for interacting with language models, +// the Ollama client implementation, and the shared types for chat +// messages, tool calls, and streaming responses. +package llm + +// Message represents a single message in a chat conversation. +// +// The Role field identifies the sender: "user", "assistant", "system", +// or "tool". For tool results, ToolCallID links the result back to the +// tool call that produced it. +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// ToolCall represents a function calling request from the LLM. +// +// Ollama returns tool calls in the format expected by OpenAI-compatible +// APIs: each call has a unique ID, a type ("function"), and a Function +// object containing the tool name and JSON-encoded arguments. +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function FunctionCall `json:"function"` +} + +// FunctionCall holds the name and arguments for a tool invocation. +// +// Arguments is a raw JSON string that should be unmarshalled into +// the tool's expected argument shape. +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` // JSON-encoded arguments +} + +// Response represents a complete (non-streaming) response from an LLM. +// +// If the LLM decides to invoke tools, Content may be empty and ToolCalls +// will contain one or more entries. The caller should execute each tool +// call and feed the results back to the LLM. +type Response struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// StreamHandler is a callback function for processing streaming response chunks. +// +// Each chunk is a partial string of the ongoing response. The handler is +// called sequentially for each chunk. Returning an error stops the stream. +type StreamHandler func(chunk string) error + +// EmbeddingResponse holds the result of an embedding request. +type EmbeddingResponse struct { + Embedding []float64 `json:"embedding"` +} + +// OllamaChatRequest is the request body sent to Ollama's /api/chat endpoint. +type OllamaChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + Tools []ToolDef `json:"tools,omitempty"` +} + +// OllamaChatResponse is the response body from Ollama's /api/chat endpoint. +type OllamaChatResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message Message `json:"message"` + Done bool `json:"done"` +} + +// OllamaToolCallResponse is the wire format Ollama returns for tool calls. +type OllamaToolCallResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message OllamaToolMsg `json:"message"` + Done bool `json:"done"` +} + +// OllamaToolMsg wraps the tool_calls field in Ollama's response. +type OllamaToolMsg struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// ToolDef describes a tool that the LLM may call, in the format +// expected by Ollama's function calling API. +type ToolDef struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +// ToolFunction describes a function available to the LLM. +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters ToolFunctionParameters `json:"parameters"` +} + +// ToolFunctionParameters is the JSON Schema for a tool's parameters. +type ToolFunctionParameters struct { + Type string `json:"type"` + Required []string `json:"required,omitempty"` + Properties map[string]ToolProperty `json:"properties"` +} + +// ToolProperty describes a single parameter of a tool function. +type ToolProperty struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +// OllamaEmbedRequest is the request body for Ollama's /api/embed endpoint. +type OllamaEmbedRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +// OllamaEmbedResponse is the response body from Ollama's /api/embed endpoint. +type OllamaEmbedResponse struct { + Embedding []float64 `json:"embedding"` +} diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go new file mode 100644 index 0000000..2dccf89 --- /dev/null +++ b/pkg/plugin/plugin.go @@ -0,0 +1,58 @@ +// Package plugin defines the plugin system for the Orca framework. +// +// All extensions to the framework (skills, tools, LLM drivers, etc.) +// are implemented as plugins that implement the Plugin interface. +// The kernel manages plugin lifecycle: load, init, start, stop, shutdown. +package plugin + +import "github.com/orca/orca/pkg/bus" + +// PluginState represents the current lifecycle state of a plugin. +type PluginState int + +const ( + StateUnknown PluginState = iota + StateRegistered + StateInitialized + StateRunning + StateStopped + StateError +) + +func (ps PluginState) String() string { + switch ps { + case StateUnknown: + return "unknown" + case StateRegistered: + return "registered" + case StateInitialized: + return "initialized" + case StateRunning: + return "running" + case StateStopped: + return "stopped" + case StateError: + return "error" + default: + return "unknown" + } +} + +// PluginHost is the interface that the kernel provides to plugins. +// +// Plugins receive a PluginHost reference during Init() to interact +// with the framework: publishing/subscribing to messages, discovering +// other plugins, and accessing shared resources. +type PluginHost interface { + Bus() bus.MessageBus + GetPlugin(name string) (Plugin, bool) + ListPlugins() []Plugin +} + +// Plugin defines the interface that all Orca plugins must implement. +type Plugin interface { + Name() string + Version() string + Init(host PluginHost) error + Shutdown() error +} diff --git a/pkg/plugin/registry.go b/pkg/plugin/registry.go new file mode 100644 index 0000000..69b2127 --- /dev/null +++ b/pkg/plugin/registry.go @@ -0,0 +1,100 @@ +package plugin + +import ( + "fmt" + "sync" +) + +// Registry is a thread-safe map that manages plugin registration. +type Registry struct { + mu sync.RWMutex + plugins map[string]Plugin + states map[string]PluginState +} + +// NewRegistry creates a new empty plugin registry. +func NewRegistry() *Registry { + return &Registry{ + plugins: make(map[string]Plugin), + states: make(map[string]PluginState), + } +} + +// Register adds a plugin to the registry. +func (r *Registry) Register(p Plugin) error { + r.mu.Lock() + defer r.mu.Unlock() + + name := p.Name() + if _, exists := r.plugins[name]; exists { + return fmt.Errorf("plugin %q is already registered", name) + } + + r.plugins[name] = p + r.states[name] = StateRegistered + return nil +} + +// Unregister removes a plugin from the registry. +func (r *Registry) Unregister(name string) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.plugins[name]; !exists { + return fmt.Errorf("plugin %q is not registered", name) + } + + delete(r.plugins, name) + delete(r.states, name) + return nil +} + +// Get retrieves a plugin by name. +func (r *Registry) Get(name string) (Plugin, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + p, ok := r.plugins[name] + return p, ok +} + +// List returns all registered plugins. +func (r *Registry) List() []Plugin { + r.mu.RLock() + defer r.mu.RUnlock() + + plugins := make([]Plugin, 0, len(r.plugins)) + for _, p := range r.plugins { + plugins = append(plugins, p) + } + return plugins +} + +// State returns the lifecycle state of a registered plugin. +func (r *Registry) State(name string) PluginState { + r.mu.RLock() + defer r.mu.RUnlock() + + if state, ok := r.states[name]; ok { + return state + } + return StateUnknown +} + +// SetState updates the lifecycle state of a registered plugin. +func (r *Registry) SetState(name string, state PluginState) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, ok := r.plugins[name]; ok { + r.states[name] = state + } +} + +// Count returns the number of registered plugins. +func (r *Registry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.plugins) +} diff --git a/pkg/plugin/registry_test.go b/pkg/plugin/registry_test.go new file mode 100644 index 0000000..0961885 --- /dev/null +++ b/pkg/plugin/registry_test.go @@ -0,0 +1,256 @@ +package plugin + +import ( + "errors" + "testing" + + "github.com/orca/orca/pkg/bus" +) + +// mockPlugin implements Plugin for testing. +type mockPlugin struct { + name string + version string + initFn func(host PluginHost) error + closeFn func() error +} + +func (m *mockPlugin) Name() string { return m.name } +func (m *mockPlugin) Version() string { return m.version } +func (m *mockPlugin) Init(host PluginHost) error { + if m.initFn != nil { + return m.initFn(host) + } + return nil +} +func (m *mockPlugin) Shutdown() error { + if m.closeFn != nil { + return m.closeFn() + } + return nil +} + +func TestRegistryNew(t *testing.T) { + r := NewRegistry() + if r == nil { + t.Fatal("NewRegistry() returned nil") + } + if n := r.Count(); n != 0 { + t.Errorf("expected empty registry, got %d plugins", n) + } +} + +func TestRegistryRegister(t *testing.T) { + r := NewRegistry() + p := &mockPlugin{name: "test", version: "1.0.0"} + + err := r.Register(p) + if err != nil { + t.Fatalf("Register failed: %v", err) + } + + if n := r.Count(); n != 1 { + t.Errorf("expected 1 plugin, got %d", n) + } +} + +func TestRegistryRegisterDuplicate(t *testing.T) { + r := NewRegistry() + p1 := &mockPlugin{name: "test", version: "1.0.0"} + p2 := &mockPlugin{name: "test", version: "2.0.0"} + + r.Register(p1) + err := r.Register(p2) + if err == nil { + t.Error("expected error registering duplicate plugin") + } +} + +func TestRegistryGet(t *testing.T) { + r := NewRegistry() + p := &mockPlugin{name: "test", version: "1.0.0"} + r.Register(p) + + got, ok := r.Get("test") + if !ok { + t.Fatal("Get returned not found") + } + if got.Name() != "test" { + t.Errorf("expected name 'test', got %q", got.Name()) + } +} + +func TestRegistryGetNotFound(t *testing.T) { + r := NewRegistry() + _, ok := r.Get("nonexistent") + if ok { + t.Error("expected false for nonexistent plugin") + } +} + +func TestRegistryUnregister(t *testing.T) { + r := NewRegistry() + p := &mockPlugin{name: "test", version: "1.0.0"} + r.Register(p) + + err := r.Unregister("test") + if err != nil { + t.Fatalf("Unregister failed: %v", err) + } + + if n := r.Count(); n != 0 { + t.Errorf("expected 0 plugins, got %d", n) + } +} + +func TestRegistryUnregisterNotFound(t *testing.T) { + r := NewRegistry() + err := r.Unregister("nonexistent") + if err == nil { + t.Error("expected error unregistering nonexistent plugin") + } +} + +func TestRegistryList(t *testing.T) { + r := NewRegistry() + r.Register(&mockPlugin{name: "a", version: "1.0.0"}) + r.Register(&mockPlugin{name: "b", version: "1.0.0"}) + r.Register(&mockPlugin{name: "c", version: "1.0.0"}) + + plugins := r.List() + if len(plugins) != 3 { + t.Errorf("expected 3 plugins, got %d", len(plugins)) + } + + names := make(map[string]bool) + for _, p := range plugins { + names[p.Name()] = true + } + + for _, n := range []string{"a", "b", "c"} { + if !names[n] { + t.Errorf("missing plugin %q in list", n) + } + } +} + +func TestRegistryState(t *testing.T) { + r := NewRegistry() + p := &mockPlugin{name: "test", version: "1.0.0"} + r.Register(p) + + if s := r.State("test"); s != StateRegistered { + t.Errorf("expected StateRegistered, got %s", s) + } + + r.SetState("test", StateRunning) + if s := r.State("test"); s != StateRunning { + t.Errorf("expected StateRunning, got %s", s) + } +} + +func TestRegistryStateUnknown(t *testing.T) { + r := NewRegistry() + if s := r.State("nonexistent"); s != StateUnknown { + t.Errorf("expected StateUnknown for nonexistent, got %s", s) + } +} + +func TestRegistrySetStateNoOp(t *testing.T) { + r := NewRegistry() + r.SetState("nonexistent", StateRunning) + if n := r.Count(); n != 0 { + t.Errorf("SetState should not add plugins") + } +} + +func TestPluginStateString(t *testing.T) { + tests := []struct { + state PluginState + want string + }{ + {StateUnknown, "unknown"}, + {StateRegistered, "registered"}, + {StateInitialized, "initialized"}, + {StateRunning, "running"}, + {StateStopped, "stopped"}, + {StateError, "error"}, + {PluginState(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.state.String(); got != tt.want { + t.Errorf("PluginState(%d).String() = %q, want %q", tt.state, got, tt.want) + } + } +} + +func TestRegistryConcurrent(t *testing.T) { + r := NewRegistry() + done := make(chan struct{}, 2) + + go func() { + for i := 0; i < 100; i++ { + r.Register(&mockPlugin{name: "a", version: "1.0.0"}) + r.Get("a") + r.Unregister("a") + } + done <- struct{}{} + }() + + go func() { + for i := 0; i < 100; i++ { + r.Register(&mockPlugin{name: "b", version: "1.0.0"}) + r.List() + r.State("b") + r.Unregister("b") + } + done <- struct{}{} + }() + + <-done + <-done +} + +// mockPluginHost implements PluginHost for testing kernel-level plugin init. +type mockPluginHost struct{} + +func (m *mockPluginHost) Bus() bus.MessageBus { return nil } +func (m *mockPluginHost) GetPlugin(name string) (Plugin, bool) { return nil, false } +func (m *mockPluginHost) ListPlugins() []Plugin { return nil } + +func TestPluginInitAndShutdown(t *testing.T) { + var initCalled, shutdownCalled bool + + p := &mockPlugin{ + name: "test", + version: "1.0.0", + initFn: func(host PluginHost) error { + initCalled = true + if host == nil { + return errors.New("host is nil") + } + return nil + }, + closeFn: func() error { + shutdownCalled = true + return nil + }, + } + + host := &mockPluginHost{} + + if err := p.Init(host); err != nil { + t.Fatalf("Init failed: %v", err) + } + if !initCalled { + t.Error("Init function was not called") + } + + if err := p.Shutdown(); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + if !shutdownCalled { + t.Error("Shutdown function was not called") + } +} diff --git a/pkg/sandbox/process.go b/pkg/sandbox/process.go new file mode 100644 index 0000000..bd76828 --- /dev/null +++ b/pkg/sandbox/process.go @@ -0,0 +1,246 @@ +package sandbox + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" +) + +const ( + // DefaultOutputLimit is the maximum number of bytes captured from stdout/stderr (64 KB). + DefaultOutputLimit = 64 * 1024 + + // DefaultWorkingDir is the default working directory for sandboxed commands. + DefaultWorkingDir = "/tmp/orca/sandbox" +) + +// AllowedEnvVars is the whitelist of environment variables accessible inside the sandbox. +// Only these variables are passed through from the parent process. +var AllowedEnvVars = []string{ + "HOME", + "USER", + "PATH", + "LANG", + "SHELL", + "TMPDIR", + "ORCA_HOME", +} + +// ProcessSandbox is a Sandbox implementation that uses os/exec to run commands +// as child processes with resource restrictions. +type ProcessSandbox struct { + // WorkingDir is the directory in which commands execute. + WorkingDir string + // OutputLimit is the maximum bytes to capture from stdout/stderr. + OutputLimit int + // EnvWhitelist controls which environment variables are passed to child processes. + // If nil, AllowedEnvVars is used. If empty, no env vars are passed. + EnvWhitelist []string +} + +// NewProcessSandbox creates a ProcessSandbox with sensible defaults. +func NewProcessSandbox() *ProcessSandbox { + return &ProcessSandbox{ + WorkingDir: DefaultWorkingDir, + OutputLimit: DefaultOutputLimit, + EnvWhitelist: nil, // uses AllowedEnvVars + } +} + +// Execute runs a command as a child process with resource restrictions. +func (ps *ProcessSandbox) Execute(ctx context.Context, cmd string, args ...string) (*Result, error) { + // Ensure working directory exists + if err := os.MkdirAll(ps.WorkingDir, 0755); err != nil { + return nil, fmt.Errorf("sandbox: failed to create working directory %q: %w", ps.WorkingDir, err) + } + + // Build the command + c := exec.CommandContext(ctx, cmd, args...) + c.Dir = ps.WorkingDir + + // Set up environment variable whitelist + env := ps.buildEnv() + c.Env = env + + // Capture stdout and stderr with size limits + stdoutBuf := newLimitedBuffer(ps.outputLimit()) + stderrBuf := newLimitedBuffer(ps.outputLimit()) + + stdoutPipe, err := c.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("sandbox: failed to create stdout pipe: %w", err) + } + stderrPipe, err := c.StderrPipe() + if err != nil { + return nil, fmt.Errorf("sandbox: failed to create stderr pipe: %w", err) + } + + // Start the command + if err := c.Start(); err != nil { + return nil, fmt.Errorf("sandbox: failed to start command: %w", err) + } + + // Read stdout and stderr concurrently + var readStdout, readStderr error + var wg syncWaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, readStdout = io.Copy(stdoutBuf, stdoutPipe) + // ErrShortWrite is expected when the output limit is reached — not a real error. + if readStdout != nil && readStdout != io.EOF && readStdout != io.ErrShortWrite { + readStdout = fmt.Errorf("sandbox: stdout read error: %w", readStdout) + } else { + readStdout = nil + } + }() + + go func() { + defer wg.Done() + _, readStderr = io.Copy(stderrBuf, stderrPipe) + if readStderr != nil && readStderr != io.EOF && readStderr != io.ErrShortWrite { + readStderr = fmt.Errorf("sandbox: stderr read error: %w", readStderr) + } else { + readStderr = nil + } + }() + + wg.Wait() + + // Wait for the command to finish + err = c.Wait() + exitCode := 0 + + if err != nil { + // Check if the process was killed due to context cancellation (timeout) + if ctx.Err() != nil { + return nil, fmt.Errorf("sandbox: command timed out: %w", ctx.Err()) + } + + // Normal non-zero exit + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + err = nil + } + } + + // Combine errors: prefer command error, then read errors + if err != nil { + return nil, err + } + if readStdout != nil { + return nil, readStdout + } + if readStderr != nil { + return nil, readStderr + } + + return &Result{ + Stdout: stdoutBuf.String(), + Stderr: stderrBuf.String(), + ExitCode: exitCode, + }, nil +} + +// buildEnv constructs the environment variable list for the child process +// based on the whitelist configuration. +func (ps *ProcessSandbox) buildEnv() []string { + whitelist := ps.EnvWhitelist + if whitelist == nil { + whitelist = AllowedEnvVars + } + + env := make([]string, 0, len(whitelist)) + for _, key := range whitelist { + if val, ok := os.LookupEnv(key); ok { + env = append(env, key+"="+val) + } + } + return env +} + +// outputLimit returns the effective output size limit. +func (ps *ProcessSandbox) outputLimit() int { + if ps.OutputLimit <= 0 { + return DefaultOutputLimit + } + return ps.OutputLimit +} + +// WorkingDirPath returns the absolute path of the sandbox working directory. +func (ps *ProcessSandbox) WorkingDirPath() string { + abs, err := filepath.Abs(ps.WorkingDir) + if err != nil { + return ps.WorkingDir + } + return abs +} + +// --------------------------------------------------------------------------- +// limitedBuffer — a writer that stops accepting data after MaxSize bytes. +// Uses a named field (not embedded) to avoid promoting bytes.Buffer.ReadFrom +// which would bypass the size limit when used with io.Copy. +// --------------------------------------------------------------------------- + +type limitedBuffer struct { + buf bytes.Buffer + MaxSize int +} + +func newLimitedBuffer(maxSize int) *limitedBuffer { + return &limitedBuffer{MaxSize: maxSize} +} + +func (lb *limitedBuffer) Write(p []byte) (int, error) { + remaining := lb.MaxSize - lb.buf.Len() + if remaining <= 0 { + return len(p), nil // silently drop excess; io.Copy sees nw==nr, continues draining pipe + } + if len(p) > remaining { + p = p[:remaining] + n, err := lb.buf.Write(p) + // Return n < original len(p) so io.Copy stops with ErrShortWrite. + return n, err + } + return lb.buf.Write(p) +} + +func (lb *limitedBuffer) String() string { + return lb.buf.String() +} + +func (lb *limitedBuffer) Len() int { + return lb.buf.Len() +} + +// --------------------------------------------------------------------------- +// syncWaitGroup — a simple goroutine synchronization mechanism. +// --------------------------------------------------------------------------- + +type syncWaitGroup struct { + ch chan struct{} +} + +func (wg *syncWaitGroup) Add(n int) { + if wg.ch == nil { + wg.ch = make(chan struct{}, n) + } +} + +func (wg *syncWaitGroup) Done() { + wg.ch <- struct{}{} +} + +func (wg *syncWaitGroup) Wait() { + for i := 0; i < cap(wg.ch); i++ { + <-wg.ch + } +} + +// Compile-time interface check. +var _ Sandbox = (*ProcessSandbox)(nil) diff --git a/pkg/sandbox/process_test.go b/pkg/sandbox/process_test.go new file mode 100644 index 0000000..a6da348 --- /dev/null +++ b/pkg/sandbox/process_test.go @@ -0,0 +1,212 @@ +package sandbox + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestNewProcessSandbox(t *testing.T) { + ps := NewProcessSandbox() + if ps == nil { + t.Fatal("NewProcessSandbox() returned nil") + } + if ps.WorkingDir != DefaultWorkingDir { + t.Errorf("expected WorkingDir %q, got %q", DefaultWorkingDir, ps.WorkingDir) + } + if ps.OutputLimit != DefaultOutputLimit { + t.Errorf("expected OutputLimit %d, got %d", DefaultOutputLimit, ps.OutputLimit) + } +} + +func TestExecuteEcho(t *testing.T) { + ps := NewProcessSandbox() + ctx := context.Background() + + result, err := ps.Execute(ctx, "echo", "hello", "world") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if result.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", result.ExitCode) + } + if strings.TrimSpace(result.Stdout) != "hello world" { + t.Errorf("expected stdout 'hello world', got %q", result.Stdout) + } +} + +func TestExecuteWithArgs(t *testing.T) { + ps := NewProcessSandbox() + ctx := context.Background() + + result, err := ps.Execute(ctx, "sh", "-c", "echo 'arg1 arg2'") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if result.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", result.ExitCode) + } + if strings.TrimSpace(result.Stdout) != "arg1 arg2" { + t.Errorf("expected stdout 'arg1 arg2', got %q", result.Stdout) + } +} + +func TestExecuteNonZeroExit(t *testing.T) { + ps := NewProcessSandbox() + ctx := context.Background() + + result, err := ps.Execute(ctx, "sh", "-c", "exit 42") + if err != nil { + t.Fatalf("Execute should not error on non-zero exit: %v", err) + } + if result.ExitCode != 42 { + t.Errorf("expected exit code 42, got %d", result.ExitCode) + } +} + +func TestExecuteCommandNotFound(t *testing.T) { + ps := NewProcessSandbox() + ctx := context.Background() + + _, err := ps.Execute(ctx, "nonexistent-command-12345") + if err == nil { + t.Fatal("expected error for nonexistent command") + } +} + +func TestExecuteTimeout(t *testing.T) { + ps := NewProcessSandbox() + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err := ps.Execute(ctx, "sleep", "10") + if err == nil { + t.Fatal("expected timeout error") + } + // On macOS the error may be "signal: killed" or "context deadline exceeded". + // Just verify an error occurred — the exact message varies by platform. + t.Logf("timeout produced error: %v", err) +} + +func TestExecuteWorkingDirectory(t *testing.T) { + // Use a temp directory for this test + tmpDir, err := os.MkdirTemp("", "sandbox-test-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + ps := &ProcessSandbox{ + WorkingDir: tmpDir, + OutputLimit: DefaultOutputLimit, + EnvWhitelist: AllowedEnvVars, + } + + ctx := context.Background() + result, err := ps.Execute(ctx, "pwd") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if result.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", result.ExitCode) + } + + // pwd should return the temp directory + gotDir := strings.TrimSpace(result.Stdout) + absGot, _ := filepath.EvalSymlinks(gotDir) + absTmp, _ := filepath.EvalSymlinks(tmpDir) + if absGot != absTmp { + t.Errorf("expected working dir %q, got %q", absTmp, absGot) + } +} + +func TestEnvironmentWhitelist(t *testing.T) { + ps := NewProcessSandbox() + ps.EnvWhitelist = []string{"HOME"} + + ctx := context.Background() + result, err := ps.Execute(ctx, "sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if result.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", result.ExitCode) + } + + home := os.Getenv("HOME") + if home != "" && strings.TrimSpace(result.Stdout) != home { + t.Errorf("expected HOME=%q, got %q", home, strings.TrimSpace(result.Stdout)) + } +} + +func TestEnvironmentIsolation(t *testing.T) { + ps := NewProcessSandbox() + // Use an empty whitelist to ensure no env vars are passed + ps.EnvWhitelist = []string{} + + ctx := context.Background() + result, err := ps.Execute(ctx, "sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + // HOME should be empty in the child process + if strings.TrimSpace(result.Stdout) != "" { + t.Errorf("expected empty HOME in isolated env, got %q", result.Stdout) + } +} + +func TestOutputLimit(t *testing.T) { + ps := NewProcessSandbox() + ps.OutputLimit = 10 // Only 10 bytes + + ctx := context.Background() + // Generate a long output well beyond the 10-byte limit + result, err := ps.Execute(ctx, "sh", "-c", "echo 'AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDEEEEEEEEEEFFFFFFFFFF'") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + // The output should be truncated to approximately 10 bytes (plus newline) + if len(result.Stdout) > 15 { + t.Errorf("expected truncated output (<=15 bytes), got %d bytes: %q", len(result.Stdout), result.Stdout) + } +} + +func TestExecuteStderr(t *testing.T) { + ps := NewProcessSandbox() + ctx := context.Background() + + result, err := ps.Execute(ctx, "sh", "-c", "echo 'error output' >&2; echo 'normal output'") + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if result.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", result.ExitCode) + } + if strings.TrimSpace(result.Stderr) != "error output" { + t.Errorf("expected stderr 'error output', got %q", result.Stderr) + } + if strings.TrimSpace(result.Stdout) != "normal output" { + t.Errorf("expected stdout 'normal output', got %q", result.Stdout) + } +} + +func TestSandboxInterfaceSatisfied(t *testing.T) { + // Compile-time check + var ps Sandbox = NewProcessSandbox() + if ps == nil { + t.Fatal("ProcessSandbox does not satisfy Sandbox interface") + } +} + +func TestWorkingDirPath(t *testing.T) { + ps := NewProcessSandbox() + path := ps.WorkingDirPath() + if !filepath.IsAbs(path) { + t.Errorf("expected absolute path, got %q", path) + } +} diff --git a/pkg/sandbox/sandbox.go b/pkg/sandbox/sandbox.go new file mode 100644 index 0000000..791703b --- /dev/null +++ b/pkg/sandbox/sandbox.go @@ -0,0 +1,28 @@ +// Package sandbox provides a secure execution environment for running commands. +// +// The sandbox restricts resource usage (timeout, output size, working directory) +// and environment variable access to prevent runaway or malicious commands. +// This is the execution backend used by the Tool system's built-in exec tool. +package sandbox + +import ( + "context" +) + +// Result holds the output and exit status of a sandboxed command execution. +type Result struct { + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + ExitCode int `json:"exit_code"` +} + +// Sandbox defines the interface for command execution environments. +// +// Implementations may use OS processes (os/exec), containers, or other +// isolation mechanisms. The context controls cancellation and timeouts. +type Sandbox interface { + // Execute runs a command with the given arguments inside the sandbox. + // The context can be used to set timeouts or cancel the execution. + // Returns the combined output, error output, and exit code. + Execute(ctx context.Context, cmd string, args ...string) (*Result, error) +} diff --git a/pkg/session/jsonl.go b/pkg/session/jsonl.go new file mode 100644 index 0000000..d08a722 --- /dev/null +++ b/pkg/session/jsonl.go @@ -0,0 +1,190 @@ +package session + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" +) + +// JSONLStore implements the Store interface using JSONL files. +// +// Each session is stored in a separate file named {session_id}.jsonl +// under the configured storage directory. Every line in the file is a +// JSON-encoded SessionMessage. New messages are appended in O(1) time. +type JSONLStore struct { + storageDir string + mu sync.RWMutex +} + +// NewJSONLStore creates a new JSONLStore with the given storage directory. +// The directory is created if it does not exist. +func NewJSONLStore(storageDir string) (*JSONLStore, error) { + if err := os.MkdirAll(storageDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create session storage directory %q: %w", storageDir, err) + } + return &JSONLStore{storageDir: storageDir}, nil +} + +// path returns the full file path for the given session ID. +func (s *JSONLStore) path(sessionID string) string { + return filepath.Join(s.storageDir, sessionID+".jsonl") +} + +// archivePath returns the archive file path for the given session ID. +func (s *JSONLStore) archivePath(sessionID string) string { + return filepath.Join(s.storageDir, sessionID+".jsonl.archived") +} + +// Save appends a message to a session's JSONL file. +// If the file does not exist, it is created. +// This is an O(1) append operation. +func (s *JSONLStore) Save(sessionID string, msg SessionMessage) error { + s.mu.Lock() + defer s.mu.Unlock() + + f, err := os.OpenFile(s.path(sessionID), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to open session file for %q: %w", sessionID, err) + } + defer f.Close() + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal session message: %w", err) + } + + if _, err := f.Write(append(data, '\n')); err != nil { + return fmt.Errorf("failed to write session message: %w", err) + } + + return nil +} + +// Load retrieves all messages for a session in chronological order. +// Returns an error if the session file does not exist. +func (s *JSONLStore) Load(sessionID string) ([]SessionMessage, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + data, err := os.ReadFile(s.path(sessionID)) + if err != nil { + if os.IsNotExist(err) { + // Check archive + archiveData, archiveErr := os.ReadFile(s.archivePath(sessionID)) + if archiveErr != nil { + return nil, fmt.Errorf("session %q not found", sessionID) + } + data = archiveData + } else { + return nil, fmt.Errorf("failed to read session file for %q: %w", sessionID, err) + } + } + + return parseJSONL(data) +} + +// parseJSONL parses a JSONL byte slice into a slice of SessionMessage. +func parseJSONL(data []byte) ([]SessionMessage, error) { + var messages []SessionMessage + trimmed := strings.TrimSpace(string(data)) + if trimmed == "" { + return messages, nil + } + + lines := strings.Split(trimmed, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + var msg SessionMessage + if err := json.Unmarshal([]byte(line), &msg); err != nil { + return nil, fmt.Errorf("failed to unmarshal session message: %w", err) + } + messages = append(messages, msg) + } + return messages, nil +} + +// List returns all session IDs by scanning the storage directory. +func (s *JSONLStore) List() ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entries, err := os.ReadDir(s.storageDir) + if err != nil { + return nil, fmt.Errorf("failed to read storage directory %q: %w", s.storageDir, err) + } + + var sessions []string + for _, entry := range entries { + name := entry.Name() + if strings.HasSuffix(name, ".jsonl") && !strings.HasSuffix(name, ".archived") { + sessions = append(sessions, strings.TrimSuffix(name, ".jsonl")) + } + } + return sessions, nil +} + +// Exists checks whether a session file exists (active or archived). +func (s *JSONLStore) Exists(sessionID string) (bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if _, err := os.Stat(s.path(sessionID)); err == nil { + return true, nil + } else if !os.IsNotExist(err) { + return false, fmt.Errorf("failed to check session %q: %w", sessionID, err) + } + + // Check archive + if _, err := os.Stat(s.archivePath(sessionID)); err == nil { + return true, nil + } else if !os.IsNotExist(err) { + return false, fmt.Errorf("failed to check archived session %q: %w", sessionID, err) + } + + return false, nil +} + +// Archive moves a session file to the archived state by renaming it. +func (s *JSONLStore) Archive(sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if err := os.Rename(s.path(sessionID), s.archivePath(sessionID)); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("session %q not found", sessionID) + } + return fmt.Errorf("failed to archive session %q: %w", sessionID, err) + } + return nil +} + +// Delete permanently removes a session file and its archive. +func (s *JSONLStore) Delete(sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + var lastErr error + + // Remove active file + if err := os.Remove(s.path(sessionID)); err != nil && !os.IsNotExist(err) { + lastErr = fmt.Errorf("failed to delete session %q: %w", sessionID, err) + } + + // Also remove archived file if it exists + if err := os.Remove(s.archivePath(sessionID)); err != nil && !os.IsNotExist(err) { + lastErr = fmt.Errorf("failed to delete archived session %q: %w", sessionID, err) + } + + return lastErr +} + +// StorageDir returns the storage directory path. +func (s *JSONLStore) StorageDir() string { + return s.storageDir +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go new file mode 100644 index 0000000..17b7db3 --- /dev/null +++ b/pkg/session/manager.go @@ -0,0 +1,198 @@ +package session + +import ( + "fmt" + "sync" + "time" + + "github.com/orca/orca/pkg/bus" +) + +// Manager provides high-level session lifecycle operations. +// +// It wraps a Store with caching, context window management, and +// event publishing on the message bus. +type Manager struct { + store Store + bus bus.MessageBus + cache map[string]*Session + mu sync.RWMutex +} + +// NewManager creates a new session Manager with the given store and optional message bus. +func NewManager(store Store, mb bus.MessageBus) *Manager { + return &Manager{ + store: store, + bus: mb, + cache: make(map[string]*Session), + } +} + +// CreateSession creates a new session with the given ID and optional metadata. +func (m *Manager) CreateSession(id string, metadata map[string]string) (*Session, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.cache[id]; exists { + return nil, fmt.Errorf("session %q already exists", id) + } + + now := time.Now() + session := &Session{ + ID: id, + Status: SessionActive, + Messages: make([]SessionMessage, 0), + CreatedAt: now, + UpdatedAt: now, + Metadata: metadata, + } + + m.cache[id] = session + + // Publish session created event + if m.bus != nil { + m.bus.Publish("session.created", bus.Message{ + ID: "session-" + id, + Type: bus.MsgTypeSystem, + From: "session.manager", + Content: map[string]interface{}{"session_id": id}, + }) + } + + return session, nil +} + +// GetSession retrieves a session by ID, checking the cache and then the store. +func (m *Manager) GetSession(id string) (*Session, error) { + m.mu.RLock() + session, ok := m.cache[id] + m.mu.RUnlock() + + if ok { + return session, nil + } + + // Try to load from store + messages, err := m.store.Load(id) + if err != nil { + return nil, fmt.Errorf("failed to load session %q: %w", id, err) + } + + // Check if we can determine created/updated timestamps from messages + var createdAt, updatedAt time.Time + if len(messages) > 0 { + createdAt = messages[0].Timestamp + updatedAt = messages[len(messages)-1].Timestamp + } + if createdAt.IsZero() { + createdAt = time.Now() + } + if updatedAt.IsZero() { + updatedAt = time.Now() + } + + session = &Session{ + ID: id, + Status: SessionActive, + Messages: messages, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + } + + m.mu.Lock() + m.cache[id] = session + m.mu.Unlock() + + return session, nil +} + +// AddMessage appends a message to a session and persists it. +func (m *Manager) AddMessage(sessionID string, role MessageRole, content string, metadata map[string]string) (*SessionMessage, error) { + msg := SessionMessage{ + Role: role, + Content: content, + Timestamp: time.Now(), + Metadata: metadata, + } + + if err := m.store.Save(sessionID, msg); err != nil { + return nil, fmt.Errorf("failed to save message to session %q: %w", sessionID, err) + } + + // Upsert cache + m.mu.Lock() + if session, ok := m.cache[sessionID]; ok { + session.Messages = append(session.Messages, msg) + session.UpdatedAt = msg.Timestamp + } else { + m.cache[sessionID] = &Session{ + ID: sessionID, + Status: SessionActive, + Messages: []SessionMessage{msg}, + CreatedAt: msg.Timestamp, + UpdatedAt: msg.Timestamp, + } + } + m.mu.Unlock() + + return &msg, nil +} + +// GetContext returns the most recent N messages in a session. +// If windowSize <= 0 or >= total messages, all messages are returned. +func (m *Manager) GetContext(sessionID string, windowSize int) ([]SessionMessage, error) { + session, err := m.GetSession(sessionID) + if err != nil { + return nil, err + } + + messages := session.Messages + if windowSize > 0 && windowSize < len(messages) { + return messages[len(messages)-windowSize:], nil + } + return messages, nil +} + +// ArchiveSession archives a session, making it read-only. +func (m *Manager) ArchiveSession(id string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if session, ok := m.cache[id]; ok { + session.Status = SessionArchived + } + + if err := m.store.Archive(id); err != nil { + return err + } + + // Publish event + if m.bus != nil { + m.bus.Publish("session.archived", bus.Message{ + ID: "session-" + id, + Type: bus.MsgTypeSystem, + From: "session.manager", + Content: map[string]interface{}{"session_id": id}, + }) + } + + return nil +} + +// DeleteSession permanently removes a session. +func (m *Manager) DeleteSession(id string) error { + m.mu.Lock() + delete(m.cache, id) + m.mu.Unlock() + return m.store.Delete(id) +} + +// ListSessions returns all known session IDs. +func (m *Manager) ListSessions() ([]string, error) { + return m.store.List() +} + +// Store returns the underlying Store. +func (m *Manager) Store() Store { + return m.store +} diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go new file mode 100644 index 0000000..beb1e74 --- /dev/null +++ b/pkg/session/session_test.go @@ -0,0 +1,550 @@ +package session + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/orca/orca/pkg/bus" +) + +// ============================================================ +// JSONL Store Tests +// ============================================================ + +func setupTestStore(t *testing.T) (*JSONLStore, func()) { + t.Helper() + dir, err := os.MkdirTemp("", "orca-session-test-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + + store, err := NewJSONLStore(dir) + if err != nil { + os.RemoveAll(dir) + t.Fatalf("NewJSONLStore failed: %v", err) + } + + cleanup := func() { + os.RemoveAll(dir) + } + return store, cleanup +} + +func TestNewJSONLStore(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + if store == nil { + t.Fatal("NewJSONLStore returned nil") + } + if store.StorageDir() == "" { + t.Error("StorageDir should not be empty") + } +} + +func TestJSONLStoreSaveAndLoad(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + msg := SessionMessage{ + Role: RoleUser, + Content: "Hello, world!", + Timestamp: time.Now(), + } + + if err := store.Save("session-1", msg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + messages, err := store.Load("session-1") + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } + if messages[0].Role != RoleUser { + t.Errorf("expected RoleUser, got %s", messages[0].Role) + } + if messages[0].Content != "Hello, world!" { + t.Errorf("expected content 'Hello, world!', got %q", messages[0].Content) + } +} + +func TestJSONLStoreAppendMultiple(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + roles := []MessageRole{RoleUser, RoleAssistant, RoleUser, RoleSystem} + for i, role := range roles { + msg := SessionMessage{ + Role: role, + Content: "message " + string(rune('0'+i)), + Timestamp: time.Now(), + } + if err := store.Save("session-append", msg); err != nil { + t.Fatalf("Save %d failed: %v", i, err) + } + } + + messages, err := store.Load("session-append") + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if len(messages) != len(roles) { + t.Fatalf("expected %d messages, got %d", len(roles), len(messages)) + } + for i, msg := range messages { + if msg.Role != roles[i] { + t.Errorf("message %d: expected role %s, got %s", i, roles[i], msg.Role) + } + } +} + +func TestJSONLStoreLoadNonexistent(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + _, err := store.Load("nonexistent") + if err == nil { + t.Error("expected error loading nonexistent session") + } +} + +func TestJSONLStoreExists(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + exists, err := store.Exists("nonexistent") + if err != nil { + t.Fatalf("Exists failed: %v", err) + } + if exists { + t.Error("expected nonexistent session to return false") + } + + store.Save("session-exists", SessionMessage{Role: RoleUser, Content: "test"}) + exists, err = store.Exists("session-exists") + if err != nil { + t.Fatalf("Exists failed: %v", err) + } + if !exists { + t.Error("expected existing session to return true") + } +} + +func TestJSONLStoreList(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + ids := []string{"sess-a", "sess-b", "sess-c"} + for _, id := range ids { + store.Save(id, SessionMessage{Role: RoleUser, Content: "test"}) + } + + list, err := store.List() + if err != nil { + t.Fatalf("List failed: %v", err) + } + if len(list) != len(ids) { + t.Fatalf("expected %d sessions, got %d", len(ids), len(list)) + } + + found := make(map[string]bool) + for _, id := range list { + found[id] = true + } + for _, id := range ids { + if !found[id] { + t.Errorf("missing session %q in list", id) + } + } +} + +func TestJSONLStoreArchiveAndLoad(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + msg := SessionMessage{Role: RoleUser, Content: "archive test"} + store.Save("sess-archive", msg) + + if err := store.Archive("sess-archive"); err != nil { + t.Fatalf("Archive failed: %v", err) + } + + // Should still be loadable (archived files are in .archived suffix) + messages, err := store.Load("sess-archive") + if err != nil { + t.Fatalf("Load after archive failed: %v", err) + } + if len(messages) != 1 { + t.Errorf("expected 1 message after archive, got %d", len(messages)) + } + + // Should not appear in List + list, _ := store.List() + for _, id := range list { + if id == "sess-archive" { + t.Error("archived session should not appear in List") + } + } +} + +func TestJSONLStoreArchiveNonexistent(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + err := store.Archive("nonexistent") + if err == nil { + t.Error("expected error archiving nonexistent session") + } +} + +func TestJSONLStoreDelete(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + store.Save("sess-delete", SessionMessage{Role: RoleUser, Content: "delete me"}) + if err := store.Delete("sess-delete"); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + exists, _ := store.Exists("sess-delete") + if exists { + t.Error("expected deleted session to not exist") + } +} + +func TestJSONLStoreDeleteNonexistent(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + err := store.Delete("nonexistent") + if err != nil { + t.Fatalf("Delete nonexistent should succeed: %v", err) + } +} + +func TestJSONLStoreConcurrentWrites(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + done := make(chan struct{}, 2) + go func() { + for i := 0; i < 50; i++ { + store.Save("concurrent", SessionMessage{Role: RoleUser, Content: "from-a"}) + } + done <- struct{}{} + }() + go func() { + for i := 0; i < 50; i++ { + store.Save("concurrent", SessionMessage{Role: RoleAssistant, Content: "from-b"}) + } + done <- struct{}{} + }() + + <-done + <-done + + messages, err := store.Load("concurrent") + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if len(messages) != 100 { + t.Errorf("expected 100 messages, got %d", len(messages)) + } +} + +func TestJSONLStoreEmptyFile(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + dir := store.StorageDir() + // Create an empty file + f, _ := os.Create(filepath.Join(dir, "empty.jsonl")) + f.Close() + + messages, err := store.Load("empty") + if err != nil { + t.Fatalf("Load empty session failed: %v", err) + } + if len(messages) != 0 { + t.Errorf("expected 0 messages from empty file, got %d", len(messages)) + } +} + +// ============================================================ +// Session Manager Tests +// ============================================================ + +func setupTestManager(t *testing.T) (*Manager, func()) { + t.Helper() + store, cleanup := setupTestStore(t) + mb := bus.New() + mgr := NewManager(store, mb) + return mgr, func() { + mb.Close() + cleanup() + } +} + +func TestNewManager(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + if mgr == nil { + t.Fatal("NewManager returned nil") + } +} + +func TestManagerCreateSession(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + session, err := mgr.CreateSession("sess-1", map[string]string{"key": "value"}) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + if session.ID != "sess-1" { + t.Errorf("expected ID 'sess-1', got %q", session.ID) + } + if session.Status != SessionActive { + t.Errorf("expected SessionActive, got %s", session.Status) + } + if session.Metadata["key"] != "value" { + t.Errorf("expected metadata key 'value', got %q", session.Metadata["key"]) + } + if session.MessageCount() != 0 { + t.Errorf("expected 0 messages, got %d", session.MessageCount()) + } + if session.CreatedAt.IsZero() { + t.Error("CreatedAt should not be zero") + } +} + +func TestManagerCreateDuplicate(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + mgr.CreateSession("dup", nil) + _, err := mgr.CreateSession("dup", nil) + if err == nil { + t.Error("expected error creating duplicate session") + } +} + +func TestManagerAddMessage(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + msg, err := mgr.AddMessage("sess-add", RoleUser, "Hello!", map[string]string{"source": "test"}) + if err != nil { + t.Fatalf("AddMessage failed: %v", err) + } + if msg.Role != RoleUser { + t.Errorf("expected RoleUser, got %s", msg.Role) + } + if msg.Content != "Hello!" { + t.Errorf("expected 'Hello!', got %q", msg.Content) + } + if msg.Metadata["source"] != "test" { + t.Errorf("expected metadata source 'test', got %q", msg.Metadata["source"]) + } + if msg.Timestamp.IsZero() { + t.Error("Timestamp should not be zero") + } + + // Verify it was persisted + messages, _ := mgr.GetContext("sess-add", 10) + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } +} + +func TestManagerAddMessageAutoCreatesSession(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + mgr.AddMessage("auto-session", RoleUser, "auto create", nil) + + session, err := mgr.GetSession("auto-session") + if err != nil { + t.Fatalf("GetSession failed: %v", err) + } + if session.MessageCount() != 1 { + t.Errorf("expected 1 message, got %d", session.MessageCount()) + } +} + +func TestManagerGetContextWindow(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + // Add 10 messages + for i := 0; i < 10; i++ { + mgr.AddMessage("window-test", RoleUser, "msg", nil) + } + + // Get last 3 + messages, err := mgr.GetContext("window-test", 3) + if err != nil { + t.Fatalf("GetContext failed: %v", err) + } + if len(messages) != 3 { + t.Errorf("expected 3 messages, got %d", len(messages)) + } + + // Get all (window larger than total) + all, _ := mgr.GetContext("window-test", 100) + if len(all) != 10 { + t.Errorf("expected 10 messages, got %d", len(all)) + } + + // Get with window <= 0 + all2, _ := mgr.GetContext("window-test", 0) + if len(all2) != 10 { + t.Errorf("expected 10 messages with window=0, got %d", len(all2)) + } +} + +func TestManagerGetContextNonexistent(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + _, err := mgr.GetContext("nonexistent", 10) + if err == nil { + t.Error("expected error getting context for nonexistent session") + } +} + +func TestManagerArchiveSession(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + mgr.CreateSession("archivable", nil) + mgr.AddMessage("archivable", RoleUser, "test", nil) + + if err := mgr.ArchiveSession("archivable"); err != nil { + t.Fatalf("ArchiveSession failed: %v", err) + } + + session, _ := mgr.GetSession("archivable") + if session.Status != SessionArchived { + t.Errorf("expected SessionArchived, got %s", session.Status) + } + if session.IsArchived() != true { + t.Error("expected IsArchived to return true") + } +} + +func TestManagerDeleteSession(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + mgr.CreateSession("deletable", nil) + if err := mgr.DeleteSession("deletable"); err != nil { + t.Fatalf("DeleteSession failed: %v", err) + } + + _, err := mgr.GetSession("deletable") + if err == nil { + t.Error("expected error getting deleted session") + } +} + +func TestManagerListSessions(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + mgr.AddMessage("list-a", RoleUser, "a", nil) + mgr.AddMessage("list-b", RoleUser, "b", nil) + + sessions, err := mgr.ListSessions() + if err != nil { + t.Fatalf("ListSessions failed: %v", err) + } + if len(sessions) != 2 { + t.Errorf("expected 2 sessions, got %d", len(sessions)) + } +} + +func TestManagerMultipleMessagesOrder(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + contents := []string{"first", "second", "third"} + for i, c := range contents { + mgr.AddMessage("order-test", RoleUser, c, nil) + _ = i + } + + messages, _ := mgr.GetContext("order-test", 10) + if len(messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(messages)) + } + if messages[0].Content != "first" { + t.Errorf("expected first message content 'first', got %q", messages[0].Content) + } + if messages[2].Content != "third" { + t.Errorf("expected third message content 'third', got %q", messages[2].Content) + } +} + +func TestManagerStoreAccess(t *testing.T) { + mgr, cleanup := setupTestManager(t) + defer cleanup() + + store := mgr.Store() + if store == nil { + t.Error("Store() should not return nil") + } +} + +// ============================================================ +// Session Types Tests +// ============================================================ + +func TestSessionIsArchived(t *testing.T) { + s := &Session{Status: SessionActive} + if s.IsArchived() { + t.Error("active session should not be archived") + } + + s.Status = SessionArchived + if !s.IsArchived() { + t.Error("archived session should be archived") + } +} + +func TestSessionMessageCount(t *testing.T) { + s := &Session{Messages: make([]SessionMessage, 5)} + if n := s.MessageCount(); n != 5 { + t.Errorf("expected 5 messages, got %d", n) + } +} + +func TestMessageRoleConstants(t *testing.T) { + if RoleUser != "user" { + t.Errorf("expected RoleUser 'user', got %q", RoleUser) + } + if RoleAssistant != "assistant" { + t.Errorf("expected RoleAssistant 'assistant', got %q", RoleAssistant) + } + if RoleSystem != "system" { + t.Errorf("expected RoleSystem 'system', got %q", RoleSystem) + } + if RoleTool != "tool" { + t.Errorf("expected RoleTool 'tool', got %q", RoleTool) + } +} + +func TestSessionStatusConstants(t *testing.T) { + if SessionActive != "active" { + t.Errorf("expected SessionActive 'active', got %q", SessionActive) + } + if SessionArchived != "archived" { + t.Errorf("expected SessionArchived 'archived', got %q", SessionArchived) + } +} diff --git a/pkg/session/store.go b/pkg/session/store.go new file mode 100644 index 0000000..c251edd --- /dev/null +++ b/pkg/session/store.go @@ -0,0 +1,28 @@ +package session + +// Store defines the persistence interface for session message storage. +// +// Implementations must be safe for concurrent use. The default implementation +// uses JSONL files (one file per session) with O(1) append writes. +type Store interface { + // Save appends a single message to a session's history. + // Creates the session file if it does not exist. + Save(sessionID string, msg SessionMessage) error + + // Load retrieves all messages for a session in chronological order. + // Returns an error if the session does not exist. + Load(sessionID string) ([]SessionMessage, error) + + // List returns all known session IDs. + List() ([]string, error) + + // Exists checks whether a session exists in the store. + Exists(sessionID string) (bool, error) + + // Archive marks a session as archived (read-only). + // This is a soft delete that preserves the data. + Archive(sessionID string) error + + // Delete removes a session permanently from the store. + Delete(sessionID string) error +} diff --git a/pkg/session/types.go b/pkg/session/types.go new file mode 100644 index 0000000..f0de3e4 --- /dev/null +++ b/pkg/session/types.go @@ -0,0 +1,60 @@ +// Package session provides conversation session management for the Orca framework. +// +// Sessions persist conversation history and provide context-window-based +// retrieval for LLM interactions. The default storage backend uses JSONL +// files with O(1) append writes. +package session + +import "time" + +// MessageRole represents the role of a message sender in a session. +type MessageRole string + +const ( + // RoleUser represents a human user message. + RoleUser MessageRole = "user" + // RoleAssistant represents an AI assistant message. + RoleAssistant MessageRole = "assistant" + // RoleSystem represents a system-level message. + RoleSystem MessageRole = "system" + // RoleTool represents a tool execution result. + RoleTool MessageRole = "tool" +) + +// SessionMessage represents a single message entry in a session's history. +type SessionMessage struct { + Role MessageRole `json:"role"` + Content string `json:"content"` + Timestamp time.Time `json:"timestamp"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// SessionStatus represents the lifecycle status of a session. +type SessionStatus string + +const ( + // SessionActive indicates an active, in-use session. + SessionActive SessionStatus = "active" + // SessionArchived indicates an archived (read-only) session. + SessionArchived SessionStatus = "archived" +) + +// Session represents a conversation session with full history. +type Session struct { + ID string `json:"id"` + Status SessionStatus `json:"status"` + Messages []SessionMessage `json:"messages,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// IsArchived returns true if the session has been archived. +func (s *Session) IsArchived() bool { + return s.Status == SessionArchived +} + +// MessageCount returns the number of messages in the session. +func (s *Session) MessageCount() int { + return len(s.Messages) +} diff --git a/pkg/skill/manager.go b/pkg/skill/manager.go new file mode 100644 index 0000000..78e1158 --- /dev/null +++ b/pkg/skill/manager.go @@ -0,0 +1,197 @@ +package skill + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" +) + +const ( + // DefaultSkillsDir is the default directory for user-installed skills. + DefaultSkillsDir = "~/.agents/skills" + + // SkillManifestFile is the name of the skill manifest file. + SkillManifestFile = "SKILL.md" +) + +// Manager is a thread-safe registry for loading, storing, and querying Skills. +// +// Skills are loaded from a directory tree where each subdirectory containing +// a SKILL.md file is treated as a skill. The Manager automatically discovers +// skills on initialization and provides methods for finding skills by trigger +// keywords or by name. +type Manager struct { + mu sync.RWMutex + skillsDir string + skills map[string]*Skill +} + +// NewManager creates a new Skill manager that scans the given directory for skills. +// If skillsDir is empty, DefaultSkillsDir is used. +func NewManager(skillsDir string) *Manager { + if skillsDir == "" { + skillsDir = DefaultSkillsDir + } + // Expand ~ to home directory + skillsDir = expandHome(skillsDir) + + return &Manager{ + skillsDir: skillsDir, + skills: make(map[string]*Skill), + } +} + +// LoadAll scans the skills directory and loads all skills found. +// It returns the number of skills loaded and any errors encountered. +func (m *Manager) LoadAll() (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // Clear existing skills + m.skills = make(map[string]*Skill) + + // Check if skills directory exists + info, err := os.Stat(m.skillsDir) + if err != nil { + if os.IsNotExist(err) { + return 0, nil // No skills directory yet — not an error + } + return 0, fmt.Errorf("skill: cannot access skills directory %q: %w", m.skillsDir, err) + } + if !info.IsDir() { + return 0, fmt.Errorf("skill: %q is not a directory", m.skillsDir) + } + + // Read all entries in the skills directory + entries, err := os.ReadDir(m.skillsDir) + if err != nil { + return 0, fmt.Errorf("skill: failed to read skills directory %q: %w", m.skillsDir, err) + } + + var loadErrors []string + loaded := 0 + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + skillDir := filepath.Join(m.skillsDir, entry.Name()) + skillPath := filepath.Join(skillDir, SkillManifestFile) + + if _, err := os.Stat(skillPath); os.IsNotExist(err) { + continue // No SKILL.md in this directory — skip + } + + skill, err := ParseSkillFile(skillPath) + if err != nil { + loadErrors = append(loadErrors, fmt.Sprintf("%s: %v", entry.Name(), err)) + continue + } + + m.skills[skill.Name] = skill + loaded++ + } + + if len(loadErrors) > 0 { + return loaded, fmt.Errorf("skill: loaded %d skills with %d errors: %s", + loaded, len(loadErrors), joinStrings(loadErrors, "; ")) + } + + return loaded, nil +} + +// GetSkill retrieves a skill by its name. Returns false if not found. +func (m *Manager) GetSkill(name string) (*Skill, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + skill, ok := m.skills[name] + return skill, ok +} + +// ListSkills returns all loaded skills sorted by name. +func (m *Manager) ListSkills() []*Skill { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]*Skill, 0, len(m.skills)) + for _, skill := range m.skills { + result = append(result, skill) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Name < result[j].Name + }) + return result +} + +// FindSkill finds skills whose triggers match the given query string. +// Returns all matching skills sorted by relevance (more trigger matches first). +func (m *Manager) FindSkill(query string) []*Skill { + m.mu.RLock() + defer m.mu.RUnlock() + + var matches []*Skill + for _, skill := range m.skills { + if skill.MatchTrigger(query) { + matches = append(matches, skill) + } + } + + // Sort by number of matching triggers (descending) + sort.Slice(matches, func(i, j int) bool { + return countMatches(matches[i], query) > countMatches(matches[j], query) + }) + + return matches +} + +// SkillsDir returns the directory being scanned for skills. +func (m *Manager) SkillsDir() string { + return m.skillsDir +} + +// Reload refreshes all skills from disk. +func (m *Manager) Reload() (int, error) { + return m.LoadAll() +} + +// countMatches counts how many of the skill's triggers match the query. +func countMatches(skill *Skill, query string) int { + count := 0 + queryLower := strings.ToLower(query) + for _, trigger := range skill.Triggers { + if strings.Contains(queryLower, strings.ToLower(trigger)) { + count++ + } + } + return count +} + +// expandHome replaces "~" with the user's home directory. +func expandHome(path string) string { + if len(path) > 0 && path[0] == '~' { + home, err := os.UserHomeDir() + if err != nil { + return path + } + return filepath.Join(home, path[1:]) + } + return path +} + +// joinStrings joins a slice of strings with a separator. +func joinStrings(parts []string, sep string) string { + if len(parts) == 0 { + return "" + } + result := parts[0] + for _, p := range parts[1:] { + result += sep + p + } + return result +} diff --git a/pkg/skill/manager_test.go b/pkg/skill/manager_test.go new file mode 100644 index 0000000..082a7d2 --- /dev/null +++ b/pkg/skill/manager_test.go @@ -0,0 +1,309 @@ +package skill + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// createTestSkill creates a temporary SKILL.md file for testing. +func createTestSkill(t *testing.T, dir, name, description string, triggers []string, body string) string { + t.Helper() + + skillDir := filepath.Join(dir, name) + if err := os.MkdirAll(skillDir, 0755); err != nil { + t.Fatalf("failed to create skill dir: %v", err) + } + + triggerStr := "[]" + if len(triggers) > 0 { + quoted := make([]string, len(triggers)) + for i, tr := range triggers { + quoted[i] = `"` + tr + `"` + } + triggerStr = "[" + strings.Join(quoted, ", ") + "]" + } + + manifest := "---\n" + manifest += "name: " + name + "\n" + manifest += "description: " + description + "\n" + manifest += "triggers: " + triggerStr + "\n" + manifest += "---\n\n" + manifest += body + + manifestPath := filepath.Join(skillDir, "SKILL.md") + if err := os.WriteFile(manifestPath, []byte(manifest), 0644); err != nil { + t.Fatalf("failed to write SKILL.md: %v", err) + } + + return manifestPath +} + +func TestNewManager(t *testing.T) { + m := NewManager("") + if m == nil { + t.Fatal("NewManager() returned nil") + } + if m.SkillsDir() == "" { + t.Error("expected non-empty skills directory") + } +} + +func TestNewManagerWithCustomDir(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + if m.SkillsDir() != tmpDir { + t.Errorf("expected skills dir %q, got %q", tmpDir, m.SkillsDir()) + } +} + +func TestLoadAllNoDirectory(t *testing.T) { + tmpDir := filepath.Join(t.TempDir(), "nonexistent") + m := NewManager(tmpDir) + + count, err := m.LoadAll() + if err != nil { + t.Fatalf("LoadAll on nonexistent dir should not error: %v", err) + } + if count != 0 { + t.Errorf("expected 0 skills, got %d", count) + } +} + +func TestLoadAllWithSkills(t *testing.T) { + tmpDir := t.TempDir() + + createTestSkill(t, tmpDir, "skill-a", "Skill A", []string{"alpha", "a"}, "# Skill A\nContent") + createTestSkill(t, tmpDir, "skill-b", "Skill B", []string{"beta", "b"}, "# Skill B\nContent") + + m := NewManager(tmpDir) + count, err := m.LoadAll() + if err != nil { + t.Fatalf("LoadAll failed: %v", err) + } + if count != 2 { + t.Errorf("expected 2 skills, got %d", count) + } +} + +func TestGetSkill(t *testing.T) { + tmpDir := t.TempDir() + createTestSkill(t, tmpDir, "test-skill", "Test Skill", []string{"test"}, "# Test") + + m := NewManager(tmpDir) + m.LoadAll() + + skill, ok := m.GetSkill("test-skill") + if !ok { + t.Fatal("GetSkill returned false for existing skill") + } + if skill.Name != "test-skill" { + t.Errorf("expected name 'test-skill', got %q", skill.Name) + } + if skill.Description != "Test Skill" { + t.Errorf("expected description 'Test Skill', got %q", skill.Description) + } +} + +func TestGetSkillNotFound(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + m.LoadAll() + + _, ok := m.GetSkill("nonexistent") + if ok { + t.Error("expected false for nonexistent skill") + } +} + +func TestListSkills(t *testing.T) { + tmpDir := t.TempDir() + createTestSkill(t, tmpDir, "beta", "Beta", nil, "# Beta") + createTestSkill(t, tmpDir, "alpha", "Alpha", nil, "# Alpha") + + m := NewManager(tmpDir) + m.LoadAll() + + skills := m.ListSkills() + if len(skills) != 2 { + t.Errorf("expected 2 skills, got %d", len(skills)) + } + + // Should be sorted alphabetically + if len(skills) >= 2 { + if skills[0].Name != "alpha" { + t.Errorf("expected first skill 'alpha', got %q", skills[0].Name) + } + if skills[1].Name != "beta" { + t.Errorf("expected second skill 'beta', got %q", skills[1].Name) + } + } +} + +func TestListSkillsEmpty(t *testing.T) { + m := NewManager(t.TempDir()) + m.LoadAll() + + skills := m.ListSkills() + if len(skills) != 0 { + t.Errorf("expected empty list, got %d skills", len(skills)) + } +} + +func TestFindSkill(t *testing.T) { + tmpDir := t.TempDir() + createTestSkill(t, tmpDir, "browser", "Browser automation", []string{"browser", "navigate", "screenshot"}, "# Browser") + createTestSkill(t, tmpDir, "memory", "Project memory", []string{"memory", "remember"}, "# Memory") + createTestSkill(t, tmpDir, "convert", "File converter", []string{"convert", "pdf"}, "# Convert") + + m := NewManager(tmpDir) + m.LoadAll() + + // Find by trigger matching "browser" + results := m.FindSkill("I need to use the browser to navigate") + if len(results) == 0 { + t.Fatal("expected at least 1 match for 'browser'") + } + + found := false + for _, s := range results { + if s.Name == "browser" { + found = true + break + } + } + if !found { + t.Error("expected 'browser' skill in results") + } +} + +func TestFindSkillNoMatch(t *testing.T) { + tmpDir := t.TempDir() + createTestSkill(t, tmpDir, "browser", "Browser", []string{"browser"}, "# Browser") + + m := NewManager(tmpDir) + m.LoadAll() + + results := m.FindSkill("completely unrelated query") + if len(results) != 0 { + t.Errorf("expected 0 matches, got %d", len(results)) + } +} + +func TestFindSkillCaseInsensitive(t *testing.T) { + tmpDir := t.TempDir() + createTestSkill(t, tmpDir, "browser", "Browser", []string{"Browser"}, "# Browser") + + m := NewManager(tmpDir) + m.LoadAll() + + results := m.FindSkill("browser") + if len(results) != 1 { + t.Errorf("expected 1 match for lowercase 'browser', got %d", len(results)) + } + + results = m.FindSkill("BROWSER") + if len(results) != 1 { + t.Errorf("expected 1 match for uppercase 'BROWSER', got %d", len(results)) + } +} + +func TestFindSkillRelevanceOrder(t *testing.T) { + tmpDir := t.TempDir() + createTestSkill(t, tmpDir, "multi-match", "Multiple matches", []string{"alpha", "beta", "gamma"}, "# Multi") + createTestSkill(t, tmpDir, "single-match", "Single match", []string{"alpha"}, "# Single") + + m := NewManager(tmpDir) + m.LoadAll() + + // A query mentioning multiple triggers + results := m.FindSkill("alpha beta") + if len(results) != 2 { + t.Errorf("expected 2 matches, got %d", len(results)) + } + + // The multi-match skill should come first (more trigger matches) + if len(results) >= 2 { + if results[0].Name != "multi-match" { + t.Errorf("expected 'multi-match' first (more relevance), got %q", results[0].Name) + } + } +} + +func TestReload(t *testing.T) { + tmpDir := t.TempDir() + createTestSkill(t, tmpDir, "initial", "Initial", []string{"init"}, "# Initial") + + m := NewManager(tmpDir) + m.LoadAll() + + if len(m.ListSkills()) != 1 { + t.Errorf("expected 1 skill after load, got %d", len(m.ListSkills())) + } + + // Add another skill + createTestSkill(t, tmpDir, "added", "Added later", []string{"new"}, "# New") + + count, err := m.Reload() + if err != nil { + t.Fatalf("Reload failed: %v", err) + } + if count != 2 { + t.Errorf("expected 2 skills after reload, got %d", count) + } +} + +func TestSkillMatchTrigger(t *testing.T) { + skill := &Skill{ + Name: "test", + Triggers: []string{"browser", "navigate"}, + } + + tests := []struct { + query string + want bool + }{ + {"I need to use the browser", true}, + {"navigate to a page", true}, + {"Browser automation", true}, + {"something unrelated", false}, + } + + for _, tt := range tests { + got := skill.MatchTrigger(tt.query) + if got != tt.want { + t.Errorf("MatchTrigger(%q) = %v, want %v", tt.query, got, tt.want) + } + } +} + +func TestSkillHasScripts(t *testing.T) { + s1 := &Skill{Name: "no-scripts"} + if s1.HasScripts() { + t.Error("expected HasScripts() = false for empty scripts") + } + + s2 := &Skill{Name: "has-scripts", Scripts: []string{"script.sh"}} + if !s2.HasScripts() { + t.Error("expected HasScripts() = true for non-empty scripts") + } +} + +func TestExpandHome(t *testing.T) { + home, err := os.UserHomeDir() + if err != nil { + t.Fatalf("failed to get home dir: %v", err) + } + + result := expandHome("~/test/path") + if !strings.HasPrefix(result, home) { + t.Errorf("expected path starting with %q, got %q", home, result) + } + + // Non-tilde path should not change + if expandHome("/absolute/path") != "/absolute/path" { + t.Error("absolute path should not be modified") + } +} diff --git a/pkg/skill/parser.go b/pkg/skill/parser.go new file mode 100644 index 0000000..43913cb --- /dev/null +++ b/pkg/skill/parser.go @@ -0,0 +1,246 @@ +package skill + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" +) + +// FrontmatterDelimiters for YAML frontmatter in SKILL.md files. +const ( + frontmatterDelim = "---" +) + +// ParseSkillFile parses a SKILL.md file and returns a populated Skill struct. +// +// The expected format is: +// +// --- +// name: my-skill +// description: Does something useful +// triggers: ["keyword1", "keyword2"] +// --- +// +// # My Skill +// +// Detailed description... +func ParseSkillFile(path string) (*Skill, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("skill: cannot read %q: %w", path, err) + } + + return ParseSkillData(path, data) +} + +// ParseSkillData parses SKILL.md content from raw bytes. +// The path parameter is used to locate the scripts/ directory. +func ParseSkillData(path string, data []byte) (*Skill, error) { + content := string(data) + + skill := &Skill{ + Path: path, + Body: content, + Triggers: []string{}, + } + + // Parse YAML frontmatter + rest, err := parseFrontmatter(content, skill) + if err != nil { + return nil, err + } + skill.Body = strings.TrimSpace(rest) + + // Validate required fields + if skill.Name == "" { + return nil, fmt.Errorf("skill: %q is missing 'name' in frontmatter", path) + } + + // Discover scripts directory + skillDir := filepath.Dir(path) + scriptsDir := filepath.Join(skillDir, "scripts") + skill.ScriptsDir = scriptsDir + + if info, err := os.Stat(scriptsDir); err == nil && info.IsDir() { + scripts, err := discoverScripts(scriptsDir) + if err != nil { + return nil, fmt.Errorf("skill: failed to discover scripts in %q: %w", scriptsDir, err) + } + skill.Scripts = scripts + } + + return skill, nil +} + +// parseFrontmatter extracts YAML frontmatter delimited by "---" lines +// and populates the Skill struct fields. +func parseFrontmatter(content string, skill *Skill) (string, error) { + content = strings.TrimSpace(content) + + if !strings.HasPrefix(content, frontmatterDelim) { + // No frontmatter — treat entire content as body + return content, nil + } + + // Find the closing delimiter + rest := content[len(frontmatterDelim):] + rest = strings.TrimLeft(rest, "\n\r") + + endIdx := strings.Index(rest, "\n"+frontmatterDelim) + if endIdx < 0 { + // Also check for end-of-file style + endIdx = strings.Index(rest, frontmatterDelim) + if endIdx < 0 { + return "", fmt.Errorf("skill: unclosed frontmatter in skill file") + } + } + + frontmatter := rest[:endIdx] + body := rest[endIdx+len(frontmatterDelim)+1:] + + // Parse the YAML frontmatter (simple key-value parser) + if err := parseSimpleYAML(frontmatter, skill); err != nil { + return "", err + } + + return body, nil +} + +// parseSimpleYAML parses a simplified YAML format for skill frontmatter. +// Supports: string values, quoted strings, and array values. +func parseSimpleYAML(yaml string, skill *Skill) error { + lines := strings.Split(yaml, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + colonIdx := strings.Index(line, ":") + if colonIdx < 0 { + continue + } + + key := strings.TrimSpace(line[:colonIdx]) + value := strings.TrimSpace(line[colonIdx+1:]) + + switch key { + case "name": + skill.Name = trimQuotes(value) + + case "description": + skill.Description = trimQuotes(value) + + case "triggers": + triggers, err := parseYAMLArray(value) + if err != nil { + return fmt.Errorf("skill: invalid triggers format: %w", err) + } + skill.Triggers = triggers + } + } + return nil +} + +// parseYAMLArray parses a YAML array like '["a", "b", "c"]' or '[a, b, c]'. +func parseYAMLArray(value string) ([]string, error) { + value = strings.TrimSpace(value) + + // Handle YAML list format: ["a", "b"] or [a, b] + if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + inner := value[1 : len(value)-1] + if strings.TrimSpace(inner) == "" { + return []string{}, nil + } + parts := splitCommas(inner) + result := make([]string, len(parts)) + for i, p := range parts { + result[i] = trimQuotes(strings.TrimSpace(p)) + } + return result, nil + } + + // Handle YAML list format with dashes: + // triggers: + // - browser + // - navigate + // (This would be handled line-by-line in a different flow) + // For now, treat single value as a one-element list + if value != "" && value != "[]" { + return []string{trimQuotes(value)}, nil + } + + return []string{}, nil +} + +// splitCommas splits a comma-separated string respecting quoted sections. +func splitCommas(s string) []string { + var parts []string + var current strings.Builder + inQuote := false + quoteChar := byte(0) + + for i := 0; i < len(s); i++ { + c := s[i] + if inQuote { + current.WriteByte(c) + if c == quoteChar { + inQuote = false + } + } else if c == '"' || c == '\'' { + current.WriteByte(c) + inQuote = true + quoteChar = c + } else if c == ',' { + parts = append(parts, current.String()) + current.Reset() + } else { + current.WriteByte(c) + } + } + if current.Len() > 0 { + parts = append(parts, current.String()) + } + return parts +} + +// trimQuotes removes surrounding quotes from a string value. +func trimQuotes(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 { + if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') { + return s[1 : len(s)-1] + } + } + return s +} + +// discoverScripts lists all executable/readable files in a scripts directory. +func discoverScripts(scriptsDir string) ([]string, error) { + entries, err := os.ReadDir(scriptsDir) + if err != nil { + return nil, err + } + + var scripts []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + scripts = append(scripts, entry.Name()) + } + + sort.Strings(scripts) + return scripts, nil +} + +// LoadSkillFromDir loads a skill from a directory containing a SKILL.md file. +func LoadSkillFromDir(dir string) (*Skill, error) { + skillPath := filepath.Join(dir, "SKILL.md") + if _, err := os.Stat(skillPath); os.IsNotExist(err) { + return nil, fmt.Errorf("skill: no SKILL.md found in %q", dir) + } + return ParseSkillFile(skillPath) +} diff --git a/pkg/skill/skill.go b/pkg/skill/skill.go new file mode 100644 index 0000000..640ca9c --- /dev/null +++ b/pkg/skill/skill.go @@ -0,0 +1,56 @@ +// Package skill provides the Skill definition and management system. +// +// Skills are composable capabilities loaded from ~/.agents/skills/. +// Each skill has a SKILL.md manifest with YAML frontmatter and optional +// scripts in a scripts/ subdirectory. Skills can be discovered and +// invoked by trigger keywords. +package skill + +import ( + "fmt" + "strings" +) + +// Skill represents a composable capability loaded from the skills directory. +// +// Each Skill is defined by a SKILL.md file with YAML frontmatter containing +// metadata (name, description, triggers) and optional executable scripts +// in a scripts/ subdirectory. +type Skill struct { + // Name is the unique identifier for this skill (e.g., "dev-browser"). + Name string `yaml:"name"` + // Description is a human-readable explanation of what this skill does. + Description string `yaml:"description"` + // Triggers are keywords that activate this skill from natural language. + Triggers []string `yaml:"triggers"` + // Scripts is the list of script file names in the scripts/ directory. + Scripts []string `yaml:"-"` + // ScriptsDir is the absolute path to the scripts/ directory. + ScriptsDir string `yaml:"-"` + // Body is the markdown content after the YAML frontmatter. + Body string `yaml:"-"` + // Path is the absolute path to the SKILL.md file. + Path string `yaml:"-"` +} + +// MatchTrigger checks if the given query matches any of the skill's triggers. +// Matching is case-insensitive and supports partial matches. +func (s *Skill) MatchTrigger(query string) bool { + query = strings.ToLower(query) + for _, trigger := range s.Triggers { + if strings.Contains(strings.ToLower(query), strings.ToLower(trigger)) { + return true + } + } + return false +} + +// String returns a human-readable representation of the skill. +func (s *Skill) String() string { + return fmt.Sprintf("Skill{Name: %q, Triggers: %v, Scripts: %d}", s.Name, s.Triggers, len(s.Scripts)) +} + +// HasScripts returns true if the skill has at least one script. +func (s *Skill) HasScripts() bool { + return len(s.Scripts) > 0 +} diff --git a/pkg/tool/builtin.go b/pkg/tool/builtin.go new file mode 100644 index 0000000..863ea8d --- /dev/null +++ b/pkg/tool/builtin.go @@ -0,0 +1,433 @@ +package tool + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/orca/orca/pkg/sandbox" +) + +// --------------------------------------------------------------------------- +// exec — Execute a shell command via the sandbox +// --------------------------------------------------------------------------- + +// execTool runs shell commands through the ProcessSandbox. +type execTool struct { + sandbox sandbox.Sandbox +} + +// NewExecTool creates a new exec tool backed by the given sandbox. +func NewExecTool(sb sandbox.Sandbox) Tool { + if sb == nil { + sb = sandbox.NewProcessSandbox() + } + return &execTool{sandbox: sb} +} + +func (t *execTool) Name() string { return "exec" } + +func (t *execTool) Description() string { + return "Execute a shell command and return its output. Use this for running scripts, " + + "installing packages, compiling code, or any command-line operation." +} + +func (t *execTool) Parameters() map[string]ParameterSchema { + return map[string]ParameterSchema{ + "command": { + Type: "string", + Description: "The shell command to execute (e.g., 'ls -la' or 'python script.py')", + Required: true, + }, + "timeout": { + Type: "number", + Description: "Timeout in seconds for the command execution (default: 30)", + Required: false, + Default: float64(30), + }, + "workdir": { + Type: "string", + Description: "Working directory for the command (default: sandbox default)", + Required: false, + }, + } +} + +func (t *execTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) { + cmdStr, ok := args["command"].(string) + if !ok || cmdStr == "" { + return ErrorResult("'command' argument is required and must be a string"), nil + } + + // Use a timeout if specified in args + execCtx := ctx + if timeoutVal, ok := args["timeout"]; ok { + if timeout, err := toFloat64(timeoutVal); err == nil && timeout > 0 { + var cancel context.CancelFunc + execCtx, cancel = context.WithTimeout(ctx, time.Duration(timeout*float64(time.Second))) + defer cancel() + } + } + + // Set working directory if specified + sb := t.sandbox + if wd, ok := args["workdir"].(string); ok && wd != "" { + if ps, ok := sb.(*sandbox.ProcessSandbox); ok { + ps.WorkingDir = wd + } + } + + // Execute the command via shell + result, err := sb.Execute(execCtx, "sh", "-c", cmdStr) + if err != nil { + return nil, fmt.Errorf("exec tool: %w", err) + } + + return &Result{ + Success: result.ExitCode == 0, + Data: map[string]interface{}{ + "stdout": result.Stdout, + "stderr": result.Stderr, + "exit_code": result.ExitCode, + }, + }, nil +} + +// --------------------------------------------------------------------------- +// read_file — Read the contents of a file +// --------------------------------------------------------------------------- + +type readFileTool struct{} + +func NewReadFileTool() Tool { return &readFileTool{} } + +func (t *readFileTool) Name() string { return "read_file" } + +func (t *readFileTool) Description() string { + return "Read the contents of a file from the local filesystem. Returns the file content as a string." +} + +func (t *readFileTool) Parameters() map[string]ParameterSchema { + return map[string]ParameterSchema{ + "path": { + Type: "string", + Description: "Absolute path to the file to read", + Required: true, + }, + } +} + +func (t *readFileTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) { + path, ok := args["path"].(string) + if !ok || path == "" { + return ErrorResult("'path' argument is required and must be a string"), nil + } + + // Prevent directory traversal / read of non-regular files + info, err := os.Stat(path) + if err != nil { + return ErrorResult(fmt.Sprintf("cannot access %q: %v", path, err)), nil + } + if info.IsDir() { + return ErrorResult(fmt.Sprintf("%q is a directory, not a file", path)), nil + } + + data, err := os.ReadFile(path) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to read %q: %v", path, err)), nil + } + + return SuccessResult(map[string]interface{}{ + "path": path, + "content": string(data), + "size": len(data), + }), nil +} + +// --------------------------------------------------------------------------- +// write_file — Write content to a file +// --------------------------------------------------------------------------- + +type writeFileTool struct{} + +func NewWriteFileTool() Tool { return &writeFileTool{} } + +func (t *writeFileTool) Name() string { return "write_file" } + +func (t *writeFileTool) Description() string { + return "Write content to a file on the local filesystem. Creates parent directories if needed." +} + +func (t *writeFileTool) Parameters() map[string]ParameterSchema { + return map[string]ParameterSchema{ + "path": { + Type: "string", + Description: "Absolute path where the file should be written", + Required: true, + }, + "content": { + Type: "string", + Description: "The content to write to the file", + Required: true, + }, + } +} + +func (t *writeFileTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) { + path, ok := args["path"].(string) + if !ok || path == "" { + return ErrorResult("'path' argument is required and must be a string"), nil + } + + content, ok := args["content"].(string) + if !ok { + return ErrorResult("'content' argument is required and must be a string"), nil + } + + // Create parent directories + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return ErrorResult(fmt.Sprintf("failed to create directories for %q: %v", path, err)), nil + } + + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + return ErrorResult(fmt.Sprintf("failed to write %q: %v", path, err)), nil + } + + return SuccessResult(map[string]interface{}{ + "path": path, + "size": len(content), + }), nil +} + +// --------------------------------------------------------------------------- +// list_dir — List the contents of a directory +// --------------------------------------------------------------------------- + +type listDirTool struct{} + +func NewListDirTool() Tool { return &listDirTool{} } + +func (t *listDirTool) Name() string { return "list_dir" } + +func (t *listDirTool) Description() string { + return "List files and directories in a given path. Returns names, sizes, and modification times." +} + +func (t *listDirTool) Parameters() map[string]ParameterSchema { + return map[string]ParameterSchema{ + "path": { + Type: "string", + Description: "Absolute path to the directory to list", + Required: true, + }, + "recursive": { + Type: "boolean", + Description: "Whether to list recursively (default: false)", + Required: false, + Default: false, + }, + } +} + +func (t *listDirTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) { + path, ok := args["path"].(string) + if !ok || path == "" { + return ErrorResult("'path' argument is required and must be a string"), nil + } + + recursive, _ := args["recursive"].(bool) + + info, err := os.Stat(path) + if err != nil { + return ErrorResult(fmt.Sprintf("cannot access %q: %v", path, err)), nil + } + if !info.IsDir() { + return ErrorResult(fmt.Sprintf("%q is not a directory", path)), nil + } + + var entries []map[string]interface{} + + if recursive { + err = filepath.Walk(path, func(p string, fi os.FileInfo, err error) error { + if err != nil { + return err + } + rel, _ := filepath.Rel(path, p) + if rel == "." { + return nil + } + entries = append(entries, entryToMap(p, rel, fi)) + return nil + }) + } else { + files, err := os.ReadDir(path) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to list %q: %v", path, err)), nil + } + for _, f := range files { + fi, err := f.Info() + if err != nil { + continue + } + fullPath := filepath.Join(path, f.Name()) + entries = append(entries, entryToMap(fullPath, f.Name(), fi)) + } + } + + if err != nil { + return ErrorResult(fmt.Sprintf("failed to list %q: %v", path, err)), nil + } + + return SuccessResult(map[string]interface{}{ + "path": path, + "entries": entries, + "count": len(entries), + }), nil +} + +func entryToMap(fullPath, name string, fi os.FileInfo) map[string]interface{} { + return map[string]interface{}{ + "name": name, + "path": fullPath, + "size": fi.Size(), + "is_dir": fi.IsDir(), + "mode": fi.Mode().String(), + "modtime": fi.ModTime().Format("2006-01-02T15:04:05Z07:00"), + } +} + +// --------------------------------------------------------------------------- +// search_files — Search for content in files +// --------------------------------------------------------------------------- + +type searchFilesTool struct{} + +func NewSearchFilesTool() Tool { return &searchFilesTool{} } + +func (t *searchFilesTool) Name() string { return "search_files" } + +func (t *searchFilesTool) Description() string { + return "Search for a pattern in files within a directory. Supports simple substring matching." +} + +func (t *searchFilesTool) Parameters() map[string]ParameterSchema { + return map[string]ParameterSchema{ + "pattern": { + Type: "string", + Description: "The text pattern to search for (substring match)", + Required: true, + }, + "path": { + Type: "string", + Description: "Directory to search in (default: current directory)", + Required: false, + Default: ".", + }, + "include": { + Type: "string", + Description: "File glob pattern to include (e.g., '*.go', '*.{ts,tsx}')", + Required: false, + }, + } +} + +func (t *searchFilesTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) { + pattern, ok := args["pattern"].(string) + if !ok || pattern == "" { + return ErrorResult("'pattern' argument is required and must be a string"), nil + } + + searchPath := "." + if p, ok := args["path"].(string); ok && p != "" { + searchPath = p + } + + include, _ := args["include"].(string) + + // Verify search path exists + info, err := os.Stat(searchPath) + if err != nil { + return ErrorResult(fmt.Sprintf("cannot access search path %q: %v", searchPath, err)), nil + } + if !info.IsDir() { + return ErrorResult(fmt.Sprintf("%q is not a directory", searchPath)), nil + } + + var matches []map[string]interface{} + + err = filepath.Walk(searchPath, func(p string, fi os.FileInfo, err error) error { + if err != nil { + return nil // skip files we can't access + } + if fi.IsDir() { + return nil + } + + // Apply include filter + if include != "" { + matched, err := filepath.Match(include, fi.Name()) + if err != nil || !matched { + return nil + } + } + + // Read file and search + data, err := os.ReadFile(p) + if err != nil { + return nil // skip unreadable files + } + + content := string(data) + if strings.Contains(content, pattern) { + matches = append(matches, map[string]interface{}{ + "path": p, + "size": len(data), + }) + } + return nil + }) + + if err != nil { + return ErrorResult(fmt.Sprintf("search failed: %v", err)), nil + } + + return SuccessResult(map[string]interface{}{ + "pattern": pattern, + "path": searchPath, + "matches": matches, + "count": len(matches), + }), nil +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// toFloat64 converts an interface{} value to float64. +// Supports float64, int, int64, and json.Number types. +func toFloat64(v interface{}) (float64, error) { + switch val := v.(type) { + case float64: + return val, nil + case int: + return float64(val), nil + case int64: + return float64(val), nil + case json.Number: + return val.Float64() + default: + return 0, fmt.Errorf("cannot convert %T to float64", v) + } +} + +// Compile-time interface checks. +var _ Tool = (*execTool)(nil) +var _ Tool = (*readFileTool)(nil) +var _ Tool = (*writeFileTool)(nil) +var _ Tool = (*listDirTool)(nil) +var _ Tool = (*searchFilesTool)(nil) diff --git a/pkg/tool/builtin_test.go b/pkg/tool/builtin_test.go new file mode 100644 index 0000000..8d43935 --- /dev/null +++ b/pkg/tool/builtin_test.go @@ -0,0 +1,399 @@ +package tool + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/orca/orca/pkg/sandbox" +) + +func TestExecTool(t *testing.T) { + sb := sandbox.NewProcessSandbox() + execT := NewExecTool(sb) + + if execT.Name() != "exec" { + t.Errorf("expected name 'exec', got %q", execT.Name()) + } + if execT.Description() == "" { + t.Error("expected non-empty description") + } + + params := execT.Parameters() + if _, ok := params["command"]; !ok { + t.Error("expected 'command' parameter") + } +} + +func TestExecToolExecute(t *testing.T) { + sb := sandbox.NewProcessSandbox() + execT := NewExecTool(sb) + ctx := context.Background() + + result, err := execT.Execute(ctx, map[string]interface{}{ + "command": "echo hello", + }) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if !result.Success { + t.Errorf("expected success, got error: %s", result.Error) + } + + data := result.Data.(map[string]interface{}) + stdout := data["stdout"].(string) + if strings.TrimSpace(stdout) != "hello" { + t.Errorf("expected stdout 'hello', got %q", stdout) + } +} + +func TestExecToolMissingCommand(t *testing.T) { + sb := sandbox.NewProcessSandbox() + execT := NewExecTool(sb) + ctx := context.Background() + + result, err := execT.Execute(ctx, map[string]interface{}{}) + if err != nil { + t.Fatalf("Execute should not error for invalid args: %v", err) + } + if result.Success { + t.Error("expected failure for missing command") + } + if !strings.Contains(result.Error, "command") { + t.Errorf("error should mention 'command', got: %s", result.Error) + } +} + +func TestReadFileTool(t *testing.T) { + readT := NewReadFileTool() + + if readT.Name() != "read_file" { + t.Errorf("expected name 'read_file', got %q", readT.Name()) + } +} + +func TestReadFileToolExecute(t *testing.T) { + // Create a temp file + tmpFile, err := os.CreateTemp("", "orca-test-*") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + content := "test content\nline 2\n" + if _, err := tmpFile.WriteString(content); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + tmpFile.Close() + defer os.Remove(tmpFile.Name()) + + readT := NewReadFileTool() + ctx := context.Background() + + result, err := readT.Execute(ctx, map[string]interface{}{ + "path": tmpFile.Name(), + }) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if !result.Success { + t.Errorf("expected success, got error: %s", result.Error) + } + + data := result.Data.(map[string]interface{}) + gotContent := data["content"].(string) + if gotContent != content { + t.Errorf("expected content %q, got %q", content, gotContent) + } +} + +func TestReadFileToolMissingPath(t *testing.T) { + readT := NewReadFileTool() + ctx := context.Background() + + result, err := readT.Execute(ctx, map[string]interface{}{}) + if err != nil { + t.Fatalf("Execute should not error for invalid args: %v", err) + } + if result.Success { + t.Error("expected failure for missing path") + } +} + +func TestReadFileToolNonexistent(t *testing.T) { + readT := NewReadFileTool() + ctx := context.Background() + + result, err := readT.Execute(ctx, map[string]interface{}{ + "path": "/nonexistent/path/that/does/not/exist.txt", + }) + if err != nil { + t.Fatalf("Execute should not error for missing file: %v", err) + } + if result.Success { + t.Error("expected failure for nonexistent file") + } +} + +func TestWriteFileTool(t *testing.T) { + writeT := NewWriteFileTool() + + if writeT.Name() != "write_file" { + t.Errorf("expected name 'write_file', got %q", writeT.Name()) + } +} + +func TestWriteFileToolExecute(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "orca-write-test-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + testPath := filepath.Join(tmpDir, "nested", "test.txt") + content := "hello world" + + writeT := NewWriteFileTool() + ctx := context.Background() + + result, err := writeT.Execute(ctx, map[string]interface{}{ + "path": testPath, + "content": content, + }) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if !result.Success { + t.Errorf("expected success, got error: %s", result.Error) + } + + // Verify the file was written + data, err := os.ReadFile(testPath) + if err != nil { + t.Fatalf("failed to read written file: %v", err) + } + if string(data) != content { + t.Errorf("expected content %q, got %q", content, string(data)) + } +} + +func TestWriteFileToolMissingArgs(t *testing.T) { + writeT := NewWriteFileTool() + ctx := context.Background() + + // Missing path + result, err := writeT.Execute(ctx, map[string]interface{}{ + "content": "test", + }) + if err != nil { + t.Fatalf("Execute should not error for invalid args: %v", err) + } + if result.Success { + t.Error("expected failure for missing path") + } + + // Missing content + result, err = writeT.Execute(ctx, map[string]interface{}{ + "path": "/tmp/test.txt", + }) + if err != nil { + t.Fatalf("Execute should not error for invalid args: %v", err) + } + if result.Success { + t.Error("expected failure for missing content") + } +} + +func TestListDirTool(t *testing.T) { + listT := NewListDirTool() + + if listT.Name() != "list_dir" { + t.Errorf("expected name 'list_dir', got %q", listT.Name()) + } +} + +func TestListDirToolExecute(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "orca-list-test-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create some test files + os.WriteFile(filepath.Join(tmpDir, "a.txt"), []byte("a"), 0644) + os.WriteFile(filepath.Join(tmpDir, "b.txt"), []byte("bb"), 0644) + os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755) + + listT := NewListDirTool() + ctx := context.Background() + + result, err := listT.Execute(ctx, map[string]interface{}{ + "path": tmpDir, + }) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if !result.Success { + t.Errorf("expected success, got error: %s", result.Error) + } + + data := result.Data.(map[string]interface{}) + count := data["count"].(int) + if count != 3 { + t.Errorf("expected 3 entries, got %d", count) + } +} + +func TestListDirToolRecursive(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "orca-list-rec-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + os.MkdirAll(filepath.Join(tmpDir, "a", "b"), 0755) + os.WriteFile(filepath.Join(tmpDir, "a", "b", "c.txt"), []byte("c"), 0644) + + listT := NewListDirTool() + ctx := context.Background() + + result, err := listT.Execute(ctx, map[string]interface{}{ + "path": tmpDir, + "recursive": true, + }) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if !result.Success { + t.Errorf("expected success, got error: %s", result.Error) + } + + data := result.Data.(map[string]interface{}) + entries := data["entries"].([]map[string]interface{}) + if len(entries) < 2 { + t.Errorf("expected at least 2 recursive entries, got %d", len(entries)) + } +} + +func TestListDirToolNonexistent(t *testing.T) { + listT := NewListDirTool() + ctx := context.Background() + + result, err := listT.Execute(ctx, map[string]interface{}{ + "path": "/nonexistent/path", + }) + if err != nil { + t.Fatalf("Execute should not error for missing path: %v", err) + } + if result.Success { + t.Error("expected failure for nonexistent path") + } +} + +func TestSearchFilesTool(t *testing.T) { + searchT := NewSearchFilesTool() + + if searchT.Name() != "search_files" { + t.Errorf("expected name 'search_files', got %q", searchT.Name()) + } +} + +func TestSearchFilesToolExecute(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "orca-search-test-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test files + os.WriteFile(filepath.Join(tmpDir, "findme.go"), []byte("package main\nfunc hello() {\n}\n"), 0644) + os.WriteFile(filepath.Join(tmpDir, "other.py"), []byte("def world():\n pass\n"), 0644) + + searchT := NewSearchFilesTool() + ctx := context.Background() + + result, err := searchT.Execute(ctx, map[string]interface{}{ + "pattern": "hello", + "path": tmpDir, + }) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if !result.Success { + t.Errorf("expected success, got error: %s", result.Error) + } + + data := result.Data.(map[string]interface{}) + count := data["count"].(int) + if count != 1 { + t.Errorf("expected 1 match for 'hello', got %d", count) + } +} + +func TestSearchFilesToolNoMatch(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "orca-search-nomatch-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte("nothing here"), 0644) + + searchT := NewSearchFilesTool() + ctx := context.Background() + + result, err := searchT.Execute(ctx, map[string]interface{}{ + "pattern": "nonexistent-pattern-xyz", + "path": tmpDir, + }) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + if !result.Success { + t.Errorf("expected success even with no matches, got error: %s", result.Error) + } + + data := result.Data.(map[string]interface{}) + count := data["count"].(int) + if count != 0 { + t.Errorf("expected 0 matches, got %d", count) + } +} + +func TestSearchFilesToolMissingPattern(t *testing.T) { + searchT := NewSearchFilesTool() + ctx := context.Background() + + result, err := searchT.Execute(ctx, map[string]interface{}{}) + if err != nil { + t.Fatalf("Execute should not error for invalid args: %v", err) + } + if result.Success { + t.Error("expected failure for missing pattern") + } +} + +func TestToolInterfaceSatisfied(t *testing.T) { + sb := sandbox.NewProcessSandbox() + + tools := []Tool{ + NewExecTool(sb), + NewReadFileTool(), + NewWriteFileTool(), + NewListDirTool(), + NewSearchFilesTool(), + } + + names := []string{"exec", "read_file", "write_file", "list_dir", "search_files"} + for i, tool := range tools { + if tool.Name() != names[i] { + t.Errorf("expected name %q, got %q", names[i], tool.Name()) + } + if tool.Description() == "" { + t.Errorf("tool %q has empty description", names[i]) + } + if tool.Parameters() == nil { + t.Errorf("tool %q has nil parameters", names[i]) + } + } +} diff --git a/pkg/tool/manager.go b/pkg/tool/manager.go new file mode 100644 index 0000000..6c3abb9 --- /dev/null +++ b/pkg/tool/manager.go @@ -0,0 +1,108 @@ +package tool + +import ( + "context" + "fmt" + "sort" + "sync" +) + +// Manager is a thread-safe registry that manages tool registration and execution. +// +// Tools are registered by name (case-sensitive) and can be discovered, +// listed, and invoked through the Manager. Duplicate registration returns +// an error. +type Manager struct { + mu sync.RWMutex + tools map[string]Tool +} + +// NewManager creates a new empty tool manager. +func NewManager() *Manager { + return &Manager{ + tools: make(map[string]Tool), + } +} + +// Register adds a tool to the manager. Returns an error if a tool with the +// same name is already registered. +func (m *Manager) Register(tool Tool) error { + m.mu.Lock() + defer m.mu.Unlock() + + name := tool.Name() + if _, exists := m.tools[name]; exists { + return fmt.Errorf("tool %q is already registered", name) + } + + m.tools[name] = tool + return nil +} + +// Unregister removes a tool from the manager by name. +func (m *Manager) Unregister(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.tools[name]; !exists { + return fmt.Errorf("tool %q is not registered", name) + } + + delete(m.tools, name) + return nil +} + +// Get retrieves a tool by name. Returns false if not found. +func (m *Manager) Get(name string) (Tool, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + t, ok := m.tools[name] + return t, ok +} + +// List returns all registered tools sorted by name. +func (m *Manager) List() []Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]Tool, 0, len(m.tools)) + for _, t := range m.tools { + result = append(result, t) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Name() < result[j].Name() + }) + return result +} + +// Execute looks up a tool by name and invokes it with the given arguments. +// Returns an error if the tool is not found. +func (m *Manager) Execute(name string, ctx context.Context, args map[string]interface{}) (*Result, error) { + tool, ok := m.Get(name) + if !ok { + return nil, fmt.Errorf("tool %q not found", name) + } + return tool.Execute(ctx, args) +} + +// Count returns the number of registered tools. +func (m *Manager) Count() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.tools) +} + +// Names returns the names of all registered tools sorted alphabetically. +func (m *Manager) Names() []string { + m.mu.RLock() + defer m.mu.RUnlock() + + names := make([]string, 0, len(m.tools)) + for name := range m.tools { + names = append(names, name) + } + sort.Strings(names) + return names +} diff --git a/pkg/tool/tool.go b/pkg/tool/tool.go new file mode 100644 index 0000000..4470594 --- /dev/null +++ b/pkg/tool/tool.go @@ -0,0 +1,81 @@ +// Package tool defines the Tool interface and the tool management system. +// +// Tools are the atomic capabilities that can be invoked by agents or LLMs. +// Each tool has a name, description, a parameter schema (for LLM function calling), +// and an Execute method that performs the actual work. +// +// Built-in tools include file operations (read, write, list, search) and +// command execution through the sandbox. Custom tools can be registered +// via the Manager. +package tool + +import ( + "context" + "encoding/json" +) + +// ParameterSchema describes a single parameter accepted by a tool. +type ParameterSchema struct { + Type string `json:"type"` + Description string `json:"description"` + Required bool `json:"required"` + Default interface{} `json:"default,omitempty"` + Properties map[string]ParameterSchema `json:"properties,omitempty"` + Items *ParameterSchema `json:"items,omitempty"` + Enum []string `json:"enum,omitempty"` +} + +// Result holds the output of a tool execution. +type Result struct { + Success bool `json:"success"` + Data interface{} `json:"data,omitempty"` + Error string `json:"error,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// Tool defines the interface that all tools must implement. +// +// Tools are registered with a Manager and can be discovered and invoked +// by name. The Execute method receives a context for cancellation and +// a map of string-keyed arguments. +type Tool interface { + // Name returns the unique identifier for this tool. + Name() string + + // Description returns a human-readable description of what this tool does. + Description() string + + // Parameters returns the schema describing accepted arguments. + // Used for LLM function calling and validation. + Parameters() map[string]ParameterSchema + + // Execute performs the tool's function with the given arguments. + // The context controls cancellation and timeouts. + Execute(ctx context.Context, args map[string]interface{}) (*Result, error) +} + +// SuccessResult creates a successful tool result with the given data. +func SuccessResult(data interface{}) *Result { + return &Result{ + Success: true, + Data: data, + } +} + +// ErrorResult creates a failed tool result with the given error message. +func ErrorResult(err string) *Result { + return &Result{ + Success: false, + Error: err, + } +} + +// MustMarshalArgs converts a map of arguments to JSON bytes, panicking on error. +// Useful for logging and debugging. +func MustMarshalArgs(args map[string]interface{}) []byte { + b, err := json.Marshal(args) + if err != nil { + panic("tool: failed to marshal args: " + err.Error()) + } + return b +} diff --git a/thoughts/shared/designs/2026-05-07-orca-agent-framework-design.md b/thoughts/shared/designs/2026-05-07-orca-agent-framework-design.md new file mode 100644 index 0000000..875c51a --- /dev/null +++ b/thoughts/shared/designs/2026-05-07-orca-agent-framework-design.md @@ -0,0 +1,416 @@ +--- +date: 2026-05-07 +topic: "Go Agent Framework - Orca" +status: validated +--- + +# Go Agent Framework (Orca) 设计文档 + +## Problem Statement + +构建一个基于 Go 的基础 Agent 框架,支持多 Agent 协作、持久化会话记忆、Skill 技能自动识别、沙箱安全执行、自定义 Tool 注册扩展,并接入 Ollama 本地模型(gemma4:e4b)。 + +**核心挑战:** +- 如何在 Go 中实现轻量、高并发的多 Agent 系统 +- 如何安全地执行用户命令和 Skill 脚本 +- 如何设计可扩展的插件机制(Skill / Tool) +- 如何管理会话上下文和记忆 + +## Constraints + +1. **语言约束:** 纯 Go 实现,最小化外部依赖 +2. **存储约束:** 使用 JSON Lines(无 SQLite/数据库依赖) +3. **隔离约束:** 进程级限制(chroot + 资源限制),不依赖 Docker +4. **模型约束:** 仅接入 Ollama 本地模型,默认 gemma4:e4b +5. **Skill 目录:** 读取 `~/.agents/skills/` 下的 Skill 定义 +6. **部署约束:** 单二进制文件,零配置启动 + +## Approach + +### 架构风格:微内核 + Actor 模型 + +采用**微内核架构**作为基础,所有功能(Skill、Tool、LLM 驱动)都以**插件**形式注册到核心。 + +每个 **Agent 实例是一个独立的 Actor**,通过 **消息总线(Message Bus)** 进行通信。这完美契合 Go 的 goroutine + channel 并发模型。 + +**为什么选择这个组合?** +- 微内核保证核心最小化,Skill 和 Tool 热插拔 +- Actor 模型天然支持高并发,避免共享状态 +- 两者结合 = 轻量级、高扩展、Go 原生友好 + +**放弃的其他方案:** +- Docker 沙箱:太重,违背最小依赖原则 +- SQLite 存储:增加依赖,JSONL 已足够 +- 中央协调器:单点瓶颈,不如 Actor 模型灵活 + +## Architecture + +### 整体架构图 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ CLI / API Layer │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Core Kernel (微内核) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ +│ │ Message Bus │ │ Plugin Reg │ │ Session Manager │ │ +│ │ (channel) │ │ (registry) │ │ (JSONL-based) │ │ +│ └──────────────┘ └──────────────┘ └──────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + │ + ┌───────────────┼───────────────┐ + ▼ ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ Agent Actor │ │ Agent Actor │ │ Agent Actor │ +│ (Specialist 1) │ │ (Specialist 2) │ │ (Orchestrator) │ +└────────┬────────┘ └────────┬────────┘ └────────┬────────┘ + │ │ │ + └───────────────────┼───────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Plugin Layer │ +│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐│ +│ │ Skills │ │ Tools │ │ Ollama │ │ Custom Tools ││ +│ │(Skill Mgr)│ │(Tool Mgr)│ │ (Driver) │ │ (Registry) ││ +│ └──────────┘ └──────────┘ └──────────┘ └──────────────┘│ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Sandbox Layer │ +│ (Process-level isolation + Resource limits) │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 模块职责 + +| 模块 | 职责 | +|------|------| +| **Core Kernel** | 消息路由、插件生命周期管理、会话协调 | +| **Message Bus** | 基于 Go channel 的异步消息传递系统 | +| **Plugin Registry** | 统一的 Skill/Tool/LLM 驱动注册中心 | +| **Session Manager** | 基于 JSONL 的会话历史读写和上下文窗口管理 | +| **Agent Actor** | 独立 goroutine,持有状态,接收/发送消息 | +| **Skill Manager** | 扫描 `~/.agents/skills/`,解析 SKILL.md,加载技能 | +| **Tool Manager** | 管理内置工具和自定义工具的注册/调用 | +| **Ollama Driver** | 封装 Ollama HTTP API,支持流式响应 | +| **Sandbox** | 安全执行 shell 命令和脚本,限制资源和时间 | + +## Components + +### 1. Core Kernel (微内核) + +**职责:** 框架的最小化核心,只负责消息路由和插件生命周期。 + +**设计要点:** +- 使用 Go 的 `interface{}` 或泛型定义插件契约 +- 启动时加载所有已注册的插件 +- 提供事件总线供插件间通信 +- **不**包含任何业务逻辑(如 LLM 调用、命令执行) + +**核心接口:** +``` +// 所有插件必须实现 +Plugin interface { + Name() string + Init(kernel *Kernel) error + Shutdown() error +} + +// 消息总线 +MessageBus interface { + Publish(topic string, msg Message) error + Subscribe(topic string, handler Handler) (Subscription, error) +} +``` + +### 2. Actor System (多 Agent 引擎) + +**职责:** 管理 Agent 生命周期和消息通信。 + +**设计要点:** +- 每个 Agent 是一个独立的 goroutine,通过 channel 接收消息 +- Agent 持有自己的状态(角色、上下文、工具列表) +- 支持三种 Agent 类型:Orchestrator(协调者)、Worker(执行者)、Specialist(专家) +- 消息类型:`TaskRequest`、`TaskResponse`、`ToolCall`、`Observation` + +**Agent 状态机:** +``` +Idle → Processing → [ToolCall] → WaitingForTool → Processing → Completed + ↓ + [Error] → Failed +``` + +### 3. Session Manager (会话记忆) + +**职责:** 持久化会话历史,支持上下文窗口管理。 + +**设计要点:** +- 每个会话一个 JSONL 文件:`~/.orca/sessions/{session_id}.jsonl` +- 每行一个 JSON 对象:`{role, content, timestamp, metadata}` +- 提供 `GetContext(windowSize)` 方法,返回最近的 N 条消息 +- 支持会话列表、搜索、归档 + +**为什么 JSON Lines?** +- 追加写入 O(1),无需加载整个文件 +- 人类可读,便于调试 +- 零依赖,无需数据库驱动 +- 通过简单文件锁保证并发安全 + +### 4. Skill Manager (技能系统) + +**职责:** 自动发现和加载 Skill。 + +**设计要点:** +- 启动时扫描 `~/.agents/skills/` 下的所有子目录 +- 解析每个 Skill 目录下的 `SKILL.md` +- 提取元数据:`name`、`description`、`triggers`(触发词) +- Skill 可以包含脚本文件(`scripts/` 目录) +- 提供 `FindSkill(query string)` 方法,基于触发词匹配 + +**Skill 结构:** +```yaml +name: "md2pdf" +description: "Convert Markdown to PDF..." +triggers: ["pdf", "markdown", "export"] +scripts: + - "scripts/convert.py" + - "scripts/setup.sh" +``` + +### 5. Tool Manager (工具系统) + +**职责:** 管理可执行工具的注册和调用。 + +**设计要点:** +- **内置工具:** `exec`(执行命令)、`read_file`、`write_file`、`list_dir` +- **Skill 工具:** 从 Skill 的 `scripts/` 目录自动注册 +- **自定义工具:** 通过代码注册,实现 `Tool` 接口 +- 每个工具定义:名称、描述、参数 schema、执行函数 +- LLM 通过 Function Calling 调用工具 + +**Tool 接口:** +``` +Tool interface { + Name() string + Description() string + Parameters() JSONSchema + Execute(ctx Context, args map[string]any) (Result, error) +} +``` + +### 6. Ollama Driver (LLM 驱动) + +**职责:** 封装 Ollama API,提供统一的 LLM 调用接口。 + +**设计要点:** +- 默认模型:`gemma4:e4b` +- 支持流式响应(SSE) +- 支持 Function Calling(通过 tools 参数) +- 自动处理上下文窗口截断 +- 可配置参数:temperature、top_p、max_tokens + +**API 封装:** +``` +LLMClient interface { + Chat(messages []Message, tools []Tool) (Response, error) + ChatStream(messages []Message, tools []Tool) (Stream, error) +} +``` + +### 7. Sandbox (沙箱执行) + +**职责:** 安全地执行终端命令和脚本。 + +**设计要点:** +- 使用 `os/exec` 创建子进程 +- 资源限制:CPU 时间、内存、输出大小 +- 超时控制:默认 30 秒,可配置 +- 工作目录限制:可选 chroot 或指定工作目录 +- 环境变量隔离:只允许白名单环境变量 +- **不**使用 Docker,保持轻量 + +**安全策略:** +```yaml +sandbox: + timeout: 30s + max_memory: 512MB + max_output: 64KB + allowed_env: [PATH, HOME, USER] + working_dir: /tmp/orca-sandbox + read_only_dirs: [] + blocked_commands: [rm -rf /, mkfs, dd] +``` + +## Data Flow + +### 典型交互流程 + +``` +用户输入 + │ + ▼ +┌─────────────┐ +│ CLI/API │ +└──────┬──────┘ + │ + ▼ +┌─────────────┐ ┌─────────────┐ +│ Session Mgr │────▶│ 加载历史上下文 │ +└──────┬──────┘ └─────────────┘ + │ + ▼ +┌─────────────┐ +│ Orchestrator │ (Agent Actor) +│ Agent │ +└──────┬──────┘ + │ + ▼ +┌─────────────┐ ┌─────────────┐ +│ Skill Mgr │────▶│ 匹配相关 Skill │ +└──────┬──────┘ └─────────────┘ + │ + ▼ +┌─────────────┐ ┌─────────────┐ +│ Ollama Driver│────▶│ 发送 prompt │ +└──────┬──────┘ └─────────────┘ + │ + ▼ +┌─────────────┐ +│ LLM Response │ +│ (Function │ +│ Calling) │ +└──────┬──────┘ + │ + ▼ +┌─────────────┐ ┌─────────────┐ +│ Tool Call │────▶│ 执行 Tool/ │ +│ │ │ 沙箱命令 │ +└──────┬──────┘ └─────────────┘ + │ + ▼ +┌─────────────┐ +│ Observation │ (工具执行结果) +└──────┬──────┘ + │ + ▼ +┌─────────────┐ ┌─────────────┐ +│ Orchestrator │────▶│ 决策:继续/完成 │ +└──────┬──────┘ └─────────────┘ + │ + ▼ +┌─────────────┐ +│ 保存会话 │ +│ 返回结果 │ +└─────────────┘ +``` + +### 消息类型定义 + +```go +type Message struct { + ID string + Type MessageType // TaskRequest, TaskResponse, ToolCall, Observation, Error + From string // Agent ID + To string // Agent ID or "broadcast" + Content interface{} // 根据 Type 不同而变化 + Timestamp time.Time +} + +type TaskRequest struct { + Query string + SessionID string + Context []ChatMessage +} + +type ToolCall struct { + ToolName string + Arguments map[string]interface{} +} + +type Observation struct { + ToolCallID string + Output string + Error string +} +``` + +## Error Handling + +### 策略 + +1. **分层错误处理:** + - **Kernel 层:** 插件加载失败 → 记录日志,跳过该插件,继续启动 + - **Agent 层:** 任务执行失败 → 返回错误消息,让 Orchestrator 决策重试或终止 + - **Tool 层:** 工具执行失败 → 返回结构化错误,LLM 可据此调整策略 + - **Sandbox 层:** 命令超时/内存超限 → 强制终止进程,返回错误 + +2. **重试机制:** + - LLM API 调用:指数退避重试 3 次 + - 工具执行:不重试(避免循环),由 LLM 决策 + +3. **优雅降级:** + - Ollama 不可用 → 提示用户检查服务 + - Skill 解析失败 → 跳过该 Skill,不影响其他 + - 沙箱执行失败 → 返回错误信息,LLM 可尝试其他工具 + +### 错误类型 + +```go +type ErrorCategory int + +const ( + ErrCategoryKernel ErrorCategory = iota // 内核错误 + ErrCategoryAgent // Agent 错误 + ErrCategoryTool // 工具错误 + ErrCategorySandbox // 沙箱错误 + ErrCategoryLLM // LLM 错误 + ErrCategoryNetwork // 网络错误 +) +``` + +## Testing Strategy + +### 测试金字塔 + +1. **单元测试(60%):** + - `Kernel`:插件注册/卸载、消息路由 + - `SessionManager`:JSONL 读写、上下文窗口截断 + - `SkillManager`:Skill 解析、触发词匹配 + - `Sandbox`:资源限制、超时控制 + - `OllamaDriver`:HTTP 请求封装(使用 mock server) + +2. **集成测试(30%):** + - Agent + Tool:端到端任务执行 + - Agent + LLM:使用 mock LLM 测试 Function Calling 流程 + - Skill + Sandbox:加载 Skill 并执行其脚本 + +3. **E2E 测试(10%):** + - 完整 CLI 工作流 + - 多 Agent 协作场景 + +### Mock 策略 + +- `LLMClient`:使用接口,测试时注入 mock +- `Sandbox`:提供 `DryRun` 模式,记录命令但不执行 +- `MessageBus`:内存实现,用于测试 + +## Open Questions + +1. **Skill 执行方式:** Skill 脚本是用 Shell 调用还是直接在 Go 中执行?当前设计倾向 Shell 调用(通过 Sandbox),但 Python/Node 脚本需要对应运行时。 + - **假设:** 用户环境已安装所需运行时(Python、Node 等),Sandbox 只负责安全执行。 + +2. **Function Calling 格式:** gemma4:e4b 对 Function Calling 的支持程度? + - **假设:** 使用 Ollama 的 `tools` 参数格式,如果不支持则 fallback 到 prompt-based tool calling。 + +3. **多 Agent 协作粒度:** Agent 之间是平等协作还是有层级? + - **假设:** 支持两种模式:层级(Orchestrator + Workers)和平等(对等协作),由用户配置。 + +4. **会话共享:** 多个 Agent 是否可以共享同一个会话上下文? + - **假设:** 是,Session Manager 通过文件锁支持并发读取,但同一时间只有一个 Agent 写入。 + +5. **Tool 参数 Schema:** 使用 JSON Schema 还是简化格式? + - **假设:** 使用简化版 JSON Schema(支持 string/number/boolean/array/object + description)。 diff --git a/thoughts/shared/plans/2026-05-07-orca-agent-framework.md b/thoughts/shared/plans/2026-05-07-orca-agent-framework.md new file mode 100644 index 0000000..b42dc0d --- /dev/null +++ b/thoughts/shared/plans/2026-05-07-orca-agent-framework.md @@ -0,0 +1,373 @@ +--- +date: 2026-05-07 +topic: "Go Agent Framework - Orca" +status: draft +--- + +# Orca Agent Framework - 实现计划 + +## 项目概览 + +- **项目名称:** orca +- **语言:** Go 1.22+ +- **路径:** /Users/wang/agent_dev/orca.ai/ +- **架构:** 微内核 + Actor 模型 + +## 实现阶段 + +### Phase 1: 项目骨架与核心基础设施(Day 1) + +**目标:** 建立项目结构,实现消息总线和插件注册机制。 + +**任务清单:** + +1. **初始化 Go 模块** + - `go mod init github.com/orca/orca` + - 创建基础目录结构 + +2. **目录结构** + ``` + orca/ + ├── cmd/orca/ # CLI 入口 + ├── pkg/ + │ ├── kernel/ # 微内核核心 + │ ├── actor/ # Actor 系统 + │ ├── bus/ # 消息总线 + │ ├── plugin/ # 插件接口和注册 + │ ├── session/ # 会话管理 (JSONL) + │ ├── skill/ # Skill 管理 + │ ├── tool/ # Tool 系统 + │ ├── llm/ # LLM 接口 + │ ├── ollama/ # Ollama 驱动 + │ └── sandbox/ # 沙箱执行 + ├── internal/ + │ ├── config/ # 配置管理 + │ └── util/ # 工具函数 + ├── plugins/ + │ ├── builtin/ # 内置插件 + │ └── tools/ # 内置工具 + ├── test/ + │ └── fixtures/ # 测试固件 + └── go.mod + ``` + +3. **核心接口定义** + - `pkg/plugin/plugin.go`: Plugin 接口 + - `pkg/bus/bus.go`: MessageBus 接口和实现 + - `pkg/kernel/kernel.go`: Kernel 结构体,插件生命周期管理 + +4. **消息总线实现** + - 基于 Go channel 的发布/订阅 + - 支持同步和异步消息 + - 消息类型枚举定义 + +**验收标准:** +- `go build ./...` 成功 +- 消息总线单元测试通过(Publish/Subscribe) +- 插件注册/卸载测试通过 + +--- + +### Phase 2: Actor 系统与会话管理(Day 2) + +**目标:** 实现多 Agent Actor 和 JSONL 会话存储。 + +**任务清单:** + +1. **Actor 系统** + - `pkg/actor/actor.go`: Agent Actor 接口 + - `pkg/actor/orchestrator.go`: 协调者 Agent + - `pkg/actor/worker.go`: 工作者 Agent + - `pkg/actor/system.go`: Actor 生命周期管理(创建、停止、监控) + - 状态机实现:Idle → Processing → WaitingForTool → Completed/Failed + +2. **会话管理(JSONL)** + - `pkg/session/store.go`: 存储接口 + - `pkg/session/jsonl.go`: JSONL 实现 + - `pkg/session/manager.go`: 会话管理器(创建、加载、归档) + - 上下文窗口截断逻辑 + - 文件锁保证并发安全(`flock` 或简单文件锁) + +3. **配置系统** + - `internal/config/config.go`: 配置结构体 + - 支持 YAML 配置文件(`~/.orca/config.yaml`) + - 环境变量覆盖 + - 默认值设置 + +**验收标准:** +- 创建 10 个 Agent Actor 并发运行测试通过 +- 会话 CRUD 测试通过 +- 上下文窗口截断测试通过 + +--- + +### Phase 3: Skill 与 Tool 系统(Day 3) + +**目标:** 实现 Skill 自动发现和 Tool 注册执行。 + +**任务清单:** + +1. **Skill 管理器** + - `pkg/skill/manager.go`: Skill 扫描和加载 + - `pkg/skill/parser.go`: SKILL.md 解析器(提取 name, description, triggers) + - `pkg/skill/skill.go`: Skill 结构体定义 + - 扫描目录:`~/.agents/skills/` 和 `~/.config/opencode/skills/` + - 触发词匹配算法(简单关键词匹配或 TF-IDF) + +2. **Tool 系统** + - `pkg/tool/tool.go`: Tool 接口定义 + - `pkg/tool/manager.go`: Tool 注册中心 + - `pkg/tool/registry.go`: 内置工具注册 + - **内置工具实现:** + - `exec`: 执行 shell 命令(通过 sandbox) + - `read_file`: 读取文件内容 + - `write_file`: 写入文件 + - `list_dir`: 列出目录 + - `search_files`: 文件内容搜索 + +3. **自定义 Tool 注册** + - 支持通过代码注册 Tool + - Tool 参数 Schema 定义(简化版 JSON Schema) + - Tool 执行上下文传递 + +**验收标准:** +- 扫描现有 `~/.agents/skills/` 目录,正确解析所有 Skill +- 触发词匹配测试通过 +- 所有内置工具单元测试通过 + +--- + +### Phase 4: 沙箱与 Ollama 集成(Day 4) + +**目标:** 实现安全执行环境和 LLM 驱动。 + +**任务清单:** + +1. **沙箱执行器** + - `pkg/sandbox/sandbox.go`: Sandbox 接口 + - `pkg/sandbox/process.go`: 进程级实现 + - 资源限制: + - 超时控制(context.WithTimeout) + - 内存限制(通过 cgroup 或 ulimit,若不可用则软限制) + - 输出大小限制 + - 工作目录隔离 + - 环境变量白名单 + - 危险命令黑名单 + +2. **Ollama 驱动** + - `pkg/ollama/client.go`: HTTP 客户端 + - `pkg/ollama/chat.go`: Chat API 封装 + - `pkg/ollama/stream.go`: 流式响应处理(SSE) + - `pkg/ollama/tools.go`: Function Calling 支持 + - 模型配置:temperature、top_p、max_tokens + - 自动重试机制(指数退避,3 次) + +3. **LLM 抽象层** + - `pkg/llm/client.go`: LLMClient 接口 + - `pkg/llm/message.go`: Message 结构体定义 + - `pkg/llm/options.go`: 调用选项 + +**验收标准:** +- 沙箱执行命令并正确限制资源测试通过 +- 超时和内存限制测试通过 +- Ollama API 调用测试通过(需要本地 Ollama 服务) +- Function Calling 格式正确 + +--- + +### Phase 5: CLI 与集成(Day 5) + +**目标:** 实现命令行界面和端到端集成。 + +**任务清单:** + +1. **CLI 实现** + - `cmd/orca/main.go`: 入口点 + - `cmd/orca/commands.go`: 子命令定义 + - 支持命令: + - `orca chat`: 交互式对话 + - `orca run "query"`: 单次执行 + - `orca sessions`: 会话列表 + - `orca skills`: 已加载 Skill 列表 + - `orca tools`: 已注册 Tool 列表 + - `orca config`: 配置查看/设置 + +2. **交互式对话** + - 读取用户输入 + - 创建/恢复会话 + - 调用 Orchestrator Agent + - 显示 Agent 思考过程和结果 + - 支持多轮对话 + +3. **端到端集成测试** + - 完整对话流程测试 + - Skill 触发和调用测试 + - Tool 调用链测试 + - 错误恢复测试 + +**验收标准:** +- `orca --help` 显示正确 +- `orca chat` 可以开始对话 +- 完整对话流程测试通过 + +--- + +### Phase 6: 多 Agent 协作与优化(Day 6-7) + +**目标:** 实现多 Agent 协作和性能优化。 + +**任务清单:** + +1. **多 Agent 协作** + - Orchestrator 任务分解逻辑 + - Worker Agent 分配策略 + - Agent 间消息传递 + - 结果汇总和冲突解决 + +2. **性能优化** + - 会话缓存(最近 N 个会话驻留内存) + - Skill 索引(倒排索引加速匹配) + - 连接池(Ollama HTTP 连接复用) + +3. **可观测性** + - 结构化日志(slog) + - Agent 执行追踪 + - 性能指标收集 + +**验收标准:** +- 多 Agent 协作测试通过 +- 性能基准测试通过(单次对话 < 5s) + +--- + +## 依赖管理 + +### 外部依赖(最小化) + +| 依赖 | 用途 | 版本 | +|------|------|------| +| `github.com/spf13/cobra` | CLI 框架 | latest | +| `github.com/spf13/viper` | 配置管理 | latest | +| `github.com/stretchr/testify` | 测试 | latest | + +**原则:** 优先使用标准库,必要时才引入外部依赖。 + +### 内部依赖图 + +``` +kernel + ├── bus + ├── plugin + ├── actor + │ ├── bus + │ ├── tool + │ └── llm + ├── session + ├── skill + ├── tool + │ └── sandbox + ├── llm + │ └── ollama + └── sandbox +``` + +## 关键接口定义 + +### Plugin 接口 + +```go +type Plugin interface { + Name() string + Version() string + Init(kernel *Kernel) error + Shutdown() error +} +``` + +### Agent 接口 + +```go +type Agent interface { + ID() string + Role() string + Process(ctx context.Context, msg Message) (Message, error) + Stop() error +} +``` + +### Tool 接口 + +```go +type Tool interface { + Name() string + Description() string + Parameters() ParameterSchema + Execute(ctx context.Context, args map[string]interface{}) (ToolResult, error) +} +``` + +### LLMClient 接口 + +```go +type LLMClient interface { + Chat(ctx context.Context, messages []Message, tools []Tool) (*ChatResponse, error) + ChatStream(ctx context.Context, messages []Message, tools []Tool) (StreamReader, error) +} +``` + +## 测试策略 + +### 单元测试覆盖目标 + +| 模块 | 覆盖率目标 | +|------|-----------| +| bus | 90% | +| kernel | 85% | +| session | 90% | +| skill | 85% | +| tool | 90% | +| sandbox | 80% | +| ollama | 75% (mock) | + +### Mock 实现 + +```go +// MockLLMClient 用于测试 +type MockLLMClient struct { + Responses []ChatResponse + Index int +} + +func (m *MockLLMClient) Chat(ctx context.Context, messages []Message, tools []Tool) (*ChatResponse, error) { + if m.Index >= len(m.Responses) { + return nil, errors.New("no more mock responses") + } + resp := m.Responses[m.Index] + m.Index++ + return &resp, nil +} +``` + +## 风险与回退方案 + +| 风险 | 影响 | 概率 | 回退方案 | +|------|------|------|----------| +| gemma4:e4b 不支持 Function Calling | 高 | 中 | 使用 prompt-based tool calling | +| 进程级沙箱限制不足 | 中 | 低 | 添加 Docker 支持作为可选 | +| JSONL 性能瓶颈 | 中 | 低 | 迁移到 SQLite(保留接口) | +| Actor 模型复杂度 | 中 | 中 | 简化中央协调器模式 | + +## 里程碑 + +| 里程碑 | 时间 | 交付物 | +|--------|------|--------| +| M1 | Day 1 | 项目骨架 + 消息总线 + 插件系统 | +| M2 | Day 2 | Actor 系统 + 会话管理 | +| M3 | Day 3 | Skill + Tool 系统 | +| M4 | Day 4 | 沙箱 + Ollama 集成 | +| M5 | Day 5 | CLI + 端到端集成 | +| M6 | Day 6-7 | 多 Agent + 优化 | + +## 下一步 + +执行 Phase 1,建立项目骨架。