feat: add DeepSeek LLM provider support
- Add DeepSeekClient implementing LLM interface - Support chat and streaming APIs - Add Provider config option (ollama/deepseek) - Default to DeepSeek with model deepseek-v4-flash - Update CLI to display provider info - Add DeepSeek environment variables (DEEPSEEK_API_KEY, etc.)
This commit is contained in:
parent
04c7ea5e39
commit
286d3dae3c
@ -23,7 +23,6 @@ func main() {
|
|||||||
// Load configuration from environment variables
|
// Load configuration from environment variables
|
||||||
cfg := config.LoadConfigFromEnv()
|
cfg := config.LoadConfigFromEnv()
|
||||||
|
|
||||||
// Support shorter env var names for Ollama (without ORCA_ prefix)
|
|
||||||
if v := os.Getenv("OLLAMA_BASE_URL"); v != "" {
|
if v := os.Getenv("OLLAMA_BASE_URL"); v != "" {
|
||||||
cfg.Ollama.BaseURL = v
|
cfg.Ollama.BaseURL = v
|
||||||
}
|
}
|
||||||
@ -36,6 +35,21 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv("DEEPSEEK_BASE_URL"); v != "" {
|
||||||
|
cfg.DeepSeek.BaseURL = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DEEPSEEK_MODEL"); v != "" {
|
||||||
|
cfg.DeepSeek.Model = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DEEPSEEK_API_KEY"); v != "" {
|
||||||
|
cfg.DeepSeek.APIKey = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DEEPSEEK_TIMEOUT"); v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil {
|
||||||
|
cfg.DeepSeek.Timeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create and start kernel
|
// Create and start kernel
|
||||||
k := kernel.NewWithConfig(cfg)
|
k := kernel.NewWithConfig(cfg)
|
||||||
|
|
||||||
@ -47,8 +61,14 @@ func main() {
|
|||||||
|
|
||||||
fmt.Println("Orca Agent Framework")
|
fmt.Println("Orca Agent Framework")
|
||||||
fmt.Println("Kernel started successfully")
|
fmt.Println("Kernel started successfully")
|
||||||
fmt.Printf(" LLM Model: %s\n", cfg.Ollama.Model)
|
if cfg.Provider == config.ProviderDeepSeek {
|
||||||
fmt.Printf(" Ollama URL: %s\n", cfg.Ollama.BaseURL)
|
fmt.Printf(" Provider: DeepSeek\n")
|
||||||
|
fmt.Printf(" LLM Model: %s\n", cfg.DeepSeek.Model)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" Provider: Ollama\n")
|
||||||
|
fmt.Printf(" LLM Model: %s\n", cfg.Ollama.Model)
|
||||||
|
fmt.Printf(" Ollama URL: %s\n", cfg.Ollama.BaseURL)
|
||||||
|
}
|
||||||
fmt.Println("Type your message or /help for commands.")
|
fmt.Println("Type your message or /help for commands.")
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,3 @@
|
|||||||
// Package config provides the configuration types for the Orca framework.
|
|
||||||
//
|
|
||||||
// Configuration is organized into logical groups: LLM (Ollama), sandbox,
|
|
||||||
// and session management. Default values are provided for all settings.
|
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -10,52 +6,60 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is the top-level configuration for the Orca framework.
|
const (
|
||||||
|
ProviderOllama = "ollama"
|
||||||
|
ProviderDeepSeek = "deepseek"
|
||||||
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Ollama OllamaConfig `json:"ollama"`
|
Provider string `json:"provider"`
|
||||||
Sandbox SandboxConfig `json:"sandbox"`
|
Ollama OllamaConfig `json:"ollama"`
|
||||||
Session SessionConfig `json:"session"`
|
DeepSeek DeepSeekConfig `json:"deepseek"`
|
||||||
|
Sandbox SandboxConfig `json:"sandbox"`
|
||||||
|
Session SessionConfig `json:"session"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OllamaConfig holds configuration for the Ollama LLM backend.
|
|
||||||
type OllamaConfig struct {
|
type OllamaConfig struct {
|
||||||
// BaseURL is the Ollama API endpoint (e.g., "http://localhost:11434").
|
BaseURL string `json:"base_url"`
|
||||||
BaseURL string `json:"base_url"`
|
Model string `json:"model"`
|
||||||
// Model is the Ollama model name to use (e.g., "gemma4:e4b", "codellama").
|
Timeout time.Duration `json:"timeout"`
|
||||||
Model string `json:"model"`
|
}
|
||||||
// Timeout is the maximum duration to wait for an Ollama response.
|
|
||||||
|
type DeepSeekConfig struct {
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
Timeout time.Duration `json:"timeout"`
|
Timeout time.Duration `json:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SandboxConfig holds configuration for the command execution sandbox.
|
|
||||||
type SandboxConfig struct {
|
type SandboxConfig struct {
|
||||||
// Timeout is the maximum duration for a sandboxed command.
|
Timeout time.Duration `json:"timeout"`
|
||||||
Timeout time.Duration `json:"timeout"`
|
MaxMemory int64 `json:"max_memory"`
|
||||||
// MaxMemory is the maximum memory allocation for the sandbox (in bytes).
|
WorkingDir string `json:"working_dir"`
|
||||||
MaxMemory int64 `json:"max_memory"`
|
|
||||||
// WorkingDir is the default working directory for sandboxed commands.
|
|
||||||
WorkingDir string `json:"working_dir"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionConfig holds configuration for session management.
|
|
||||||
type SessionConfig struct {
|
type SessionConfig struct {
|
||||||
// StorageDir is the directory for session JSONL files.
|
|
||||||
StorageDir string `json:"storage_dir"`
|
StorageDir string `json:"storage_dir"`
|
||||||
// MaxHistory is the maximum number of messages to retain per session.
|
MaxHistory int `json:"max_history"`
|
||||||
MaxHistory int `json:"max_history"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultConfig returns a Config with sensible defaults.
|
|
||||||
func DefaultConfig() *Config {
|
func DefaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
|
Provider: ProviderDeepSeek,
|
||||||
Ollama: OllamaConfig{
|
Ollama: OllamaConfig{
|
||||||
BaseURL: "http://localhost:11434",
|
BaseURL: "http://localhost:11434",
|
||||||
Model: "gemma4:e4b",
|
Model: "gemma4:e4b",
|
||||||
Timeout: 120 * time.Second,
|
Timeout: 120 * time.Second,
|
||||||
},
|
},
|
||||||
|
DeepSeek: DeepSeekConfig{
|
||||||
|
BaseURL: "https://api.deepseek.com/v1",
|
||||||
|
Model: "deepseek-v4-flash",
|
||||||
|
APIKey: "sk-2f1049148e06492dbc304ba49c81c321",
|
||||||
|
Timeout: 120 * time.Second,
|
||||||
|
},
|
||||||
Sandbox: SandboxConfig{
|
Sandbox: SandboxConfig{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
MaxMemory: 512 * 1024 * 1024, // 512 MB
|
MaxMemory: 512 * 1024 * 1024,
|
||||||
WorkingDir: "/tmp/orca/sandbox",
|
WorkingDir: "/tmp/orca/sandbox",
|
||||||
},
|
},
|
||||||
Session: SessionConfig{
|
Session: SessionConfig{
|
||||||
@ -68,11 +72,12 @@ func DefaultConfig() *Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfigFromEnv reads configuration from environment variables,
|
|
||||||
// overriding defaults where environment variables are set.
|
|
||||||
func LoadConfigFromEnv() *Config {
|
func LoadConfigFromEnv() *Config {
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
|
if v := os.Getenv("ORCA_PROVIDER"); v != "" {
|
||||||
|
cfg.Provider = v
|
||||||
|
}
|
||||||
if v := os.Getenv("ORCA_OLLAMA_BASE_URL"); v != "" {
|
if v := os.Getenv("ORCA_OLLAMA_BASE_URL"); v != "" {
|
||||||
cfg.Ollama.BaseURL = v
|
cfg.Ollama.BaseURL = v
|
||||||
}
|
}
|
||||||
@ -84,6 +89,20 @@ func LoadConfigFromEnv() *Config {
|
|||||||
cfg.Ollama.Timeout = d
|
cfg.Ollama.Timeout = d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v := os.Getenv("ORCA_DEEPSEEK_BASE_URL"); v != "" {
|
||||||
|
cfg.DeepSeek.BaseURL = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("ORCA_DEEPSEEK_MODEL"); v != "" {
|
||||||
|
cfg.DeepSeek.Model = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("ORCA_DEEPSEEK_API_KEY"); v != "" {
|
||||||
|
cfg.DeepSeek.APIKey = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("ORCA_DEEPSEEK_TIMEOUT"); v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil {
|
||||||
|
cfg.DeepSeek.Timeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
if v := os.Getenv("ORCA_SANDBOX_TIMEOUT"); v != "" {
|
if v := os.Getenv("ORCA_SANDBOX_TIMEOUT"); v != "" {
|
||||||
if d, err := time.ParseDuration(v); err == nil {
|
if d, err := time.ParseDuration(v); err == nil {
|
||||||
cfg.Sandbox.Timeout = d
|
cfg.Sandbox.Timeout = d
|
||||||
@ -109,16 +128,34 @@ func LoadConfigFromEnv() *Config {
|
|||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValid checks whether the configuration has valid values.
|
|
||||||
func (c *Config) IsValid() error {
|
func (c *Config) IsValid() error {
|
||||||
if c.Ollama.BaseURL == "" {
|
if c.Provider != ProviderOllama && c.Provider != ProviderDeepSeek {
|
||||||
return errConfig("ollama.base_url must not be empty")
|
return errConfig("provider must be 'ollama' or 'deepseek'")
|
||||||
}
|
}
|
||||||
if c.Ollama.Model == "" {
|
if c.Provider == ProviderOllama {
|
||||||
return errConfig("ollama.model must not be empty")
|
if c.Ollama.BaseURL == "" {
|
||||||
|
return errConfig("ollama.base_url must not be empty")
|
||||||
|
}
|
||||||
|
if c.Ollama.Model == "" {
|
||||||
|
return errConfig("ollama.model must not be empty")
|
||||||
|
}
|
||||||
|
if c.Ollama.Timeout <= 0 {
|
||||||
|
return errConfig("ollama.timeout must be positive")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if c.Ollama.Timeout <= 0 {
|
if c.Provider == ProviderDeepSeek {
|
||||||
return errConfig("ollama.timeout must be positive")
|
if c.DeepSeek.BaseURL == "" {
|
||||||
|
return errConfig("deepseek.base_url must not be empty")
|
||||||
|
}
|
||||||
|
if c.DeepSeek.Model == "" {
|
||||||
|
return errConfig("deepseek.model must not be empty")
|
||||||
|
}
|
||||||
|
if c.DeepSeek.APIKey == "" {
|
||||||
|
return errConfig("deepseek.api_key must not be empty")
|
||||||
|
}
|
||||||
|
if c.DeepSeek.Timeout <= 0 {
|
||||||
|
return errConfig("deepseek.timeout must be positive")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if c.Sandbox.Timeout <= 0 {
|
if c.Sandbox.Timeout <= 0 {
|
||||||
return errConfig("sandbox.timeout must be positive")
|
return errConfig("sandbox.timeout must be positive")
|
||||||
@ -132,12 +169,10 @@ func (c *Config) IsValid() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// errConfig creates a configuration error.
|
|
||||||
func errConfig(msg string) error {
|
func errConfig(msg string) error {
|
||||||
return &ConfigError{Message: msg}
|
return &ConfigError{Message: msg}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigError represents a configuration validation error.
|
|
||||||
type ConfigError struct {
|
type ConfigError struct {
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|||||||
@ -127,6 +127,7 @@ func TestConfigIsValid(t *testing.T) {
|
|||||||
|
|
||||||
func TestConfigInvalidBaseURL(t *testing.T) {
|
func TestConfigInvalidBaseURL(t *testing.T) {
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
|
cfg.Provider = ProviderOllama
|
||||||
cfg.Ollama.BaseURL = ""
|
cfg.Ollama.BaseURL = ""
|
||||||
if err := cfg.IsValid(); err == nil {
|
if err := cfg.IsValid(); err == nil {
|
||||||
t.Error("expected error for empty BaseURL")
|
t.Error("expected error for empty BaseURL")
|
||||||
@ -135,6 +136,7 @@ func TestConfigInvalidBaseURL(t *testing.T) {
|
|||||||
|
|
||||||
func TestConfigInvalidModel(t *testing.T) {
|
func TestConfigInvalidModel(t *testing.T) {
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
|
cfg.Provider = ProviderOllama
|
||||||
cfg.Ollama.Model = ""
|
cfg.Ollama.Model = ""
|
||||||
if err := cfg.IsValid(); err == nil {
|
if err := cfg.IsValid(); err == nil {
|
||||||
t.Error("expected error for empty Model")
|
t.Error("expected error for empty Model")
|
||||||
@ -143,12 +145,22 @@ func TestConfigInvalidModel(t *testing.T) {
|
|||||||
|
|
||||||
func TestConfigInvalidOllamaTimeout(t *testing.T) {
|
func TestConfigInvalidOllamaTimeout(t *testing.T) {
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
|
cfg.Provider = ProviderOllama
|
||||||
cfg.Ollama.Timeout = 0
|
cfg.Ollama.Timeout = 0
|
||||||
if err := cfg.IsValid(); err == nil {
|
if err := cfg.IsValid(); err == nil {
|
||||||
t.Error("expected error for zero Ollama Timeout")
|
t.Error("expected error for zero Ollama Timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigInvalidDeepSeekAPIKey(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.Provider = ProviderDeepSeek
|
||||||
|
cfg.DeepSeek.APIKey = ""
|
||||||
|
if err := cfg.IsValid(); err == nil {
|
||||||
|
t.Error("expected error for empty DeepSeek APIKey")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConfigInvalidSandboxTimeout(t *testing.T) {
|
func TestConfigInvalidSandboxTimeout(t *testing.T) {
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
cfg.Sandbox.Timeout = -1
|
cfg.Sandbox.Timeout = -1
|
||||||
|
|||||||
@ -160,13 +160,20 @@ func (k *Kernel) initializeActorSystem() {
|
|||||||
k.orch.AddWorker(tw)
|
k.orch.AddWorker(tw)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createLLMBackend creates the LLM backend based on configuration.
|
|
||||||
func (k *Kernel) createLLMBackend() llm.LLM {
|
func (k *Kernel) createLLMBackend() llm.LLM {
|
||||||
|
switch k.config.Provider {
|
||||||
|
case config.ProviderDeepSeek:
|
||||||
|
return k.createDeepSeekBackend()
|
||||||
|
default:
|
||||||
|
return k.createOllamaBackend()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *Kernel) createOllamaBackend() llm.LLM {
|
||||||
baseURL := k.config.Ollama.BaseURL
|
baseURL := k.config.Ollama.BaseURL
|
||||||
model := k.config.Ollama.Model
|
model := k.config.Ollama.Model
|
||||||
timeout := k.config.Ollama.Timeout
|
timeout := k.config.Ollama.Timeout
|
||||||
|
|
||||||
// Allow shorter env var names to override
|
|
||||||
if v := os.Getenv("OLLAMA_BASE_URL"); v != "" {
|
if v := os.Getenv("OLLAMA_BASE_URL"); v != "" {
|
||||||
baseURL = v
|
baseURL = v
|
||||||
}
|
}
|
||||||
@ -189,6 +196,38 @@ func (k *Kernel) createLLMBackend() llm.LLM {
|
|||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k *Kernel) createDeepSeekBackend() llm.LLM {
|
||||||
|
baseURL := k.config.DeepSeek.BaseURL
|
||||||
|
model := k.config.DeepSeek.Model
|
||||||
|
apiKey := k.config.DeepSeek.APIKey
|
||||||
|
timeout := k.config.DeepSeek.Timeout
|
||||||
|
|
||||||
|
if v := os.Getenv("DEEPSEEK_BASE_URL"); v != "" {
|
||||||
|
baseURL = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DEEPSEEK_MODEL"); v != "" {
|
||||||
|
model = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DEEPSEEK_API_KEY"); v != "" {
|
||||||
|
apiKey = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DEEPSEEK_TIMEOUT"); v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil {
|
||||||
|
timeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := llm.NewDeepSeekClient(
|
||||||
|
llm.WithDeepSeekBaseURL(baseURL),
|
||||||
|
llm.WithDeepSeekModel(model),
|
||||||
|
llm.WithDeepSeekAPIKey(apiKey),
|
||||||
|
llm.WithDeepSeekTimeout(timeout),
|
||||||
|
)
|
||||||
|
|
||||||
|
log.Printf("kernel: created DeepSeek client (model=%s)", model)
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
// Bus returns the kernel's message bus.
|
// Bus returns the kernel's message bus.
|
||||||
func (k *Kernel) Bus() bus.MessageBus {
|
func (k *Kernel) Bus() bus.MessageBus {
|
||||||
return k.mb
|
return k.mb
|
||||||
|
|||||||
193
pkg/llm/deepseek.go
Normal file
193
pkg/llm/deepseek.go
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeepSeekClient struct {
|
||||||
|
baseURL string
|
||||||
|
model string
|
||||||
|
apiKey string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeepSeekOption func(*DeepSeekClient)
|
||||||
|
|
||||||
|
func WithDeepSeekBaseURL(url string) DeepSeekOption {
|
||||||
|
return func(c *DeepSeekClient) {
|
||||||
|
c.baseURL = strings.TrimRight(url, "/")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDeepSeekModel(model string) DeepSeekOption {
|
||||||
|
return func(c *DeepSeekClient) {
|
||||||
|
c.model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDeepSeekAPIKey(key string) DeepSeekOption {
|
||||||
|
return func(c *DeepSeekClient) {
|
||||||
|
c.apiKey = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDeepSeekTimeout(timeout time.Duration) DeepSeekOption {
|
||||||
|
return func(c *DeepSeekClient) {
|
||||||
|
c.httpClient.Timeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeepSeekClient(opts ...DeepSeekOption) *DeepSeekClient {
|
||||||
|
c := &DeepSeekClient{
|
||||||
|
baseURL: "https://api.deepseek.com/v1",
|
||||||
|
model: "deepseek-chat",
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 120 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(c)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DeepSeekClient) Chat(ctx context.Context, messages []Message) (*Response, error) {
|
||||||
|
reqBody := c.buildChatRequest(messages, false)
|
||||||
|
body, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("deepseek: failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("deepseek: failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("deepseek: request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("deepseek: API returned %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiResp deepSeekChatResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("deepseek: failed to decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(apiResp.Choices) == 0 {
|
||||||
|
return nil, fmt.Errorf("deepseek: no choices in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := apiResp.Choices[0]
|
||||||
|
return &Response{
|
||||||
|
Content: choice.Message.Content,
|
||||||
|
ToolCalls: choice.Message.ToolCalls,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DeepSeekClient) Stream(ctx context.Context, messages []Message, handler StreamHandler) error {
|
||||||
|
reqBody := c.buildChatRequest(messages, true)
|
||||||
|
body, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("deepseek: failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("deepseek: failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("deepseek: request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("deepseek: API returned %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return fmt.Errorf("deepseek: error reading stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" || line == "data: [DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
var chunk deepSeekStreamChunk
|
||||||
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||||
|
if err := handler(chunk.Choices[0].Delta.Content); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DeepSeekClient) buildChatRequest(messages []Message, stream bool) deepSeekChatRequest {
|
||||||
|
return deepSeekChatRequest{
|
||||||
|
Model: c.model,
|
||||||
|
Messages: messages,
|
||||||
|
Stream: stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepSeekChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepSeekChatResponse struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepSeekStreamChunk struct {
|
||||||
|
Choices []struct {
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user