306 lines
8.0 KiB
Go
306 lines
8.0 KiB
Go
package llm
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// OllamaClient implements the LLM interface for Ollama's API.
|
|
//
|
|
// It communicates with a running Ollama server via its REST API.
|
|
// Supports chat, streaming, tool calling (function calling), and
|
|
// embedding generation.
|
|
type OllamaClient struct {
|
|
baseURL string
|
|
model string
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// OllamaOption is a functional option for configuring the OllamaClient.
|
|
type OllamaOption func(*OllamaClient)
|
|
|
|
// WithBaseURL sets the Ollama server base URL.
|
|
func WithBaseURL(url string) OllamaOption {
|
|
return func(c *OllamaClient) {
|
|
c.baseURL = strings.TrimRight(url, "/")
|
|
}
|
|
}
|
|
|
|
// WithModel sets the default model name.
|
|
func WithModel(model string) OllamaOption {
|
|
return func(c *OllamaClient) {
|
|
c.model = model
|
|
}
|
|
}
|
|
|
|
// WithTimeout sets the HTTP client timeout.
|
|
func WithTimeout(timeout time.Duration) OllamaOption {
|
|
return func(c *OllamaClient) {
|
|
c.httpClient.Timeout = timeout
|
|
}
|
|
}
|
|
|
|
// WithHTTPClient sets a custom HTTP client.
|
|
func WithHTTPClient(client *http.Client) OllamaOption {
|
|
return func(c *OllamaClient) {
|
|
c.httpClient = client
|
|
}
|
|
}
|
|
|
|
// NewOllamaClient creates a new OllamaClient with the given options.
|
|
//
|
|
// Default values:
|
|
// - BaseURL: http://localhost:11434
|
|
// - Model: gemma4:e4b
|
|
// - Timeout: 30s
|
|
func NewOllamaClient(opts ...OllamaOption) *OllamaClient {
|
|
c := &OllamaClient{
|
|
baseURL: "http://localhost:11434",
|
|
model: "gemma4:e4b",
|
|
httpClient: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
},
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(c)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// Chat sends a chat request to Ollama and returns the complete response.
|
|
// If the Ollama model returns tool calls, they are parsed and included
|
|
// in the Response.
|
|
func (c *OllamaClient) ChatWithTools(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) {
|
|
return c.Chat(ctx, messages)
|
|
}
|
|
|
|
func (c *OllamaClient) Chat(ctx context.Context, messages []Message) (*Response, error) {
|
|
req := OllamaChatRequest{
|
|
Model: c.model,
|
|
Messages: messages,
|
|
Stream: false,
|
|
}
|
|
|
|
// Build tool definitions if tool package is integrated
|
|
// (tools are added externally via BuildToolDefs)
|
|
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: failed to marshal request: %w", err)
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
|
c.baseURL+"/api/chat", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: failed to create request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("ollama: API error (status %d): %s",
|
|
resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var rawResp json.RawMessage
|
|
if err := json.NewDecoder(resp.Body).Decode(&rawResp); err != nil {
|
|
return nil, fmt.Errorf("ollama: failed to decode response: %w", err)
|
|
}
|
|
|
|
return parseOllamaResponse(rawResp)
|
|
}
|
|
|
|
// parseOllamaResponse attempts to parse the Ollama API response,
|
|
// handling both regular text responses and tool call responses.
|
|
func parseOllamaResponse(raw json.RawMessage) (*Response, error) {
|
|
// Try as tool call response first (has message.tool_calls)
|
|
var toolResp OllamaToolCallResponse
|
|
if err := json.Unmarshal(raw, &toolResp); err == nil && len(toolResp.Message.ToolCalls) > 0 {
|
|
return &Response{
|
|
Content: toolResp.Message.Content,
|
|
ToolCalls: toolResp.Message.ToolCalls,
|
|
}, nil
|
|
}
|
|
|
|
// Try as regular response
|
|
var chatResp OllamaChatResponse
|
|
if err := json.Unmarshal(raw, &chatResp); err != nil {
|
|
return nil, fmt.Errorf("ollama: failed to parse response: %w", err)
|
|
}
|
|
|
|
return &Response{
|
|
Content: chatResp.Message.Content,
|
|
}, nil
|
|
}
|
|
|
|
// Stream sends a chat request to Ollama with streaming enabled.
|
|
// The handler receives each content chunk as it arrives.
|
|
func (c *OllamaClient) Stream(ctx context.Context, messages []Message, handler StreamHandler) error {
|
|
req := OllamaChatRequest{
|
|
Model: c.model,
|
|
Messages: messages,
|
|
Stream: true,
|
|
}
|
|
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return fmt.Errorf("ollama: failed to marshal request: %w", err)
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
|
c.baseURL+"/api/chat", bytes.NewReader(body))
|
|
if err != nil {
|
|
return fmt.Errorf("ollama: failed to create request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return fmt.Errorf("ollama: stream request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("ollama: stream API error (status %d): %s",
|
|
resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
// Each line is a JSON object: {"model":"...","created_at":"...","message":{"role":"assistant","content":"..."},"done":false}
|
|
var streamResp OllamaChatResponse
|
|
if err := json.Unmarshal([]byte(line), &streamResp); err != nil {
|
|
continue // Skip malformed lines
|
|
}
|
|
|
|
if streamResp.Message.Content != "" {
|
|
if err := handler(streamResp.Message.Content); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if streamResp.Done {
|
|
break
|
|
}
|
|
}
|
|
|
|
return scanner.Err()
|
|
}
|
|
|
|
// Embed generates an embedding vector for the given input text.
|
|
func (c *OllamaClient) Embed(ctx context.Context, input string) (*EmbeddingResponse, error) {
|
|
req := OllamaEmbedRequest{
|
|
Model: c.model,
|
|
Input: input,
|
|
}
|
|
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: failed to marshal embed request: %w", err)
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
|
c.baseURL+"/api/embed", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: failed to create embed request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: embed request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("ollama: embed API error (status %d): %s",
|
|
resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var apiResp OllamaEmbedResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
|
|
return nil, fmt.Errorf("ollama: failed to decode embed response: %w", err)
|
|
}
|
|
|
|
return &EmbeddingResponse{
|
|
Embedding: apiResp.Embedding,
|
|
}, nil
|
|
}
|
|
|
|
// BuildToolDefsFromMap converts a generic tool definition map into Ollama ToolDefs.
|
|
// This is used to bridge the tool package's Tool interface with Ollama's API format.
|
|
func BuildToolDefsFromMap(tools []map[string]interface{}) []ToolDef {
|
|
var defs []ToolDef
|
|
for _, t := range tools {
|
|
name, _ := t["name"].(string)
|
|
desc, _ := t["description"].(string)
|
|
|
|
def := ToolDef{
|
|
Type: "function",
|
|
Function: ToolFunction{
|
|
Name: name,
|
|
Description: desc,
|
|
Parameters: ToolFunctionParameters{
|
|
Type: "object",
|
|
Properties: make(map[string]ToolProperty),
|
|
},
|
|
},
|
|
}
|
|
|
|
if params, ok := t["parameters"].(map[string]interface{}); ok {
|
|
if props, ok := params["properties"].(map[string]interface{}); ok {
|
|
for key, val := range props {
|
|
if p, ok := val.(map[string]interface{}); ok {
|
|
prop := ToolProperty{
|
|
Type: toString(p["type"]),
|
|
Description: toString(p["description"]),
|
|
}
|
|
def.Function.Parameters.Properties[key] = prop
|
|
if isRequired(p) {
|
|
def.Function.Parameters.Required = append(def.Function.Parameters.Required, key)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
defs = append(defs, def)
|
|
}
|
|
return defs
|
|
}
|
|
|
|
func toString(v interface{}) string {
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
s, _ := v.(string)
|
|
return s
|
|
}
|
|
|
|
func isRequired(p map[string]interface{}) bool {
|
|
req, _ := p["required"].(bool)
|
|
return req
|
|
}
|