Initial commit: Orca Agent Framework
Core features: - Microkernel architecture with Actor model - Session management with JSONL persistence - Tool system (5 built-in tools) - Skill system with SKILL.md parsing - Sandbox security execution - Ollama integration with gemma4:e4b - Prompt-based tool calling (compatible with native function calling) - REPL interface 11 packages, all tests passing
This commit is contained in:
commit
6b94476347
205
cmd/orca/main.go
Normal file
205
cmd/orca/main.go
Normal file
@ -0,0 +1,205 @@
|
||||
// Orca is a Go-based Agent framework with a microkernel architecture.
|
||||
//
|
||||
// It supports multi-agent collaboration, persistent session memory,
|
||||
// skill-based automation, sandboxed execution, custom tool registration,
|
||||
// and local LLM integration via Ollama.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/orca/orca/internal/config"
|
||||
"github.com/orca/orca/pkg/kernel"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration from environment variables
|
||||
cfg := config.LoadConfigFromEnv()
|
||||
|
||||
// Support shorter env var names for Ollama (without ORCA_ prefix)
|
||||
if v := os.Getenv("OLLAMA_BASE_URL"); v != "" {
|
||||
cfg.Ollama.BaseURL = v
|
||||
}
|
||||
if v := os.Getenv("OLLAMA_MODEL"); v != "" {
|
||||
cfg.Ollama.Model = v
|
||||
}
|
||||
if v := os.Getenv("OLLAMA_TIMEOUT"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
cfg.Ollama.Timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
// Create and start kernel
|
||||
k := kernel.NewWithConfig(cfg)
|
||||
|
||||
if err := k.Start(); err != nil {
|
||||
log.Fatalf("Failed to start kernel: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Orca Agent Framework")
|
||||
fmt.Println("Kernel started successfully")
|
||||
fmt.Printf(" LLM Model: %s\n", cfg.Ollama.Model)
|
||||
fmt.Printf(" Ollama URL: %s\n", cfg.Ollama.BaseURL)
|
||||
fmt.Println("Type your message or /help for commands.")
|
||||
fmt.Println()
|
||||
|
||||
// Handle graceful shutdown
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// REPL loop in a goroutine so we can catch signals
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for {
|
||||
fmt.Print("> ")
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
|
||||
input := strings.TrimSpace(scanner.Text())
|
||||
if input == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle commands
|
||||
if strings.HasPrefix(input, "/") {
|
||||
handleCommand(input, k)
|
||||
continue
|
||||
}
|
||||
|
||||
// Send message to LLM agent via kernel
|
||||
response, err := k.SendMessage("user", "llm", input)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Println(response)
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading input: %v\n", err)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Wait for either SIGINT or REPL exit
|
||||
select {
|
||||
case <-sig:
|
||||
fmt.Println("\nShutting down Orca kernel...")
|
||||
case <-done:
|
||||
fmt.Println("\nInput closed. Shutting down Orca kernel...")
|
||||
}
|
||||
|
||||
if err := k.Stop(); err != nil {
|
||||
log.Fatalf("Failed to stop kernel: %v", err)
|
||||
}
|
||||
fmt.Println("Orca kernel shut down gracefully.")
|
||||
}
|
||||
|
||||
// handleCommand processes REPL commands.
|
||||
func handleCommand(cmd string, k *kernel.Kernel) {
|
||||
switch cmd {
|
||||
case "/help":
|
||||
fmt.Println("Available commands:")
|
||||
fmt.Println(" /help - Show this help message")
|
||||
fmt.Println(" /exit - Exit the program")
|
||||
fmt.Println(" /quit - Exit the program")
|
||||
fmt.Println(" /plugins - List registered plugins")
|
||||
fmt.Println(" /agents - List active agents")
|
||||
fmt.Println(" /tools - List registered tools")
|
||||
fmt.Println(" /skills - List loaded skills")
|
||||
fmt.Println(" /status - Show kernel status")
|
||||
fmt.Println()
|
||||
fmt.Println("Any other input is sent to the LLM agent for processing.")
|
||||
|
||||
case "/exit", "/quit":
|
||||
fmt.Println("Goodbye!")
|
||||
os.Exit(0)
|
||||
|
||||
case "/plugins":
|
||||
plugins := k.ListPlugins()
|
||||
if len(plugins) == 0 {
|
||||
fmt.Println("No plugins registered.")
|
||||
} else {
|
||||
fmt.Println("Registered plugins:")
|
||||
for _, p := range plugins {
|
||||
fmt.Printf(" - %s (%s)\n", p.Name(), p.Version())
|
||||
}
|
||||
}
|
||||
|
||||
case "/agents":
|
||||
as := k.ActorSystem()
|
||||
if as == nil {
|
||||
fmt.Println("Actor system not initialized.")
|
||||
return
|
||||
}
|
||||
infos := as.AgentInfos()
|
||||
if len(infos) == 0 {
|
||||
fmt.Println("No agents running.")
|
||||
} else {
|
||||
fmt.Println("Active agents:")
|
||||
for _, info := range infos {
|
||||
fmt.Printf(" - %s [%s] (status: %s)\n", info.ID, info.Role, info.Status)
|
||||
}
|
||||
}
|
||||
|
||||
case "/tools":
|
||||
tm := k.ToolManager()
|
||||
if tm == nil {
|
||||
fmt.Println("Tool manager not initialized.")
|
||||
return
|
||||
}
|
||||
tools := tm.List()
|
||||
if len(tools) == 0 {
|
||||
fmt.Println("No tools registered.")
|
||||
} else {
|
||||
fmt.Println("Registered tools:")
|
||||
for _, t := range tools {
|
||||
fmt.Printf(" - %s: %s\n", t.Name(), t.Description())
|
||||
}
|
||||
}
|
||||
|
||||
case "/skills":
|
||||
sm := k.SkillManager()
|
||||
if sm == nil {
|
||||
fmt.Println("Skill manager not initialized.")
|
||||
return
|
||||
}
|
||||
skills := sm.ListSkills()
|
||||
if len(skills) == 0 {
|
||||
fmt.Println("No skills loaded.")
|
||||
} else {
|
||||
fmt.Println("Loaded skills:")
|
||||
for _, s := range skills {
|
||||
fmt.Printf(" - %s: %s\n", s.Name, s.Description)
|
||||
}
|
||||
}
|
||||
|
||||
case "/status":
|
||||
fmt.Printf("Kernel running: %v\n", k.IsRunning())
|
||||
if tm := k.ToolManager(); tm != nil {
|
||||
fmt.Printf("Tools registered: %d\n", tm.Count())
|
||||
}
|
||||
if as := k.ActorSystem(); as != nil {
|
||||
fmt.Printf("Agents active: %d\n", as.AgentCount())
|
||||
}
|
||||
if sm := k.SkillManager(); sm != nil {
|
||||
fmt.Printf("Skills loaded: %d\n", len(sm.ListSkills()))
|
||||
}
|
||||
|
||||
default:
|
||||
fmt.Printf("Unknown command: %s\n", cmd)
|
||||
fmt.Println("Type /help for available commands.")
|
||||
}
|
||||
}
|
||||
147
internal/config/config.go
Normal file
147
internal/config/config.go
Normal file
@ -0,0 +1,147 @@
|
||||
// Package config provides the configuration types for the Orca framework.
|
||||
//
|
||||
// Configuration is organized into logical groups: LLM (Ollama), sandbox,
|
||||
// and session management. Default values are provided for all settings.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is the top-level configuration for the Orca framework.
|
||||
type Config struct {
|
||||
Ollama OllamaConfig `json:"ollama"`
|
||||
Sandbox SandboxConfig `json:"sandbox"`
|
||||
Session SessionConfig `json:"session"`
|
||||
}
|
||||
|
||||
// OllamaConfig holds configuration for the Ollama LLM backend.
|
||||
type OllamaConfig struct {
|
||||
// BaseURL is the Ollama API endpoint (e.g., "http://localhost:11434").
|
||||
BaseURL string `json:"base_url"`
|
||||
// Model is the Ollama model name to use (e.g., "gemma4:e4b", "codellama").
|
||||
Model string `json:"model"`
|
||||
// Timeout is the maximum duration to wait for an Ollama response.
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
}
|
||||
|
||||
// SandboxConfig holds configuration for the command execution sandbox.
|
||||
type SandboxConfig struct {
|
||||
// Timeout is the maximum duration for a sandboxed command.
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
// MaxMemory is the maximum memory allocation for the sandbox (in bytes).
|
||||
MaxMemory int64 `json:"max_memory"`
|
||||
// WorkingDir is the default working directory for sandboxed commands.
|
||||
WorkingDir string `json:"working_dir"`
|
||||
}
|
||||
|
||||
// SessionConfig holds configuration for session management.
|
||||
type SessionConfig struct {
|
||||
// StorageDir is the directory for session JSONL files.
|
||||
StorageDir string `json:"storage_dir"`
|
||||
// MaxHistory is the maximum number of messages to retain per session.
|
||||
MaxHistory int `json:"max_history"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Ollama: OllamaConfig{
|
||||
BaseURL: "http://localhost:11434",
|
||||
Model: "gemma4:e4b",
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
Sandbox: SandboxConfig{
|
||||
Timeout: 30 * time.Second,
|
||||
MaxMemory: 512 * 1024 * 1024, // 512 MB
|
||||
WorkingDir: "/tmp/orca/sandbox",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
StorageDir: func() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return home + "/.orca/sessions"
|
||||
}(),
|
||||
MaxHistory: 100,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfigFromEnv reads configuration from environment variables,
|
||||
// overriding defaults where environment variables are set.
|
||||
func LoadConfigFromEnv() *Config {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if v := os.Getenv("ORCA_OLLAMA_BASE_URL"); v != "" {
|
||||
cfg.Ollama.BaseURL = v
|
||||
}
|
||||
if v := os.Getenv("ORCA_OLLAMA_MODEL"); v != "" {
|
||||
cfg.Ollama.Model = v
|
||||
}
|
||||
if v := os.Getenv("ORCA_OLLAMA_TIMEOUT"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
cfg.Ollama.Timeout = d
|
||||
}
|
||||
}
|
||||
if v := os.Getenv("ORCA_SANDBOX_TIMEOUT"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
cfg.Sandbox.Timeout = d
|
||||
}
|
||||
}
|
||||
if v := os.Getenv("ORCA_SANDBOX_MAX_MEMORY"); v != "" {
|
||||
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
cfg.Sandbox.MaxMemory = n
|
||||
}
|
||||
}
|
||||
if v := os.Getenv("ORCA_SANDBOX_WORKING_DIR"); v != "" {
|
||||
cfg.Sandbox.WorkingDir = v
|
||||
}
|
||||
if v := os.Getenv("ORCA_SESSION_STORAGE_DIR"); v != "" {
|
||||
cfg.Session.StorageDir = v
|
||||
}
|
||||
if v := os.Getenv("ORCA_SESSION_MAX_HISTORY"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
cfg.Session.MaxHistory = n
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// IsValid checks whether the configuration has valid values.
|
||||
func (c *Config) IsValid() error {
|
||||
if c.Ollama.BaseURL == "" {
|
||||
return errConfig("ollama.base_url must not be empty")
|
||||
}
|
||||
if c.Ollama.Model == "" {
|
||||
return errConfig("ollama.model must not be empty")
|
||||
}
|
||||
if c.Ollama.Timeout <= 0 {
|
||||
return errConfig("ollama.timeout must be positive")
|
||||
}
|
||||
if c.Sandbox.Timeout <= 0 {
|
||||
return errConfig("sandbox.timeout must be positive")
|
||||
}
|
||||
if c.Sandbox.MaxMemory <= 0 {
|
||||
return errConfig("sandbox.max_memory must be positive")
|
||||
}
|
||||
if c.Session.MaxHistory <= 0 {
|
||||
return errConfig("session.max_history must be positive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// errConfig creates a configuration error.
|
||||
func errConfig(msg string) error {
|
||||
return &ConfigError{Message: msg}
|
||||
}
|
||||
|
||||
// ConfigError represents a configuration validation error.
|
||||
type ConfigError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ConfigError) Error() string {
|
||||
return "config: " + e.Message
|
||||
}
|
||||
221
internal/config/config_test.go
Normal file
221
internal/config/config_test.go
Normal file
@ -0,0 +1,221 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg == nil {
|
||||
t.Fatal("DefaultConfig() returned nil")
|
||||
}
|
||||
|
||||
// Check Ollama defaults
|
||||
if cfg.Ollama.BaseURL != "http://localhost:11434" {
|
||||
t.Errorf("expected default Ollama BaseURL 'http://localhost:11434', got %q", cfg.Ollama.BaseURL)
|
||||
}
|
||||
if cfg.Ollama.Model != "gemma4:e4b" {
|
||||
t.Errorf("expected default Ollama Model 'gemma4:e4b', got %q", cfg.Ollama.Model)
|
||||
}
|
||||
if cfg.Ollama.Timeout != 120*time.Second {
|
||||
t.Errorf("expected default Ollama Timeout 120s, got %v", cfg.Ollama.Timeout)
|
||||
}
|
||||
|
||||
// Check Sandbox defaults
|
||||
if cfg.Sandbox.Timeout != 30*time.Second {
|
||||
t.Errorf("expected default Sandbox Timeout 30s, got %v", cfg.Sandbox.Timeout)
|
||||
}
|
||||
if cfg.Sandbox.MaxMemory != 512*1024*1024 {
|
||||
t.Errorf("expected default Sandbox MaxMemory 512MB, got %d", cfg.Sandbox.MaxMemory)
|
||||
}
|
||||
if cfg.Sandbox.WorkingDir != "/tmp/orca/sandbox" {
|
||||
t.Errorf("expected default Sandbox WorkingDir '/tmp/orca/sandbox', got %q", cfg.Sandbox.WorkingDir)
|
||||
}
|
||||
|
||||
// Check Session defaults
|
||||
if cfg.Session.MaxHistory != 100 {
|
||||
t.Errorf("expected default Session MaxHistory 100, got %d", cfg.Session.MaxHistory)
|
||||
}
|
||||
if cfg.Session.StorageDir == "" {
|
||||
t.Error("expected non-empty Session StorageDir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigStorageDir(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
home, _ := os.UserHomeDir()
|
||||
expected := home + "/.orca/sessions"
|
||||
if cfg.Session.StorageDir != expected {
|
||||
t.Errorf("expected StorageDir %q, got %q", expected, cfg.Session.StorageDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFromEnv(t *testing.T) {
|
||||
// Set environment variables
|
||||
os.Setenv("ORCA_OLLAMA_BASE_URL", "http://custom:11434")
|
||||
os.Setenv("ORCA_OLLAMA_MODEL", "codellama")
|
||||
os.Setenv("ORCA_OLLAMA_TIMEOUT", "60s")
|
||||
os.Setenv("ORCA_SANDBOX_TIMEOUT", "120s")
|
||||
os.Setenv("ORCA_SANDBOX_MAX_MEMORY", "1073741824")
|
||||
os.Setenv("ORCA_SANDBOX_WORKING_DIR", "/custom/sandbox")
|
||||
os.Setenv("ORCA_SESSION_STORAGE_DIR", "/custom/sessions")
|
||||
os.Setenv("ORCA_SESSION_MAX_HISTORY", "200")
|
||||
|
||||
defer func() {
|
||||
os.Unsetenv("ORCA_OLLAMA_BASE_URL")
|
||||
os.Unsetenv("ORCA_OLLAMA_MODEL")
|
||||
os.Unsetenv("ORCA_OLLAMA_TIMEOUT")
|
||||
os.Unsetenv("ORCA_SANDBOX_TIMEOUT")
|
||||
os.Unsetenv("ORCA_SANDBOX_MAX_MEMORY")
|
||||
os.Unsetenv("ORCA_SANDBOX_WORKING_DIR")
|
||||
os.Unsetenv("ORCA_SESSION_STORAGE_DIR")
|
||||
os.Unsetenv("ORCA_SESSION_MAX_HISTORY")
|
||||
}()
|
||||
|
||||
cfg := LoadConfigFromEnv()
|
||||
|
||||
if cfg.Ollama.BaseURL != "http://custom:11434" {
|
||||
t.Errorf("expected Ollama BaseURL 'http://custom:11434', got %q", cfg.Ollama.BaseURL)
|
||||
}
|
||||
if cfg.Ollama.Model != "codellama" {
|
||||
t.Errorf("expected Ollama Model 'codellama', got %q", cfg.Ollama.Model)
|
||||
}
|
||||
if cfg.Ollama.Timeout != 60*time.Second {
|
||||
t.Errorf("expected Ollama Timeout 60s, got %v", cfg.Ollama.Timeout)
|
||||
}
|
||||
if cfg.Sandbox.Timeout != 120*time.Second {
|
||||
t.Errorf("expected Sandbox Timeout 120s, got %v", cfg.Sandbox.Timeout)
|
||||
}
|
||||
if cfg.Sandbox.MaxMemory != 1073741824 {
|
||||
t.Errorf("expected Sandbox MaxMemory 1073741824, got %d", cfg.Sandbox.MaxMemory)
|
||||
}
|
||||
if cfg.Sandbox.WorkingDir != "/custom/sandbox" {
|
||||
t.Errorf("expected Sandbox WorkingDir '/custom/sandbox', got %q", cfg.Sandbox.WorkingDir)
|
||||
}
|
||||
if cfg.Session.StorageDir != "/custom/sessions" {
|
||||
t.Errorf("expected Session StorageDir '/custom/sessions', got %q", cfg.Session.StorageDir)
|
||||
}
|
||||
if cfg.Session.MaxHistory != 200 {
|
||||
t.Errorf("expected Session MaxHistory 200, got %d", cfg.Session.MaxHistory)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFromEnvPartial(t *testing.T) {
|
||||
os.Setenv("ORCA_OLLAMA_MODEL", "mistral")
|
||||
defer os.Unsetenv("ORCA_OLLAMA_MODEL")
|
||||
|
||||
cfg := LoadConfigFromEnv()
|
||||
|
||||
// Should use env override
|
||||
if cfg.Ollama.Model != "mistral" {
|
||||
t.Errorf("expected Model 'mistral', got %q", cfg.Ollama.Model)
|
||||
}
|
||||
// Should keep defaults for unset values
|
||||
if cfg.Ollama.BaseURL != "http://localhost:11434" {
|
||||
t.Errorf("expected default BaseURL, got %q", cfg.Ollama.BaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigIsValid(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if err := cfg.IsValid(); err != nil {
|
||||
t.Errorf("default config should be valid: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigInvalidBaseURL(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Ollama.BaseURL = ""
|
||||
if err := cfg.IsValid(); err == nil {
|
||||
t.Error("expected error for empty BaseURL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigInvalidModel(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Ollama.Model = ""
|
||||
if err := cfg.IsValid(); err == nil {
|
||||
t.Error("expected error for empty Model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigInvalidOllamaTimeout(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Ollama.Timeout = 0
|
||||
if err := cfg.IsValid(); err == nil {
|
||||
t.Error("expected error for zero Ollama Timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigInvalidSandboxTimeout(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Sandbox.Timeout = -1
|
||||
if err := cfg.IsValid(); err == nil {
|
||||
t.Error("expected error for negative Sandbox Timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigInvalidMaxMemory(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Sandbox.MaxMemory = 0
|
||||
if err := cfg.IsValid(); err == nil {
|
||||
t.Error("expected error for zero MaxMemory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigInvalidMaxHistory(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Session.MaxHistory = 0
|
||||
if err := cfg.IsValid(); err == nil {
|
||||
t.Error("expected error for zero MaxHistory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigError(t *testing.T) {
|
||||
err := errConfig("test error")
|
||||
if err.Error() != "config: test error" {
|
||||
t.Errorf("unexpected error message: %s", err.Error())
|
||||
}
|
||||
|
||||
ce, ok := err.(*ConfigError)
|
||||
if !ok {
|
||||
t.Fatal("expected ConfigError type")
|
||||
}
|
||||
if ce.Message != "test error" {
|
||||
t.Errorf("expected Message 'test error', got %q", ce.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFromEnvInvalidTimeout(t *testing.T) {
|
||||
os.Setenv("ORCA_OLLAMA_TIMEOUT", "not-a-duration")
|
||||
defer os.Unsetenv("ORCA_OLLAMA_TIMEOUT")
|
||||
|
||||
cfg := LoadConfigFromEnv()
|
||||
// Should keep default when env var is unparseable
|
||||
if cfg.Ollama.Timeout != 120*time.Second {
|
||||
t.Errorf("expected default 120s when env is invalid, got %v", cfg.Ollama.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFromEnvInvalidMaxMemory(t *testing.T) {
|
||||
os.Setenv("ORCA_SANDBOX_MAX_MEMORY", "not-a-number")
|
||||
defer os.Unsetenv("ORCA_SANDBOX_MAX_MEMORY")
|
||||
|
||||
cfg := LoadConfigFromEnv()
|
||||
// Should keep default when env var is unparseable
|
||||
if cfg.Sandbox.MaxMemory != 512*1024*1024 {
|
||||
t.Errorf("expected default MaxMemory when env is invalid, got %d", cfg.Sandbox.MaxMemory)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFromEnvInvalidMaxHistory(t *testing.T) {
|
||||
os.Setenv("ORCA_SESSION_MAX_HISTORY", "not-a-number")
|
||||
defer os.Unsetenv("ORCA_SESSION_MAX_HISTORY")
|
||||
|
||||
cfg := LoadConfigFromEnv()
|
||||
if cfg.Session.MaxHistory != 100 {
|
||||
t.Errorf("expected default MaxHistory when env is invalid, got %d", cfg.Session.MaxHistory)
|
||||
}
|
||||
}
|
||||
220
pkg/actor/actor.go
Normal file
220
pkg/actor/actor.go
Normal file
@ -0,0 +1,220 @@
|
||||
// Package actor implements the Actor model for the Orca framework.
|
||||
//
|
||||
// An Agent is an independent goroutine that communicates via channels.
|
||||
// Each agent has a state machine: Idle -> Processing -> [ToolCall] ->
|
||||
// WaitingForTool -> Processing -> Completed.
|
||||
package actor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
)
|
||||
|
||||
// ActorStatus represents the current state of an agent in its state machine.
|
||||
type ActorStatus int
|
||||
|
||||
const (
|
||||
// StatusIdle indicates the agent is ready to accept messages.
|
||||
StatusIdle ActorStatus = iota
|
||||
// StatusProcessing indicates the agent is actively handling a message.
|
||||
StatusProcessing
|
||||
// StatusWaitingForTool indicates the agent has called a tool and is awaiting its result.
|
||||
StatusWaitingForTool
|
||||
// StatusCompleted indicates the agent has finished processing the last message.
|
||||
StatusCompleted
|
||||
// StatusStopped indicates the agent has been shut down.
|
||||
StatusStopped
|
||||
)
|
||||
|
||||
// String returns the human-readable name of the actor status.
|
||||
func (s ActorStatus) String() string {
|
||||
switch s {
|
||||
case StatusIdle:
|
||||
return "idle"
|
||||
case StatusProcessing:
|
||||
return "processing"
|
||||
case StatusWaitingForTool:
|
||||
return "waiting_for_tool"
|
||||
case StatusCompleted:
|
||||
return "completed"
|
||||
case StatusStopped:
|
||||
return "stopped"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Agent is the interface that all actors in the Orca framework must implement.
|
||||
//
|
||||
// Each Agent runs as an independent goroutine processing messages
|
||||
// through an internal channel. The Process method provides a synchronous
|
||||
// API to submit messages and await responses.
|
||||
type Agent interface {
|
||||
// ID returns the unique identifier for this agent.
|
||||
ID() string
|
||||
// Role returns the role/type of this agent (e.g., "orchestrator", "worker").
|
||||
Role() string
|
||||
// Process sends a message to this agent and waits for a response.
|
||||
// This is a synchronous call; the agent's goroutine handles the message.
|
||||
Process(ctx context.Context, msg bus.Message) (bus.Message, error)
|
||||
// Stop gracefully shuts down this agent, waiting for in-flight processing to complete.
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// agentRequest wraps a message and provides a response channel.
|
||||
type agentRequest struct {
|
||||
ctx context.Context
|
||||
msg bus.Message
|
||||
resp chan agentResponse
|
||||
}
|
||||
|
||||
// agentResponse wraps the result of processing a message.
|
||||
type agentResponse struct {
|
||||
msg bus.Message
|
||||
err error
|
||||
}
|
||||
|
||||
// BaseAgent provides shared infrastructure for all agent implementations.
|
||||
//
|
||||
// It manages the message channel, goroutine lifecycle, and status tracking.
|
||||
// Concrete agents should embed BaseAgent and set a handler via SetHandler.
|
||||
type BaseAgent struct {
|
||||
id string
|
||||
role string
|
||||
msgCh chan agentRequest
|
||||
stopCh chan struct{}
|
||||
status atomic.Value
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
started bool
|
||||
handler func(context.Context, bus.Message) (bus.Message, error)
|
||||
}
|
||||
|
||||
// NewBaseAgent creates a new BaseAgent with the given id and role.
|
||||
// The agent is not started until Start() is called and a handler is set.
|
||||
func NewBaseAgent(id, role string) *BaseAgent {
|
||||
a := &BaseAgent{
|
||||
id: id,
|
||||
role: role,
|
||||
msgCh: make(chan agentRequest, 64),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
a.status.Store(StatusIdle)
|
||||
return a
|
||||
}
|
||||
|
||||
// ID returns the agent's unique identifier.
|
||||
func (a *BaseAgent) ID() string { return a.id }
|
||||
|
||||
// Role returns the agent's role.
|
||||
func (a *BaseAgent) Role() string { return a.role }
|
||||
|
||||
// Status returns the current ActorStatus of this agent.
|
||||
func (a *BaseAgent) Status() ActorStatus {
|
||||
s, _ := a.status.Load().(ActorStatus)
|
||||
return s
|
||||
}
|
||||
|
||||
// setStatus atomically updates the agent's status.
|
||||
func (a *BaseAgent) setStatus(s ActorStatus) {
|
||||
a.status.Store(s)
|
||||
}
|
||||
|
||||
// SetHandler sets the message handler function for this agent.
|
||||
// Must be called before Start().
|
||||
func (a *BaseAgent) SetHandler(handler func(context.Context, bus.Message) (bus.Message, error)) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.handler = handler
|
||||
}
|
||||
|
||||
// Start launches the agent's message processing goroutine.
|
||||
// The handler must be set before calling Start.
|
||||
func (a *BaseAgent) Start() error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if a.started {
|
||||
return fmt.Errorf("agent %s is already started", a.id)
|
||||
}
|
||||
if a.handler == nil {
|
||||
return fmt.Errorf("agent %s has no handler set", a.id)
|
||||
}
|
||||
|
||||
a.started = true
|
||||
a.status.Store(StatusIdle)
|
||||
a.wg.Add(1)
|
||||
go a.loop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loop is the main goroutine that reads messages from msgCh and processes them.
|
||||
func (a *BaseAgent) loop() {
|
||||
defer a.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-a.msgCh:
|
||||
a.setStatus(StatusProcessing)
|
||||
resp, err := a.handler(req.ctx, req.msg)
|
||||
if err != nil {
|
||||
a.setStatus(StatusIdle)
|
||||
} else {
|
||||
a.setStatus(StatusCompleted)
|
||||
}
|
||||
req.resp <- agentResponse{msg: resp, err: err}
|
||||
case <-a.stopCh:
|
||||
a.setStatus(StatusStopped)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process sends a message to the agent's processing loop and waits for a response.
|
||||
// It respects context cancellation and the agent's stop signal.
|
||||
func (a *BaseAgent) Process(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
respCh := make(chan agentResponse, 1)
|
||||
|
||||
select {
|
||||
case a.msgCh <- agentRequest{ctx: ctx, msg: msg, resp: respCh}:
|
||||
case <-ctx.Done():
|
||||
return bus.Message{}, ctx.Err()
|
||||
case <-a.stopCh:
|
||||
return bus.Message{}, fmt.Errorf("agent %s is stopped", a.id)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-respCh:
|
||||
return r.msg, r.err
|
||||
case <-ctx.Done():
|
||||
return bus.Message{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the agent.
|
||||
// It signals the processing loop to exit and waits for it to finish.
|
||||
func (a *BaseAgent) Stop() error {
|
||||
a.mu.Lock()
|
||||
started := a.started
|
||||
a.started = false
|
||||
a.mu.Unlock()
|
||||
|
||||
if !started {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(a.stopCh)
|
||||
a.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsStarted returns whether the agent's processing loop is running.
|
||||
func (a *BaseAgent) IsStarted() bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.started
|
||||
}
|
||||
697
pkg/actor/actor_test.go
Normal file
697
pkg/actor/actor_test.go
Normal file
@ -0,0 +1,697 @@
|
||||
package actor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// BaseAgent Tests
|
||||
// ============================================================
|
||||
|
||||
func TestNewBaseAgent(t *testing.T) {
|
||||
a := NewBaseAgent("test-1", "worker")
|
||||
if a == nil {
|
||||
t.Fatal("NewBaseAgent() returned nil")
|
||||
}
|
||||
if a.ID() != "test-1" {
|
||||
t.Errorf("expected id 'test-1', got %q", a.ID())
|
||||
}
|
||||
if a.Role() != "worker" {
|
||||
t.Errorf("expected role 'worker', got %q", a.Role())
|
||||
}
|
||||
if s := a.Status(); s != StatusIdle {
|
||||
t.Errorf("expected initial StatusIdle, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentStartAndStop(t *testing.T) {
|
||||
a := NewBaseAgent("test-2", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{ID: "response"}, nil
|
||||
})
|
||||
|
||||
if err := a.Start(); err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
if !a.IsStarted() {
|
||||
t.Error("expected agent to be started")
|
||||
}
|
||||
|
||||
if err := a.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
if a.IsStarted() {
|
||||
t.Error("expected agent to be stopped after Stop()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentDoubleStart(t *testing.T) {
|
||||
a := NewBaseAgent("test-3", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{ID: "response"}, nil
|
||||
})
|
||||
|
||||
if err := a.Start(); err != nil {
|
||||
t.Fatalf("first Start failed: %v", err)
|
||||
}
|
||||
|
||||
err := a.Start()
|
||||
if err == nil {
|
||||
t.Error("expected error on double start")
|
||||
}
|
||||
a.Stop()
|
||||
}
|
||||
|
||||
func TestBaseAgentStartWithoutHandler(t *testing.T) {
|
||||
a := NewBaseAgent("test-4", "worker")
|
||||
err := a.Start()
|
||||
if err == nil {
|
||||
t.Error("expected error starting agent without handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentProcessAndResponse(t *testing.T) {
|
||||
a := NewBaseAgent("test-5", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-resp",
|
||||
Type: bus.MsgTypeTaskResponse,
|
||||
From: a.ID(),
|
||||
To: msg.From,
|
||||
Content: "processed: " + msg.ID,
|
||||
}, nil
|
||||
})
|
||||
a.Start()
|
||||
defer a.Stop()
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := a.Process(ctx, bus.Message{
|
||||
ID: "task-1",
|
||||
Type: bus.MsgTypeTaskRequest,
|
||||
From: "caller",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
if resp.ID != "task-1-resp" {
|
||||
t.Errorf("expected response ID 'task-1-resp', got %q", resp.ID)
|
||||
}
|
||||
if resp.Content != "processed: task-1" {
|
||||
t.Errorf("expected content 'processed: task-1', got %v", resp.Content)
|
||||
}
|
||||
if resp.From != "test-5" {
|
||||
t.Errorf("expected From 'test-5', got %q", resp.From)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentProcessReturnsError(t *testing.T) {
|
||||
expectedErr := errors.New("processing failed")
|
||||
a := NewBaseAgent("test-6", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{}, expectedErr
|
||||
})
|
||||
a.Start()
|
||||
defer a.Stop()
|
||||
|
||||
_, err := a.Process(context.Background(), bus.Message{ID: "task-1"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from Process")
|
||||
}
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Errorf("expected error %v, got %v", expectedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentProcessOnStoppedAgent(t *testing.T) {
|
||||
a := NewBaseAgent("test-7", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{ID: "response"}, nil
|
||||
})
|
||||
a.Start()
|
||||
a.Stop()
|
||||
|
||||
_, err := a.Process(context.Background(), bus.Message{ID: "task-1"})
|
||||
if err == nil {
|
||||
t.Error("expected error processing on stopped agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentContextCancellation(t *testing.T) {
|
||||
a := NewBaseAgent("test-8", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
// Simulate long processing
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
return bus.Message{ID: "response"}, nil
|
||||
case <-ctx.Done():
|
||||
return bus.Message{}, ctx.Err()
|
||||
}
|
||||
})
|
||||
a.Start()
|
||||
defer a.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := a.Process(ctx, bus.Message{ID: "task-1"})
|
||||
if err == nil {
|
||||
t.Error("expected error from context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentStatusTransitions(t *testing.T) {
|
||||
a := NewBaseAgent("test-9", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
if a.Status() != StatusProcessing {
|
||||
t.Errorf("expected StatusProcessing inside handler, got %s", a.Status())
|
||||
}
|
||||
return bus.Message{ID: "response"}, nil
|
||||
})
|
||||
a.Start()
|
||||
defer a.Stop()
|
||||
|
||||
// Should be idle before processing
|
||||
if s := a.Status(); s != StatusIdle {
|
||||
t.Errorf("expected StatusIdle before Process, got %s", s)
|
||||
}
|
||||
|
||||
_, err := a.Process(context.Background(), bus.Message{ID: "task-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be completed after processing
|
||||
// Give the goroutine a moment to update the status
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
if s := a.Status(); s != StatusCompleted {
|
||||
t.Errorf("expected StatusCompleted after Process, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentConcurrentProcess(t *testing.T) {
|
||||
a := NewBaseAgent("test-10", "worker")
|
||||
var counter int32
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
return bus.Message{ID: "response"}, nil
|
||||
})
|
||||
a.Start()
|
||||
defer a.Stop()
|
||||
|
||||
var completed int32
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(i int) {
|
||||
_, err := a.Process(context.Background(), bus.Message{ID: "task"})
|
||||
if err == nil {
|
||||
atomic.AddInt32(&completed, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if n := atomic.LoadInt32(&completed); n != 10 {
|
||||
t.Errorf("expected 10 completed tasks, got %d", n)
|
||||
}
|
||||
if n := atomic.LoadInt32(&counter); n != 10 {
|
||||
t.Errorf("expected 10 handler calls, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAgentStopIdempotent(t *testing.T) {
|
||||
a := NewBaseAgent("test-11", "worker")
|
||||
a.SetHandler(func(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{ID: "response"}, nil
|
||||
})
|
||||
a.Start()
|
||||
|
||||
if err := a.Stop(); err != nil {
|
||||
t.Fatalf("first Stop failed: %v", err)
|
||||
}
|
||||
if err := a.Stop(); err != nil {
|
||||
t.Fatalf("second Stop should be idempotent: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActorStatusString(t *testing.T) {
|
||||
tests := []struct {
|
||||
status ActorStatus
|
||||
want string
|
||||
}{
|
||||
{StatusIdle, "idle"},
|
||||
{StatusProcessing, "processing"},
|
||||
{StatusWaitingForTool, "waiting_for_tool"},
|
||||
{StatusCompleted, "completed"},
|
||||
{StatusStopped, "stopped"},
|
||||
{ActorStatus(99), "unknown"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := tt.status.String(); got != tt.want {
|
||||
t.Errorf("ActorStatus(%d).String() = %q, want %q", tt.status, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Worker Tests
|
||||
// ============================================================
|
||||
|
||||
func TestNewWorker(t *testing.T) {
|
||||
w := NewWorker("worker-1")
|
||||
if w == nil {
|
||||
t.Fatal("NewWorker() returned nil")
|
||||
}
|
||||
if w.ID() != "worker-1" {
|
||||
t.Errorf("expected id 'worker-1', got %q", w.ID())
|
||||
}
|
||||
if w.Role() != "worker" {
|
||||
t.Errorf("expected role 'worker', got %q", w.Role())
|
||||
}
|
||||
if !w.IsStarted() {
|
||||
t.Error("expected worker to be started automatically")
|
||||
}
|
||||
w.Stop()
|
||||
}
|
||||
|
||||
func TestWorkerProcessTask(t *testing.T) {
|
||||
w := NewWorker("worker-2")
|
||||
defer w.Stop()
|
||||
|
||||
resp, err := w.Process(context.Background(), bus.Message{
|
||||
ID: "task-1",
|
||||
Type: bus.MsgTypeTaskRequest,
|
||||
From: "caller",
|
||||
Content: "do something",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
if resp.Type != bus.MsgTypeTaskResponse {
|
||||
t.Errorf("expected MsgTypeTaskResponse, got %s", resp.Type)
|
||||
}
|
||||
if resp.From != "worker-2" {
|
||||
t.Errorf("expected From 'worker-2', got %q", resp.From)
|
||||
}
|
||||
if resp.Metadata["processed_by"] != "worker-2" {
|
||||
t.Errorf("expected processed_by 'worker-2', got %q", resp.Metadata["processed_by"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerProcessToolCall(t *testing.T) {
|
||||
w := NewWorker("worker-3")
|
||||
defer w.Stop()
|
||||
|
||||
resp, err := w.Process(context.Background(), bus.Message{
|
||||
ID: "tool-1",
|
||||
Type: bus.MsgTypeToolCall,
|
||||
From: "caller",
|
||||
Content: "execute command",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
if resp.Type != bus.MsgTypeToolResult {
|
||||
t.Errorf("expected MsgTypeToolResult, got %s", resp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerStatusDuringToolCall(t *testing.T) {
|
||||
w := NewWorker("worker-4")
|
||||
|
||||
// Perform a tool call
|
||||
_, err := w.Process(context.Background(), bus.Message{
|
||||
ID: "tool-1",
|
||||
Type: bus.MsgTypeToolCall,
|
||||
From: "caller",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
|
||||
// After tool call, status should be Processing (set back by defer)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
// Status could be Processing or Completed depending on timing
|
||||
s := w.Status()
|
||||
if s != StatusProcessing && s != StatusIdle && s != StatusCompleted {
|
||||
t.Errorf("expected Processing/Idle/Completed after tool call, got %s", s)
|
||||
}
|
||||
w.Stop()
|
||||
}
|
||||
|
||||
func TestWorkerUnsupportedMessage(t *testing.T) {
|
||||
w := NewWorker("worker-5")
|
||||
defer w.Stop()
|
||||
|
||||
_, err := w.Process(context.Background(), bus.Message{
|
||||
ID: "unknown-1",
|
||||
Type: bus.MsgTypeObservation,
|
||||
From: "caller",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for unsupported message type")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Orchestrator Tests
|
||||
// ============================================================
|
||||
|
||||
func TestNewOrchestrator(t *testing.T) {
|
||||
o := NewOrchestrator("orch-1", nil)
|
||||
if o == nil {
|
||||
t.Fatal("NewOrchestrator() returned nil")
|
||||
}
|
||||
if o.ID() != "orch-1" {
|
||||
t.Errorf("expected id 'orch-1', got %q", o.ID())
|
||||
}
|
||||
if o.Role() != "orchestrator" {
|
||||
t.Errorf("expected role 'orchestrator', got %q", o.Role())
|
||||
}
|
||||
if !o.IsStarted() {
|
||||
t.Error("expected orchestrator to be started automatically")
|
||||
}
|
||||
o.Stop()
|
||||
}
|
||||
|
||||
func TestOrchestratorAddWorker(t *testing.T) {
|
||||
o := NewOrchestrator("orch-2", nil)
|
||||
defer o.Stop()
|
||||
|
||||
w := NewWorker("worker-10")
|
||||
defer w.Stop()
|
||||
|
||||
o.AddWorker(w)
|
||||
if n := o.WorkerCount(); n != 1 {
|
||||
t.Errorf("expected 1 worker, got %d", n)
|
||||
}
|
||||
|
||||
got, ok := o.GetWorker("worker-10")
|
||||
if !ok {
|
||||
t.Fatal("expected to find worker-10")
|
||||
}
|
||||
if got.ID() != "worker-10" {
|
||||
t.Errorf("expected worker ID 'worker-10', got %q", got.ID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchestratorRemoveWorker(t *testing.T) {
|
||||
o := NewOrchestrator("orch-3", nil)
|
||||
defer o.Stop()
|
||||
|
||||
w := NewWorker("worker-11")
|
||||
defer w.Stop()
|
||||
o.AddWorker(w)
|
||||
o.RemoveWorker("worker-11")
|
||||
|
||||
if n := o.WorkerCount(); n != 0 {
|
||||
t.Errorf("expected 0 workers after removal, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchestratorListWorkers(t *testing.T) {
|
||||
o := NewOrchestrator("orch-4", nil)
|
||||
defer o.Stop()
|
||||
|
||||
workers := []string{"w-1", "w-2", "w-3"}
|
||||
for _, name := range workers {
|
||||
w := NewWorker(name)
|
||||
defer w.Stop()
|
||||
o.AddWorker(w)
|
||||
}
|
||||
|
||||
list := o.ListWorkers()
|
||||
if len(list) != len(workers) {
|
||||
t.Errorf("expected %d workers, got %d", len(workers), len(list))
|
||||
}
|
||||
|
||||
ids := make(map[string]bool)
|
||||
for _, w := range list {
|
||||
ids[w.ID()] = true
|
||||
}
|
||||
for _, name := range workers {
|
||||
if !ids[name] {
|
||||
t.Errorf("missing worker %q in list", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchestratorDelegatesToWorker(t *testing.T) {
|
||||
o := NewOrchestrator("orch-5", nil)
|
||||
defer o.Stop()
|
||||
|
||||
w := NewWorker("worker-20")
|
||||
defer w.Stop()
|
||||
o.AddWorker(w)
|
||||
|
||||
resp, err := o.Process(context.Background(), bus.Message{
|
||||
ID: "task-1",
|
||||
Type: bus.MsgTypeTaskRequest,
|
||||
From: "caller",
|
||||
Content: "do work",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
if resp.From != "worker-20" {
|
||||
t.Errorf("expected response from 'worker-20', got %q", resp.From)
|
||||
}
|
||||
if resp.Metadata["processed_by"] != "worker-20" {
|
||||
t.Errorf("expected processed_by 'worker-20', got %q", resp.Metadata["processed_by"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchestratorNoWorkers(t *testing.T) {
|
||||
o := NewOrchestrator("orch-6", nil)
|
||||
defer o.Stop()
|
||||
|
||||
_, err := o.Process(context.Background(), bus.Message{
|
||||
ID: "task-1",
|
||||
Type: bus.MsgTypeTaskRequest,
|
||||
From: "caller",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when no workers available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchestratorSystemMessage(t *testing.T) {
|
||||
o := NewOrchestrator("orch-7", nil)
|
||||
defer o.Stop()
|
||||
|
||||
resp, err := o.Process(context.Background(), bus.Message{
|
||||
ID: "sys-1",
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: "caller",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
if resp.Content != "orchestrator acknowledged" {
|
||||
t.Errorf("expected acknowledged message, got %v", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// System Tests
|
||||
// ============================================================
|
||||
|
||||
func TestNewSystem(t *testing.T) {
|
||||
s := NewSystem()
|
||||
if s == nil {
|
||||
t.Fatal("NewSystem() returned nil")
|
||||
}
|
||||
if n := s.AgentCount(); n != 0 {
|
||||
t.Errorf("expected 0 agents, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemCreateWorker(t *testing.T) {
|
||||
s := NewSystem()
|
||||
w, err := s.CreateWorker()
|
||||
if err != nil {
|
||||
t.Fatalf("CreateWorker failed: %v", err)
|
||||
}
|
||||
if w == nil {
|
||||
t.Fatal("CreateWorker returned nil")
|
||||
}
|
||||
if w.Role() != "worker" {
|
||||
t.Errorf("expected role 'worker', got %q", w.Role())
|
||||
}
|
||||
if n := s.AgentCount(); n != 1 {
|
||||
t.Errorf("expected 1 agent, got %d", n)
|
||||
}
|
||||
s.StopAll()
|
||||
}
|
||||
|
||||
func TestSystemStopAgent(t *testing.T) {
|
||||
s := NewSystem()
|
||||
w, _ := s.CreateWorker()
|
||||
|
||||
if err := s.StopAgent(w.ID()); err != nil {
|
||||
t.Fatalf("StopAgent failed: %v", err)
|
||||
}
|
||||
if n := s.AgentCount(); n != 0 {
|
||||
t.Errorf("expected 0 agents, got %d", n)
|
||||
}
|
||||
|
||||
_, ok := s.GetAgent(w.ID())
|
||||
if ok {
|
||||
t.Error("expected agent to be removed after StopAgent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemStopAgentNotFound(t *testing.T) {
|
||||
s := NewSystem()
|
||||
err := s.StopAgent("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error stopping nonexistent agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemListAgents(t *testing.T) {
|
||||
s := NewSystem()
|
||||
s.CreateWorker()
|
||||
s.CreateWorker()
|
||||
|
||||
agents := s.ListAgents()
|
||||
if len(agents) != 2 {
|
||||
t.Errorf("expected 2 agents, got %d", len(agents))
|
||||
}
|
||||
s.StopAll()
|
||||
}
|
||||
|
||||
func TestSystemAgentInfos(t *testing.T) {
|
||||
s := NewSystem()
|
||||
w, _ := s.CreateWorker()
|
||||
|
||||
infos := s.AgentInfos()
|
||||
if len(infos) != 1 {
|
||||
t.Fatalf("expected 1 agent info, got %d", len(infos))
|
||||
}
|
||||
if infos[0].ID != w.ID() {
|
||||
t.Errorf("expected ID %q, got %q", w.ID(), infos[0].ID)
|
||||
}
|
||||
if infos[0].Role != "worker" {
|
||||
t.Errorf("expected Role 'worker', got %q", infos[0].Role)
|
||||
}
|
||||
if infos[0].Status != StatusIdle {
|
||||
t.Errorf("expected Status StatusIdle, got %s", infos[0].Status)
|
||||
}
|
||||
s.StopAll()
|
||||
}
|
||||
|
||||
func TestSystemStopAll(t *testing.T) {
|
||||
s := NewSystem()
|
||||
s.CreateWorker()
|
||||
s.CreateWorker()
|
||||
s.CreateWorker()
|
||||
|
||||
if err := s.StopAll(); err != nil {
|
||||
t.Fatalf("StopAll failed: %v", err)
|
||||
}
|
||||
if n := s.AgentCount(); n != 0 {
|
||||
t.Errorf("expected 0 agents after StopAll, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// ToolWorker Tests
|
||||
// ============================================================
|
||||
|
||||
func TestNewToolWorker(t *testing.T) {
|
||||
tw := NewToolWorker("tool-1", nil)
|
||||
if tw == nil {
|
||||
t.Fatal("NewToolWorker() returned nil")
|
||||
}
|
||||
if tw.ID() != "tool-1" {
|
||||
t.Errorf("expected id 'tool-1', got %q", tw.ID())
|
||||
}
|
||||
if tw.Role() != "tool_worker" {
|
||||
t.Errorf("expected role 'tool_worker', got %q", tw.Role())
|
||||
}
|
||||
if !tw.IsStarted() {
|
||||
t.Error("expected tool worker to be started automatically")
|
||||
}
|
||||
tw.Stop()
|
||||
}
|
||||
|
||||
func TestToolWorkerProcessSystemMessage(t *testing.T) {
|
||||
tw := NewToolWorker("tool-2", nil)
|
||||
defer tw.Stop()
|
||||
|
||||
resp, err := tw.Process(context.Background(), bus.Message{
|
||||
ID: "sys-1",
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: "caller",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
if resp.Content != "tool_worker acknowledged" {
|
||||
t.Errorf("expected 'tool_worker acknowledged', got %v", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolWorkerUnsupportedMessage(t *testing.T) {
|
||||
tw := NewToolWorker("tool-3", nil)
|
||||
defer tw.Stop()
|
||||
|
||||
_, err := tw.Process(context.Background(), bus.Message{
|
||||
ID: "obs-1",
|
||||
Type: bus.MsgTypeObservation,
|
||||
From: "caller",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for unsupported message type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallContentMap(t *testing.T) {
|
||||
name, args, err := parseToolCallContent(map[string]interface{}{
|
||||
"name": "exec",
|
||||
"arguments": map[string]interface{}{"command": "ls"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("parseToolCallContent failed: %v", err)
|
||||
}
|
||||
if name != "exec" {
|
||||
t.Errorf("expected name 'exec', got %q", name)
|
||||
}
|
||||
if args["command"] != "ls" {
|
||||
t.Errorf("expected args['command'] = 'ls', got %v", args["command"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallContentMissingName(t *testing.T) {
|
||||
_, _, err := parseToolCallContent(map[string]interface{}{"foo": "bar"})
|
||||
if err == nil {
|
||||
t.Error("expected error for missing name")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// System ToolWorker Tests
|
||||
// ============================================================
|
||||
|
||||
func TestSystemCreateToolWorker(t *testing.T) {
|
||||
s := NewSystem()
|
||||
tw, err := s.CreateToolWorker(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateToolWorker failed: %v", err)
|
||||
}
|
||||
if tw == nil {
|
||||
t.Fatal("CreateToolWorker returned nil")
|
||||
}
|
||||
if tw.Role() != "tool_worker" {
|
||||
t.Errorf("expected role 'tool_worker', got %q", tw.Role())
|
||||
}
|
||||
if n := s.AgentCount(); n != 1 {
|
||||
t.Errorf("expected 1 agent, got %d", n)
|
||||
}
|
||||
s.StopAll()
|
||||
}
|
||||
18
pkg/actor/agent.go
Normal file
18
pkg/actor/agent.go
Normal file
@ -0,0 +1,18 @@
|
||||
// Package actor implements the Actor model for the Orca framework.
|
||||
//
|
||||
// This file provides additional agent types that integrate the LLM
|
||||
// and Tool systems with the actor framework. See actor.go for the
|
||||
// base Agent interface and BaseAgent implementation.
|
||||
package actor
|
||||
|
||||
// This file exists alongside actor.go to provide the LLMAgent and
|
||||
// ToolWorker types, completing the integration between the actor
|
||||
// system and the LLM / Tool subsystems.
|
||||
//
|
||||
// The key types in this package are:
|
||||
// - Agent (interface, in actor.go)
|
||||
// - BaseAgent (struct, in actor.go)
|
||||
// - Orchestrator (in orchestrator.go)
|
||||
// - Worker (in worker.go)
|
||||
// - LLMAgent (in llm_agent.go)
|
||||
// - ToolWorker (in tool_worker.go)
|
||||
338
pkg/actor/llm_agent.go
Normal file
338
pkg/actor/llm_agent.go
Normal file
@ -0,0 +1,338 @@
|
||||
package actor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
"github.com/orca/orca/pkg/llm"
|
||||
"github.com/orca/orca/pkg/session"
|
||||
"github.com/orca/orca/pkg/tool"
|
||||
)
|
||||
|
||||
// LLMAgent implements the Agent interface by integrating an LLM backend
|
||||
// with the actor system and tool framework.
|
||||
//
|
||||
// It receives user messages, retrieves session context, calls the LLM,
|
||||
// and handles tool call responses by executing tools and feeding results
|
||||
// back to the LLM for final response generation.
|
||||
type LLMAgent struct {
|
||||
*BaseAgent
|
||||
llm llm.LLM
|
||||
sessionMgr *session.Manager
|
||||
sessionID string
|
||||
toolManager *tool.Manager
|
||||
toolWorker *ToolWorker
|
||||
windowSize int
|
||||
}
|
||||
|
||||
// LLMAgentOption is a functional option for configuring the LLMAgent.
|
||||
type LLMAgentOption func(*LLMAgent)
|
||||
|
||||
// WithSessionManager sets the session manager for conversation history.
|
||||
func WithSessionManager(mgr *session.Manager) LLMAgentOption {
|
||||
return func(a *LLMAgent) {
|
||||
a.sessionMgr = mgr
|
||||
}
|
||||
}
|
||||
|
||||
// WithSessionID sets the session ID for conversation persistence.
|
||||
func WithSessionID(id string) LLMAgentOption {
|
||||
return func(a *LLMAgent) {
|
||||
a.sessionID = id
|
||||
}
|
||||
}
|
||||
|
||||
// WithToolManager sets the tool manager for executing tools.
|
||||
func WithToolManager(mgr *tool.Manager) LLMAgentOption {
|
||||
return func(a *LLMAgent) {
|
||||
a.toolManager = mgr
|
||||
}
|
||||
}
|
||||
|
||||
// WithToolWorker sets the tool worker for delegated tool execution.
|
||||
func WithToolWorker(w *ToolWorker) LLMAgentOption {
|
||||
return func(a *LLMAgent) {
|
||||
a.toolWorker = w
|
||||
}
|
||||
}
|
||||
|
||||
// WithWindowSize sets the context window size for session history.
|
||||
func WithWindowSize(size int) LLMAgentOption {
|
||||
return func(a *LLMAgent) {
|
||||
a.windowSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// NewLLMAgent creates a new LLMAgent with the given LLM backend and options.
|
||||
// The agent is started automatically upon creation.
|
||||
func NewLLMAgent(id string, backend llm.LLM, opts ...LLMAgentOption) *LLMAgent {
|
||||
a := &LLMAgent{
|
||||
BaseAgent: NewBaseAgent(id, "llm_agent"),
|
||||
llm: backend,
|
||||
windowSize: 20, // Default context window
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
|
||||
a.SetHandler(a.handleMessage)
|
||||
if err := a.Start(); err != nil {
|
||||
panic(fmt.Sprintf("llm_agent: failed to start: %v", err))
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// handleMessage routes incoming messages to the appropriate handler.
|
||||
func (a *LLMAgent) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
switch msg.Type {
|
||||
case bus.MsgTypeTaskRequest:
|
||||
return a.handleUserMessage(ctx, msg)
|
||||
case bus.MsgTypeSystem:
|
||||
return a.handleSystem(ctx, msg)
|
||||
default:
|
||||
return bus.Message{}, fmt.Errorf("llm_agent %s: unsupported message type %s", a.ID(), msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUserMessage processes a user message through the LLM.
|
||||
//
|
||||
// Flow:
|
||||
// 1. Persist the user message to session history
|
||||
// 2. Retrieve recent conversation context
|
||||
// 3. Convert to LLM message format
|
||||
// 4. Call LLM.Chat
|
||||
// 5. If response has tool calls:
|
||||
// a. Execute each tool (directly or via ToolWorker)
|
||||
// b. Add tool results to conversation
|
||||
// c. Call LLM.Chat again with results
|
||||
// 6. Persist the assistant response
|
||||
// 7. Return the final response
|
||||
func (a *LLMAgent) handleUserMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
content, ok := msg.Content.(string)
|
||||
if !ok {
|
||||
return bus.Message{}, fmt.Errorf("llm_agent: expected string content, got %T", msg.Content)
|
||||
}
|
||||
|
||||
// Ensure session exists
|
||||
if a.sessionMgr != nil && a.sessionID != "" {
|
||||
// Check if session exists; create if not
|
||||
if _, err := a.sessionMgr.GetSession(a.sessionID); err != nil {
|
||||
a.sessionMgr.CreateSession(a.sessionID, map[string]string{
|
||||
"source": "llm_agent",
|
||||
})
|
||||
}
|
||||
|
||||
// Persist user message
|
||||
a.sessionMgr.AddMessage(a.sessionID, session.RoleUser, content, nil)
|
||||
}
|
||||
|
||||
// Build LLM messages from session context
|
||||
llmMessages := a.buildLLMMessages()
|
||||
|
||||
// Call LLM (potentially multiple rounds for tool calls)
|
||||
finalResponse, err := a.chatWithToolLoop(ctx, llmMessages)
|
||||
if err != nil {
|
||||
return bus.Message{}, fmt.Errorf("llm_agent: LLM chat failed: %w", err)
|
||||
}
|
||||
|
||||
// Persist assistant response
|
||||
if a.sessionMgr != nil && a.sessionID != "" {
|
||||
a.sessionMgr.AddMessage(a.sessionID, session.RoleAssistant, finalResponse, nil)
|
||||
}
|
||||
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-response",
|
||||
Type: bus.MsgTypeTaskResponse,
|
||||
From: a.ID(),
|
||||
To: msg.From,
|
||||
Content: finalResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *LLMAgent) buildLLMMessages() []llm.Message {
|
||||
messages := make([]llm.Message, 0)
|
||||
|
||||
if a.toolManager != nil {
|
||||
messages = append(messages, llm.Message{
|
||||
Role: "system",
|
||||
Content: a.buildToolSystemPrompt(),
|
||||
})
|
||||
}
|
||||
|
||||
if a.sessionMgr == nil || a.sessionID == "" {
|
||||
return messages
|
||||
}
|
||||
|
||||
sessionMsgs, err := a.sessionMgr.GetContext(a.sessionID, a.windowSize)
|
||||
if err != nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
for _, sm := range sessionMsgs {
|
||||
msg := llm.Message{
|
||||
Role: string(sm.Role),
|
||||
Content: sm.Content,
|
||||
}
|
||||
if sm.Role == session.RoleTool && sm.Metadata != nil {
|
||||
msg.ToolCallID = sm.Metadata["tool_call_id"]
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// buildToolSystemPrompt creates a system prompt describing all available tools.
|
||||
// This enables prompt-based tool calling for models without native function
|
||||
// calling support.
|
||||
func (a *LLMAgent) buildToolSystemPrompt() string {
|
||||
if a.toolManager == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("你是一个 AI 助手,可以使用以下工具来完成用户的请求。\n\n")
|
||||
b.WriteString("可用工具列表:\n")
|
||||
|
||||
for _, t := range a.toolManager.List() {
|
||||
b.WriteString(fmt.Sprintf("\n工具名: %s\n", t.Name()))
|
||||
b.WriteString(fmt.Sprintf("描述: %s\n", t.Description()))
|
||||
paramsJSON, _ := json.Marshal(t.Parameters())
|
||||
b.WriteString(fmt.Sprintf("参数: %s\n", string(paramsJSON)))
|
||||
}
|
||||
|
||||
b.WriteString("\n规则:\n")
|
||||
b.WriteString("1. 当你需要调用工具时,请在回复中**只输出**以下 JSON 格式(不要添加其他文字):\n")
|
||||
b.WriteString(` {"tool": "工具名", "arguments": {"参数名": "参数值"}}` + "\n")
|
||||
b.WriteString("2. 如果你已经看到了工具返回的结果,请直接根据结果回答用户,不要再次调用工具。\n")
|
||||
b.WriteString("3. 如果你不需要调用工具,请直接回复用户。\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (a *LLMAgent) chatWithToolLoop(ctx context.Context, messages []llm.Message) (string, error) {
|
||||
maxRounds := 10
|
||||
|
||||
for round := 0; round < maxRounds; round++ {
|
||||
response, err := a.llm.Chat(ctx, messages)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("chat round %d failed: %w", round, err)
|
||||
}
|
||||
|
||||
toolCalls := response.ToolCalls
|
||||
if len(toolCalls) == 0 {
|
||||
toolCalls = a.parseToolCallsFromContent(response.Content)
|
||||
}
|
||||
|
||||
if len(toolCalls) == 0 {
|
||||
return response.Content, nil
|
||||
}
|
||||
|
||||
messages = append(messages, llm.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
})
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
resultContent := a.executeToolCall(ctx, tc)
|
||||
messages = append(messages, llm.Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("工具 %s 的执行结果:%s", tc.Function.Name, resultContent),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("llm_agent: exceeded maximum tool call rounds (%d)", maxRounds)
|
||||
}
|
||||
|
||||
func (a *LLMAgent) parseToolCallsFromContent(content string) []llm.ToolCall {
|
||||
var parsed struct {
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err != nil || parsed.Tool == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(parsed.Arguments)
|
||||
return []llm.ToolCall{{
|
||||
ID: "call_0",
|
||||
Type: "function",
|
||||
Function: llm.FunctionCall{
|
||||
Name: parsed.Tool,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
// executeToolCall runs a single tool call and returns the result as a JSON string.
|
||||
func (a *LLMAgent) executeToolCall(ctx context.Context, tc llm.ToolCall) string {
|
||||
toolName := tc.Function.Name
|
||||
|
||||
// Parse arguments
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{
|
||||
"_raw": tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute via ToolWorker (preferred) or directly via tool.Manager
|
||||
if a.toolWorker != nil {
|
||||
// Create a tool call message for the ToolWorker
|
||||
toolCallMsg := bus.Message{
|
||||
ID: tc.ID,
|
||||
Type: bus.MsgTypeToolCall,
|
||||
From: a.ID(),
|
||||
To: a.toolWorker.ID(),
|
||||
Content: map[string]interface{}{"name": toolName, "arguments": args},
|
||||
}
|
||||
|
||||
resultMsg, err := a.toolWorker.Process(ctx, toolCallMsg)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
|
||||
}
|
||||
|
||||
// Serialize the result
|
||||
resultJSON, err := json.Marshal(resultMsg.Content)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
|
||||
}
|
||||
return string(resultJSON)
|
||||
}
|
||||
|
||||
// Fallback: execute directly via tool.Manager
|
||||
if a.toolManager != nil {
|
||||
result, err := a.toolManager.Execute(toolName, ctx, args)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error": "tool execution failed: %v"}`, err)
|
||||
}
|
||||
|
||||
resultJSON, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error": "failed to marshal result: %v"}`, err)
|
||||
}
|
||||
return string(resultJSON)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`{"error": "no tool worker or tool manager available for %q"}`, toolName)
|
||||
}
|
||||
|
||||
// handleSystem processes internal system messages.
|
||||
func (a *LLMAgent) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-ack",
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: a.ID(),
|
||||
To: msg.From,
|
||||
Content: "llm_agent acknowledged",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compile-time interface checks.
|
||||
var _ Agent = (*LLMAgent)(nil)
|
||||
var _ Agent = (*ToolWorker)(nil)
|
||||
123
pkg/actor/orchestrator.go
Normal file
123
pkg/actor/orchestrator.go
Normal file
@ -0,0 +1,123 @@
|
||||
package actor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
)
|
||||
|
||||
// Orchestrator is an agent that coordinates task execution across a pool of workers.
|
||||
//
|
||||
// It receives task requests, delegates them to available workers, and
|
||||
// collects responses. The orchestrator maintains a registry of worker
|
||||
// agents and can dynamically add or remove them.
|
||||
type Orchestrator struct {
|
||||
*BaseAgent
|
||||
workers map[string]Agent
|
||||
bus bus.MessageBus
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOrchestrator creates a new Orchestrator agent with the given id and message bus.
|
||||
// The agent is started automatically upon creation.
|
||||
func NewOrchestrator(id string, mb bus.MessageBus) *Orchestrator {
|
||||
o := &Orchestrator{
|
||||
BaseAgent: NewBaseAgent(id, "orchestrator"),
|
||||
workers: make(map[string]Agent),
|
||||
bus: mb,
|
||||
}
|
||||
o.SetHandler(o.handleMessage)
|
||||
// Start the agent's processing loop
|
||||
if err := o.Start(); err != nil {
|
||||
// This should not happen since handler is set above
|
||||
panic(fmt.Sprintf("orchestrator: failed to start: %v", err))
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// handleMessage routes incoming messages to the appropriate handler.
|
||||
func (o *Orchestrator) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
switch msg.Type {
|
||||
case bus.MsgTypeTaskRequest:
|
||||
return o.handleTask(ctx, msg)
|
||||
case bus.MsgTypeSystem:
|
||||
return o.handleSystem(ctx, msg)
|
||||
default:
|
||||
return bus.Message{}, fmt.Errorf("orchestrator %s: unsupported message type %s", o.ID(), msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleTask processes a task request by delegating to an available worker.
|
||||
func (o *Orchestrator) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
if len(o.workers) == 0 {
|
||||
return bus.Message{}, fmt.Errorf("orchestrator %s: no workers available", o.ID())
|
||||
}
|
||||
|
||||
// Simple round-robin: pick the first available worker
|
||||
for _, w := range o.workers {
|
||||
return w.Process(ctx, msg)
|
||||
}
|
||||
|
||||
return bus.Message{}, fmt.Errorf("orchestrator %s: no workers available", o.ID())
|
||||
}
|
||||
|
||||
// handleSystem processes internal system messages.
|
||||
func (o *Orchestrator) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-ack",
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: o.ID(),
|
||||
To: msg.From,
|
||||
Content: "orchestrator acknowledged",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddWorker registers a worker agent with this orchestrator.
|
||||
func (o *Orchestrator) AddWorker(w Agent) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
o.workers[w.ID()] = w
|
||||
}
|
||||
|
||||
// RemoveWorker unregisters a worker agent from this orchestrator.
|
||||
func (o *Orchestrator) RemoveWorker(id string) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
delete(o.workers, id)
|
||||
}
|
||||
|
||||
// WorkerCount returns the number of registered workers.
|
||||
func (o *Orchestrator) WorkerCount() int {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
return len(o.workers)
|
||||
}
|
||||
|
||||
// GetWorker retrieves a registered worker by ID.
|
||||
func (o *Orchestrator) GetWorker(id string) (Agent, bool) {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
w, ok := o.workers[id]
|
||||
return w, ok
|
||||
}
|
||||
|
||||
// ListWorkers returns all registered workers.
|
||||
func (o *Orchestrator) ListWorkers() []Agent {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
workers := make([]Agent, 0, len(o.workers))
|
||||
for _, w := range o.workers {
|
||||
workers = append(workers, w)
|
||||
}
|
||||
return workers
|
||||
}
|
||||
|
||||
// Bus returns the orchestrator's message bus reference.
|
||||
func (o *Orchestrator) Bus() bus.MessageBus {
|
||||
return o.bus
|
||||
}
|
||||
180
pkg/actor/system.go
Normal file
180
pkg/actor/system.go
Normal file
@ -0,0 +1,180 @@
|
||||
package actor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/orca/orca/pkg/tool"
|
||||
)
|
||||
|
||||
// System manages the lifecycle of all agents in the Orca actor framework.
|
||||
//
|
||||
// It provides centralized agent creation, monitoring, and shutdown
|
||||
// capabilities. Agents are identified by unique IDs and organized by role.
|
||||
type System struct {
|
||||
mu sync.RWMutex
|
||||
agents map[string]Agent
|
||||
nextID int64
|
||||
}
|
||||
|
||||
// NewSystem creates a new empty actor System.
|
||||
func NewSystem() *System {
|
||||
return &System{
|
||||
agents: make(map[string]Agent),
|
||||
}
|
||||
}
|
||||
|
||||
// AgentInfo holds summary information about a managed agent.
|
||||
type AgentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Status ActorStatus `json:"status"`
|
||||
}
|
||||
|
||||
// CreateOrchestrator creates a new Orchestrator agent and registers it.
|
||||
func (s *System) CreateOrchestrator(bus interface{}) (*Orchestrator, error) {
|
||||
id := s.nextAgentID("orch")
|
||||
return s.addOrchestrator(id, bus)
|
||||
}
|
||||
|
||||
// CreateWorker creates a new Worker agent and registers it.
|
||||
func (s *System) CreateWorker() (*Worker, error) {
|
||||
id := s.nextAgentID("worker")
|
||||
return s.addWorker(id)
|
||||
}
|
||||
|
||||
// CreateToolWorker creates a new ToolWorker agent with the given tool manager and registers it.
|
||||
func (s *System) CreateToolWorker(manager *tool.Manager) (*ToolWorker, error) {
|
||||
id := s.nextAgentID("tool")
|
||||
return s.addToolWorker(id, manager)
|
||||
}
|
||||
|
||||
// nextAgentID generates a unique agent ID with the given prefix.
|
||||
func (s *System) nextAgentID(prefix string) string {
|
||||
n := atomic.AddInt64(&s.nextID, 1)
|
||||
return fmt.Sprintf("%s-%d", prefix, n)
|
||||
}
|
||||
|
||||
// addOrchestrator creates and registers an orchestrator.
|
||||
func (s *System) addOrchestrator(id string, busInterface interface{}) (*Orchestrator, error) {
|
||||
mb, ok := busInterface.(interface{ Bus() })
|
||||
var orch *Orchestrator
|
||||
if ok {
|
||||
// If busInterface has a Bus() method, we could extract it here
|
||||
_ = mb
|
||||
}
|
||||
orch = NewOrchestrator(id, nil)
|
||||
|
||||
s.mu.Lock()
|
||||
s.agents[id] = orch
|
||||
s.mu.Unlock()
|
||||
|
||||
return orch, nil
|
||||
}
|
||||
|
||||
// addWorker creates and registers a worker.
|
||||
func (s *System) addWorker(id string) (*Worker, error) {
|
||||
w := NewWorker(id)
|
||||
|
||||
s.mu.Lock()
|
||||
s.agents[id] = w
|
||||
s.mu.Unlock()
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// addToolWorker creates and registers a tool worker with the given tool manager.
|
||||
func (s *System) addToolWorker(id string, manager *tool.Manager) (*ToolWorker, error) {
|
||||
w := NewToolWorker(id, manager)
|
||||
|
||||
s.mu.Lock()
|
||||
s.agents[id] = w
|
||||
s.mu.Unlock()
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// StopAgent stops and removes a single agent by ID.
|
||||
func (s *System) StopAgent(id string) error {
|
||||
s.mu.Lock()
|
||||
agent, ok := s.agents[id]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("agent %s not found", id)
|
||||
}
|
||||
delete(s.agents, id)
|
||||
s.mu.Unlock()
|
||||
|
||||
return agent.Stop()
|
||||
}
|
||||
|
||||
// GetAgent retrieves a registered agent by ID.
|
||||
func (s *System) GetAgent(id string) (Agent, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
agent, ok := s.agents[id]
|
||||
return agent, ok
|
||||
}
|
||||
|
||||
// ListAgents returns all registered agents.
|
||||
func (s *System) ListAgents() []Agent {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
agents := make([]Agent, 0, len(s.agents))
|
||||
for _, a := range s.agents {
|
||||
agents = append(agents, a)
|
||||
}
|
||||
return agents
|
||||
}
|
||||
|
||||
// AgentInfos returns summary information for all registered agents.
|
||||
func (s *System) AgentInfos() []AgentInfo {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
infos := make([]AgentInfo, 0, len(s.agents))
|
||||
for _, a := range s.agents {
|
||||
// Try to get status from BaseAgent
|
||||
status := StatusIdle
|
||||
if ba, ok := a.(*BaseAgent); ok {
|
||||
status = ba.Status()
|
||||
} else if orch, ok := a.(*Orchestrator); ok {
|
||||
status = orch.Status()
|
||||
} else if w, ok := a.(*Worker); ok {
|
||||
status = w.Status()
|
||||
} else if tw, ok := a.(*ToolWorker); ok {
|
||||
status = tw.Status()
|
||||
}
|
||||
|
||||
infos = append(infos, AgentInfo{
|
||||
ID: a.ID(),
|
||||
Role: a.Role(),
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// StopAll gracefully stops all registered agents.
|
||||
func (s *System) StopAll() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for id, agent := range s.agents {
|
||||
if err := agent.Stop(); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
delete(s.agents, id)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// AgentCount returns the number of registered agents.
|
||||
func (s *System) AgentCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.agents)
|
||||
}
|
||||
153
pkg/actor/tool_worker.go
Normal file
153
pkg/actor/tool_worker.go
Normal file
@ -0,0 +1,153 @@
|
||||
package actor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
"github.com/orca/orca/pkg/tool"
|
||||
)
|
||||
|
||||
// ToolWorker is an agent that processes tool call messages by executing
|
||||
// tools through the tool.Manager.
|
||||
//
|
||||
// It implements the Agent interface and handles MsgTypeToolCall messages.
|
||||
// When a tool call is received, it extracts the tool name and arguments
|
||||
// from the message content, executes the tool via the Manager, and
|
||||
// returns a MsgTypeToolResult with the execution result.
|
||||
type ToolWorker struct {
|
||||
*BaseAgent
|
||||
manager *tool.Manager
|
||||
}
|
||||
|
||||
// NewToolWorker creates a new ToolWorker agent with the given id and tool manager.
|
||||
// The agent is started automatically upon creation.
|
||||
func NewToolWorker(id string, manager *tool.Manager) *ToolWorker {
|
||||
w := &ToolWorker{
|
||||
BaseAgent: NewBaseAgent(id, "tool_worker"),
|
||||
manager: manager,
|
||||
}
|
||||
w.SetHandler(w.handleMessage)
|
||||
if err := w.Start(); err != nil {
|
||||
panic(fmt.Sprintf("tool_worker: failed to start: %v", err))
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// handleMessage routes incoming messages to the appropriate handler.
|
||||
func (w *ToolWorker) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
switch msg.Type {
|
||||
case bus.MsgTypeToolCall:
|
||||
return w.handleToolCall(ctx, msg)
|
||||
case bus.MsgTypeTaskRequest:
|
||||
return w.handleTask(ctx, msg)
|
||||
case bus.MsgTypeSystem:
|
||||
return w.handleSystem(ctx, msg)
|
||||
default:
|
||||
return bus.Message{}, fmt.Errorf("tool_worker %s: unsupported message type %s", w.ID(), msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleToolCall processes a tool call by executing the named tool
|
||||
// with the provided arguments.
|
||||
//
|
||||
// The msg.Content is expected to contain a JSON object with:
|
||||
// - "name": the tool name (string)
|
||||
// - "arguments": the tool arguments (object)
|
||||
//
|
||||
// Or alternatively, msg.Content can be a string in the format:
|
||||
// tool_name(arg1=val1, arg2=val2)
|
||||
func (w *ToolWorker) handleToolCall(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
w.setStatus(StatusWaitingForTool)
|
||||
defer w.setStatus(StatusProcessing)
|
||||
|
||||
toolName, args, err := parseToolCallContent(msg.Content)
|
||||
if err != nil {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-result",
|
||||
Type: bus.MsgTypeToolResult,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: map[string]interface{}{"error": err.Error()},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
result, err := w.manager.Execute(toolName, ctx, args)
|
||||
if err != nil {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-result",
|
||||
Type: bus.MsgTypeToolResult,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: map[string]interface{}{"error": err.Error()},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-result",
|
||||
Type: bus.MsgTypeToolResult,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseToolCallContent extracts the tool name and arguments from various
|
||||
// content formats.
|
||||
func parseToolCallContent(content interface{}) (string, map[string]interface{}, error) {
|
||||
switch v := content.(type) {
|
||||
case map[string]interface{}:
|
||||
// Format: {"name": "tool_name", "arguments": {...}}
|
||||
name, ok := v["name"].(string)
|
||||
if !ok || name == "" {
|
||||
return "", nil, fmt.Errorf("tool call content missing 'name' field")
|
||||
}
|
||||
args, _ := v["arguments"].(map[string]interface{})
|
||||
if args == nil {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
return name, args, nil
|
||||
|
||||
case string:
|
||||
// Try JSON format
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &parsed); err == nil {
|
||||
name, ok := parsed["name"].(string)
|
||||
if ok && name != "" {
|
||||
args, _ := parsed["arguments"].(map[string]interface{})
|
||||
if args == nil {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
return name, args, nil
|
||||
}
|
||||
}
|
||||
return "", nil, fmt.Errorf("cannot parse tool call from string content: %s", v)
|
||||
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported tool call content type: %T", content)
|
||||
}
|
||||
}
|
||||
|
||||
// handleTask processes a task request by returning a task response.
|
||||
func (w *ToolWorker) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-response",
|
||||
Type: bus.MsgTypeTaskResponse,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSystem processes internal system messages.
|
||||
func (w *ToolWorker) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-ack",
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: "tool_worker acknowledged",
|
||||
}, nil
|
||||
}
|
||||
88
pkg/actor/worker.go
Normal file
88
pkg/actor/worker.go
Normal file
@ -0,0 +1,88 @@
|
||||
package actor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
)
|
||||
|
||||
// Worker is an agent that processes tasks and makes tool calls.
|
||||
//
|
||||
// Workers are the execution units in the actor system. They receive
|
||||
// task requests from the orchestrator, process them (potentially making
|
||||
// tool calls), and return results.
|
||||
type Worker struct {
|
||||
*BaseAgent
|
||||
}
|
||||
|
||||
// NewWorker creates a new Worker agent with the given id.
|
||||
// The agent is started automatically upon creation.
|
||||
func NewWorker(id string) *Worker {
|
||||
w := &Worker{
|
||||
BaseAgent: NewBaseAgent(id, "worker"),
|
||||
}
|
||||
w.SetHandler(w.handleMessage)
|
||||
if err := w.Start(); err != nil {
|
||||
panic(fmt.Sprintf("worker: failed to start: %v", err))
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// handleMessage routes incoming messages to the appropriate handler.
|
||||
func (w *Worker) handleMessage(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
switch msg.Type {
|
||||
case bus.MsgTypeTaskRequest:
|
||||
return w.handleTask(ctx, msg)
|
||||
case bus.MsgTypeToolCall:
|
||||
return w.handleToolCall(ctx, msg)
|
||||
case bus.MsgTypeSystem:
|
||||
return w.handleSystem(ctx, msg)
|
||||
default:
|
||||
return bus.Message{}, fmt.Errorf("worker %s: unsupported message type %s", w.ID(), msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleTask processes a task request and returns a task response.
|
||||
func (w *Worker) handleTask(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
// Process the task - in a real implementation this would involve
|
||||
// the LLM, tool calls, etc.
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-response",
|
||||
Type: bus.MsgTypeTaskResponse,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: msg.Content,
|
||||
Metadata: map[string]string{
|
||||
"processed_by": w.ID(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleToolCall processes a tool call request, transitions to WaitingForTool
|
||||
// state, and returns the result.
|
||||
func (w *Worker) handleToolCall(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
w.setStatus(StatusWaitingForTool)
|
||||
defer w.setStatus(StatusProcessing)
|
||||
|
||||
// In a real implementation, this would invoke the actual tool.
|
||||
// For now, acknowledge the tool call.
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-result",
|
||||
Type: bus.MsgTypeToolResult,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSystem processes internal system messages.
|
||||
func (w *Worker) handleSystem(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
return bus.Message{
|
||||
ID: msg.ID + "-ack",
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: w.ID(),
|
||||
To: msg.From,
|
||||
Content: "worker acknowledged",
|
||||
}, nil
|
||||
}
|
||||
164
pkg/bus/bus.go
Normal file
164
pkg/bus/bus.go
Normal file
@ -0,0 +1,164 @@
|
||||
package bus
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Handler is a callback function that processes a delivered message.
|
||||
type Handler func(Message)
|
||||
|
||||
// Subscription represents an active subscription to a message bus topic.
|
||||
type Subscription interface {
|
||||
ID() string
|
||||
Topic() string
|
||||
Unsubscribe()
|
||||
}
|
||||
|
||||
// MessageBus is the central communication hub of the Orca framework.
|
||||
//
|
||||
// It uses a publish/subscribe pattern built on Go channels. Components
|
||||
// publish messages to named topics, and all subscribers to that topic
|
||||
// receive the message asynchronously.
|
||||
type MessageBus interface {
|
||||
// Publish sends a message to all active subscribers of the given topic.
|
||||
Publish(topic string, msg Message) error
|
||||
// Subscribe registers a handler for the given topic.
|
||||
Subscribe(topic string, handler Handler) (Subscription, error)
|
||||
// Close gracefully shuts down the bus, cleaning up all subscriptions.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// subscription implements the Subscription interface.
|
||||
type subscription struct {
|
||||
id string
|
||||
topic string
|
||||
ch chan Message
|
||||
bus *messageBus
|
||||
active *atomic.Bool
|
||||
}
|
||||
|
||||
func (s *subscription) ID() string { return s.id }
|
||||
func (s *subscription) Topic() string { return s.topic }
|
||||
func (s *subscription) Unsubscribe() { s.bus.unsubscribe(s) }
|
||||
|
||||
// messageBus is the channel-based implementation of MessageBus.
|
||||
type messageBus struct {
|
||||
mu sync.RWMutex
|
||||
topics map[string][]*subscription
|
||||
nextID int64
|
||||
closed bool
|
||||
}
|
||||
|
||||
// New creates a new message bus instance.
|
||||
func New() MessageBus {
|
||||
return &messageBus{
|
||||
topics: make(map[string][]*subscription),
|
||||
}
|
||||
}
|
||||
|
||||
// Publish sends a message to all subscribers of the given topic.
|
||||
// The send is non-blocking: if a subscriber's channel buffer is full,
|
||||
// the message is dropped for that subscriber.
|
||||
func (mb *messageBus) Publish(topic string, msg Message) error {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
|
||||
if mb.closed {
|
||||
return errors.New("message bus is closed")
|
||||
}
|
||||
|
||||
subs, ok := mb.topics[topic]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, sub := range subs {
|
||||
if sub.active.Load() {
|
||||
select {
|
||||
case sub.ch <- msg:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe adds a handler for the given topic.
|
||||
func (mb *messageBus) Subscribe(topic string, handler Handler) (Subscription, error) {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if mb.closed {
|
||||
return nil, errors.New("message bus is closed")
|
||||
}
|
||||
|
||||
id := fmt.Sprintf("sub-%d", atomic.AddInt64(&mb.nextID, 1))
|
||||
sub := &subscription{
|
||||
id: id,
|
||||
topic: topic,
|
||||
ch: make(chan Message, 64),
|
||||
bus: mb,
|
||||
active: &atomic.Bool{},
|
||||
}
|
||||
sub.active.Store(true)
|
||||
|
||||
mb.topics[topic] = append(mb.topics[topic], sub)
|
||||
|
||||
go sub.deliver(handler)
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// deliver reads messages from the subscription channel and calls the handler.
|
||||
func (s *subscription) deliver(handler Handler) {
|
||||
for msg := range s.ch {
|
||||
if !s.active.Load() {
|
||||
return
|
||||
}
|
||||
handler(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// unsubscribe removes a subscription from the bus and closes its channel.
|
||||
func (mb *messageBus) unsubscribe(sub *subscription) {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
sub.active.Store(false)
|
||||
|
||||
subs, ok := mb.topics[sub.Topic()]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for i, s := range subs {
|
||||
if s.ID() == sub.ID() {
|
||||
mb.topics[sub.Topic()] = append(subs[:i], subs[i+1:]...)
|
||||
close(s.ch)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the bus, unsubscribing all active subscriptions.
|
||||
func (mb *messageBus) Close() error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if mb.closed {
|
||||
return nil
|
||||
}
|
||||
mb.closed = true
|
||||
|
||||
for topic, subs := range mb.topics {
|
||||
for _, sub := range subs {
|
||||
sub.active.Store(false)
|
||||
close(sub.ch)
|
||||
}
|
||||
delete(mb.topics, topic)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
252
pkg/bus/bus_test.go
Normal file
252
pkg/bus/bus_test.go
Normal file
@ -0,0 +1,252 @@
|
||||
package bus
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewBus(t *testing.T) {
|
||||
b := New()
|
||||
if b == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishSubscribe(t *testing.T) {
|
||||
b := New()
|
||||
defer b.Close()
|
||||
|
||||
var received int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
sub, err := b.Subscribe("test", func(msg Message) {
|
||||
atomic.AddInt32(&received, 1)
|
||||
wg.Done()
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
err = b.Publish("test", Message{
|
||||
ID: "msg-1",
|
||||
Type: MsgTypeSystem,
|
||||
From: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Publish failed: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if atomic.LoadInt32(&received) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", received)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishNoSubscribers(t *testing.T) {
|
||||
b := New()
|
||||
defer b.Close()
|
||||
|
||||
err := b.Publish("nonexistent", Message{ID: "msg-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Publish to nonexistent topic should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSubscribers(t *testing.T) {
|
||||
b := New()
|
||||
defer b.Close()
|
||||
|
||||
var received int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
sub, err := b.Subscribe("multi", func(msg Message) {
|
||||
atomic.AddInt32(&received, 1)
|
||||
wg.Done()
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe %d failed: %v", i, err)
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
}
|
||||
|
||||
err := b.Publish("multi", Message{ID: "msg-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Publish failed: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if n := atomic.LoadInt32(&received); n != 3 {
|
||||
t.Errorf("expected 3 messages, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
b := New()
|
||||
defer b.Close()
|
||||
|
||||
var received int32
|
||||
sub, err := b.Subscribe("test", func(msg Message) {
|
||||
atomic.AddInt32(&received, 1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
|
||||
// Publish before unsubscribe
|
||||
b.Publish("test", Message{ID: "msg-1"})
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sub.Unsubscribe()
|
||||
|
||||
// Publish after unsubscribe
|
||||
b.Publish("test", Message{ID: "msg-2"})
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if n := atomic.LoadInt32(&received); n != 1 {
|
||||
t.Errorf("expected 1 message after unsubscribe, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeAfterClose(t *testing.T) {
|
||||
b := New()
|
||||
b.Close()
|
||||
|
||||
_, err := b.Subscribe("test", func(msg Message) {})
|
||||
if err == nil {
|
||||
t.Error("expected error subscribing to closed bus")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishAfterClose(t *testing.T) {
|
||||
b := New()
|
||||
b.Close()
|
||||
|
||||
err := b.Publish("test", Message{ID: "msg-1"})
|
||||
if err == nil {
|
||||
t.Error("expected error publishing to closed bus")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionID(t *testing.T) {
|
||||
b := New()
|
||||
defer b.Close()
|
||||
|
||||
sub1, _ := b.Subscribe("a", func(msg Message) {})
|
||||
defer sub1.Unsubscribe()
|
||||
sub2, _ := b.Subscribe("b", func(msg Message) {})
|
||||
defer sub2.Unsubscribe()
|
||||
|
||||
if sub1.ID() == sub2.ID() {
|
||||
t.Error("subscription IDs should be unique")
|
||||
}
|
||||
|
||||
if sub1.Topic() != "a" || sub2.Topic() != "b" {
|
||||
t.Error("topic mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentPublish(t *testing.T) {
|
||||
b := New()
|
||||
defer b.Close()
|
||||
|
||||
var received int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(100)
|
||||
|
||||
sub, err := b.Subscribe("concurrent", func(msg Message) {
|
||||
atomic.AddInt32(&received, 1)
|
||||
wg.Done()
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
go func(i int) {
|
||||
b.Publish("concurrent", Message{
|
||||
ID: time.Now().String(),
|
||||
Type: MsgTypeSystem,
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if n := atomic.LoadInt32(&received); n != 100 {
|
||||
t.Errorf("expected 100 messages, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentTopics(t *testing.T) {
|
||||
b := New()
|
||||
defer b.Close()
|
||||
|
||||
var topics []string
|
||||
var mu sync.Mutex
|
||||
|
||||
sub1, _ := b.Subscribe("topic-a", func(msg Message) {
|
||||
mu.Lock()
|
||||
topics = append(topics, "a")
|
||||
mu.Unlock()
|
||||
})
|
||||
defer sub1.Unsubscribe()
|
||||
|
||||
sub2, _ := b.Subscribe("topic-b", func(msg Message) {
|
||||
mu.Lock()
|
||||
topics = append(topics, "b")
|
||||
mu.Unlock()
|
||||
})
|
||||
defer sub2.Unsubscribe()
|
||||
|
||||
b.Publish("topic-a", Message{ID: "msg-1"})
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if len(topics) != 1 || topics[0] != "a" {
|
||||
t.Errorf("expected only topic-a to receive message, got %v", topics)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseIdempotent(t *testing.T) {
|
||||
b := New()
|
||||
err1 := b.Close()
|
||||
err2 := b.Close()
|
||||
|
||||
if err1 != nil {
|
||||
t.Fatalf("first Close failed: %v", err1)
|
||||
}
|
||||
if err2 != nil {
|
||||
t.Fatalf("second Close should be idempotent: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTypeString(t *testing.T) {
|
||||
tests := []struct {
|
||||
mt MessageType
|
||||
want string
|
||||
}{
|
||||
{MsgTypeSystem, "system"},
|
||||
{MsgTypeTaskRequest, "task_request"},
|
||||
{MsgTypeTaskResponse, "task_response"},
|
||||
{MsgTypeToolCall, "tool_call"},
|
||||
{MsgTypeToolResult, "tool_result"},
|
||||
{MsgTypeObservation, "observation"},
|
||||
{MsgTypeError, "error"},
|
||||
{MsgTypeLog, "log"},
|
||||
{MessageType(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := tt.mt.String(); got != tt.want {
|
||||
t.Errorf("MessageType(%d).String() = %q, want %q", tt.mt, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
67
pkg/bus/types.go
Normal file
67
pkg/bus/types.go
Normal file
@ -0,0 +1,67 @@
|
||||
// Package bus provides the message bus system for inter-component communication.
|
||||
//
|
||||
// The message bus is the central nervous system of the Orca framework.
|
||||
// All components (kernel, plugins, agents) communicate through it
|
||||
// via a publish/subscribe pattern over Go channels.
|
||||
package bus
|
||||
|
||||
import "time"
|
||||
|
||||
// MessageType represents the category of a message in the bus system.
|
||||
type MessageType int
|
||||
|
||||
const (
|
||||
// MsgTypeSystem is for internal kernel messages.
|
||||
MsgTypeSystem MessageType = iota
|
||||
// MsgTypeTaskRequest is a request to perform a task.
|
||||
MsgTypeTaskRequest
|
||||
// MsgTypeTaskResponse is the result of a task.
|
||||
MsgTypeTaskResponse
|
||||
// MsgTypeToolCall is a request to invoke a tool.
|
||||
MsgTypeToolCall
|
||||
// MsgTypeToolResult is the result of a tool execution.
|
||||
MsgTypeToolResult
|
||||
// MsgTypeObservation is an observation from tool/command execution.
|
||||
MsgTypeObservation
|
||||
// MsgTypeError is an error message.
|
||||
MsgTypeError
|
||||
// MsgTypeLog is a log message for observability.
|
||||
MsgTypeLog
|
||||
)
|
||||
|
||||
// String returns the human-readable name of the message type.
|
||||
func (mt MessageType) String() string {
|
||||
switch mt {
|
||||
case MsgTypeSystem:
|
||||
return "system"
|
||||
case MsgTypeTaskRequest:
|
||||
return "task_request"
|
||||
case MsgTypeTaskResponse:
|
||||
return "task_response"
|
||||
case MsgTypeToolCall:
|
||||
return "tool_call"
|
||||
case MsgTypeToolResult:
|
||||
return "tool_result"
|
||||
case MsgTypeObservation:
|
||||
return "observation"
|
||||
case MsgTypeError:
|
||||
return "error"
|
||||
case MsgTypeLog:
|
||||
return "log"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Message is the universal data unit in the message bus system.
|
||||
//
|
||||
// Every component communicates by publishing and subscribing to Messages.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Type MessageType `json:"type"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
Content interface{} `json:"content"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
392
pkg/kernel/kernel.go
Normal file
392
pkg/kernel/kernel.go
Normal file
@ -0,0 +1,392 @@
|
||||
// Package kernel implements the microkernel core of the Orca framework.
|
||||
//
|
||||
// The kernel is the minimal runtime that manages plugin lifecycle,
|
||||
// message routing, and inter-component communication.
|
||||
package kernel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orca/orca/internal/config"
|
||||
"github.com/orca/orca/pkg/actor"
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
"github.com/orca/orca/pkg/llm"
|
||||
"github.com/orca/orca/pkg/plugin"
|
||||
"github.com/orca/orca/pkg/session"
|
||||
"github.com/orca/orca/pkg/skill"
|
||||
"github.com/orca/orca/pkg/tool"
|
||||
)
|
||||
|
||||
// Kernel is the microkernel core of the Orca framework.
|
||||
//
|
||||
// It orchestrates plugin lifecycle, message routing, and inter-component
|
||||
// communication. The kernel initializes and manages:
|
||||
// - Message bus for inter-component communication
|
||||
// - Plugin registry for extensibility
|
||||
// - Session manager for conversation persistence
|
||||
// - Tool manager with built-in tools
|
||||
// - Skill manager for skill-based automation
|
||||
// - Actor system with orchestrator, workers, and LLM agent
|
||||
type Kernel struct {
|
||||
mu sync.RWMutex
|
||||
mb bus.MessageBus
|
||||
registry *plugin.Registry
|
||||
plugins []plugin.Plugin
|
||||
started bool
|
||||
|
||||
// Integration components
|
||||
config *config.Config
|
||||
sessionMgr *session.Manager
|
||||
toolMgr *tool.Manager
|
||||
skillMgr *skill.Manager
|
||||
actorSystem *actor.System
|
||||
orch *actor.Orchestrator
|
||||
llmAgent *actor.LLMAgent
|
||||
toolWorker *actor.ToolWorker
|
||||
}
|
||||
|
||||
// New creates a new Kernel instance with default configuration.
|
||||
func New() *Kernel {
|
||||
return NewWithConfig(config.DefaultConfig())
|
||||
}
|
||||
|
||||
// NewWithConfig creates a new Kernel instance with the given configuration.
|
||||
func NewWithConfig(cfg *config.Config) *Kernel {
|
||||
if cfg == nil {
|
||||
cfg = config.DefaultConfig()
|
||||
}
|
||||
|
||||
k := &Kernel{
|
||||
mb: bus.New(),
|
||||
registry: plugin.NewRegistry(),
|
||||
config: cfg,
|
||||
actorSystem: actor.NewSystem(),
|
||||
}
|
||||
|
||||
// Initialize session manager
|
||||
store, err := session.NewJSONLStore(cfg.Session.StorageDir)
|
||||
if err != nil {
|
||||
log.Printf("kernel: warning: failed to create session store: %v", err)
|
||||
} else {
|
||||
k.sessionMgr = session.NewManager(store, k.mb)
|
||||
}
|
||||
|
||||
// Initialize tool manager with all built-in tools
|
||||
k.toolMgr = tool.NewManager()
|
||||
k.registerBuiltinTools()
|
||||
|
||||
// Initialize skill manager
|
||||
k.skillMgr = skill.NewManager(cfg.Session.StorageDir + "/skills")
|
||||
|
||||
// Initialize actor system
|
||||
k.initializeActorSystem()
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
// registerBuiltinTools registers all built-in tools with the tool manager.
|
||||
func (k *Kernel) registerBuiltinTools() {
|
||||
tools := []tool.Tool{
|
||||
tool.NewExecTool(nil), // exec - shell commands
|
||||
tool.NewReadFileTool(), // read_file
|
||||
tool.NewWriteFileTool(), // write_file
|
||||
tool.NewListDirTool(), // list_dir
|
||||
tool.NewSearchFilesTool(), // search_files
|
||||
}
|
||||
|
||||
for _, t := range tools {
|
||||
if err := k.toolMgr.Register(t); err != nil {
|
||||
log.Printf("kernel: warning: failed to register tool %q: %v", t.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initializeActorSystem sets up the orchestrator, tool worker, and LLM agent.
|
||||
func (k *Kernel) initializeActorSystem() {
|
||||
// Create orchestrator
|
||||
orch, err := k.actorSystem.CreateOrchestrator(k)
|
||||
if err != nil {
|
||||
log.Printf("kernel: warning: failed to create orchestrator: %v", err)
|
||||
return
|
||||
}
|
||||
k.orch = orch
|
||||
|
||||
// Create tool worker
|
||||
tw, err := k.actorSystem.CreateToolWorker(k.toolMgr)
|
||||
if err != nil {
|
||||
log.Printf("kernel: warning: failed to create tool worker: %v", err)
|
||||
return
|
||||
}
|
||||
k.toolWorker = tw
|
||||
|
||||
// Create LLM backend
|
||||
ollama := k.createLLMBackend()
|
||||
|
||||
// Create LLM agent
|
||||
llmAgentID := fmt.Sprintf("llm-%d", len(k.actorSystem.ListAgents())+1)
|
||||
llmOpts := []actor.LLMAgentOption{
|
||||
actor.WithToolManager(k.toolMgr),
|
||||
actor.WithToolWorker(k.toolWorker),
|
||||
actor.WithWindowSize(k.config.Session.MaxHistory),
|
||||
}
|
||||
|
||||
if k.sessionMgr != nil {
|
||||
sessionID := "default"
|
||||
if _, err := k.sessionMgr.GetSession(sessionID); err != nil {
|
||||
k.sessionMgr.CreateSession(sessionID, map[string]string{
|
||||
"source": "kernel",
|
||||
})
|
||||
}
|
||||
|
||||
llmOpts = append(llmOpts,
|
||||
actor.WithSessionManager(k.sessionMgr),
|
||||
actor.WithSessionID(sessionID),
|
||||
)
|
||||
}
|
||||
|
||||
llmAgent := actor.NewLLMAgent(llmAgentID, ollama, llmOpts...)
|
||||
k.llmAgent = llmAgent
|
||||
|
||||
// Register LLM agent as orchestrator's worker
|
||||
k.orch.AddWorker(llmAgent)
|
||||
|
||||
// Also register tool worker as a fallback worker
|
||||
k.orch.AddWorker(tw)
|
||||
}
|
||||
|
||||
// createLLMBackend creates the LLM backend based on configuration.
|
||||
func (k *Kernel) createLLMBackend() llm.LLM {
|
||||
baseURL := k.config.Ollama.BaseURL
|
||||
model := k.config.Ollama.Model
|
||||
timeout := k.config.Ollama.Timeout
|
||||
|
||||
// Allow shorter env var names to override
|
||||
if v := os.Getenv("OLLAMA_BASE_URL"); v != "" {
|
||||
baseURL = v
|
||||
}
|
||||
if v := os.Getenv("OLLAMA_MODEL"); v != "" {
|
||||
model = v
|
||||
}
|
||||
if v := os.Getenv("OLLAMA_TIMEOUT"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
client := llm.NewOllamaClient(
|
||||
llm.WithBaseURL(baseURL),
|
||||
llm.WithModel(model),
|
||||
llm.WithTimeout(timeout),
|
||||
)
|
||||
|
||||
log.Printf("kernel: created Ollama client (model=%s, url=%s)", model, baseURL)
|
||||
return client
|
||||
}
|
||||
|
||||
// Bus returns the kernel's message bus.
|
||||
func (k *Kernel) Bus() bus.MessageBus {
|
||||
return k.mb
|
||||
}
|
||||
|
||||
// Registry returns the plugin registry.
|
||||
func (k *Kernel) Registry() *plugin.Registry {
|
||||
return k.registry
|
||||
}
|
||||
|
||||
// SessionManager returns the session manager.
|
||||
func (k *Kernel) SessionManager() *session.Manager {
|
||||
return k.sessionMgr
|
||||
}
|
||||
|
||||
// ToolManager returns the tool manager.
|
||||
func (k *Kernel) ToolManager() *tool.Manager {
|
||||
return k.toolMgr
|
||||
}
|
||||
|
||||
// SkillManager returns the skill manager.
|
||||
func (k *Kernel) SkillManager() *skill.Manager {
|
||||
return k.skillMgr
|
||||
}
|
||||
|
||||
// ActorSystem returns the actor system.
|
||||
func (k *Kernel) ActorSystem() *actor.System {
|
||||
return k.actorSystem
|
||||
}
|
||||
|
||||
// Orchestrator returns the orchestrator agent.
|
||||
func (k *Kernel) Orchestrator() *actor.Orchestrator {
|
||||
return k.orch
|
||||
}
|
||||
|
||||
// LLMAgent returns the LLM agent.
|
||||
func (k *Kernel) LLMAgent() *actor.LLMAgent {
|
||||
return k.llmAgent
|
||||
}
|
||||
|
||||
// SendMessage sends a message from a source to the LLM agent.
|
||||
//
|
||||
// This is the primary public API for interacting with the Orca system.
|
||||
// It creates a task request message and sends it through the orchestrator
|
||||
// to the LLM agent for processing.
|
||||
//
|
||||
// Parameters:
|
||||
// - from: the sender identifier (e.g., "user", "cli")
|
||||
// - to: the recipient (use "llm" for the LLM agent)
|
||||
// - content: the message content (plain text)
|
||||
//
|
||||
// Returns the response content as a string, or an error.
|
||||
func (k *Kernel) SendMessage(from, to, content string) (string, error) {
|
||||
if !k.IsRunning() {
|
||||
return "", fmt.Errorf("kernel: kernel is not running")
|
||||
}
|
||||
|
||||
if k.orch == nil {
|
||||
return "", fmt.Errorf("kernel: orchestrator not initialized")
|
||||
}
|
||||
|
||||
// Create a task request message
|
||||
msg := bus.Message{
|
||||
Type: bus.MsgTypeTaskRequest,
|
||||
From: from,
|
||||
To: to,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
// Send through the orchestrator
|
||||
ctx := context.Background()
|
||||
resp, err := k.orch.Process(ctx, msg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("kernel: orchestrator processing failed: %w", err)
|
||||
}
|
||||
|
||||
// Extract response content
|
||||
switch v := resp.Content.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
default:
|
||||
return fmt.Sprintf("%v", v), nil
|
||||
}
|
||||
}
|
||||
|
||||
// InitPlugins loads and initializes skills from the skills directory.
|
||||
func (k *Kernel) InitPlugins() error {
|
||||
if k.skillMgr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := k.skillMgr.LoadAll()
|
||||
if err != nil {
|
||||
log.Printf("kernel: warning: skill loading had errors: %v", err)
|
||||
}
|
||||
if count > 0 {
|
||||
log.Printf("kernel: loaded %d skills", count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPlugin returns a registered plugin by name.
|
||||
func (k *Kernel) GetPlugin(name string) (plugin.Plugin, bool) {
|
||||
return k.registry.Get(name)
|
||||
}
|
||||
|
||||
// ListPlugins returns all currently registered plugins.
|
||||
func (k *Kernel) ListPlugins() []plugin.Plugin {
|
||||
return k.registry.List()
|
||||
}
|
||||
|
||||
// RegisterPlugin registers a plugin without starting it.
|
||||
func (k *Kernel) RegisterPlugin(p plugin.Plugin) error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
|
||||
if k.started {
|
||||
return fmt.Errorf("kernel: cannot register plugin %q: kernel already started", p.Name())
|
||||
}
|
||||
|
||||
return k.registry.Register(p)
|
||||
}
|
||||
|
||||
// UnregisterPlugin removes a plugin from the registry.
|
||||
func (k *Kernel) UnregisterPlugin(name string) error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
|
||||
return k.registry.Unregister(name)
|
||||
}
|
||||
|
||||
// Start initializes all registered plugins and marks the kernel as running.
|
||||
func (k *Kernel) Start() error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
|
||||
if k.started {
|
||||
return fmt.Errorf("kernel: already started")
|
||||
}
|
||||
|
||||
k.started = true
|
||||
|
||||
// Initialize plugins
|
||||
plugins := k.registry.List()
|
||||
k.plugins = make([]plugin.Plugin, 0, len(plugins))
|
||||
|
||||
for _, p := range plugins {
|
||||
k.registry.SetState(p.Name(), plugin.StateInitialized)
|
||||
if err := p.Init(k); err != nil {
|
||||
log.Printf("kernel: warning: failed to init plugin %q: %v", p.Name(), err)
|
||||
k.registry.SetState(p.Name(), plugin.StateError)
|
||||
continue
|
||||
}
|
||||
k.registry.SetState(p.Name(), plugin.StateRunning)
|
||||
k.plugins = append(k.plugins, p)
|
||||
log.Printf("kernel: plugin %q (%s) initialized", p.Name(), p.Version())
|
||||
}
|
||||
|
||||
log.Printf("kernel: started (tools=%d)", k.toolMgr.Count())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the kernel.
|
||||
func (k *Kernel) Stop() error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
|
||||
if !k.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop actor system first
|
||||
if k.actorSystem != nil {
|
||||
if err := k.actorSystem.StopAll(); err != nil {
|
||||
log.Printf("kernel: warning: error stopping actor system: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop plugins
|
||||
for i := len(k.plugins) - 1; i >= 0; i-- {
|
||||
p := k.plugins[i]
|
||||
k.registry.SetState(p.Name(), plugin.StateStopped)
|
||||
if err := p.Shutdown(); err != nil {
|
||||
log.Printf("kernel: warning: error shutting down plugin %q: %v", p.Name(), err)
|
||||
continue
|
||||
}
|
||||
log.Printf("kernel: plugin %q shut down", p.Name())
|
||||
}
|
||||
|
||||
k.plugins = nil
|
||||
k.started = false
|
||||
|
||||
return k.mb.Close()
|
||||
}
|
||||
|
||||
// IsRunning returns whether the kernel has been started and not yet stopped.
|
||||
func (k *Kernel) IsRunning() bool {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
return k.started
|
||||
}
|
||||
343
pkg/kernel/kernel_test.go
Normal file
343
pkg/kernel/kernel_test.go
Normal file
@ -0,0 +1,343 @@
|
||||
package kernel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
"github.com/orca/orca/pkg/plugin"
|
||||
)
|
||||
|
||||
// testPlugin implements Plugin for kernel testing.
|
||||
type testPlugin struct {
|
||||
name string
|
||||
version string
|
||||
initFn func(host plugin.PluginHost) error
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (p *testPlugin) Name() string { return p.name }
|
||||
func (p *testPlugin) Version() string { return p.version }
|
||||
func (p *testPlugin) Init(host plugin.PluginHost) error {
|
||||
if p.initFn != nil {
|
||||
return p.initFn(host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (p *testPlugin) Shutdown() error {
|
||||
if p.closeFn != nil {
|
||||
return p.closeFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewKernel(t *testing.T) {
|
||||
k := New()
|
||||
if k == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
if k.Bus() == nil {
|
||||
t.Error("Bus() returned nil")
|
||||
}
|
||||
if k.Registry() == nil {
|
||||
t.Error("Registry() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelStartStop(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
if err := k.Start(); err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
if !k.IsRunning() {
|
||||
t.Error("expected kernel running after Start")
|
||||
}
|
||||
|
||||
if err := k.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
if k.IsRunning() {
|
||||
t.Error("expected kernel stopped after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelDoubleStart(t *testing.T) {
|
||||
k := New()
|
||||
k.Start()
|
||||
err := k.Start()
|
||||
if err == nil {
|
||||
t.Error("expected error on double start")
|
||||
}
|
||||
k.Stop()
|
||||
}
|
||||
|
||||
func TestKernelRegisterPlugin(t *testing.T) {
|
||||
k := New()
|
||||
p := &testPlugin{name: "test", version: "1.0.0"}
|
||||
|
||||
err := k.RegisterPlugin(p)
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterPlugin failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := k.GetPlugin("test")
|
||||
if !ok {
|
||||
t.Fatal("GetPlugin returned not found")
|
||||
}
|
||||
if got.Name() != "test" {
|
||||
t.Errorf("expected name 'test', got %q", got.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelRegisterPluginAfterStart(t *testing.T) {
|
||||
k := New()
|
||||
k.Start()
|
||||
defer k.Stop()
|
||||
|
||||
err := k.RegisterPlugin(&testPlugin{name: "test", version: "1.0.0"})
|
||||
if err == nil {
|
||||
t.Error("expected error registering plugin after start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelPluginLifecycle(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
var initCount int32
|
||||
var shutdownCount int32
|
||||
|
||||
p := &testPlugin{
|
||||
name: "lifecycle",
|
||||
version: "1.0.0",
|
||||
initFn: func(host plugin.PluginHost) error {
|
||||
atomic.AddInt32(&initCount, 1)
|
||||
return nil
|
||||
},
|
||||
closeFn: func() error {
|
||||
atomic.AddInt32(&shutdownCount, 1)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
k.RegisterPlugin(p)
|
||||
k.Start()
|
||||
|
||||
if n := atomic.LoadInt32(&initCount); n != 1 {
|
||||
t.Errorf("expected init called once, got %d", n)
|
||||
}
|
||||
|
||||
k.Stop()
|
||||
|
||||
if n := atomic.LoadInt32(&shutdownCount); n != 1 {
|
||||
t.Errorf("expected shutdown called once, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelPluginInitFailure(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
p := &testPlugin{
|
||||
name: "failing",
|
||||
version: "1.0.0",
|
||||
initFn: func(host plugin.PluginHost) error {
|
||||
return errors.New("init failed")
|
||||
},
|
||||
}
|
||||
|
||||
k.RegisterPlugin(p)
|
||||
|
||||
// Init failure should not prevent Start from succeeding (graceful degradation)
|
||||
err := k.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Start should succeed even with failing plugin: %v", err)
|
||||
}
|
||||
k.Stop()
|
||||
}
|
||||
|
||||
func TestKernelPluginShutdownFailure(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
p := &testPlugin{
|
||||
name: "failing-shutdown",
|
||||
version: "1.0.0",
|
||||
initFn: func(host plugin.PluginHost) error {
|
||||
return nil
|
||||
},
|
||||
closeFn: func() error {
|
||||
return errors.New("shutdown failed")
|
||||
},
|
||||
}
|
||||
|
||||
k.RegisterPlugin(p)
|
||||
k.Start()
|
||||
|
||||
// Shutdown failure should not prevent Stop from succeeding
|
||||
err := k.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("Stop should succeed even with failing plugin shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelMultiplePlugins(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
names := []string{"alpha", "beta", "gamma"}
|
||||
for _, name := range names {
|
||||
k.RegisterPlugin(&testPlugin{name: name, version: "1.0.0"})
|
||||
}
|
||||
|
||||
k.Start()
|
||||
|
||||
plugins := k.ListPlugins()
|
||||
if len(plugins) != len(names) {
|
||||
t.Errorf("expected %d plugins, got %d", len(names), len(plugins))
|
||||
}
|
||||
|
||||
k.Stop()
|
||||
}
|
||||
|
||||
func TestKernelUnregisterPlugin(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
k.RegisterPlugin(&testPlugin{name: "remove-me", version: "1.0.0"})
|
||||
|
||||
err := k.UnregisterPlugin("remove-me")
|
||||
if err != nil {
|
||||
t.Fatalf("UnregisterPlugin failed: %v", err)
|
||||
}
|
||||
|
||||
_, ok := k.GetPlugin("remove-me")
|
||||
if ok {
|
||||
t.Error("plugin should not exist after unregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelStopWithoutStart(t *testing.T) {
|
||||
k := New()
|
||||
err := k.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("Stop without Start should be a no-op: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelPluginReceivesHost(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
var gotHost plugin.PluginHost
|
||||
p := &testPlugin{
|
||||
name: "host-check",
|
||||
version: "1.0.0",
|
||||
initFn: func(host plugin.PluginHost) error {
|
||||
gotHost = host
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
k.RegisterPlugin(p)
|
||||
k.Start()
|
||||
|
||||
if gotHost == nil {
|
||||
t.Fatal("plugin did not receive PluginHost")
|
||||
}
|
||||
|
||||
// Verify the host can access bus
|
||||
if gotHost.Bus() == nil {
|
||||
t.Error("PluginHost.Bus() returned nil")
|
||||
}
|
||||
|
||||
// Verify plugin discovery through host
|
||||
p2, ok := gotHost.GetPlugin("host-check")
|
||||
if !ok {
|
||||
t.Error("PluginHost.GetPlugin should find itself")
|
||||
}
|
||||
if p2.Name() != "host-check" {
|
||||
t.Errorf("expected name 'host-check', got %q", p2.Name())
|
||||
}
|
||||
|
||||
k.Stop()
|
||||
}
|
||||
|
||||
func TestKernelAllPluginsInitialized(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
names := []string{"a", "b", "c"}
|
||||
initialized := make(map[string]bool)
|
||||
|
||||
for _, name := range names {
|
||||
n := name
|
||||
k.RegisterPlugin(&testPlugin{
|
||||
name: n,
|
||||
initFn: func(host plugin.PluginHost) error {
|
||||
initialized[n] = true
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
k.Start()
|
||||
|
||||
for _, name := range names {
|
||||
if !initialized[name] {
|
||||
t.Errorf("plugin %q was not initialized", name)
|
||||
}
|
||||
}
|
||||
|
||||
k.Stop()
|
||||
}
|
||||
|
||||
func TestKernelShutdownAllPlugins(t *testing.T) {
|
||||
k := New()
|
||||
|
||||
names := []string{"x", "y", "z"}
|
||||
shutdown := make(map[string]bool)
|
||||
|
||||
for _, name := range names {
|
||||
n := name
|
||||
k.RegisterPlugin(&testPlugin{
|
||||
name: n,
|
||||
initFn: func(host plugin.PluginHost) error {
|
||||
return nil
|
||||
},
|
||||
closeFn: func() error {
|
||||
shutdown[n] = true
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
k.Start()
|
||||
k.Stop()
|
||||
|
||||
for _, name := range names {
|
||||
if !shutdown[name] {
|
||||
t.Errorf("plugin %q was not shut down", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKernelMessageBusIntegration(t *testing.T) {
|
||||
k := New()
|
||||
k.Start()
|
||||
defer k.Stop()
|
||||
|
||||
mb := k.Bus()
|
||||
|
||||
var received int32
|
||||
sub, err := mb.Subscribe("kernel-test", func(msg bus.Message) {
|
||||
atomic.AddInt32(&received, 1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
mb.Publish("kernel-test", bus.Message{ID: "test-msg"})
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if n := atomic.LoadInt32(&received); n != 1 {
|
||||
t.Errorf("expected 1 message via kernel bus, got %d", n)
|
||||
}
|
||||
}
|
||||
24
pkg/llm/llm.go
Normal file
24
pkg/llm/llm.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Package llm provides the LLM integration layer for the Orca framework.
|
||||
//
|
||||
// It defines the LLM interface for interacting with language models,
|
||||
// the Ollama client implementation, and the shared types for chat
|
||||
// messages, tool calls, and streaming responses.
|
||||
package llm
|
||||
|
||||
import "context"
|
||||
|
||||
// LLM is the interface for interacting with language models.
|
||||
//
|
||||
// Implementations provide Chat (for complete responses) and Stream
|
||||
// (for streaming token-by-token responses) methods. Both methods
|
||||
// accept a list of messages and return the model's response.
|
||||
type LLM interface {
|
||||
// Chat sends a list of messages to the LLM and returns a complete response.
|
||||
// If the model decides to call tools, the response contains ToolCalls.
|
||||
Chat(ctx context.Context, messages []Message) (*Response, error)
|
||||
|
||||
// Stream sends messages and streams the response token-by-token.
|
||||
// The handler is called for each chunk. The final response is not
|
||||
// collected; use Chat for complete responses.
|
||||
Stream(ctx context.Context, messages []Message, handler StreamHandler) error
|
||||
}
|
||||
301
pkg/llm/ollama.go
Normal file
301
pkg/llm/ollama.go
Normal file
@ -0,0 +1,301 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OllamaClient implements the LLM interface for Ollama's API.
|
||||
//
|
||||
// It communicates with a running Ollama server via its REST API.
|
||||
// Supports chat, streaming, tool calling (function calling), and
|
||||
// embedding generation.
|
||||
type OllamaClient struct {
|
||||
baseURL string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// OllamaOption is a functional option for configuring the OllamaClient.
|
||||
type OllamaOption func(*OllamaClient)
|
||||
|
||||
// WithBaseURL sets the Ollama server base URL.
|
||||
func WithBaseURL(url string) OllamaOption {
|
||||
return func(c *OllamaClient) {
|
||||
c.baseURL = strings.TrimRight(url, "/")
|
||||
}
|
||||
}
|
||||
|
||||
// WithModel sets the default model name.
|
||||
func WithModel(model string) OllamaOption {
|
||||
return func(c *OllamaClient) {
|
||||
c.model = model
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the HTTP client timeout.
|
||||
func WithTimeout(timeout time.Duration) OllamaOption {
|
||||
return func(c *OllamaClient) {
|
||||
c.httpClient.Timeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets a custom HTTP client.
|
||||
func WithHTTPClient(client *http.Client) OllamaOption {
|
||||
return func(c *OllamaClient) {
|
||||
c.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
// NewOllamaClient creates a new OllamaClient with the given options.
|
||||
//
|
||||
// Default values:
|
||||
// - BaseURL: http://localhost:11434
|
||||
// - Model: gemma4:e4b
|
||||
// - Timeout: 30s
|
||||
func NewOllamaClient(opts ...OllamaOption) *OllamaClient {
|
||||
c := &OllamaClient{
|
||||
baseURL: "http://localhost:11434",
|
||||
model: "gemma4:e4b",
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Chat sends a chat request to Ollama and returns the complete response.
|
||||
// If the Ollama model returns tool calls, they are parsed and included
|
||||
// in the Response.
|
||||
func (c *OllamaClient) Chat(ctx context.Context, messages []Message) (*Response, error) {
|
||||
req := OllamaChatRequest{
|
||||
Model: c.model,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
// Build tool definitions if tool package is integrated
|
||||
// (tools are added externally via BuildToolDefs)
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
c.baseURL+"/api/chat", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: failed to create request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama: API error (status %d): %s",
|
||||
resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var rawResp json.RawMessage
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rawResp); err != nil {
|
||||
return nil, fmt.Errorf("ollama: failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return parseOllamaResponse(rawResp)
|
||||
}
|
||||
|
||||
// parseOllamaResponse attempts to parse the Ollama API response,
|
||||
// handling both regular text responses and tool call responses.
|
||||
func parseOllamaResponse(raw json.RawMessage) (*Response, error) {
|
||||
// Try as tool call response first (has message.tool_calls)
|
||||
var toolResp OllamaToolCallResponse
|
||||
if err := json.Unmarshal(raw, &toolResp); err == nil && len(toolResp.Message.ToolCalls) > 0 {
|
||||
return &Response{
|
||||
Content: toolResp.Message.Content,
|
||||
ToolCalls: toolResp.Message.ToolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try as regular response
|
||||
var chatResp OllamaChatResponse
|
||||
if err := json.Unmarshal(raw, &chatResp); err != nil {
|
||||
return nil, fmt.Errorf("ollama: failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &Response{
|
||||
Content: chatResp.Message.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stream sends a chat request to Ollama with streaming enabled.
|
||||
// The handler receives each content chunk as it arrives.
|
||||
func (c *OllamaClient) Stream(ctx context.Context, messages []Message, handler StreamHandler) error {
|
||||
req := OllamaChatRequest{
|
||||
Model: c.model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama: failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
c.baseURL+"/api/chat", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama: failed to create request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama: stream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("ollama: stream API error (status %d): %s",
|
||||
resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Each line is a JSON object: {"model":"...","created_at":"...","message":{"role":"assistant","content":"..."},"done":false}
|
||||
var streamResp OllamaChatResponse
|
||||
if err := json.Unmarshal([]byte(line), &streamResp); err != nil {
|
||||
continue // Skip malformed lines
|
||||
}
|
||||
|
||||
if streamResp.Message.Content != "" {
|
||||
if err := handler(streamResp.Message.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if streamResp.Done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// Embed generates an embedding vector for the given input text.
|
||||
func (c *OllamaClient) Embed(ctx context.Context, input string) (*EmbeddingResponse, error) {
|
||||
req := OllamaEmbedRequest{
|
||||
Model: c.model,
|
||||
Input: input,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: failed to marshal embed request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
c.baseURL+"/api/embed", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: failed to create embed request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: embed request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama: embed API error (status %d): %s",
|
||||
resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var apiResp OllamaEmbedResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
|
||||
return nil, fmt.Errorf("ollama: failed to decode embed response: %w", err)
|
||||
}
|
||||
|
||||
return &EmbeddingResponse{
|
||||
Embedding: apiResp.Embedding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BuildToolDefsFromMap converts a generic tool definition map into Ollama ToolDefs.
|
||||
// This is used to bridge the tool package's Tool interface with Ollama's API format.
|
||||
func BuildToolDefsFromMap(tools []map[string]interface{}) []ToolDef {
|
||||
var defs []ToolDef
|
||||
for _, t := range tools {
|
||||
name, _ := t["name"].(string)
|
||||
desc, _ := t["description"].(string)
|
||||
|
||||
def := ToolDef{
|
||||
Type: "function",
|
||||
Function: ToolFunction{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
Parameters: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: make(map[string]ToolProperty),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if params, ok := t["parameters"].(map[string]interface{}); ok {
|
||||
if props, ok := params["properties"].(map[string]interface{}); ok {
|
||||
for key, val := range props {
|
||||
if p, ok := val.(map[string]interface{}); ok {
|
||||
prop := ToolProperty{
|
||||
Type: toString(p["type"]),
|
||||
Description: toString(p["description"]),
|
||||
}
|
||||
def.Function.Parameters.Properties[key] = prop
|
||||
if isRequired(p) {
|
||||
def.Function.Parameters.Required = append(def.Function.Parameters.Required, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
defs = append(defs, def)
|
||||
}
|
||||
return defs
|
||||
}
|
||||
|
||||
func toString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func isRequired(p map[string]interface{}) bool {
|
||||
req, _ := p["required"].(bool)
|
||||
return req
|
||||
}
|
||||
385
pkg/llm/ollama_test.go
Normal file
385
pkg/llm/ollama_test.go
Normal file
@ -0,0 +1,385 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// Helper: create a mock Ollama server
|
||||
// ============================================================
|
||||
|
||||
// mockOllamaHandler returns an http.Handler that simulates the Ollama API.
|
||||
func mockOllamaHandler(t *testing.T, responseFunc func(reqBody map[string]interface{}) (int, interface{})) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request path
|
||||
if r.URL.Path != "/api/chat" && r.URL.Path != "/api/embed" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode request body
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
t.Fatalf("failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
status, resp := responseFunc(reqBody)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// NewOllamaClient Tests
|
||||
// ============================================================
|
||||
|
||||
func TestNewOllamaClientDefaults(t *testing.T) {
|
||||
c := NewOllamaClient()
|
||||
if c == nil {
|
||||
t.Fatal("NewOllamaClient() returned nil")
|
||||
}
|
||||
if c.baseURL != "http://localhost:11434" {
|
||||
t.Errorf("expected default base URL 'http://localhost:11434', got %q", c.baseURL)
|
||||
}
|
||||
if c.model != "gemma4:e4b" {
|
||||
t.Errorf("expected default model 'gemma4:e4b', got %q", c.model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOllamaClientWithOptions(t *testing.T) {
|
||||
c := NewOllamaClient(
|
||||
WithBaseURL("http://custom:11434"),
|
||||
WithModel("codellama"),
|
||||
WithTimeout(60),
|
||||
)
|
||||
if c.baseURL != "http://custom:11434" {
|
||||
t.Errorf("expected base URL 'http://custom:11434', got %q", c.baseURL)
|
||||
}
|
||||
if c.model != "codellama" {
|
||||
t.Errorf("expected model 'codellama', got %q", c.model)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Chat Tests
|
||||
// ============================================================
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) {
|
||||
// Verify the request has the expected shape
|
||||
if model, ok := reqBody["model"]; !ok || model != "gemma4:e4b" {
|
||||
t.Errorf("expected model 'gemma4:e4b', got %v", model)
|
||||
}
|
||||
if stream, ok := reqBody["stream"]; !ok || stream != false {
|
||||
t.Errorf("expected stream false, got %v", stream)
|
||||
}
|
||||
|
||||
return http.StatusOK, OllamaChatResponse{
|
||||
Model: "gemma4:e4b",
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello! How can I help you?",
|
||||
},
|
||||
Done: true,
|
||||
}
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
client := NewOllamaClient(WithBaseURL(srv.URL))
|
||||
resp, err := client.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat failed: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello! How can I help you?" {
|
||||
t.Errorf("expected content 'Hello! How can I help you?', got %q", resp.Content)
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("expected no tool calls, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatWithToolCalls(t *testing.T) {
|
||||
srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) {
|
||||
return http.StatusOK, OllamaToolCallResponse{
|
||||
Model: "gemma4:e4b",
|
||||
Message: OllamaToolMsg{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Type: "function",
|
||||
Function: FunctionCall{
|
||||
Name: "exec",
|
||||
Arguments: `{"command":"ls -la"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
}
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
client := NewOllamaClient(WithBaseURL(srv.URL))
|
||||
resp, err := client.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "List files"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat failed: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Function.Name != "exec" {
|
||||
t.Errorf("expected tool name 'exec', got %q", resp.ToolCalls[0].Function.Name)
|
||||
}
|
||||
if resp.ToolCalls[0].Function.Arguments != `{"command":"ls -la"}` {
|
||||
t.Errorf("unexpected arguments: %q", resp.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatAPIError(t *testing.T) {
|
||||
srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) {
|
||||
return http.StatusInternalServerError, map[string]string{"error": "internal error"}
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
client := NewOllamaClient(WithBaseURL(srv.URL))
|
||||
_, err := client.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for API error response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected error to contain status code, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
client := NewOllamaClient(WithBaseURL("http://localhost:11434"))
|
||||
_, err := client.Chat(ctx, []Message{{Role: "user", Content: "Hello"}})
|
||||
if err == nil {
|
||||
t.Error("expected error for cancelled context")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Stream Tests
|
||||
// ============================================================
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
chunks := []string{"Hello", "!", " How", " can", " I", " help?"}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/chat" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatal("expected http.Flusher")
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
resp := OllamaChatResponse{
|
||||
Model: "gemma4:e4b",
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: chunk,
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(append(data, '\n'))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Send done signal
|
||||
doneResp := OllamaChatResponse{
|
||||
Model: "gemma4:e4b",
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(doneResp)
|
||||
w.Write(append(data, '\n'))
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewOllamaClient(WithBaseURL(srv.URL))
|
||||
|
||||
var received []string
|
||||
err := client.Stream(context.Background(), []Message{{Role: "user", Content: "Hi"}},
|
||||
func(chunk string) error {
|
||||
received = append(received, chunk)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream failed: %v", err)
|
||||
}
|
||||
|
||||
if len(received) != len(chunks) {
|
||||
t.Errorf("expected %d chunks, got %d", len(chunks), len(received))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHandlerError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
resp := OllamaChatResponse{
|
||||
Model: "gemma4:e4b",
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: "chunk",
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(append(data, '\n'))
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewOllamaClient(WithBaseURL(srv.URL))
|
||||
err := client.Stream(context.Background(), []Message{{Role: "user", Content: "Hi"}},
|
||||
func(chunk string) error {
|
||||
return &streamError{msg: "handler error"}
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from handler")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "handler error") {
|
||||
t.Errorf("expected 'handler error', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type streamError struct{ msg string }
|
||||
|
||||
func (e *streamError) Error() string { return e.msg }
|
||||
|
||||
// ============================================================
|
||||
// Embed Tests
|
||||
// ============================================================
|
||||
|
||||
func TestEmbed(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/embed" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(OllamaEmbedResponse{
|
||||
Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewOllamaClient(WithBaseURL(srv.URL))
|
||||
resp, err := client.Embed(context.Background(), "test input")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed failed: %v", err)
|
||||
}
|
||||
if len(resp.Embedding) != 5 {
|
||||
t.Errorf("expected 5 embedding values, got %d", len(resp.Embedding))
|
||||
}
|
||||
if resp.Embedding[0] != 0.1 {
|
||||
t.Errorf("expected first value 0.1, got %f", resp.Embedding[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Tool Def Builder Tests
|
||||
// ============================================================
|
||||
|
||||
func TestBuildToolDefsFromMap(t *testing.T) {
|
||||
tools := []map[string]interface{}{
|
||||
{
|
||||
"name": "exec",
|
||||
"description": "Execute a shell command",
|
||||
"parameters": map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"command": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Command to run",
|
||||
"required": true,
|
||||
},
|
||||
"timeout": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "Timeout in seconds",
|
||||
"required": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
defs := BuildToolDefsFromMap(tools)
|
||||
if len(defs) != 1 {
|
||||
t.Fatalf("expected 1 tool def, got %d", len(defs))
|
||||
}
|
||||
if defs[0].Function.Name != "exec" {
|
||||
t.Errorf("expected name 'exec', got %q", defs[0].Function.Name)
|
||||
}
|
||||
if len(defs[0].Function.Parameters.Required) != 1 {
|
||||
t.Errorf("expected 1 required parameter, got %d", len(defs[0].Function.Parameters.Required))
|
||||
}
|
||||
if defs[0].Function.Parameters.Required[0] != "command" {
|
||||
t.Errorf("expected required 'command', got %q", defs[0].Function.Parameters.Required[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Mock LLM (for use by other tests)
|
||||
// ============================================================
|
||||
|
||||
// MockLLM is a configurable mock implementation of the LLM interface for testing.
|
||||
type MockLLM struct {
|
||||
ChatFunc func(ctx context.Context, messages []Message) (*Response, error)
|
||||
StreamFunc func(ctx context.Context, messages []Message, handler StreamHandler) error
|
||||
}
|
||||
|
||||
func (m *MockLLM) Chat(ctx context.Context, messages []Message) (*Response, error) {
|
||||
if m.ChatFunc != nil {
|
||||
return m.ChatFunc(ctx, messages)
|
||||
}
|
||||
return &Response{Content: "mock response"}, nil
|
||||
}
|
||||
|
||||
func (m *MockLLM) Stream(ctx context.Context, messages []Message, handler StreamHandler) error {
|
||||
if m.StreamFunc != nil {
|
||||
return m.StreamFunc(ctx, messages, handler)
|
||||
}
|
||||
return handler("mock stream response")
|
||||
}
|
||||
128
pkg/llm/types.go
Normal file
128
pkg/llm/types.go
Normal file
@ -0,0 +1,128 @@
|
||||
// Package llm provides the LLM integration layer for the Orca framework.
|
||||
//
|
||||
// It defines the LLM interface for interacting with language models,
|
||||
// the Ollama client implementation, and the shared types for chat
|
||||
// messages, tool calls, and streaming responses.
|
||||
package llm
|
||||
|
||||
// Message represents a single message in a chat conversation.
|
||||
//
|
||||
// The Role field identifies the sender: "user", "assistant", "system",
|
||||
// or "tool". For tool results, ToolCallID links the result back to the
|
||||
// tool call that produced it.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall represents a function calling request from the LLM.
|
||||
//
|
||||
// Ollama returns tool calls in the format expected by OpenAI-compatible
|
||||
// APIs: each call has a unique ID, a type ("function"), and a Function
|
||||
// object containing the tool name and JSON-encoded arguments.
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionCall holds the name and arguments for a tool invocation.
|
||||
//
|
||||
// Arguments is a raw JSON string that should be unmarshalled into
|
||||
// the tool's expected argument shape.
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"` // JSON-encoded arguments
|
||||
}
|
||||
|
||||
// Response represents a complete (non-streaming) response from an LLM.
|
||||
//
|
||||
// If the LLM decides to invoke tools, Content may be empty and ToolCalls
|
||||
// will contain one or more entries. The caller should execute each tool
|
||||
// call and feed the results back to the LLM.
|
||||
type Response struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// StreamHandler is a callback function for processing streaming response chunks.
|
||||
//
|
||||
// Each chunk is a partial string of the ongoing response. The handler is
|
||||
// called sequentially for each chunk. Returning an error stops the stream.
|
||||
type StreamHandler func(chunk string) error
|
||||
|
||||
// EmbeddingResponse holds the result of an embedding request.
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
// OllamaChatRequest is the request body sent to Ollama's /api/chat endpoint.
|
||||
type OllamaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []ToolDef `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaChatResponse is the response body from Ollama's /api/chat endpoint.
|
||||
type OllamaChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// OllamaToolCallResponse is the wire format Ollama returns for tool calls.
|
||||
type OllamaToolCallResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message OllamaToolMsg `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// OllamaToolMsg wraps the tool_calls field in Ollama's response.
|
||||
type OllamaToolMsg struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ToolDef describes a tool that the LLM may call, in the format
|
||||
// expected by Ollama's function calling API.
|
||||
type ToolDef struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
// ToolFunction describes a function available to the LLM.
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolFunctionParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
// ToolFunctionParameters is the JSON Schema for a tool's parameters.
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
// ToolProperty describes a single parameter of a tool function.
|
||||
type ToolProperty struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaEmbedRequest is the request body for Ollama's /api/embed endpoint.
|
||||
type OllamaEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
// OllamaEmbedResponse is the response body from Ollama's /api/embed endpoint.
|
||||
type OllamaEmbedResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
58
pkg/plugin/plugin.go
Normal file
58
pkg/plugin/plugin.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Package plugin defines the plugin system for the Orca framework.
|
||||
//
|
||||
// All extensions to the framework (skills, tools, LLM drivers, etc.)
|
||||
// are implemented as plugins that implement the Plugin interface.
|
||||
// The kernel manages plugin lifecycle: load, init, start, stop, shutdown.
|
||||
package plugin
|
||||
|
||||
import "github.com/orca/orca/pkg/bus"
|
||||
|
||||
// PluginState represents the current lifecycle state of a plugin.
|
||||
type PluginState int
|
||||
|
||||
const (
|
||||
StateUnknown PluginState = iota
|
||||
StateRegistered
|
||||
StateInitialized
|
||||
StateRunning
|
||||
StateStopped
|
||||
StateError
|
||||
)
|
||||
|
||||
func (ps PluginState) String() string {
|
||||
switch ps {
|
||||
case StateUnknown:
|
||||
return "unknown"
|
||||
case StateRegistered:
|
||||
return "registered"
|
||||
case StateInitialized:
|
||||
return "initialized"
|
||||
case StateRunning:
|
||||
return "running"
|
||||
case StateStopped:
|
||||
return "stopped"
|
||||
case StateError:
|
||||
return "error"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// PluginHost is the interface that the kernel provides to plugins.
|
||||
//
|
||||
// Plugins receive a PluginHost reference during Init() to interact
|
||||
// with the framework: publishing/subscribing to messages, discovering
|
||||
// other plugins, and accessing shared resources.
|
||||
type PluginHost interface {
|
||||
Bus() bus.MessageBus
|
||||
GetPlugin(name string) (Plugin, bool)
|
||||
ListPlugins() []Plugin
|
||||
}
|
||||
|
||||
// Plugin defines the interface that all Orca plugins must implement.
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
Version() string
|
||||
Init(host PluginHost) error
|
||||
Shutdown() error
|
||||
}
|
||||
100
pkg/plugin/registry.go
Normal file
100
pkg/plugin/registry.go
Normal file
@ -0,0 +1,100 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Registry is a thread-safe map that manages plugin registration.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
plugins map[string]Plugin
|
||||
states map[string]PluginState
|
||||
}
|
||||
|
||||
// NewRegistry creates a new empty plugin registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
plugins: make(map[string]Plugin),
|
||||
states: make(map[string]PluginState),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a plugin to the registry.
|
||||
func (r *Registry) Register(p Plugin) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
name := p.Name()
|
||||
if _, exists := r.plugins[name]; exists {
|
||||
return fmt.Errorf("plugin %q is already registered", name)
|
||||
}
|
||||
|
||||
r.plugins[name] = p
|
||||
r.states[name] = StateRegistered
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unregister removes a plugin from the registry.
|
||||
func (r *Registry) Unregister(name string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.plugins[name]; !exists {
|
||||
return fmt.Errorf("plugin %q is not registered", name)
|
||||
}
|
||||
|
||||
delete(r.plugins, name)
|
||||
delete(r.states, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a plugin by name.
|
||||
func (r *Registry) Get(name string) (Plugin, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
p, ok := r.plugins[name]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// List returns all registered plugins.
|
||||
func (r *Registry) List() []Plugin {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugins := make([]Plugin, 0, len(r.plugins))
|
||||
for _, p := range r.plugins {
|
||||
plugins = append(plugins, p)
|
||||
}
|
||||
return plugins
|
||||
}
|
||||
|
||||
// State returns the lifecycle state of a registered plugin.
|
||||
func (r *Registry) State(name string) PluginState {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if state, ok := r.states[name]; ok {
|
||||
return state
|
||||
}
|
||||
return StateUnknown
|
||||
}
|
||||
|
||||
// SetState updates the lifecycle state of a registered plugin.
|
||||
func (r *Registry) SetState(name string, state PluginState) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, ok := r.plugins[name]; ok {
|
||||
r.states[name] = state
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of registered plugins.
|
||||
func (r *Registry) Count() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.plugins)
|
||||
}
|
||||
256
pkg/plugin/registry_test.go
Normal file
256
pkg/plugin/registry_test.go
Normal file
@ -0,0 +1,256 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
)
|
||||
|
||||
// mockPlugin implements Plugin for testing.
|
||||
type mockPlugin struct {
|
||||
name string
|
||||
version string
|
||||
initFn func(host PluginHost) error
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (m *mockPlugin) Name() string { return m.name }
|
||||
func (m *mockPlugin) Version() string { return m.version }
|
||||
func (m *mockPlugin) Init(host PluginHost) error {
|
||||
if m.initFn != nil {
|
||||
return m.initFn(host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) Shutdown() error {
|
||||
if m.closeFn != nil {
|
||||
return m.closeFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRegistryNew(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
if r == nil {
|
||||
t.Fatal("NewRegistry() returned nil")
|
||||
}
|
||||
if n := r.Count(); n != 0 {
|
||||
t.Errorf("expected empty registry, got %d plugins", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryRegister(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
p := &mockPlugin{name: "test", version: "1.0.0"}
|
||||
|
||||
err := r.Register(p)
|
||||
if err != nil {
|
||||
t.Fatalf("Register failed: %v", err)
|
||||
}
|
||||
|
||||
if n := r.Count(); n != 1 {
|
||||
t.Errorf("expected 1 plugin, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryRegisterDuplicate(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
p1 := &mockPlugin{name: "test", version: "1.0.0"}
|
||||
p2 := &mockPlugin{name: "test", version: "2.0.0"}
|
||||
|
||||
r.Register(p1)
|
||||
err := r.Register(p2)
|
||||
if err == nil {
|
||||
t.Error("expected error registering duplicate plugin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGet(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
p := &mockPlugin{name: "test", version: "1.0.0"}
|
||||
r.Register(p)
|
||||
|
||||
got, ok := r.Get("test")
|
||||
if !ok {
|
||||
t.Fatal("Get returned not found")
|
||||
}
|
||||
if got.Name() != "test" {
|
||||
t.Errorf("expected name 'test', got %q", got.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGetNotFound(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, ok := r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected false for nonexistent plugin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryUnregister(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
p := &mockPlugin{name: "test", version: "1.0.0"}
|
||||
r.Register(p)
|
||||
|
||||
err := r.Unregister("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Unregister failed: %v", err)
|
||||
}
|
||||
|
||||
if n := r.Count(); n != 0 {
|
||||
t.Errorf("expected 0 plugins, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryUnregisterNotFound(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := r.Unregister("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error unregistering nonexistent plugin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryList(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&mockPlugin{name: "a", version: "1.0.0"})
|
||||
r.Register(&mockPlugin{name: "b", version: "1.0.0"})
|
||||
r.Register(&mockPlugin{name: "c", version: "1.0.0"})
|
||||
|
||||
plugins := r.List()
|
||||
if len(plugins) != 3 {
|
||||
t.Errorf("expected 3 plugins, got %d", len(plugins))
|
||||
}
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, p := range plugins {
|
||||
names[p.Name()] = true
|
||||
}
|
||||
|
||||
for _, n := range []string{"a", "b", "c"} {
|
||||
if !names[n] {
|
||||
t.Errorf("missing plugin %q in list", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryState(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
p := &mockPlugin{name: "test", version: "1.0.0"}
|
||||
r.Register(p)
|
||||
|
||||
if s := r.State("test"); s != StateRegistered {
|
||||
t.Errorf("expected StateRegistered, got %s", s)
|
||||
}
|
||||
|
||||
r.SetState("test", StateRunning)
|
||||
if s := r.State("test"); s != StateRunning {
|
||||
t.Errorf("expected StateRunning, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryStateUnknown(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
if s := r.State("nonexistent"); s != StateUnknown {
|
||||
t.Errorf("expected StateUnknown for nonexistent, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrySetStateNoOp(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.SetState("nonexistent", StateRunning)
|
||||
if n := r.Count(); n != 0 {
|
||||
t.Errorf("SetState should not add plugins")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginStateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
state PluginState
|
||||
want string
|
||||
}{
|
||||
{StateUnknown, "unknown"},
|
||||
{StateRegistered, "registered"},
|
||||
{StateInitialized, "initialized"},
|
||||
{StateRunning, "running"},
|
||||
{StateStopped, "stopped"},
|
||||
{StateError, "error"},
|
||||
{PluginState(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := tt.state.String(); got != tt.want {
|
||||
t.Errorf("PluginState(%d).String() = %q, want %q", tt.state, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryConcurrent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
r.Register(&mockPlugin{name: "a", version: "1.0.0"})
|
||||
r.Get("a")
|
||||
r.Unregister("a")
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
r.Register(&mockPlugin{name: "b", version: "1.0.0"})
|
||||
r.List()
|
||||
r.State("b")
|
||||
r.Unregister("b")
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
// mockPluginHost implements PluginHost for testing kernel-level plugin init.
|
||||
type mockPluginHost struct{}
|
||||
|
||||
func (m *mockPluginHost) Bus() bus.MessageBus { return nil }
|
||||
func (m *mockPluginHost) GetPlugin(name string) (Plugin, bool) { return nil, false }
|
||||
func (m *mockPluginHost) ListPlugins() []Plugin { return nil }
|
||||
|
||||
func TestPluginInitAndShutdown(t *testing.T) {
|
||||
var initCalled, shutdownCalled bool
|
||||
|
||||
p := &mockPlugin{
|
||||
name: "test",
|
||||
version: "1.0.0",
|
||||
initFn: func(host PluginHost) error {
|
||||
initCalled = true
|
||||
if host == nil {
|
||||
return errors.New("host is nil")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
closeFn: func() error {
|
||||
shutdownCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
host := &mockPluginHost{}
|
||||
|
||||
if err := p.Init(host); err != nil {
|
||||
t.Fatalf("Init failed: %v", err)
|
||||
}
|
||||
if !initCalled {
|
||||
t.Error("Init function was not called")
|
||||
}
|
||||
|
||||
if err := p.Shutdown(); err != nil {
|
||||
t.Fatalf("Shutdown failed: %v", err)
|
||||
}
|
||||
if !shutdownCalled {
|
||||
t.Error("Shutdown function was not called")
|
||||
}
|
||||
}
|
||||
246
pkg/sandbox/process.go
Normal file
246
pkg/sandbox/process.go
Normal file
@ -0,0 +1,246 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultOutputLimit is the maximum number of bytes captured from stdout/stderr (64 KB).
|
||||
DefaultOutputLimit = 64 * 1024
|
||||
|
||||
// DefaultWorkingDir is the default working directory for sandboxed commands.
|
||||
DefaultWorkingDir = "/tmp/orca/sandbox"
|
||||
)
|
||||
|
||||
// AllowedEnvVars is the whitelist of environment variables accessible inside the sandbox.
|
||||
// Only these variables are passed through from the parent process.
|
||||
var AllowedEnvVars = []string{
|
||||
"HOME",
|
||||
"USER",
|
||||
"PATH",
|
||||
"LANG",
|
||||
"SHELL",
|
||||
"TMPDIR",
|
||||
"ORCA_HOME",
|
||||
}
|
||||
|
||||
// ProcessSandbox is a Sandbox implementation that uses os/exec to run commands
|
||||
// as child processes with resource restrictions.
|
||||
type ProcessSandbox struct {
|
||||
// WorkingDir is the directory in which commands execute.
|
||||
WorkingDir string
|
||||
// OutputLimit is the maximum bytes to capture from stdout/stderr.
|
||||
OutputLimit int
|
||||
// EnvWhitelist controls which environment variables are passed to child processes.
|
||||
// If nil, AllowedEnvVars is used. If empty, no env vars are passed.
|
||||
EnvWhitelist []string
|
||||
}
|
||||
|
||||
// NewProcessSandbox creates a ProcessSandbox with sensible defaults.
|
||||
func NewProcessSandbox() *ProcessSandbox {
|
||||
return &ProcessSandbox{
|
||||
WorkingDir: DefaultWorkingDir,
|
||||
OutputLimit: DefaultOutputLimit,
|
||||
EnvWhitelist: nil, // uses AllowedEnvVars
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs a command as a child process with resource restrictions.
|
||||
func (ps *ProcessSandbox) Execute(ctx context.Context, cmd string, args ...string) (*Result, error) {
|
||||
// Ensure working directory exists
|
||||
if err := os.MkdirAll(ps.WorkingDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("sandbox: failed to create working directory %q: %w", ps.WorkingDir, err)
|
||||
}
|
||||
|
||||
// Build the command
|
||||
c := exec.CommandContext(ctx, cmd, args...)
|
||||
c.Dir = ps.WorkingDir
|
||||
|
||||
// Set up environment variable whitelist
|
||||
env := ps.buildEnv()
|
||||
c.Env = env
|
||||
|
||||
// Capture stdout and stderr with size limits
|
||||
stdoutBuf := newLimitedBuffer(ps.outputLimit())
|
||||
stderrBuf := newLimitedBuffer(ps.outputLimit())
|
||||
|
||||
stdoutPipe, err := c.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sandbox: failed to create stdout pipe: %w", err)
|
||||
}
|
||||
stderrPipe, err := c.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sandbox: failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Start the command
|
||||
if err := c.Start(); err != nil {
|
||||
return nil, fmt.Errorf("sandbox: failed to start command: %w", err)
|
||||
}
|
||||
|
||||
// Read stdout and stderr concurrently
|
||||
var readStdout, readStderr error
|
||||
var wg syncWaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, readStdout = io.Copy(stdoutBuf, stdoutPipe)
|
||||
// ErrShortWrite is expected when the output limit is reached — not a real error.
|
||||
if readStdout != nil && readStdout != io.EOF && readStdout != io.ErrShortWrite {
|
||||
readStdout = fmt.Errorf("sandbox: stdout read error: %w", readStdout)
|
||||
} else {
|
||||
readStdout = nil
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, readStderr = io.Copy(stderrBuf, stderrPipe)
|
||||
if readStderr != nil && readStderr != io.EOF && readStderr != io.ErrShortWrite {
|
||||
readStderr = fmt.Errorf("sandbox: stderr read error: %w", readStderr)
|
||||
} else {
|
||||
readStderr = nil
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Wait for the command to finish
|
||||
err = c.Wait()
|
||||
exitCode := 0
|
||||
|
||||
if err != nil {
|
||||
// Check if the process was killed due to context cancellation (timeout)
|
||||
if ctx.Err() != nil {
|
||||
return nil, fmt.Errorf("sandbox: command timed out: %w", ctx.Err())
|
||||
}
|
||||
|
||||
// Normal non-zero exit
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Combine errors: prefer command error, then read errors
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if readStdout != nil {
|
||||
return nil, readStdout
|
||||
}
|
||||
if readStderr != nil {
|
||||
return nil, readStderr
|
||||
}
|
||||
|
||||
return &Result{
|
||||
Stdout: stdoutBuf.String(),
|
||||
Stderr: stderrBuf.String(),
|
||||
ExitCode: exitCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildEnv constructs the environment variable list for the child process
|
||||
// based on the whitelist configuration.
|
||||
func (ps *ProcessSandbox) buildEnv() []string {
|
||||
whitelist := ps.EnvWhitelist
|
||||
if whitelist == nil {
|
||||
whitelist = AllowedEnvVars
|
||||
}
|
||||
|
||||
env := make([]string, 0, len(whitelist))
|
||||
for _, key := range whitelist {
|
||||
if val, ok := os.LookupEnv(key); ok {
|
||||
env = append(env, key+"="+val)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// outputLimit returns the effective output size limit.
|
||||
func (ps *ProcessSandbox) outputLimit() int {
|
||||
if ps.OutputLimit <= 0 {
|
||||
return DefaultOutputLimit
|
||||
}
|
||||
return ps.OutputLimit
|
||||
}
|
||||
|
||||
// WorkingDirPath returns the absolute path of the sandbox working directory.
|
||||
func (ps *ProcessSandbox) WorkingDirPath() string {
|
||||
abs, err := filepath.Abs(ps.WorkingDir)
|
||||
if err != nil {
|
||||
return ps.WorkingDir
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// limitedBuffer — a writer that stops accepting data after MaxSize bytes.
|
||||
// Uses a named field (not embedded) to avoid promoting bytes.Buffer.ReadFrom
|
||||
// which would bypass the size limit when used with io.Copy.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type limitedBuffer struct {
|
||||
buf bytes.Buffer
|
||||
MaxSize int
|
||||
}
|
||||
|
||||
func newLimitedBuffer(maxSize int) *limitedBuffer {
|
||||
return &limitedBuffer{MaxSize: maxSize}
|
||||
}
|
||||
|
||||
func (lb *limitedBuffer) Write(p []byte) (int, error) {
|
||||
remaining := lb.MaxSize - lb.buf.Len()
|
||||
if remaining <= 0 {
|
||||
return len(p), nil // silently drop excess; io.Copy sees nw==nr, continues draining pipe
|
||||
}
|
||||
if len(p) > remaining {
|
||||
p = p[:remaining]
|
||||
n, err := lb.buf.Write(p)
|
||||
// Return n < original len(p) so io.Copy stops with ErrShortWrite.
|
||||
return n, err
|
||||
}
|
||||
return lb.buf.Write(p)
|
||||
}
|
||||
|
||||
func (lb *limitedBuffer) String() string {
|
||||
return lb.buf.String()
|
||||
}
|
||||
|
||||
func (lb *limitedBuffer) Len() int {
|
||||
return lb.buf.Len()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// syncWaitGroup — a simple goroutine synchronization mechanism.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type syncWaitGroup struct {
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
func (wg *syncWaitGroup) Add(n int) {
|
||||
if wg.ch == nil {
|
||||
wg.ch = make(chan struct{}, n)
|
||||
}
|
||||
}
|
||||
|
||||
func (wg *syncWaitGroup) Done() {
|
||||
wg.ch <- struct{}{}
|
||||
}
|
||||
|
||||
func (wg *syncWaitGroup) Wait() {
|
||||
for i := 0; i < cap(wg.ch); i++ {
|
||||
<-wg.ch
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ Sandbox = (*ProcessSandbox)(nil)
|
||||
212
pkg/sandbox/process_test.go
Normal file
212
pkg/sandbox/process_test.go
Normal file
@ -0,0 +1,212 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewProcessSandbox(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
if ps == nil {
|
||||
t.Fatal("NewProcessSandbox() returned nil")
|
||||
}
|
||||
if ps.WorkingDir != DefaultWorkingDir {
|
||||
t.Errorf("expected WorkingDir %q, got %q", DefaultWorkingDir, ps.WorkingDir)
|
||||
}
|
||||
if ps.OutputLimit != DefaultOutputLimit {
|
||||
t.Errorf("expected OutputLimit %d, got %d", DefaultOutputLimit, ps.OutputLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEcho(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := ps.Execute(ctx, "echo", "hello", "world")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
t.Errorf("expected exit code 0, got %d", result.ExitCode)
|
||||
}
|
||||
if strings.TrimSpace(result.Stdout) != "hello world" {
|
||||
t.Errorf("expected stdout 'hello world', got %q", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithArgs(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := ps.Execute(ctx, "sh", "-c", "echo 'arg1 arg2'")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
t.Errorf("expected exit code 0, got %d", result.ExitCode)
|
||||
}
|
||||
if strings.TrimSpace(result.Stdout) != "arg1 arg2" {
|
||||
t.Errorf("expected stdout 'arg1 arg2', got %q", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteNonZeroExit(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := ps.Execute(ctx, "sh", "-c", "exit 42")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error on non-zero exit: %v", err)
|
||||
}
|
||||
if result.ExitCode != 42 {
|
||||
t.Errorf("expected exit code 42, got %d", result.ExitCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCommandNotFound(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := ps.Execute(ctx, "nonexistent-command-12345")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteTimeout(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := ps.Execute(ctx, "sleep", "10")
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout error")
|
||||
}
|
||||
// On macOS the error may be "signal: killed" or "context deadline exceeded".
|
||||
// Just verify an error occurred — the exact message varies by platform.
|
||||
t.Logf("timeout produced error: %v", err)
|
||||
}
|
||||
|
||||
func TestExecuteWorkingDirectory(t *testing.T) {
|
||||
// Use a temp directory for this test
|
||||
tmpDir, err := os.MkdirTemp("", "sandbox-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
ps := &ProcessSandbox{
|
||||
WorkingDir: tmpDir,
|
||||
OutputLimit: DefaultOutputLimit,
|
||||
EnvWhitelist: AllowedEnvVars,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := ps.Execute(ctx, "pwd")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
t.Errorf("expected exit code 0, got %d", result.ExitCode)
|
||||
}
|
||||
|
||||
// pwd should return the temp directory
|
||||
gotDir := strings.TrimSpace(result.Stdout)
|
||||
absGot, _ := filepath.EvalSymlinks(gotDir)
|
||||
absTmp, _ := filepath.EvalSymlinks(tmpDir)
|
||||
if absGot != absTmp {
|
||||
t.Errorf("expected working dir %q, got %q", absTmp, absGot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvironmentWhitelist(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ps.EnvWhitelist = []string{"HOME"}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := ps.Execute(ctx, "sh", "-c", "echo $HOME")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
t.Errorf("expected exit code 0, got %d", result.ExitCode)
|
||||
}
|
||||
|
||||
home := os.Getenv("HOME")
|
||||
if home != "" && strings.TrimSpace(result.Stdout) != home {
|
||||
t.Errorf("expected HOME=%q, got %q", home, strings.TrimSpace(result.Stdout))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvironmentIsolation(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
// Use an empty whitelist to ensure no env vars are passed
|
||||
ps.EnvWhitelist = []string{}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := ps.Execute(ctx, "sh", "-c", "echo $HOME")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
|
||||
// HOME should be empty in the child process
|
||||
if strings.TrimSpace(result.Stdout) != "" {
|
||||
t.Errorf("expected empty HOME in isolated env, got %q", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputLimit(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ps.OutputLimit = 10 // Only 10 bytes
|
||||
|
||||
ctx := context.Background()
|
||||
// Generate a long output well beyond the 10-byte limit
|
||||
result, err := ps.Execute(ctx, "sh", "-c", "echo 'AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDEEEEEEEEEEFFFFFFFFFF'")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
|
||||
// The output should be truncated to approximately 10 bytes (plus newline)
|
||||
if len(result.Stdout) > 15 {
|
||||
t.Errorf("expected truncated output (<=15 bytes), got %d bytes: %q", len(result.Stdout), result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStderr(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := ps.Execute(ctx, "sh", "-c", "echo 'error output' >&2; echo 'normal output'")
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
t.Errorf("expected exit code 0, got %d", result.ExitCode)
|
||||
}
|
||||
if strings.TrimSpace(result.Stderr) != "error output" {
|
||||
t.Errorf("expected stderr 'error output', got %q", result.Stderr)
|
||||
}
|
||||
if strings.TrimSpace(result.Stdout) != "normal output" {
|
||||
t.Errorf("expected stdout 'normal output', got %q", result.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxInterfaceSatisfied(t *testing.T) {
|
||||
// Compile-time check
|
||||
var ps Sandbox = NewProcessSandbox()
|
||||
if ps == nil {
|
||||
t.Fatal("ProcessSandbox does not satisfy Sandbox interface")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkingDirPath(t *testing.T) {
|
||||
ps := NewProcessSandbox()
|
||||
path := ps.WorkingDirPath()
|
||||
if !filepath.IsAbs(path) {
|
||||
t.Errorf("expected absolute path, got %q", path)
|
||||
}
|
||||
}
|
||||
28
pkg/sandbox/sandbox.go
Normal file
28
pkg/sandbox/sandbox.go
Normal file
@ -0,0 +1,28 @@
|
||||
// Package sandbox provides a secure execution environment for running commands.
|
||||
//
|
||||
// The sandbox restricts resource usage (timeout, output size, working directory)
|
||||
// and environment variable access to prevent runaway or malicious commands.
|
||||
// This is the execution backend used by the Tool system's built-in exec tool.
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Result holds the output and exit status of a sandboxed command execution.
|
||||
type Result struct {
|
||||
Stdout string `json:"stdout"`
|
||||
Stderr string `json:"stderr"`
|
||||
ExitCode int `json:"exit_code"`
|
||||
}
|
||||
|
||||
// Sandbox defines the interface for command execution environments.
|
||||
//
|
||||
// Implementations may use OS processes (os/exec), containers, or other
|
||||
// isolation mechanisms. The context controls cancellation and timeouts.
|
||||
type Sandbox interface {
|
||||
// Execute runs a command with the given arguments inside the sandbox.
|
||||
// The context can be used to set timeouts or cancel the execution.
|
||||
// Returns the combined output, error output, and exit code.
|
||||
Execute(ctx context.Context, cmd string, args ...string) (*Result, error)
|
||||
}
|
||||
190
pkg/session/jsonl.go
Normal file
190
pkg/session/jsonl.go
Normal file
@ -0,0 +1,190 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// JSONLStore implements the Store interface using JSONL files.
|
||||
//
|
||||
// Each session is stored in a separate file named {session_id}.jsonl
|
||||
// under the configured storage directory. Every line in the file is a
|
||||
// JSON-encoded SessionMessage. New messages are appended in O(1) time.
|
||||
type JSONLStore struct {
|
||||
storageDir string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewJSONLStore creates a new JSONLStore with the given storage directory.
|
||||
// The directory is created if it does not exist.
|
||||
func NewJSONLStore(storageDir string) (*JSONLStore, error) {
|
||||
if err := os.MkdirAll(storageDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create session storage directory %q: %w", storageDir, err)
|
||||
}
|
||||
return &JSONLStore{storageDir: storageDir}, nil
|
||||
}
|
||||
|
||||
// path returns the full file path for the given session ID.
|
||||
func (s *JSONLStore) path(sessionID string) string {
|
||||
return filepath.Join(s.storageDir, sessionID+".jsonl")
|
||||
}
|
||||
|
||||
// archivePath returns the archive file path for the given session ID.
|
||||
func (s *JSONLStore) archivePath(sessionID string) string {
|
||||
return filepath.Join(s.storageDir, sessionID+".jsonl.archived")
|
||||
}
|
||||
|
||||
// Save appends a message to a session's JSONL file.
|
||||
// If the file does not exist, it is created.
|
||||
// This is an O(1) append operation.
|
||||
func (s *JSONLStore) Save(sessionID string, msg SessionMessage) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
f, err := os.OpenFile(s.path(sessionID), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open session file for %q: %w", sessionID, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session message: %w", err)
|
||||
}
|
||||
|
||||
if _, err := f.Write(append(data, '\n')); err != nil {
|
||||
return fmt.Errorf("failed to write session message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves all messages for a session in chronological order.
|
||||
// Returns an error if the session file does not exist.
|
||||
func (s *JSONLStore) Load(sessionID string) ([]SessionMessage, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
data, err := os.ReadFile(s.path(sessionID))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Check archive
|
||||
archiveData, archiveErr := os.ReadFile(s.archivePath(sessionID))
|
||||
if archiveErr != nil {
|
||||
return nil, fmt.Errorf("session %q not found", sessionID)
|
||||
}
|
||||
data = archiveData
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to read session file for %q: %w", sessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return parseJSONL(data)
|
||||
}
|
||||
|
||||
// parseJSONL parses a JSONL byte slice into a slice of SessionMessage.
|
||||
func parseJSONL(data []byte) ([]SessionMessage, error) {
|
||||
var messages []SessionMessage
|
||||
trimmed := strings.TrimSpace(string(data))
|
||||
if trimmed == "" {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
lines := strings.Split(trimmed, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var msg SessionMessage
|
||||
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal session message: %w", err)
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// List returns all session IDs by scanning the storage directory.
|
||||
func (s *JSONLStore) List() ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entries, err := os.ReadDir(s.storageDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read storage directory %q: %w", s.storageDir, err)
|
||||
}
|
||||
|
||||
var sessions []string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".jsonl") && !strings.HasSuffix(name, ".archived") {
|
||||
sessions = append(sessions, strings.TrimSuffix(name, ".jsonl"))
|
||||
}
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// Exists checks whether a session file exists (active or archived).
|
||||
func (s *JSONLStore) Exists(sessionID string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if _, err := os.Stat(s.path(sessionID)); err == nil {
|
||||
return true, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("failed to check session %q: %w", sessionID, err)
|
||||
}
|
||||
|
||||
// Check archive
|
||||
if _, err := os.Stat(s.archivePath(sessionID)); err == nil {
|
||||
return true, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("failed to check archived session %q: %w", sessionID, err)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Archive moves a session file to the archived state by renaming it.
|
||||
func (s *JSONLStore) Archive(sessionID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := os.Rename(s.path(sessionID), s.archivePath(sessionID)); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("session %q not found", sessionID)
|
||||
}
|
||||
return fmt.Errorf("failed to archive session %q: %w", sessionID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete permanently removes a session file and its archive.
|
||||
func (s *JSONLStore) Delete(sessionID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
|
||||
// Remove active file
|
||||
if err := os.Remove(s.path(sessionID)); err != nil && !os.IsNotExist(err) {
|
||||
lastErr = fmt.Errorf("failed to delete session %q: %w", sessionID, err)
|
||||
}
|
||||
|
||||
// Also remove archived file if it exists
|
||||
if err := os.Remove(s.archivePath(sessionID)); err != nil && !os.IsNotExist(err) {
|
||||
lastErr = fmt.Errorf("failed to delete archived session %q: %w", sessionID, err)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// StorageDir returns the storage directory path.
|
||||
func (s *JSONLStore) StorageDir() string {
|
||||
return s.storageDir
|
||||
}
|
||||
198
pkg/session/manager.go
Normal file
198
pkg/session/manager.go
Normal file
@ -0,0 +1,198 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
)
|
||||
|
||||
// Manager provides high-level session lifecycle operations.
|
||||
//
|
||||
// It wraps a Store with caching, context window management, and
|
||||
// event publishing on the message bus.
|
||||
type Manager struct {
|
||||
store Store
|
||||
bus bus.MessageBus
|
||||
cache map[string]*Session
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new session Manager with the given store and optional message bus.
|
||||
func NewManager(store Store, mb bus.MessageBus) *Manager {
|
||||
return &Manager{
|
||||
store: store,
|
||||
bus: mb,
|
||||
cache: make(map[string]*Session),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new session with the given ID and optional metadata.
|
||||
func (m *Manager) CreateSession(id string, metadata map[string]string) (*Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.cache[id]; exists {
|
||||
return nil, fmt.Errorf("session %q already exists", id)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
session := &Session{
|
||||
ID: id,
|
||||
Status: SessionActive,
|
||||
Messages: make([]SessionMessage, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
m.cache[id] = session
|
||||
|
||||
// Publish session created event
|
||||
if m.bus != nil {
|
||||
m.bus.Publish("session.created", bus.Message{
|
||||
ID: "session-" + id,
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: "session.manager",
|
||||
Content: map[string]interface{}{"session_id": id},
|
||||
})
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID, checking the cache and then the store.
|
||||
func (m *Manager) GetSession(id string) (*Session, error) {
|
||||
m.mu.RLock()
|
||||
session, ok := m.cache[id]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if ok {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Try to load from store
|
||||
messages, err := m.store.Load(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load session %q: %w", id, err)
|
||||
}
|
||||
|
||||
// Check if we can determine created/updated timestamps from messages
|
||||
var createdAt, updatedAt time.Time
|
||||
if len(messages) > 0 {
|
||||
createdAt = messages[0].Timestamp
|
||||
updatedAt = messages[len(messages)-1].Timestamp
|
||||
}
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now()
|
||||
}
|
||||
if updatedAt.IsZero() {
|
||||
updatedAt = time.Now()
|
||||
}
|
||||
|
||||
session = &Session{
|
||||
ID: id,
|
||||
Status: SessionActive,
|
||||
Messages: messages,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.cache[id] = session
|
||||
m.mu.Unlock()
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// AddMessage appends a message to a session and persists it.
|
||||
func (m *Manager) AddMessage(sessionID string, role MessageRole, content string, metadata map[string]string) (*SessionMessage, error) {
|
||||
msg := SessionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
Timestamp: time.Now(),
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
if err := m.store.Save(sessionID, msg); err != nil {
|
||||
return nil, fmt.Errorf("failed to save message to session %q: %w", sessionID, err)
|
||||
}
|
||||
|
||||
// Upsert cache
|
||||
m.mu.Lock()
|
||||
if session, ok := m.cache[sessionID]; ok {
|
||||
session.Messages = append(session.Messages, msg)
|
||||
session.UpdatedAt = msg.Timestamp
|
||||
} else {
|
||||
m.cache[sessionID] = &Session{
|
||||
ID: sessionID,
|
||||
Status: SessionActive,
|
||||
Messages: []SessionMessage{msg},
|
||||
CreatedAt: msg.Timestamp,
|
||||
UpdatedAt: msg.Timestamp,
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// GetContext returns the most recent N messages in a session.
|
||||
// If windowSize <= 0 or >= total messages, all messages are returned.
|
||||
func (m *Manager) GetContext(sessionID string, windowSize int) ([]SessionMessage, error) {
|
||||
session, err := m.GetSession(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages := session.Messages
|
||||
if windowSize > 0 && windowSize < len(messages) {
|
||||
return messages[len(messages)-windowSize:], nil
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// ArchiveSession archives a session, making it read-only.
|
||||
func (m *Manager) ArchiveSession(id string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if session, ok := m.cache[id]; ok {
|
||||
session.Status = SessionArchived
|
||||
}
|
||||
|
||||
if err := m.store.Archive(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if m.bus != nil {
|
||||
m.bus.Publish("session.archived", bus.Message{
|
||||
ID: "session-" + id,
|
||||
Type: bus.MsgTypeSystem,
|
||||
From: "session.manager",
|
||||
Content: map[string]interface{}{"session_id": id},
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSession permanently removes a session.
|
||||
func (m *Manager) DeleteSession(id string) error {
|
||||
m.mu.Lock()
|
||||
delete(m.cache, id)
|
||||
m.mu.Unlock()
|
||||
return m.store.Delete(id)
|
||||
}
|
||||
|
||||
// ListSessions returns all known session IDs.
|
||||
func (m *Manager) ListSessions() ([]string, error) {
|
||||
return m.store.List()
|
||||
}
|
||||
|
||||
// Store returns the underlying Store.
|
||||
func (m *Manager) Store() Store {
|
||||
return m.store
|
||||
}
|
||||
550
pkg/session/session_test.go
Normal file
550
pkg/session/session_test.go
Normal file
@ -0,0 +1,550 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orca/orca/pkg/bus"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// JSONL Store Tests
|
||||
// ============================================================
|
||||
|
||||
func setupTestStore(t *testing.T) (*JSONLStore, func()) {
|
||||
t.Helper()
|
||||
dir, err := os.MkdirTemp("", "orca-session-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
os.RemoveAll(dir)
|
||||
t.Fatalf("NewJSONLStore failed: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
return store, cleanup
|
||||
}
|
||||
|
||||
func TestNewJSONLStore(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("NewJSONLStore returned nil")
|
||||
}
|
||||
if store.StorageDir() == "" {
|
||||
t.Error("StorageDir should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreSaveAndLoad(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
msg := SessionMessage{
|
||||
Role: RoleUser,
|
||||
Content: "Hello, world!",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
if err := store.Save("session-1", msg); err != nil {
|
||||
t.Fatalf("Save failed: %v", err)
|
||||
}
|
||||
|
||||
messages, err := store.Load("session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != RoleUser {
|
||||
t.Errorf("expected RoleUser, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[0].Content != "Hello, world!" {
|
||||
t.Errorf("expected content 'Hello, world!', got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreAppendMultiple(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
roles := []MessageRole{RoleUser, RoleAssistant, RoleUser, RoleSystem}
|
||||
for i, role := range roles {
|
||||
msg := SessionMessage{
|
||||
Role: role,
|
||||
Content: "message " + string(rune('0'+i)),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
if err := store.Save("session-append", msg); err != nil {
|
||||
t.Fatalf("Save %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
messages, err := store.Load("session-append")
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if len(messages) != len(roles) {
|
||||
t.Fatalf("expected %d messages, got %d", len(roles), len(messages))
|
||||
}
|
||||
for i, msg := range messages {
|
||||
if msg.Role != roles[i] {
|
||||
t.Errorf("message %d: expected role %s, got %s", i, roles[i], msg.Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreLoadNonexistent(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := store.Load("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error loading nonexistent session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreExists(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
exists, err := store.Exists("nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists failed: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("expected nonexistent session to return false")
|
||||
}
|
||||
|
||||
store.Save("session-exists", SessionMessage{Role: RoleUser, Content: "test"})
|
||||
exists, err = store.Exists("session-exists")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("expected existing session to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreList(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ids := []string{"sess-a", "sess-b", "sess-c"}
|
||||
for _, id := range ids {
|
||||
store.Save(id, SessionMessage{Role: RoleUser, Content: "test"})
|
||||
}
|
||||
|
||||
list, err := store.List()
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
if len(list) != len(ids) {
|
||||
t.Fatalf("expected %d sessions, got %d", len(ids), len(list))
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, id := range list {
|
||||
found[id] = true
|
||||
}
|
||||
for _, id := range ids {
|
||||
if !found[id] {
|
||||
t.Errorf("missing session %q in list", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreArchiveAndLoad(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
msg := SessionMessage{Role: RoleUser, Content: "archive test"}
|
||||
store.Save("sess-archive", msg)
|
||||
|
||||
if err := store.Archive("sess-archive"); err != nil {
|
||||
t.Fatalf("Archive failed: %v", err)
|
||||
}
|
||||
|
||||
// Should still be loadable (archived files are in .archived suffix)
|
||||
messages, err := store.Load("sess-archive")
|
||||
if err != nil {
|
||||
t.Fatalf("Load after archive failed: %v", err)
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
t.Errorf("expected 1 message after archive, got %d", len(messages))
|
||||
}
|
||||
|
||||
// Should not appear in List
|
||||
list, _ := store.List()
|
||||
for _, id := range list {
|
||||
if id == "sess-archive" {
|
||||
t.Error("archived session should not appear in List")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreArchiveNonexistent(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
err := store.Archive("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error archiving nonexistent session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreDelete(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
store.Save("sess-delete", SessionMessage{Role: RoleUser, Content: "delete me"})
|
||||
if err := store.Delete("sess-delete"); err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists("sess-delete")
|
||||
if exists {
|
||||
t.Error("expected deleted session to not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreDeleteNonexistent(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
err := store.Delete("nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete nonexistent should succeed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreConcurrentWrites(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
for i := 0; i < 50; i++ {
|
||||
store.Save("concurrent", SessionMessage{Role: RoleUser, Content: "from-a"})
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
for i := 0; i < 50; i++ {
|
||||
store.Save("concurrent", SessionMessage{Role: RoleAssistant, Content: "from-b"})
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
|
||||
messages, err := store.Load("concurrent")
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if len(messages) != 100 {
|
||||
t.Errorf("expected 100 messages, got %d", len(messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLStoreEmptyFile(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := store.StorageDir()
|
||||
// Create an empty file
|
||||
f, _ := os.Create(filepath.Join(dir, "empty.jsonl"))
|
||||
f.Close()
|
||||
|
||||
messages, err := store.Load("empty")
|
||||
if err != nil {
|
||||
t.Fatalf("Load empty session failed: %v", err)
|
||||
}
|
||||
if len(messages) != 0 {
|
||||
t.Errorf("expected 0 messages from empty file, got %d", len(messages))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Session Manager Tests
|
||||
// ============================================================
|
||||
|
||||
func setupTestManager(t *testing.T) (*Manager, func()) {
|
||||
t.Helper()
|
||||
store, cleanup := setupTestStore(t)
|
||||
mb := bus.New()
|
||||
mgr := NewManager(store, mb)
|
||||
return mgr, func() {
|
||||
mb.Close()
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
if mgr == nil {
|
||||
t.Fatal("NewManager returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCreateSession(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
session, err := mgr.CreateSession("sess-1", map[string]string{"key": "value"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession failed: %v", err)
|
||||
}
|
||||
if session.ID != "sess-1" {
|
||||
t.Errorf("expected ID 'sess-1', got %q", session.ID)
|
||||
}
|
||||
if session.Status != SessionActive {
|
||||
t.Errorf("expected SessionActive, got %s", session.Status)
|
||||
}
|
||||
if session.Metadata["key"] != "value" {
|
||||
t.Errorf("expected metadata key 'value', got %q", session.Metadata["key"])
|
||||
}
|
||||
if session.MessageCount() != 0 {
|
||||
t.Errorf("expected 0 messages, got %d", session.MessageCount())
|
||||
}
|
||||
if session.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCreateDuplicate(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
mgr.CreateSession("dup", nil)
|
||||
_, err := mgr.CreateSession("dup", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error creating duplicate session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddMessage(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
msg, err := mgr.AddMessage("sess-add", RoleUser, "Hello!", map[string]string{"source": "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage failed: %v", err)
|
||||
}
|
||||
if msg.Role != RoleUser {
|
||||
t.Errorf("expected RoleUser, got %s", msg.Role)
|
||||
}
|
||||
if msg.Content != "Hello!" {
|
||||
t.Errorf("expected 'Hello!', got %q", msg.Content)
|
||||
}
|
||||
if msg.Metadata["source"] != "test" {
|
||||
t.Errorf("expected metadata source 'test', got %q", msg.Metadata["source"])
|
||||
}
|
||||
if msg.Timestamp.IsZero() {
|
||||
t.Error("Timestamp should not be zero")
|
||||
}
|
||||
|
||||
// Verify it was persisted
|
||||
messages, _ := mgr.GetContext("sess-add", 10)
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddMessageAutoCreatesSession(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
mgr.AddMessage("auto-session", RoleUser, "auto create", nil)
|
||||
|
||||
session, err := mgr.GetSession("auto-session")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
}
|
||||
if session.MessageCount() != 1 {
|
||||
t.Errorf("expected 1 message, got %d", session.MessageCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGetContextWindow(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
// Add 10 messages
|
||||
for i := 0; i < 10; i++ {
|
||||
mgr.AddMessage("window-test", RoleUser, "msg", nil)
|
||||
}
|
||||
|
||||
// Get last 3
|
||||
messages, err := mgr.GetContext("window-test", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("GetContext failed: %v", err)
|
||||
}
|
||||
if len(messages) != 3 {
|
||||
t.Errorf("expected 3 messages, got %d", len(messages))
|
||||
}
|
||||
|
||||
// Get all (window larger than total)
|
||||
all, _ := mgr.GetContext("window-test", 100)
|
||||
if len(all) != 10 {
|
||||
t.Errorf("expected 10 messages, got %d", len(all))
|
||||
}
|
||||
|
||||
// Get with window <= 0
|
||||
all2, _ := mgr.GetContext("window-test", 0)
|
||||
if len(all2) != 10 {
|
||||
t.Errorf("expected 10 messages with window=0, got %d", len(all2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGetContextNonexistent(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := mgr.GetContext("nonexistent", 10)
|
||||
if err == nil {
|
||||
t.Error("expected error getting context for nonexistent session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerArchiveSession(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
mgr.CreateSession("archivable", nil)
|
||||
mgr.AddMessage("archivable", RoleUser, "test", nil)
|
||||
|
||||
if err := mgr.ArchiveSession("archivable"); err != nil {
|
||||
t.Fatalf("ArchiveSession failed: %v", err)
|
||||
}
|
||||
|
||||
session, _ := mgr.GetSession("archivable")
|
||||
if session.Status != SessionArchived {
|
||||
t.Errorf("expected SessionArchived, got %s", session.Status)
|
||||
}
|
||||
if session.IsArchived() != true {
|
||||
t.Error("expected IsArchived to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerDeleteSession(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
mgr.CreateSession("deletable", nil)
|
||||
if err := mgr.DeleteSession("deletable"); err != nil {
|
||||
t.Fatalf("DeleteSession failed: %v", err)
|
||||
}
|
||||
|
||||
_, err := mgr.GetSession("deletable")
|
||||
if err == nil {
|
||||
t.Error("expected error getting deleted session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerListSessions(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
mgr.AddMessage("list-a", RoleUser, "a", nil)
|
||||
mgr.AddMessage("list-b", RoleUser, "b", nil)
|
||||
|
||||
sessions, err := mgr.ListSessions()
|
||||
if err != nil {
|
||||
t.Fatalf("ListSessions failed: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Errorf("expected 2 sessions, got %d", len(sessions))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerMultipleMessagesOrder(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
contents := []string{"first", "second", "third"}
|
||||
for i, c := range contents {
|
||||
mgr.AddMessage("order-test", RoleUser, c, nil)
|
||||
_ = i
|
||||
}
|
||||
|
||||
messages, _ := mgr.GetContext("order-test", 10)
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Content != "first" {
|
||||
t.Errorf("expected first message content 'first', got %q", messages[0].Content)
|
||||
}
|
||||
if messages[2].Content != "third" {
|
||||
t.Errorf("expected third message content 'third', got %q", messages[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerStoreAccess(t *testing.T) {
|
||||
mgr, cleanup := setupTestManager(t)
|
||||
defer cleanup()
|
||||
|
||||
store := mgr.Store()
|
||||
if store == nil {
|
||||
t.Error("Store() should not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Session Types Tests
|
||||
// ============================================================
|
||||
|
||||
func TestSessionIsArchived(t *testing.T) {
|
||||
s := &Session{Status: SessionActive}
|
||||
if s.IsArchived() {
|
||||
t.Error("active session should not be archived")
|
||||
}
|
||||
|
||||
s.Status = SessionArchived
|
||||
if !s.IsArchived() {
|
||||
t.Error("archived session should be archived")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionMessageCount(t *testing.T) {
|
||||
s := &Session{Messages: make([]SessionMessage, 5)}
|
||||
if n := s.MessageCount(); n != 5 {
|
||||
t.Errorf("expected 5 messages, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageRoleConstants(t *testing.T) {
|
||||
if RoleUser != "user" {
|
||||
t.Errorf("expected RoleUser 'user', got %q", RoleUser)
|
||||
}
|
||||
if RoleAssistant != "assistant" {
|
||||
t.Errorf("expected RoleAssistant 'assistant', got %q", RoleAssistant)
|
||||
}
|
||||
if RoleSystem != "system" {
|
||||
t.Errorf("expected RoleSystem 'system', got %q", RoleSystem)
|
||||
}
|
||||
if RoleTool != "tool" {
|
||||
t.Errorf("expected RoleTool 'tool', got %q", RoleTool)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStatusConstants(t *testing.T) {
|
||||
if SessionActive != "active" {
|
||||
t.Errorf("expected SessionActive 'active', got %q", SessionActive)
|
||||
}
|
||||
if SessionArchived != "archived" {
|
||||
t.Errorf("expected SessionArchived 'archived', got %q", SessionArchived)
|
||||
}
|
||||
}
|
||||
28
pkg/session/store.go
Normal file
28
pkg/session/store.go
Normal file
@ -0,0 +1,28 @@
|
||||
package session
|
||||
|
||||
// Store defines the persistence interface for session message storage.
|
||||
//
|
||||
// Implementations must be safe for concurrent use. The default implementation
|
||||
// uses JSONL files (one file per session) with O(1) append writes.
|
||||
type Store interface {
|
||||
// Save appends a single message to a session's history.
|
||||
// Creates the session file if it does not exist.
|
||||
Save(sessionID string, msg SessionMessage) error
|
||||
|
||||
// Load retrieves all messages for a session in chronological order.
|
||||
// Returns an error if the session does not exist.
|
||||
Load(sessionID string) ([]SessionMessage, error)
|
||||
|
||||
// List returns all known session IDs.
|
||||
List() ([]string, error)
|
||||
|
||||
// Exists checks whether a session exists in the store.
|
||||
Exists(sessionID string) (bool, error)
|
||||
|
||||
// Archive marks a session as archived (read-only).
|
||||
// This is a soft delete that preserves the data.
|
||||
Archive(sessionID string) error
|
||||
|
||||
// Delete removes a session permanently from the store.
|
||||
Delete(sessionID string) error
|
||||
}
|
||||
60
pkg/session/types.go
Normal file
60
pkg/session/types.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Package session provides conversation session management for the Orca framework.
|
||||
//
|
||||
// Sessions persist conversation history and provide context-window-based
|
||||
// retrieval for LLM interactions. The default storage backend uses JSONL
|
||||
// files with O(1) append writes.
|
||||
package session
|
||||
|
||||
import "time"
|
||||
|
||||
// MessageRole represents the role of a message sender in a session.
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
// RoleUser represents a human user message.
|
||||
RoleUser MessageRole = "user"
|
||||
// RoleAssistant represents an AI assistant message.
|
||||
RoleAssistant MessageRole = "assistant"
|
||||
// RoleSystem represents a system-level message.
|
||||
RoleSystem MessageRole = "system"
|
||||
// RoleTool represents a tool execution result.
|
||||
RoleTool MessageRole = "tool"
|
||||
)
|
||||
|
||||
// SessionMessage represents a single message entry in a session's history.
|
||||
type SessionMessage struct {
|
||||
Role MessageRole `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// SessionStatus represents the lifecycle status of a session.
|
||||
type SessionStatus string
|
||||
|
||||
const (
|
||||
// SessionActive indicates an active, in-use session.
|
||||
SessionActive SessionStatus = "active"
|
||||
// SessionArchived indicates an archived (read-only) session.
|
||||
SessionArchived SessionStatus = "archived"
|
||||
)
|
||||
|
||||
// Session represents a conversation session with full history.
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
Status SessionStatus `json:"status"`
|
||||
Messages []SessionMessage `json:"messages,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// IsArchived returns true if the session has been archived.
|
||||
func (s *Session) IsArchived() bool {
|
||||
return s.Status == SessionArchived
|
||||
}
|
||||
|
||||
// MessageCount returns the number of messages in the session.
|
||||
func (s *Session) MessageCount() int {
|
||||
return len(s.Messages)
|
||||
}
|
||||
197
pkg/skill/manager.go
Normal file
197
pkg/skill/manager.go
Normal file
@ -0,0 +1,197 @@
|
||||
package skill
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultSkillsDir is the default directory for user-installed skills.
|
||||
DefaultSkillsDir = "~/.agents/skills"
|
||||
|
||||
// SkillManifestFile is the name of the skill manifest file.
|
||||
SkillManifestFile = "SKILL.md"
|
||||
)
|
||||
|
||||
// Manager is a thread-safe registry for loading, storing, and querying Skills.
|
||||
//
|
||||
// Skills are loaded from a directory tree where each subdirectory containing
|
||||
// a SKILL.md file is treated as a skill. The Manager automatically discovers
|
||||
// skills on initialization and provides methods for finding skills by trigger
|
||||
// keywords or by name.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
skillsDir string
|
||||
skills map[string]*Skill
|
||||
}
|
||||
|
||||
// NewManager creates a new Skill manager that scans the given directory for skills.
|
||||
// If skillsDir is empty, DefaultSkillsDir is used.
|
||||
func NewManager(skillsDir string) *Manager {
|
||||
if skillsDir == "" {
|
||||
skillsDir = DefaultSkillsDir
|
||||
}
|
||||
// Expand ~ to home directory
|
||||
skillsDir = expandHome(skillsDir)
|
||||
|
||||
return &Manager{
|
||||
skillsDir: skillsDir,
|
||||
skills: make(map[string]*Skill),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAll scans the skills directory and loads all skills found.
|
||||
// It returns the number of skills loaded and any errors encountered.
|
||||
func (m *Manager) LoadAll() (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Clear existing skills
|
||||
m.skills = make(map[string]*Skill)
|
||||
|
||||
// Check if skills directory exists
|
||||
info, err := os.Stat(m.skillsDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return 0, nil // No skills directory yet — not an error
|
||||
}
|
||||
return 0, fmt.Errorf("skill: cannot access skills directory %q: %w", m.skillsDir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return 0, fmt.Errorf("skill: %q is not a directory", m.skillsDir)
|
||||
}
|
||||
|
||||
// Read all entries in the skills directory
|
||||
entries, err := os.ReadDir(m.skillsDir)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("skill: failed to read skills directory %q: %w", m.skillsDir, err)
|
||||
}
|
||||
|
||||
var loadErrors []string
|
||||
loaded := 0
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(m.skillsDir, entry.Name())
|
||||
skillPath := filepath.Join(skillDir, SkillManifestFile)
|
||||
|
||||
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
|
||||
continue // No SKILL.md in this directory — skip
|
||||
}
|
||||
|
||||
skill, err := ParseSkillFile(skillPath)
|
||||
if err != nil {
|
||||
loadErrors = append(loadErrors, fmt.Sprintf("%s: %v", entry.Name(), err))
|
||||
continue
|
||||
}
|
||||
|
||||
m.skills[skill.Name] = skill
|
||||
loaded++
|
||||
}
|
||||
|
||||
if len(loadErrors) > 0 {
|
||||
return loaded, fmt.Errorf("skill: loaded %d skills with %d errors: %s",
|
||||
loaded, len(loadErrors), joinStrings(loadErrors, "; "))
|
||||
}
|
||||
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// GetSkill retrieves a skill by its name. Returns false if not found.
|
||||
func (m *Manager) GetSkill(name string) (*Skill, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
skill, ok := m.skills[name]
|
||||
return skill, ok
|
||||
}
|
||||
|
||||
// ListSkills returns all loaded skills sorted by name.
|
||||
func (m *Manager) ListSkills() []*Skill {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make([]*Skill, 0, len(m.skills))
|
||||
for _, skill := range m.skills {
|
||||
result = append(result, skill)
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Name < result[j].Name
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// FindSkill finds skills whose triggers match the given query string.
|
||||
// Returns all matching skills sorted by relevance (more trigger matches first).
|
||||
func (m *Manager) FindSkill(query string) []*Skill {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var matches []*Skill
|
||||
for _, skill := range m.skills {
|
||||
if skill.MatchTrigger(query) {
|
||||
matches = append(matches, skill)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by number of matching triggers (descending)
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return countMatches(matches[i], query) > countMatches(matches[j], query)
|
||||
})
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
// SkillsDir returns the directory being scanned for skills.
|
||||
func (m *Manager) SkillsDir() string {
|
||||
return m.skillsDir
|
||||
}
|
||||
|
||||
// Reload refreshes all skills from disk.
|
||||
func (m *Manager) Reload() (int, error) {
|
||||
return m.LoadAll()
|
||||
}
|
||||
|
||||
// countMatches counts how many of the skill's triggers match the query.
|
||||
func countMatches(skill *Skill, query string) int {
|
||||
count := 0
|
||||
queryLower := strings.ToLower(query)
|
||||
for _, trigger := range skill.Triggers {
|
||||
if strings.Contains(queryLower, strings.ToLower(trigger)) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// expandHome replaces "~" with the user's home directory.
|
||||
func expandHome(path string) string {
|
||||
if len(path) > 0 && path[0] == '~' {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return path
|
||||
}
|
||||
return filepath.Join(home, path[1:])
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
// joinStrings joins a slice of strings with a separator.
|
||||
func joinStrings(parts []string, sep string) string {
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := parts[0]
|
||||
for _, p := range parts[1:] {
|
||||
result += sep + p
|
||||
}
|
||||
return result
|
||||
}
|
||||
309
pkg/skill/manager_test.go
Normal file
309
pkg/skill/manager_test.go
Normal file
@ -0,0 +1,309 @@
|
||||
package skill
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// createTestSkill creates a temporary SKILL.md file for testing.
|
||||
func createTestSkill(t *testing.T, dir, name, description string, triggers []string, body string) string {
|
||||
t.Helper()
|
||||
|
||||
skillDir := filepath.Join(dir, name)
|
||||
if err := os.MkdirAll(skillDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create skill dir: %v", err)
|
||||
}
|
||||
|
||||
triggerStr := "[]"
|
||||
if len(triggers) > 0 {
|
||||
quoted := make([]string, len(triggers))
|
||||
for i, tr := range triggers {
|
||||
quoted[i] = `"` + tr + `"`
|
||||
}
|
||||
triggerStr = "[" + strings.Join(quoted, ", ") + "]"
|
||||
}
|
||||
|
||||
manifest := "---\n"
|
||||
manifest += "name: " + name + "\n"
|
||||
manifest += "description: " + description + "\n"
|
||||
manifest += "triggers: " + triggerStr + "\n"
|
||||
manifest += "---\n\n"
|
||||
manifest += body
|
||||
|
||||
manifestPath := filepath.Join(skillDir, "SKILL.md")
|
||||
if err := os.WriteFile(manifestPath, []byte(manifest), 0644); err != nil {
|
||||
t.Fatalf("failed to write SKILL.md: %v", err)
|
||||
}
|
||||
|
||||
return manifestPath
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
m := NewManager("")
|
||||
if m == nil {
|
||||
t.Fatal("NewManager() returned nil")
|
||||
}
|
||||
if m.SkillsDir() == "" {
|
||||
t.Error("expected non-empty skills directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManagerWithCustomDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
m := NewManager(tmpDir)
|
||||
|
||||
if m.SkillsDir() != tmpDir {
|
||||
t.Errorf("expected skills dir %q, got %q", tmpDir, m.SkillsDir())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAllNoDirectory(t *testing.T) {
|
||||
tmpDir := filepath.Join(t.TempDir(), "nonexistent")
|
||||
m := NewManager(tmpDir)
|
||||
|
||||
count, err := m.LoadAll()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAll on nonexistent dir should not error: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Errorf("expected 0 skills, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAllWithSkills(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
createTestSkill(t, tmpDir, "skill-a", "Skill A", []string{"alpha", "a"}, "# Skill A\nContent")
|
||||
createTestSkill(t, tmpDir, "skill-b", "Skill B", []string{"beta", "b"}, "# Skill B\nContent")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
count, err := m.LoadAll()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAll failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 skills, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSkill(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
createTestSkill(t, tmpDir, "test-skill", "Test Skill", []string{"test"}, "# Test")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
skill, ok := m.GetSkill("test-skill")
|
||||
if !ok {
|
||||
t.Fatal("GetSkill returned false for existing skill")
|
||||
}
|
||||
if skill.Name != "test-skill" {
|
||||
t.Errorf("expected name 'test-skill', got %q", skill.Name)
|
||||
}
|
||||
if skill.Description != "Test Skill" {
|
||||
t.Errorf("expected description 'Test Skill', got %q", skill.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSkillNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
_, ok := m.GetSkill("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected false for nonexistent skill")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSkills(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
createTestSkill(t, tmpDir, "beta", "Beta", nil, "# Beta")
|
||||
createTestSkill(t, tmpDir, "alpha", "Alpha", nil, "# Alpha")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
skills := m.ListSkills()
|
||||
if len(skills) != 2 {
|
||||
t.Errorf("expected 2 skills, got %d", len(skills))
|
||||
}
|
||||
|
||||
// Should be sorted alphabetically
|
||||
if len(skills) >= 2 {
|
||||
if skills[0].Name != "alpha" {
|
||||
t.Errorf("expected first skill 'alpha', got %q", skills[0].Name)
|
||||
}
|
||||
if skills[1].Name != "beta" {
|
||||
t.Errorf("expected second skill 'beta', got %q", skills[1].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSkillsEmpty(t *testing.T) {
|
||||
m := NewManager(t.TempDir())
|
||||
m.LoadAll()
|
||||
|
||||
skills := m.ListSkills()
|
||||
if len(skills) != 0 {
|
||||
t.Errorf("expected empty list, got %d skills", len(skills))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSkill(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
createTestSkill(t, tmpDir, "browser", "Browser automation", []string{"browser", "navigate", "screenshot"}, "# Browser")
|
||||
createTestSkill(t, tmpDir, "memory", "Project memory", []string{"memory", "remember"}, "# Memory")
|
||||
createTestSkill(t, tmpDir, "convert", "File converter", []string{"convert", "pdf"}, "# Convert")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
// Find by trigger matching "browser"
|
||||
results := m.FindSkill("I need to use the browser to navigate")
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected at least 1 match for 'browser'")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, s := range results {
|
||||
if s.Name == "browser" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected 'browser' skill in results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSkillNoMatch(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
createTestSkill(t, tmpDir, "browser", "Browser", []string{"browser"}, "# Browser")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
results := m.FindSkill("completely unrelated query")
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 matches, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSkillCaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
createTestSkill(t, tmpDir, "browser", "Browser", []string{"Browser"}, "# Browser")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
results := m.FindSkill("browser")
|
||||
if len(results) != 1 {
|
||||
t.Errorf("expected 1 match for lowercase 'browser', got %d", len(results))
|
||||
}
|
||||
|
||||
results = m.FindSkill("BROWSER")
|
||||
if len(results) != 1 {
|
||||
t.Errorf("expected 1 match for uppercase 'BROWSER', got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSkillRelevanceOrder(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
createTestSkill(t, tmpDir, "multi-match", "Multiple matches", []string{"alpha", "beta", "gamma"}, "# Multi")
|
||||
createTestSkill(t, tmpDir, "single-match", "Single match", []string{"alpha"}, "# Single")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
// A query mentioning multiple triggers
|
||||
results := m.FindSkill("alpha beta")
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected 2 matches, got %d", len(results))
|
||||
}
|
||||
|
||||
// The multi-match skill should come first (more trigger matches)
|
||||
if len(results) >= 2 {
|
||||
if results[0].Name != "multi-match" {
|
||||
t.Errorf("expected 'multi-match' first (more relevance), got %q", results[0].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReload(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
createTestSkill(t, tmpDir, "initial", "Initial", []string{"init"}, "# Initial")
|
||||
|
||||
m := NewManager(tmpDir)
|
||||
m.LoadAll()
|
||||
|
||||
if len(m.ListSkills()) != 1 {
|
||||
t.Errorf("expected 1 skill after load, got %d", len(m.ListSkills()))
|
||||
}
|
||||
|
||||
// Add another skill
|
||||
createTestSkill(t, tmpDir, "added", "Added later", []string{"new"}, "# New")
|
||||
|
||||
count, err := m.Reload()
|
||||
if err != nil {
|
||||
t.Fatalf("Reload failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 skills after reload, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkillMatchTrigger(t *testing.T) {
|
||||
skill := &Skill{
|
||||
Name: "test",
|
||||
Triggers: []string{"browser", "navigate"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query string
|
||||
want bool
|
||||
}{
|
||||
{"I need to use the browser", true},
|
||||
{"navigate to a page", true},
|
||||
{"Browser automation", true},
|
||||
{"something unrelated", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := skill.MatchTrigger(tt.query)
|
||||
if got != tt.want {
|
||||
t.Errorf("MatchTrigger(%q) = %v, want %v", tt.query, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkillHasScripts(t *testing.T) {
|
||||
s1 := &Skill{Name: "no-scripts"}
|
||||
if s1.HasScripts() {
|
||||
t.Error("expected HasScripts() = false for empty scripts")
|
||||
}
|
||||
|
||||
s2 := &Skill{Name: "has-scripts", Scripts: []string{"script.sh"}}
|
||||
if !s2.HasScripts() {
|
||||
t.Error("expected HasScripts() = true for non-empty scripts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandHome(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get home dir: %v", err)
|
||||
}
|
||||
|
||||
result := expandHome("~/test/path")
|
||||
if !strings.HasPrefix(result, home) {
|
||||
t.Errorf("expected path starting with %q, got %q", home, result)
|
||||
}
|
||||
|
||||
// Non-tilde path should not change
|
||||
if expandHome("/absolute/path") != "/absolute/path" {
|
||||
t.Error("absolute path should not be modified")
|
||||
}
|
||||
}
|
||||
246
pkg/skill/parser.go
Normal file
246
pkg/skill/parser.go
Normal file
@ -0,0 +1,246 @@
|
||||
package skill
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FrontmatterDelimiters for YAML frontmatter in SKILL.md files.
|
||||
const (
|
||||
frontmatterDelim = "---"
|
||||
)
|
||||
|
||||
// ParseSkillFile parses a SKILL.md file and returns a populated Skill struct.
|
||||
//
|
||||
// The expected format is:
|
||||
//
|
||||
// ---
|
||||
// name: my-skill
|
||||
// description: Does something useful
|
||||
// triggers: ["keyword1", "keyword2"]
|
||||
// ---
|
||||
//
|
||||
// # My Skill
|
||||
//
|
||||
// Detailed description...
|
||||
func ParseSkillFile(path string) (*Skill, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skill: cannot read %q: %w", path, err)
|
||||
}
|
||||
|
||||
return ParseSkillData(path, data)
|
||||
}
|
||||
|
||||
// ParseSkillData parses SKILL.md content from raw bytes.
|
||||
// The path parameter is used to locate the scripts/ directory.
|
||||
func ParseSkillData(path string, data []byte) (*Skill, error) {
|
||||
content := string(data)
|
||||
|
||||
skill := &Skill{
|
||||
Path: path,
|
||||
Body: content,
|
||||
Triggers: []string{},
|
||||
}
|
||||
|
||||
// Parse YAML frontmatter
|
||||
rest, err := parseFrontmatter(content, skill)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
skill.Body = strings.TrimSpace(rest)
|
||||
|
||||
// Validate required fields
|
||||
if skill.Name == "" {
|
||||
return nil, fmt.Errorf("skill: %q is missing 'name' in frontmatter", path)
|
||||
}
|
||||
|
||||
// Discover scripts directory
|
||||
skillDir := filepath.Dir(path)
|
||||
scriptsDir := filepath.Join(skillDir, "scripts")
|
||||
skill.ScriptsDir = scriptsDir
|
||||
|
||||
if info, err := os.Stat(scriptsDir); err == nil && info.IsDir() {
|
||||
scripts, err := discoverScripts(scriptsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skill: failed to discover scripts in %q: %w", scriptsDir, err)
|
||||
}
|
||||
skill.Scripts = scripts
|
||||
}
|
||||
|
||||
return skill, nil
|
||||
}
|
||||
|
||||
// parseFrontmatter extracts YAML frontmatter delimited by "---" lines
|
||||
// and populates the Skill struct fields.
|
||||
func parseFrontmatter(content string, skill *Skill) (string, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
if !strings.HasPrefix(content, frontmatterDelim) {
|
||||
// No frontmatter — treat entire content as body
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Find the closing delimiter
|
||||
rest := content[len(frontmatterDelim):]
|
||||
rest = strings.TrimLeft(rest, "\n\r")
|
||||
|
||||
endIdx := strings.Index(rest, "\n"+frontmatterDelim)
|
||||
if endIdx < 0 {
|
||||
// Also check for end-of-file style
|
||||
endIdx = strings.Index(rest, frontmatterDelim)
|
||||
if endIdx < 0 {
|
||||
return "", fmt.Errorf("skill: unclosed frontmatter in skill file")
|
||||
}
|
||||
}
|
||||
|
||||
frontmatter := rest[:endIdx]
|
||||
body := rest[endIdx+len(frontmatterDelim)+1:]
|
||||
|
||||
// Parse the YAML frontmatter (simple key-value parser)
|
||||
if err := parseSimpleYAML(frontmatter, skill); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// parseSimpleYAML parses a simplified YAML format for skill frontmatter.
|
||||
// Supports: string values, quoted strings, and array values.
|
||||
func parseSimpleYAML(yaml string, skill *Skill) error {
|
||||
lines := strings.Split(yaml, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
colonIdx := strings.Index(line, ":")
|
||||
if colonIdx < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(line[:colonIdx])
|
||||
value := strings.TrimSpace(line[colonIdx+1:])
|
||||
|
||||
switch key {
|
||||
case "name":
|
||||
skill.Name = trimQuotes(value)
|
||||
|
||||
case "description":
|
||||
skill.Description = trimQuotes(value)
|
||||
|
||||
case "triggers":
|
||||
triggers, err := parseYAMLArray(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("skill: invalid triggers format: %w", err)
|
||||
}
|
||||
skill.Triggers = triggers
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseYAMLArray parses a YAML array like '["a", "b", "c"]' or '[a, b, c]'.
|
||||
func parseYAMLArray(value string) ([]string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// Handle YAML list format: ["a", "b"] or [a, b]
|
||||
if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") {
|
||||
inner := value[1 : len(value)-1]
|
||||
if strings.TrimSpace(inner) == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
parts := splitCommas(inner)
|
||||
result := make([]string, len(parts))
|
||||
for i, p := range parts {
|
||||
result[i] = trimQuotes(strings.TrimSpace(p))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Handle YAML list format with dashes:
|
||||
// triggers:
|
||||
// - browser
|
||||
// - navigate
|
||||
// (This would be handled line-by-line in a different flow)
|
||||
// For now, treat single value as a one-element list
|
||||
if value != "" && value != "[]" {
|
||||
return []string{trimQuotes(value)}, nil
|
||||
}
|
||||
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// splitCommas splits a comma-separated string respecting quoted sections.
|
||||
func splitCommas(s string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
inQuote := false
|
||||
quoteChar := byte(0)
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if inQuote {
|
||||
current.WriteByte(c)
|
||||
if c == quoteChar {
|
||||
inQuote = false
|
||||
}
|
||||
} else if c == '"' || c == '\'' {
|
||||
current.WriteByte(c)
|
||||
inQuote = true
|
||||
quoteChar = c
|
||||
} else if c == ',' {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteByte(c)
|
||||
}
|
||||
}
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// trimQuotes removes surrounding quotes from a string value.
|
||||
func trimQuotes(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) >= 2 {
|
||||
if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// discoverScripts lists all executable/readable files in a scripts directory.
|
||||
func discoverScripts(scriptsDir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(scriptsDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var scripts []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
scripts = append(scripts, entry.Name())
|
||||
}
|
||||
|
||||
sort.Strings(scripts)
|
||||
return scripts, nil
|
||||
}
|
||||
|
||||
// LoadSkillFromDir loads a skill from a directory containing a SKILL.md file.
|
||||
func LoadSkillFromDir(dir string) (*Skill, error) {
|
||||
skillPath := filepath.Join(dir, "SKILL.md")
|
||||
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("skill: no SKILL.md found in %q", dir)
|
||||
}
|
||||
return ParseSkillFile(skillPath)
|
||||
}
|
||||
56
pkg/skill/skill.go
Normal file
56
pkg/skill/skill.go
Normal file
@ -0,0 +1,56 @@
|
||||
// Package skill provides the Skill definition and management system.
|
||||
//
|
||||
// Skills are composable capabilities loaded from ~/.agents/skills/.
|
||||
// Each skill has a SKILL.md manifest with YAML frontmatter and optional
|
||||
// scripts in a scripts/ subdirectory. Skills can be discovered and
|
||||
// invoked by trigger keywords.
|
||||
package skill
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Skill represents a composable capability loaded from the skills directory.
|
||||
//
|
||||
// Each Skill is defined by a SKILL.md file with YAML frontmatter containing
|
||||
// metadata (name, description, triggers) and optional executable scripts
|
||||
// in a scripts/ subdirectory.
|
||||
type Skill struct {
|
||||
// Name is the unique identifier for this skill (e.g., "dev-browser").
|
||||
Name string `yaml:"name"`
|
||||
// Description is a human-readable explanation of what this skill does.
|
||||
Description string `yaml:"description"`
|
||||
// Triggers are keywords that activate this skill from natural language.
|
||||
Triggers []string `yaml:"triggers"`
|
||||
// Scripts is the list of script file names in the scripts/ directory.
|
||||
Scripts []string `yaml:"-"`
|
||||
// ScriptsDir is the absolute path to the scripts/ directory.
|
||||
ScriptsDir string `yaml:"-"`
|
||||
// Body is the markdown content after the YAML frontmatter.
|
||||
Body string `yaml:"-"`
|
||||
// Path is the absolute path to the SKILL.md file.
|
||||
Path string `yaml:"-"`
|
||||
}
|
||||
|
||||
// MatchTrigger checks if the given query matches any of the skill's triggers.
|
||||
// Matching is case-insensitive and supports partial matches.
|
||||
func (s *Skill) MatchTrigger(query string) bool {
|
||||
query = strings.ToLower(query)
|
||||
for _, trigger := range s.Triggers {
|
||||
if strings.Contains(strings.ToLower(query), strings.ToLower(trigger)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of the skill.
|
||||
func (s *Skill) String() string {
|
||||
return fmt.Sprintf("Skill{Name: %q, Triggers: %v, Scripts: %d}", s.Name, s.Triggers, len(s.Scripts))
|
||||
}
|
||||
|
||||
// HasScripts returns true if the skill has at least one script.
|
||||
func (s *Skill) HasScripts() bool {
|
||||
return len(s.Scripts) > 0
|
||||
}
|
||||
433
pkg/tool/builtin.go
Normal file
433
pkg/tool/builtin.go
Normal file
@ -0,0 +1,433 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/orca/orca/pkg/sandbox"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// exec — Execute a shell command via the sandbox
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// execTool runs shell commands through the ProcessSandbox.
|
||||
type execTool struct {
|
||||
sandbox sandbox.Sandbox
|
||||
}
|
||||
|
||||
// NewExecTool creates a new exec tool backed by the given sandbox.
|
||||
func NewExecTool(sb sandbox.Sandbox) Tool {
|
||||
if sb == nil {
|
||||
sb = sandbox.NewProcessSandbox()
|
||||
}
|
||||
return &execTool{sandbox: sb}
|
||||
}
|
||||
|
||||
func (t *execTool) Name() string { return "exec" }
|
||||
|
||||
func (t *execTool) Description() string {
|
||||
return "Execute a shell command and return its output. Use this for running scripts, " +
|
||||
"installing packages, compiling code, or any command-line operation."
|
||||
}
|
||||
|
||||
func (t *execTool) Parameters() map[string]ParameterSchema {
|
||||
return map[string]ParameterSchema{
|
||||
"command": {
|
||||
Type: "string",
|
||||
Description: "The shell command to execute (e.g., 'ls -la' or 'python script.py')",
|
||||
Required: true,
|
||||
},
|
||||
"timeout": {
|
||||
Type: "number",
|
||||
Description: "Timeout in seconds for the command execution (default: 30)",
|
||||
Required: false,
|
||||
Default: float64(30),
|
||||
},
|
||||
"workdir": {
|
||||
Type: "string",
|
||||
Description: "Working directory for the command (default: sandbox default)",
|
||||
Required: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *execTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) {
|
||||
cmdStr, ok := args["command"].(string)
|
||||
if !ok || cmdStr == "" {
|
||||
return ErrorResult("'command' argument is required and must be a string"), nil
|
||||
}
|
||||
|
||||
// Use a timeout if specified in args
|
||||
execCtx := ctx
|
||||
if timeoutVal, ok := args["timeout"]; ok {
|
||||
if timeout, err := toFloat64(timeoutVal); err == nil && timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
execCtx, cancel = context.WithTimeout(ctx, time.Duration(timeout*float64(time.Second)))
|
||||
defer cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Set working directory if specified
|
||||
sb := t.sandbox
|
||||
if wd, ok := args["workdir"].(string); ok && wd != "" {
|
||||
if ps, ok := sb.(*sandbox.ProcessSandbox); ok {
|
||||
ps.WorkingDir = wd
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the command via shell
|
||||
result, err := sb.Execute(execCtx, "sh", "-c", cmdStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exec tool: %w", err)
|
||||
}
|
||||
|
||||
return &Result{
|
||||
Success: result.ExitCode == 0,
|
||||
Data: map[string]interface{}{
|
||||
"stdout": result.Stdout,
|
||||
"stderr": result.Stderr,
|
||||
"exit_code": result.ExitCode,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// read_file — Read the contents of a file
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type readFileTool struct{}
|
||||
|
||||
func NewReadFileTool() Tool { return &readFileTool{} }
|
||||
|
||||
func (t *readFileTool) Name() string { return "read_file" }
|
||||
|
||||
func (t *readFileTool) Description() string {
|
||||
return "Read the contents of a file from the local filesystem. Returns the file content as a string."
|
||||
}
|
||||
|
||||
func (t *readFileTool) Parameters() map[string]ParameterSchema {
|
||||
return map[string]ParameterSchema{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Absolute path to the file to read",
|
||||
Required: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *readFileTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok || path == "" {
|
||||
return ErrorResult("'path' argument is required and must be a string"), nil
|
||||
}
|
||||
|
||||
// Prevent directory traversal / read of non-regular files
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("cannot access %q: %v", path, err)), nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
return ErrorResult(fmt.Sprintf("%q is a directory, not a file", path)), nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read %q: %v", path, err)), nil
|
||||
}
|
||||
|
||||
return SuccessResult(map[string]interface{}{
|
||||
"path": path,
|
||||
"content": string(data),
|
||||
"size": len(data),
|
||||
}), nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// write_file — Write content to a file
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type writeFileTool struct{}
|
||||
|
||||
func NewWriteFileTool() Tool { return &writeFileTool{} }
|
||||
|
||||
func (t *writeFileTool) Name() string { return "write_file" }
|
||||
|
||||
func (t *writeFileTool) Description() string {
|
||||
return "Write content to a file on the local filesystem. Creates parent directories if needed."
|
||||
}
|
||||
|
||||
func (t *writeFileTool) Parameters() map[string]ParameterSchema {
|
||||
return map[string]ParameterSchema{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Absolute path where the file should be written",
|
||||
Required: true,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: "The content to write to the file",
|
||||
Required: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *writeFileTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok || path == "" {
|
||||
return ErrorResult("'path' argument is required and must be a string"), nil
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("'content' argument is required and must be a string"), nil
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create directories for %q: %v", path, err)), nil
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write %q: %v", path, err)), nil
|
||||
}
|
||||
|
||||
return SuccessResult(map[string]interface{}{
|
||||
"path": path,
|
||||
"size": len(content),
|
||||
}), nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// list_dir — List the contents of a directory
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type listDirTool struct{}
|
||||
|
||||
func NewListDirTool() Tool { return &listDirTool{} }
|
||||
|
||||
func (t *listDirTool) Name() string { return "list_dir" }
|
||||
|
||||
func (t *listDirTool) Description() string {
|
||||
return "List files and directories in a given path. Returns names, sizes, and modification times."
|
||||
}
|
||||
|
||||
func (t *listDirTool) Parameters() map[string]ParameterSchema {
|
||||
return map[string]ParameterSchema{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Absolute path to the directory to list",
|
||||
Required: true,
|
||||
},
|
||||
"recursive": {
|
||||
Type: "boolean",
|
||||
Description: "Whether to list recursively (default: false)",
|
||||
Required: false,
|
||||
Default: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *listDirTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok || path == "" {
|
||||
return ErrorResult("'path' argument is required and must be a string"), nil
|
||||
}
|
||||
|
||||
recursive, _ := args["recursive"].(bool)
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("cannot access %q: %v", path, err)), nil
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return ErrorResult(fmt.Sprintf("%q is not a directory", path)), nil
|
||||
}
|
||||
|
||||
var entries []map[string]interface{}
|
||||
|
||||
if recursive {
|
||||
err = filepath.Walk(path, func(p string, fi os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rel, _ := filepath.Rel(path, p)
|
||||
if rel == "." {
|
||||
return nil
|
||||
}
|
||||
entries = append(entries, entryToMap(p, rel, fi))
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to list %q: %v", path, err)), nil
|
||||
}
|
||||
for _, f := range files {
|
||||
fi, err := f.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fullPath := filepath.Join(path, f.Name())
|
||||
entries = append(entries, entryToMap(fullPath, f.Name(), fi))
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to list %q: %v", path, err)), nil
|
||||
}
|
||||
|
||||
return SuccessResult(map[string]interface{}{
|
||||
"path": path,
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func entryToMap(fullPath, name string, fi os.FileInfo) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"name": name,
|
||||
"path": fullPath,
|
||||
"size": fi.Size(),
|
||||
"is_dir": fi.IsDir(),
|
||||
"mode": fi.Mode().String(),
|
||||
"modtime": fi.ModTime().Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// search_files — Search for content in files
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type searchFilesTool struct{}
|
||||
|
||||
func NewSearchFilesTool() Tool { return &searchFilesTool{} }
|
||||
|
||||
func (t *searchFilesTool) Name() string { return "search_files" }
|
||||
|
||||
func (t *searchFilesTool) Description() string {
|
||||
return "Search for a pattern in files within a directory. Supports simple substring matching."
|
||||
}
|
||||
|
||||
func (t *searchFilesTool) Parameters() map[string]ParameterSchema {
|
||||
return map[string]ParameterSchema{
|
||||
"pattern": {
|
||||
Type: "string",
|
||||
Description: "The text pattern to search for (substring match)",
|
||||
Required: true,
|
||||
},
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Directory to search in (default: current directory)",
|
||||
Required: false,
|
||||
Default: ".",
|
||||
},
|
||||
"include": {
|
||||
Type: "string",
|
||||
Description: "File glob pattern to include (e.g., '*.go', '*.{ts,tsx}')",
|
||||
Required: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *searchFilesTool) Execute(ctx context.Context, args map[string]interface{}) (*Result, error) {
|
||||
pattern, ok := args["pattern"].(string)
|
||||
if !ok || pattern == "" {
|
||||
return ErrorResult("'pattern' argument is required and must be a string"), nil
|
||||
}
|
||||
|
||||
searchPath := "."
|
||||
if p, ok := args["path"].(string); ok && p != "" {
|
||||
searchPath = p
|
||||
}
|
||||
|
||||
include, _ := args["include"].(string)
|
||||
|
||||
// Verify search path exists
|
||||
info, err := os.Stat(searchPath)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("cannot access search path %q: %v", searchPath, err)), nil
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return ErrorResult(fmt.Sprintf("%q is not a directory", searchPath)), nil
|
||||
}
|
||||
|
||||
var matches []map[string]interface{}
|
||||
|
||||
err = filepath.Walk(searchPath, func(p string, fi os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // skip files we can't access
|
||||
}
|
||||
if fi.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply include filter
|
||||
if include != "" {
|
||||
matched, err := filepath.Match(include, fi.Name())
|
||||
if err != nil || !matched {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Read file and search
|
||||
data, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
return nil // skip unreadable files
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if strings.Contains(content, pattern) {
|
||||
matches = append(matches, map[string]interface{}{
|
||||
"path": p,
|
||||
"size": len(data),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("search failed: %v", err)), nil
|
||||
}
|
||||
|
||||
return SuccessResult(map[string]interface{}{
|
||||
"pattern": pattern,
|
||||
"path": searchPath,
|
||||
"matches": matches,
|
||||
"count": len(matches),
|
||||
}), nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// toFloat64 converts an interface{} value to float64.
|
||||
// Supports float64, int, int64, and json.Number types.
|
||||
func toFloat64(v interface{}) (float64, error) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, nil
|
||||
case int:
|
||||
return float64(val), nil
|
||||
case int64:
|
||||
return float64(val), nil
|
||||
case json.Number:
|
||||
return val.Float64()
|
||||
default:
|
||||
return 0, fmt.Errorf("cannot convert %T to float64", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time interface checks.
|
||||
var _ Tool = (*execTool)(nil)
|
||||
var _ Tool = (*readFileTool)(nil)
|
||||
var _ Tool = (*writeFileTool)(nil)
|
||||
var _ Tool = (*listDirTool)(nil)
|
||||
var _ Tool = (*searchFilesTool)(nil)
|
||||
399
pkg/tool/builtin_test.go
Normal file
399
pkg/tool/builtin_test.go
Normal file
@ -0,0 +1,399 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/orca/orca/pkg/sandbox"
|
||||
)
|
||||
|
||||
func TestExecTool(t *testing.T) {
|
||||
sb := sandbox.NewProcessSandbox()
|
||||
execT := NewExecTool(sb)
|
||||
|
||||
if execT.Name() != "exec" {
|
||||
t.Errorf("expected name 'exec', got %q", execT.Name())
|
||||
}
|
||||
if execT.Description() == "" {
|
||||
t.Error("expected non-empty description")
|
||||
}
|
||||
|
||||
params := execT.Parameters()
|
||||
if _, ok := params["command"]; !ok {
|
||||
t.Error("expected 'command' parameter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecToolExecute(t *testing.T) {
|
||||
sb := sandbox.NewProcessSandbox()
|
||||
execT := NewExecTool(sb)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := execT.Execute(ctx, map[string]interface{}{
|
||||
"command": "echo hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Errorf("expected success, got error: %s", result.Error)
|
||||
}
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
stdout := data["stdout"].(string)
|
||||
if strings.TrimSpace(stdout) != "hello" {
|
||||
t.Errorf("expected stdout 'hello', got %q", stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecToolMissingCommand(t *testing.T) {
|
||||
sb := sandbox.NewProcessSandbox()
|
||||
execT := NewExecTool(sb)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := execT.Execute(ctx, map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error for invalid args: %v", err)
|
||||
}
|
||||
if result.Success {
|
||||
t.Error("expected failure for missing command")
|
||||
}
|
||||
if !strings.Contains(result.Error, "command") {
|
||||
t.Errorf("error should mention 'command', got: %s", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileTool(t *testing.T) {
|
||||
readT := NewReadFileTool()
|
||||
|
||||
if readT.Name() != "read_file" {
|
||||
t.Errorf("expected name 'read_file', got %q", readT.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileToolExecute(t *testing.T) {
|
||||
// Create a temp file
|
||||
tmpFile, err := os.CreateTemp("", "orca-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
content := "test content\nline 2\n"
|
||||
if _, err := tmpFile.WriteString(content); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
readT := NewReadFileTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := readT.Execute(ctx, map[string]interface{}{
|
||||
"path": tmpFile.Name(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Errorf("expected success, got error: %s", result.Error)
|
||||
}
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
gotContent := data["content"].(string)
|
||||
if gotContent != content {
|
||||
t.Errorf("expected content %q, got %q", content, gotContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileToolMissingPath(t *testing.T) {
|
||||
readT := NewReadFileTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := readT.Execute(ctx, map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error for invalid args: %v", err)
|
||||
}
|
||||
if result.Success {
|
||||
t.Error("expected failure for missing path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileToolNonexistent(t *testing.T) {
|
||||
readT := NewReadFileTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := readT.Execute(ctx, map[string]interface{}{
|
||||
"path": "/nonexistent/path/that/does/not/exist.txt",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error for missing file: %v", err)
|
||||
}
|
||||
if result.Success {
|
||||
t.Error("expected failure for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFileTool(t *testing.T) {
|
||||
writeT := NewWriteFileTool()
|
||||
|
||||
if writeT.Name() != "write_file" {
|
||||
t.Errorf("expected name 'write_file', got %q", writeT.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFileToolExecute(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "orca-write-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
testPath := filepath.Join(tmpDir, "nested", "test.txt")
|
||||
content := "hello world"
|
||||
|
||||
writeT := NewWriteFileTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := writeT.Execute(ctx, map[string]interface{}{
|
||||
"path": testPath,
|
||||
"content": content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Errorf("expected success, got error: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify the file was written
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read written file: %v", err)
|
||||
}
|
||||
if string(data) != content {
|
||||
t.Errorf("expected content %q, got %q", content, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFileToolMissingArgs(t *testing.T) {
|
||||
writeT := NewWriteFileTool()
|
||||
ctx := context.Background()
|
||||
|
||||
// Missing path
|
||||
result, err := writeT.Execute(ctx, map[string]interface{}{
|
||||
"content": "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error for invalid args: %v", err)
|
||||
}
|
||||
if result.Success {
|
||||
t.Error("expected failure for missing path")
|
||||
}
|
||||
|
||||
// Missing content
|
||||
result, err = writeT.Execute(ctx, map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error for invalid args: %v", err)
|
||||
}
|
||||
if result.Success {
|
||||
t.Error("expected failure for missing content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDirTool(t *testing.T) {
|
||||
listT := NewListDirTool()
|
||||
|
||||
if listT.Name() != "list_dir" {
|
||||
t.Errorf("expected name 'list_dir', got %q", listT.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDirToolExecute(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "orca-list-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create some test files
|
||||
os.WriteFile(filepath.Join(tmpDir, "a.txt"), []byte("a"), 0644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "b.txt"), []byte("bb"), 0644)
|
||||
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
|
||||
|
||||
listT := NewListDirTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := listT.Execute(ctx, map[string]interface{}{
|
||||
"path": tmpDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Errorf("expected success, got error: %s", result.Error)
|
||||
}
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
count := data["count"].(int)
|
||||
if count != 3 {
|
||||
t.Errorf("expected 3 entries, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDirToolRecursive(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "orca-list-rec-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
os.MkdirAll(filepath.Join(tmpDir, "a", "b"), 0755)
|
||||
os.WriteFile(filepath.Join(tmpDir, "a", "b", "c.txt"), []byte("c"), 0644)
|
||||
|
||||
listT := NewListDirTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := listT.Execute(ctx, map[string]interface{}{
|
||||
"path": tmpDir,
|
||||
"recursive": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Errorf("expected success, got error: %s", result.Error)
|
||||
}
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
entries := data["entries"].([]map[string]interface{})
|
||||
if len(entries) < 2 {
|
||||
t.Errorf("expected at least 2 recursive entries, got %d", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDirToolNonexistent(t *testing.T) {
|
||||
listT := NewListDirTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := listT.Execute(ctx, map[string]interface{}{
|
||||
"path": "/nonexistent/path",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error for missing path: %v", err)
|
||||
}
|
||||
if result.Success {
|
||||
t.Error("expected failure for nonexistent path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchFilesTool(t *testing.T) {
|
||||
searchT := NewSearchFilesTool()
|
||||
|
||||
if searchT.Name() != "search_files" {
|
||||
t.Errorf("expected name 'search_files', got %q", searchT.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchFilesToolExecute(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "orca-search-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test files
|
||||
os.WriteFile(filepath.Join(tmpDir, "findme.go"), []byte("package main\nfunc hello() {\n}\n"), 0644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "other.py"), []byte("def world():\n pass\n"), 0644)
|
||||
|
||||
searchT := NewSearchFilesTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := searchT.Execute(ctx, map[string]interface{}{
|
||||
"pattern": "hello",
|
||||
"path": tmpDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Errorf("expected success, got error: %s", result.Error)
|
||||
}
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
count := data["count"].(int)
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 match for 'hello', got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchFilesToolNoMatch(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "orca-search-nomatch-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte("nothing here"), 0644)
|
||||
|
||||
searchT := NewSearchFilesTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := searchT.Execute(ctx, map[string]interface{}{
|
||||
"pattern": "nonexistent-pattern-xyz",
|
||||
"path": tmpDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Errorf("expected success even with no matches, got error: %s", result.Error)
|
||||
}
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
count := data["count"].(int)
|
||||
if count != 0 {
|
||||
t.Errorf("expected 0 matches, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchFilesToolMissingPattern(t *testing.T) {
|
||||
searchT := NewSearchFilesTool()
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := searchT.Execute(ctx, map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute should not error for invalid args: %v", err)
|
||||
}
|
||||
if result.Success {
|
||||
t.Error("expected failure for missing pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolInterfaceSatisfied(t *testing.T) {
|
||||
sb := sandbox.NewProcessSandbox()
|
||||
|
||||
tools := []Tool{
|
||||
NewExecTool(sb),
|
||||
NewReadFileTool(),
|
||||
NewWriteFileTool(),
|
||||
NewListDirTool(),
|
||||
NewSearchFilesTool(),
|
||||
}
|
||||
|
||||
names := []string{"exec", "read_file", "write_file", "list_dir", "search_files"}
|
||||
for i, tool := range tools {
|
||||
if tool.Name() != names[i] {
|
||||
t.Errorf("expected name %q, got %q", names[i], tool.Name())
|
||||
}
|
||||
if tool.Description() == "" {
|
||||
t.Errorf("tool %q has empty description", names[i])
|
||||
}
|
||||
if tool.Parameters() == nil {
|
||||
t.Errorf("tool %q has nil parameters", names[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
108
pkg/tool/manager.go
Normal file
108
pkg/tool/manager.go
Normal file
@ -0,0 +1,108 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager is a thread-safe registry that manages tool registration and execution.
|
||||
//
|
||||
// Tools are registered by name (case-sensitive) and can be discovered,
|
||||
// listed, and invoked through the Manager. Duplicate registration returns
|
||||
// an error.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
// NewManager creates a new empty tool manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
tools: make(map[string]Tool),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a tool to the manager. Returns an error if a tool with the
|
||||
// same name is already registered.
|
||||
func (m *Manager) Register(tool Tool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
name := tool.Name()
|
||||
if _, exists := m.tools[name]; exists {
|
||||
return fmt.Errorf("tool %q is already registered", name)
|
||||
}
|
||||
|
||||
m.tools[name] = tool
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unregister removes a tool from the manager by name.
|
||||
func (m *Manager) Unregister(name string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.tools[name]; !exists {
|
||||
return fmt.Errorf("tool %q is not registered", name)
|
||||
}
|
||||
|
||||
delete(m.tools, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a tool by name. Returns false if not found.
|
||||
func (m *Manager) Get(name string) (Tool, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
t, ok := m.tools[name]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
// List returns all registered tools sorted by name.
|
||||
func (m *Manager) List() []Tool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make([]Tool, 0, len(m.tools))
|
||||
for _, t := range m.tools {
|
||||
result = append(result, t)
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Name() < result[j].Name()
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// Execute looks up a tool by name and invokes it with the given arguments.
|
||||
// Returns an error if the tool is not found.
|
||||
func (m *Manager) Execute(name string, ctx context.Context, args map[string]interface{}) (*Result, error) {
|
||||
tool, ok := m.Get(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool %q not found", name)
|
||||
}
|
||||
return tool.Execute(ctx, args)
|
||||
}
|
||||
|
||||
// Count returns the number of registered tools.
|
||||
func (m *Manager) Count() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.tools)
|
||||
}
|
||||
|
||||
// Names returns the names of all registered tools sorted alphabetically.
|
||||
func (m *Manager) Names() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(m.tools))
|
||||
for name := range m.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
81
pkg/tool/tool.go
Normal file
81
pkg/tool/tool.go
Normal file
@ -0,0 +1,81 @@
|
||||
// Package tool defines the Tool interface and the tool management system.
|
||||
//
|
||||
// Tools are the atomic capabilities that can be invoked by agents or LLMs.
|
||||
// Each tool has a name, description, a parameter schema (for LLM function calling),
|
||||
// and an Execute method that performs the actual work.
|
||||
//
|
||||
// Built-in tools include file operations (read, write, list, search) and
|
||||
// command execution through the sandbox. Custom tools can be registered
|
||||
// via the Manager.
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// ParameterSchema describes a single parameter accepted by a tool.
|
||||
type ParameterSchema struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Required bool `json:"required"`
|
||||
Default interface{} `json:"default,omitempty"`
|
||||
Properties map[string]ParameterSchema `json:"properties,omitempty"`
|
||||
Items *ParameterSchema `json:"items,omitempty"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
// Result holds the output of a tool execution.
|
||||
type Result struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Tool defines the interface that all tools must implement.
|
||||
//
|
||||
// Tools are registered with a Manager and can be discovered and invoked
|
||||
// by name. The Execute method receives a context for cancellation and
|
||||
// a map of string-keyed arguments.
|
||||
type Tool interface {
|
||||
// Name returns the unique identifier for this tool.
|
||||
Name() string
|
||||
|
||||
// Description returns a human-readable description of what this tool does.
|
||||
Description() string
|
||||
|
||||
// Parameters returns the schema describing accepted arguments.
|
||||
// Used for LLM function calling and validation.
|
||||
Parameters() map[string]ParameterSchema
|
||||
|
||||
// Execute performs the tool's function with the given arguments.
|
||||
// The context controls cancellation and timeouts.
|
||||
Execute(ctx context.Context, args map[string]interface{}) (*Result, error)
|
||||
}
|
||||
|
||||
// SuccessResult creates a successful tool result with the given data.
|
||||
func SuccessResult(data interface{}) *Result {
|
||||
return &Result{
|
||||
Success: true,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResult creates a failed tool result with the given error message.
|
||||
func ErrorResult(err string) *Result {
|
||||
return &Result{
|
||||
Success: false,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// MustMarshalArgs converts a map of arguments to JSON bytes, panicking on error.
|
||||
// Useful for logging and debugging.
|
||||
func MustMarshalArgs(args map[string]interface{}) []byte {
|
||||
b, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
panic("tool: failed to marshal args: " + err.Error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
@ -0,0 +1,416 @@
|
||||
---
|
||||
date: 2026-05-07
|
||||
topic: "Go Agent Framework - Orca"
|
||||
status: validated
|
||||
---
|
||||
|
||||
# Go Agent Framework (Orca) 设计文档
|
||||
|
||||
## Problem Statement
|
||||
|
||||
构建一个基于 Go 的基础 Agent 框架,支持多 Agent 协作、持久化会话记忆、Skill 技能自动识别、沙箱安全执行、自定义 Tool 注册扩展,并接入 Ollama 本地模型(gemma4:e4b)。
|
||||
|
||||
**核心挑战:**
|
||||
- 如何在 Go 中实现轻量、高并发的多 Agent 系统
|
||||
- 如何安全地执行用户命令和 Skill 脚本
|
||||
- 如何设计可扩展的插件机制(Skill / Tool)
|
||||
- 如何管理会话上下文和记忆
|
||||
|
||||
## Constraints
|
||||
|
||||
1. **语言约束:** 纯 Go 实现,最小化外部依赖
|
||||
2. **存储约束:** 使用 JSON Lines(无 SQLite/数据库依赖)
|
||||
3. **隔离约束:** 进程级限制(chroot + 资源限制),不依赖 Docker
|
||||
4. **模型约束:** 仅接入 Ollama 本地模型,默认 gemma4:e4b
|
||||
5. **Skill 目录:** 读取 `~/.agents/skills/` 下的 Skill 定义
|
||||
6. **部署约束:** 单二进制文件,零配置启动
|
||||
|
||||
## Approach
|
||||
|
||||
### 架构风格:微内核 + Actor 模型
|
||||
|
||||
采用**微内核架构**作为基础,所有功能(Skill、Tool、LLM 驱动)都以**插件**形式注册到核心。
|
||||
|
||||
每个 **Agent 实例是一个独立的 Actor**,通过 **消息总线(Message Bus)** 进行通信。这完美契合 Go 的 goroutine + channel 并发模型。
|
||||
|
||||
**为什么选择这个组合?**
|
||||
- 微内核保证核心最小化,Skill 和 Tool 热插拔
|
||||
- Actor 模型天然支持高并发,避免共享状态
|
||||
- 两者结合 = 轻量级、高扩展、Go 原生友好
|
||||
|
||||
**放弃的其他方案:**
|
||||
- Docker 沙箱:太重,违背最小依赖原则
|
||||
- SQLite 存储:增加依赖,JSONL 已足够
|
||||
- 中央协调器:单点瓶颈,不如 Actor 模型灵活
|
||||
|
||||
## Architecture
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ CLI / API Layer │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Core Kernel (微内核) │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
|
||||
│ │ Message Bus │ │ Plugin Reg │ │ Session Manager │ │
|
||||
│ │ (channel) │ │ (registry) │ │ (JSONL-based) │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────┼───────────────┐
|
||||
▼ ▼ ▼
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Agent Actor │ │ Agent Actor │ │ Agent Actor │
|
||||
│ (Specialist 1) │ │ (Specialist 2) │ │ (Orchestrator) │
|
||||
└────────┬────────┘ └────────┬────────┘ └────────┬────────┘
|
||||
│ │ │
|
||||
└───────────────────┼───────────────────┘
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Plugin Layer │
|
||||
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐│
|
||||
│ │ Skills │ │ Tools │ │ Ollama │ │ Custom Tools ││
|
||||
│ │(Skill Mgr)│ │(Tool Mgr)│ │ (Driver) │ │ (Registry) ││
|
||||
│ └──────────┘ └──────────┘ └──────────┘ └──────────────┘│
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Sandbox Layer │
|
||||
│ (Process-level isolation + Resource limits) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 模块职责
|
||||
|
||||
| 模块 | 职责 |
|
||||
|------|------|
|
||||
| **Core Kernel** | 消息路由、插件生命周期管理、会话协调 |
|
||||
| **Message Bus** | 基于 Go channel 的异步消息传递系统 |
|
||||
| **Plugin Registry** | 统一的 Skill/Tool/LLM 驱动注册中心 |
|
||||
| **Session Manager** | 基于 JSONL 的会话历史读写和上下文窗口管理 |
|
||||
| **Agent Actor** | 独立 goroutine,持有状态,接收/发送消息 |
|
||||
| **Skill Manager** | 扫描 `~/.agents/skills/`,解析 SKILL.md,加载技能 |
|
||||
| **Tool Manager** | 管理内置工具和自定义工具的注册/调用 |
|
||||
| **Ollama Driver** | 封装 Ollama HTTP API,支持流式响应 |
|
||||
| **Sandbox** | 安全执行 shell 命令和脚本,限制资源和时间 |
|
||||
|
||||
## Components
|
||||
|
||||
### 1. Core Kernel (微内核)
|
||||
|
||||
**职责:** 框架的最小化核心,只负责消息路由和插件生命周期。
|
||||
|
||||
**设计要点:**
|
||||
- 使用 Go 的 `interface{}` 或泛型定义插件契约
|
||||
- 启动时加载所有已注册的插件
|
||||
- 提供事件总线供插件间通信
|
||||
- **不**包含任何业务逻辑(如 LLM 调用、命令执行)
|
||||
|
||||
**核心接口:**
|
||||
```
|
||||
// 所有插件必须实现
|
||||
Plugin interface {
|
||||
Name() string
|
||||
Init(kernel *Kernel) error
|
||||
Shutdown() error
|
||||
}
|
||||
|
||||
// 消息总线
|
||||
MessageBus interface {
|
||||
Publish(topic string, msg Message) error
|
||||
Subscribe(topic string, handler Handler) (Subscription, error)
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Actor System (多 Agent 引擎)
|
||||
|
||||
**职责:** 管理 Agent 生命周期和消息通信。
|
||||
|
||||
**设计要点:**
|
||||
- 每个 Agent 是一个独立的 goroutine,通过 channel 接收消息
|
||||
- Agent 持有自己的状态(角色、上下文、工具列表)
|
||||
- 支持三种 Agent 类型:Orchestrator(协调者)、Worker(执行者)、Specialist(专家)
|
||||
- 消息类型:`TaskRequest`、`TaskResponse`、`ToolCall`、`Observation`
|
||||
|
||||
**Agent 状态机:**
|
||||
```
|
||||
Idle → Processing → [ToolCall] → WaitingForTool → Processing → Completed
|
||||
↓
|
||||
[Error] → Failed
|
||||
```
|
||||
|
||||
### 3. Session Manager (会话记忆)
|
||||
|
||||
**职责:** 持久化会话历史,支持上下文窗口管理。
|
||||
|
||||
**设计要点:**
|
||||
- 每个会话一个 JSONL 文件:`~/.orca/sessions/{session_id}.jsonl`
|
||||
- 每行一个 JSON 对象:`{role, content, timestamp, metadata}`
|
||||
- 提供 `GetContext(windowSize)` 方法,返回最近的 N 条消息
|
||||
- 支持会话列表、搜索、归档
|
||||
|
||||
**为什么 JSON Lines?**
|
||||
- 追加写入 O(1),无需加载整个文件
|
||||
- 人类可读,便于调试
|
||||
- 零依赖,无需数据库驱动
|
||||
- 通过简单文件锁保证并发安全
|
||||
|
||||
### 4. Skill Manager (技能系统)
|
||||
|
||||
**职责:** 自动发现和加载 Skill。
|
||||
|
||||
**设计要点:**
|
||||
- 启动时扫描 `~/.agents/skills/` 下的所有子目录
|
||||
- 解析每个 Skill 目录下的 `SKILL.md`
|
||||
- 提取元数据:`name`、`description`、`triggers`(触发词)
|
||||
- Skill 可以包含脚本文件(`scripts/` 目录)
|
||||
- 提供 `FindSkill(query string)` 方法,基于触发词匹配
|
||||
|
||||
**Skill 结构:**
|
||||
```yaml
|
||||
name: "md2pdf"
|
||||
description: "Convert Markdown to PDF..."
|
||||
triggers: ["pdf", "markdown", "export"]
|
||||
scripts:
|
||||
- "scripts/convert.py"
|
||||
- "scripts/setup.sh"
|
||||
```
|
||||
|
||||
### 5. Tool Manager (工具系统)
|
||||
|
||||
**职责:** 管理可执行工具的注册和调用。
|
||||
|
||||
**设计要点:**
|
||||
- **内置工具:** `exec`(执行命令)、`read_file`、`write_file`、`list_dir`
|
||||
- **Skill 工具:** 从 Skill 的 `scripts/` 目录自动注册
|
||||
- **自定义工具:** 通过代码注册,实现 `Tool` 接口
|
||||
- 每个工具定义:名称、描述、参数 schema、执行函数
|
||||
- LLM 通过 Function Calling 调用工具
|
||||
|
||||
**Tool 接口:**
|
||||
```
|
||||
Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Parameters() JSONSchema
|
||||
Execute(ctx Context, args map[string]any) (Result, error)
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Ollama Driver (LLM 驱动)
|
||||
|
||||
**职责:** 封装 Ollama API,提供统一的 LLM 调用接口。
|
||||
|
||||
**设计要点:**
|
||||
- 默认模型:`gemma4:e4b`
|
||||
- 支持流式响应(SSE)
|
||||
- 支持 Function Calling(通过 tools 参数)
|
||||
- 自动处理上下文窗口截断
|
||||
- 可配置参数:temperature、top_p、max_tokens
|
||||
|
||||
**API 封装:**
|
||||
```
|
||||
LLMClient interface {
|
||||
Chat(messages []Message, tools []Tool) (Response, error)
|
||||
ChatStream(messages []Message, tools []Tool) (Stream, error)
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Sandbox (沙箱执行)
|
||||
|
||||
**职责:** 安全地执行终端命令和脚本。
|
||||
|
||||
**设计要点:**
|
||||
- 使用 `os/exec` 创建子进程
|
||||
- 资源限制:CPU 时间、内存、输出大小
|
||||
- 超时控制:默认 30 秒,可配置
|
||||
- 工作目录限制:可选 chroot 或指定工作目录
|
||||
- 环境变量隔离:只允许白名单环境变量
|
||||
- **不**使用 Docker,保持轻量
|
||||
|
||||
**安全策略:**
|
||||
```yaml
|
||||
sandbox:
|
||||
timeout: 30s
|
||||
max_memory: 512MB
|
||||
max_output: 64KB
|
||||
allowed_env: [PATH, HOME, USER]
|
||||
working_dir: /tmp/orca-sandbox
|
||||
read_only_dirs: []
|
||||
blocked_commands: [rm -rf /, mkfs, dd]
|
||||
```
|
||||
|
||||
## Data Flow
|
||||
|
||||
### 典型交互流程
|
||||
|
||||
```
|
||||
用户输入
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ CLI/API │
|
||||
└──────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐ ┌─────────────┐
|
||||
│ Session Mgr │────▶│ 加载历史上下文 │
|
||||
└──────┬──────┘ └─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Orchestrator │ (Agent Actor)
|
||||
│ Agent │
|
||||
└──────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐ ┌─────────────┐
|
||||
│ Skill Mgr │────▶│ 匹配相关 Skill │
|
||||
└──────┬──────┘ └─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐ ┌─────────────┐
|
||||
│ Ollama Driver│────▶│ 发送 prompt │
|
||||
└──────┬──────┘ └─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ LLM Response │
|
||||
│ (Function │
|
||||
│ Calling) │
|
||||
└──────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐ ┌─────────────┐
|
||||
│ Tool Call │────▶│ 执行 Tool/ │
|
||||
│ │ │ 沙箱命令 │
|
||||
└──────┬──────┘ └─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Observation │ (工具执行结果)
|
||||
└──────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐ ┌─────────────┐
|
||||
│ Orchestrator │────▶│ 决策:继续/完成 │
|
||||
└──────┬──────┘ └─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ 保存会话 │
|
||||
│ 返回结果 │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
### 消息类型定义
|
||||
|
||||
```go
|
||||
type Message struct {
|
||||
ID string
|
||||
Type MessageType // TaskRequest, TaskResponse, ToolCall, Observation, Error
|
||||
From string // Agent ID
|
||||
To string // Agent ID or "broadcast"
|
||||
Content interface{} // 根据 Type 不同而变化
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
type TaskRequest struct {
|
||||
Query string
|
||||
SessionID string
|
||||
Context []ChatMessage
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ToolName string
|
||||
Arguments map[string]interface{}
|
||||
}
|
||||
|
||||
type Observation struct {
|
||||
ToolCallID string
|
||||
Output string
|
||||
Error string
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### 策略
|
||||
|
||||
1. **分层错误处理:**
|
||||
- **Kernel 层:** 插件加载失败 → 记录日志,跳过该插件,继续启动
|
||||
- **Agent 层:** 任务执行失败 → 返回错误消息,让 Orchestrator 决策重试或终止
|
||||
- **Tool 层:** 工具执行失败 → 返回结构化错误,LLM 可据此调整策略
|
||||
- **Sandbox 层:** 命令超时/内存超限 → 强制终止进程,返回错误
|
||||
|
||||
2. **重试机制:**
|
||||
- LLM API 调用:指数退避重试 3 次
|
||||
- 工具执行:不重试(避免循环),由 LLM 决策
|
||||
|
||||
3. **优雅降级:**
|
||||
- Ollama 不可用 → 提示用户检查服务
|
||||
- Skill 解析失败 → 跳过该 Skill,不影响其他
|
||||
- 沙箱执行失败 → 返回错误信息,LLM 可尝试其他工具
|
||||
|
||||
### 错误类型
|
||||
|
||||
```go
|
||||
type ErrorCategory int
|
||||
|
||||
const (
|
||||
ErrCategoryKernel ErrorCategory = iota // 内核错误
|
||||
ErrCategoryAgent // Agent 错误
|
||||
ErrCategoryTool // 工具错误
|
||||
ErrCategorySandbox // 沙箱错误
|
||||
ErrCategoryLLM // LLM 错误
|
||||
ErrCategoryNetwork // 网络错误
|
||||
)
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### 测试金字塔
|
||||
|
||||
1. **单元测试(60%):**
|
||||
- `Kernel`:插件注册/卸载、消息路由
|
||||
- `SessionManager`:JSONL 读写、上下文窗口截断
|
||||
- `SkillManager`:Skill 解析、触发词匹配
|
||||
- `Sandbox`:资源限制、超时控制
|
||||
- `OllamaDriver`:HTTP 请求封装(使用 mock server)
|
||||
|
||||
2. **集成测试(30%):**
|
||||
- Agent + Tool:端到端任务执行
|
||||
- Agent + LLM:使用 mock LLM 测试 Function Calling 流程
|
||||
- Skill + Sandbox:加载 Skill 并执行其脚本
|
||||
|
||||
3. **E2E 测试(10%):**
|
||||
- 完整 CLI 工作流
|
||||
- 多 Agent 协作场景
|
||||
|
||||
### Mock 策略
|
||||
|
||||
- `LLMClient`:使用接口,测试时注入 mock
|
||||
- `Sandbox`:提供 `DryRun` 模式,记录命令但不执行
|
||||
- `MessageBus`:内存实现,用于测试
|
||||
|
||||
## Open Questions
|
||||
|
||||
1. **Skill 执行方式:** Skill 脚本是用 Shell 调用还是直接在 Go 中执行?当前设计倾向 Shell 调用(通过 Sandbox),但 Python/Node 脚本需要对应运行时。
|
||||
- **假设:** 用户环境已安装所需运行时(Python、Node 等),Sandbox 只负责安全执行。
|
||||
|
||||
2. **Function Calling 格式:** gemma4:e4b 对 Function Calling 的支持程度?
|
||||
- **假设:** 使用 Ollama 的 `tools` 参数格式,如果不支持则 fallback 到 prompt-based tool calling。
|
||||
|
||||
3. **多 Agent 协作粒度:** Agent 之间是平等协作还是有层级?
|
||||
- **假设:** 支持两种模式:层级(Orchestrator + Workers)和平等(对等协作),由用户配置。
|
||||
|
||||
4. **会话共享:** 多个 Agent 是否可以共享同一个会话上下文?
|
||||
- **假设:** 是,Session Manager 通过文件锁支持并发读取,但同一时间只有一个 Agent 写入。
|
||||
|
||||
5. **Tool 参数 Schema:** 使用 JSON Schema 还是简化格式?
|
||||
- **假设:** 使用简化版 JSON Schema(支持 string/number/boolean/array/object + description)。
|
||||
373
thoughts/shared/plans/2026-05-07-orca-agent-framework.md
Normal file
373
thoughts/shared/plans/2026-05-07-orca-agent-framework.md
Normal file
@ -0,0 +1,373 @@
|
||||
---
|
||||
date: 2026-05-07
|
||||
topic: "Go Agent Framework - Orca"
|
||||
status: draft
|
||||
---
|
||||
|
||||
# Orca Agent Framework - 实现计划
|
||||
|
||||
## 项目概览
|
||||
|
||||
- **项目名称:** orca
|
||||
- **语言:** Go 1.22+
|
||||
- **路径:** /Users/wang/agent_dev/orca.ai/
|
||||
- **架构:** 微内核 + Actor 模型
|
||||
|
||||
## 实现阶段
|
||||
|
||||
### Phase 1: 项目骨架与核心基础设施(Day 1)
|
||||
|
||||
**目标:** 建立项目结构,实现消息总线和插件注册机制。
|
||||
|
||||
**任务清单:**
|
||||
|
||||
1. **初始化 Go 模块**
|
||||
- `go mod init github.com/orca/orca`
|
||||
- 创建基础目录结构
|
||||
|
||||
2. **目录结构**
|
||||
```
|
||||
orca/
|
||||
├── cmd/orca/ # CLI 入口
|
||||
├── pkg/
|
||||
│ ├── kernel/ # 微内核核心
|
||||
│ ├── actor/ # Actor 系统
|
||||
│ ├── bus/ # 消息总线
|
||||
│ ├── plugin/ # 插件接口和注册
|
||||
│ ├── session/ # 会话管理 (JSONL)
|
||||
│ ├── skill/ # Skill 管理
|
||||
│ ├── tool/ # Tool 系统
|
||||
│ ├── llm/ # LLM 接口
|
||||
│ ├── ollama/ # Ollama 驱动
|
||||
│ └── sandbox/ # 沙箱执行
|
||||
├── internal/
|
||||
│ ├── config/ # 配置管理
|
||||
│ └── util/ # 工具函数
|
||||
├── plugins/
|
||||
│ ├── builtin/ # 内置插件
|
||||
│ └── tools/ # 内置工具
|
||||
├── test/
|
||||
│ └── fixtures/ # 测试固件
|
||||
└── go.mod
|
||||
```
|
||||
|
||||
3. **核心接口定义**
|
||||
- `pkg/plugin/plugin.go`: Plugin 接口
|
||||
- `pkg/bus/bus.go`: MessageBus 接口和实现
|
||||
- `pkg/kernel/kernel.go`: Kernel 结构体,插件生命周期管理
|
||||
|
||||
4. **消息总线实现**
|
||||
- 基于 Go channel 的发布/订阅
|
||||
- 支持同步和异步消息
|
||||
- 消息类型枚举定义
|
||||
|
||||
**验收标准:**
|
||||
- `go build ./...` 成功
|
||||
- 消息总线单元测试通过(Publish/Subscribe)
|
||||
- 插件注册/卸载测试通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Actor 系统与会话管理(Day 2)
|
||||
|
||||
**目标:** 实现多 Agent Actor 和 JSONL 会话存储。
|
||||
|
||||
**任务清单:**
|
||||
|
||||
1. **Actor 系统**
|
||||
- `pkg/actor/actor.go`: Agent Actor 接口
|
||||
- `pkg/actor/orchestrator.go`: 协调者 Agent
|
||||
- `pkg/actor/worker.go`: 工作者 Agent
|
||||
- `pkg/actor/system.go`: Actor 生命周期管理(创建、停止、监控)
|
||||
- 状态机实现:Idle → Processing → WaitingForTool → Completed/Failed
|
||||
|
||||
2. **会话管理(JSONL)**
|
||||
- `pkg/session/store.go`: 存储接口
|
||||
- `pkg/session/jsonl.go`: JSONL 实现
|
||||
- `pkg/session/manager.go`: 会话管理器(创建、加载、归档)
|
||||
- 上下文窗口截断逻辑
|
||||
- 文件锁保证并发安全(`flock` 或简单文件锁)
|
||||
|
||||
3. **配置系统**
|
||||
- `internal/config/config.go`: 配置结构体
|
||||
- 支持 YAML 配置文件(`~/.orca/config.yaml`)
|
||||
- 环境变量覆盖
|
||||
- 默认值设置
|
||||
|
||||
**验收标准:**
|
||||
- 创建 10 个 Agent Actor 并发运行测试通过
|
||||
- 会话 CRUD 测试通过
|
||||
- 上下文窗口截断测试通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: Skill 与 Tool 系统(Day 3)
|
||||
|
||||
**目标:** 实现 Skill 自动发现和 Tool 注册执行。
|
||||
|
||||
**任务清单:**
|
||||
|
||||
1. **Skill 管理器**
|
||||
- `pkg/skill/manager.go`: Skill 扫描和加载
|
||||
- `pkg/skill/parser.go`: SKILL.md 解析器(提取 name, description, triggers)
|
||||
- `pkg/skill/skill.go`: Skill 结构体定义
|
||||
- 扫描目录:`~/.agents/skills/` 和 `~/.config/opencode/skills/`
|
||||
- 触发词匹配算法(简单关键词匹配或 TF-IDF)
|
||||
|
||||
2. **Tool 系统**
|
||||
- `pkg/tool/tool.go`: Tool 接口定义
|
||||
- `pkg/tool/manager.go`: Tool 注册中心
|
||||
- `pkg/tool/registry.go`: 内置工具注册
|
||||
- **内置工具实现:**
|
||||
- `exec`: 执行 shell 命令(通过 sandbox)
|
||||
- `read_file`: 读取文件内容
|
||||
- `write_file`: 写入文件
|
||||
- `list_dir`: 列出目录
|
||||
- `search_files`: 文件内容搜索
|
||||
|
||||
3. **自定义 Tool 注册**
|
||||
- 支持通过代码注册 Tool
|
||||
- Tool 参数 Schema 定义(简化版 JSON Schema)
|
||||
- Tool 执行上下文传递
|
||||
|
||||
**验收标准:**
|
||||
- 扫描现有 `~/.agents/skills/` 目录,正确解析所有 Skill
|
||||
- 触发词匹配测试通过
|
||||
- 所有内置工具单元测试通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: 沙箱与 Ollama 集成(Day 4)
|
||||
|
||||
**目标:** 实现安全执行环境和 LLM 驱动。
|
||||
|
||||
**任务清单:**
|
||||
|
||||
1. **沙箱执行器**
|
||||
- `pkg/sandbox/sandbox.go`: Sandbox 接口
|
||||
- `pkg/sandbox/process.go`: 进程级实现
|
||||
- 资源限制:
|
||||
- 超时控制(context.WithTimeout)
|
||||
- 内存限制(通过 cgroup 或 ulimit,若不可用则软限制)
|
||||
- 输出大小限制
|
||||
- 工作目录隔离
|
||||
- 环境变量白名单
|
||||
- 危险命令黑名单
|
||||
|
||||
2. **Ollama 驱动**
|
||||
- `pkg/ollama/client.go`: HTTP 客户端
|
||||
- `pkg/ollama/chat.go`: Chat API 封装
|
||||
- `pkg/ollama/stream.go`: 流式响应处理(SSE)
|
||||
- `pkg/ollama/tools.go`: Function Calling 支持
|
||||
- 模型配置:temperature、top_p、max_tokens
|
||||
- 自动重试机制(指数退避,3 次)
|
||||
|
||||
3. **LLM 抽象层**
|
||||
- `pkg/llm/client.go`: LLMClient 接口
|
||||
- `pkg/llm/message.go`: Message 结构体定义
|
||||
- `pkg/llm/options.go`: 调用选项
|
||||
|
||||
**验收标准:**
|
||||
- 沙箱执行命令并正确限制资源测试通过
|
||||
- 超时和内存限制测试通过
|
||||
- Ollama API 调用测试通过(需要本地 Ollama 服务)
|
||||
- Function Calling 格式正确
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: CLI 与集成(Day 5)
|
||||
|
||||
**目标:** 实现命令行界面和端到端集成。
|
||||
|
||||
**任务清单:**
|
||||
|
||||
1. **CLI 实现**
|
||||
- `cmd/orca/main.go`: 入口点
|
||||
- `cmd/orca/commands.go`: 子命令定义
|
||||
- 支持命令:
|
||||
- `orca chat`: 交互式对话
|
||||
- `orca run "query"`: 单次执行
|
||||
- `orca sessions`: 会话列表
|
||||
- `orca skills`: 已加载 Skill 列表
|
||||
- `orca tools`: 已注册 Tool 列表
|
||||
- `orca config`: 配置查看/设置
|
||||
|
||||
2. **交互式对话**
|
||||
- 读取用户输入
|
||||
- 创建/恢复会话
|
||||
- 调用 Orchestrator Agent
|
||||
- 显示 Agent 思考过程和结果
|
||||
- 支持多轮对话
|
||||
|
||||
3. **端到端集成测试**
|
||||
- 完整对话流程测试
|
||||
- Skill 触发和调用测试
|
||||
- Tool 调用链测试
|
||||
- 错误恢复测试
|
||||
|
||||
**验收标准:**
|
||||
- `orca --help` 显示正确
|
||||
- `orca chat` 可以开始对话
|
||||
- 完整对话流程测试通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 6: 多 Agent 协作与优化(Day 6-7)
|
||||
|
||||
**目标:** 实现多 Agent 协作和性能优化。
|
||||
|
||||
**任务清单:**
|
||||
|
||||
1. **多 Agent 协作**
|
||||
- Orchestrator 任务分解逻辑
|
||||
- Worker Agent 分配策略
|
||||
- Agent 间消息传递
|
||||
- 结果汇总和冲突解决
|
||||
|
||||
2. **性能优化**
|
||||
- 会话缓存(最近 N 个会话驻留内存)
|
||||
- Skill 索引(倒排索引加速匹配)
|
||||
- 连接池(Ollama HTTP 连接复用)
|
||||
|
||||
3. **可观测性**
|
||||
- 结构化日志(slog)
|
||||
- Agent 执行追踪
|
||||
- 性能指标收集
|
||||
|
||||
**验收标准:**
|
||||
- 多 Agent 协作测试通过
|
||||
- 性能基准测试通过(单次对话 < 5s)
|
||||
|
||||
---
|
||||
|
||||
## 依赖管理
|
||||
|
||||
### 外部依赖(最小化)
|
||||
|
||||
| 依赖 | 用途 | 版本 |
|
||||
|------|------|------|
|
||||
| `github.com/spf13/cobra` | CLI 框架 | latest |
|
||||
| `github.com/spf13/viper` | 配置管理 | latest |
|
||||
| `github.com/stretchr/testify` | 测试 | latest |
|
||||
|
||||
**原则:** 优先使用标准库,必要时才引入外部依赖。
|
||||
|
||||
### 内部依赖图
|
||||
|
||||
```
|
||||
kernel
|
||||
├── bus
|
||||
├── plugin
|
||||
├── actor
|
||||
│ ├── bus
|
||||
│ ├── tool
|
||||
│ └── llm
|
||||
├── session
|
||||
├── skill
|
||||
├── tool
|
||||
│ └── sandbox
|
||||
├── llm
|
||||
│ └── ollama
|
||||
└── sandbox
|
||||
```
|
||||
|
||||
## 关键接口定义
|
||||
|
||||
### Plugin 接口
|
||||
|
||||
```go
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
Version() string
|
||||
Init(kernel *Kernel) error
|
||||
Shutdown() error
|
||||
}
|
||||
```
|
||||
|
||||
### Agent 接口
|
||||
|
||||
```go
|
||||
type Agent interface {
|
||||
ID() string
|
||||
Role() string
|
||||
Process(ctx context.Context, msg Message) (Message, error)
|
||||
Stop() error
|
||||
}
|
||||
```
|
||||
|
||||
### Tool 接口
|
||||
|
||||
```go
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Parameters() ParameterSchema
|
||||
Execute(ctx context.Context, args map[string]interface{}) (ToolResult, error)
|
||||
}
|
||||
```
|
||||
|
||||
### LLMClient 接口
|
||||
|
||||
```go
|
||||
type LLMClient interface {
|
||||
Chat(ctx context.Context, messages []Message, tools []Tool) (*ChatResponse, error)
|
||||
ChatStream(ctx context.Context, messages []Message, tools []Tool) (StreamReader, error)
|
||||
}
|
||||
```
|
||||
|
||||
## 测试策略
|
||||
|
||||
### 单元测试覆盖目标
|
||||
|
||||
| 模块 | 覆盖率目标 |
|
||||
|------|-----------|
|
||||
| bus | 90% |
|
||||
| kernel | 85% |
|
||||
| session | 90% |
|
||||
| skill | 85% |
|
||||
| tool | 90% |
|
||||
| sandbox | 80% |
|
||||
| ollama | 75% (mock) |
|
||||
|
||||
### Mock 实现
|
||||
|
||||
```go
|
||||
// MockLLMClient 用于测试
|
||||
type MockLLMClient struct {
|
||||
Responses []ChatResponse
|
||||
Index int
|
||||
}
|
||||
|
||||
func (m *MockLLMClient) Chat(ctx context.Context, messages []Message, tools []Tool) (*ChatResponse, error) {
|
||||
if m.Index >= len(m.Responses) {
|
||||
return nil, errors.New("no more mock responses")
|
||||
}
|
||||
resp := m.Responses[m.Index]
|
||||
m.Index++
|
||||
return &resp, nil
|
||||
}
|
||||
```
|
||||
|
||||
## 风险与回退方案
|
||||
|
||||
| 风险 | 影响 | 概率 | 回退方案 |
|
||||
|------|------|------|----------|
|
||||
| gemma4:e4b 不支持 Function Calling | 高 | 中 | 使用 prompt-based tool calling |
|
||||
| 进程级沙箱限制不足 | 中 | 低 | 添加 Docker 支持作为可选 |
|
||||
| JSONL 性能瓶颈 | 中 | 低 | 迁移到 SQLite(保留接口) |
|
||||
| Actor 模型复杂度 | 中 | 中 | 简化中央协调器模式 |
|
||||
|
||||
## 里程碑
|
||||
|
||||
| 里程碑 | 时间 | 交付物 |
|
||||
|--------|------|--------|
|
||||
| M1 | Day 1 | 项目骨架 + 消息总线 + 插件系统 |
|
||||
| M2 | Day 2 | Actor 系统 + 会话管理 |
|
||||
| M3 | Day 3 | Skill + Tool 系统 |
|
||||
| M4 | Day 4 | 沙箱 + Ollama 集成 |
|
||||
| M5 | Day 5 | CLI + 端到端集成 |
|
||||
| M6 | Day 6-7 | 多 Agent + 优化 |
|
||||
|
||||
## 下一步
|
||||
|
||||
执行 Phase 1,建立项目骨架。
|
||||
Loading…
x
Reference in New Issue
Block a user