package llm import ( "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" ) // ============================================================ // Helper: create a mock Ollama server // ============================================================ // mockOllamaHandler returns an http.Handler that simulates the Ollama API. func mockOllamaHandler(t *testing.T, responseFunc func(reqBody map[string]interface{}) (int, interface{})) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Verify request path if r.URL.Path != "/api/chat" && r.URL.Path != "/api/embed" { t.Errorf("unexpected path: %s", r.URL.Path) w.WriteHeader(http.StatusNotFound) return } // Decode request body var reqBody map[string]interface{} if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { t.Fatalf("failed to decode request body: %v", err) } status, resp := responseFunc(reqBody) w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(resp); err != nil { t.Fatalf("failed to encode response: %v", err) } })) } // ============================================================ // NewOllamaClient Tests // ============================================================ func TestNewOllamaClientDefaults(t *testing.T) { c := NewOllamaClient() if c == nil { t.Fatal("NewOllamaClient() returned nil") } if c.baseURL != "http://localhost:11434" { t.Errorf("expected default base URL 'http://localhost:11434', got %q", c.baseURL) } if c.model != "gemma4:e4b" { t.Errorf("expected default model 'gemma4:e4b', got %q", c.model) } } func TestNewOllamaClientWithOptions(t *testing.T) { c := NewOllamaClient( WithBaseURL("http://custom:11434"), WithModel("codellama"), WithTimeout(60), ) if c.baseURL != "http://custom:11434" { t.Errorf("expected base URL 'http://custom:11434', got %q", c.baseURL) } if c.model != "codellama" { t.Errorf("expected model 'codellama', got %q", c.model) } } // ============================================================ // Chat Tests // ============================================================ func TestChat(t *testing.T) { srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) { // Verify the request has the expected shape if model, ok := reqBody["model"]; !ok || model != "gemma4:e4b" { t.Errorf("expected model 'gemma4:e4b', got %v", model) } if stream, ok := reqBody["stream"]; !ok || stream != false { t.Errorf("expected stream false, got %v", stream) } return http.StatusOK, OllamaChatResponse{ Model: "gemma4:e4b", Message: Message{ Role: "assistant", Content: "Hello! How can I help you?", }, Done: true, } }) defer srv.Close() client := NewOllamaClient(WithBaseURL(srv.URL)) resp, err := client.Chat(context.Background(), []Message{ {Role: "user", Content: "Hello"}, }) if err != nil { t.Fatalf("Chat failed: %v", err) } if resp.Content != "Hello! How can I help you?" { t.Errorf("expected content 'Hello! How can I help you?', got %q", resp.Content) } if len(resp.ToolCalls) != 0 { t.Errorf("expected no tool calls, got %d", len(resp.ToolCalls)) } } func TestChatWithToolCalls(t *testing.T) { srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) { return http.StatusOK, OllamaToolCallResponse{ Model: "gemma4:e4b", Message: OllamaToolMsg{ Role: "assistant", Content: "", ToolCalls: []ToolCall{ { ID: "call-1", Type: "function", Function: FunctionCall{ Name: "exec", Arguments: `{"command":"ls -la"}`, }, }, }, }, Done: true, } }) defer srv.Close() client := NewOllamaClient(WithBaseURL(srv.URL)) resp, err := client.Chat(context.Background(), []Message{ {Role: "user", Content: "List files"}, }) if err != nil { t.Fatalf("Chat failed: %v", err) } if len(resp.ToolCalls) != 1 { t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls)) } if resp.ToolCalls[0].Function.Name != "exec" { t.Errorf("expected tool name 'exec', got %q", resp.ToolCalls[0].Function.Name) } if resp.ToolCalls[0].Function.Arguments != `{"command":"ls -la"}` { t.Errorf("unexpected arguments: %q", resp.ToolCalls[0].Function.Arguments) } } func TestChatAPIError(t *testing.T) { srv := mockOllamaHandler(t, func(reqBody map[string]interface{}) (int, interface{}) { return http.StatusInternalServerError, map[string]string{"error": "internal error"} }) defer srv.Close() client := NewOllamaClient(WithBaseURL(srv.URL)) _, err := client.Chat(context.Background(), []Message{ {Role: "user", Content: "Hello"}, }) if err == nil { t.Fatal("expected error for API error response") } if !strings.Contains(err.Error(), "500") { t.Errorf("expected error to contain status code, got: %v", err) } } func TestChatContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately client := NewOllamaClient(WithBaseURL("http://localhost:11434")) _, err := client.Chat(ctx, []Message{{Role: "user", Content: "Hello"}}) if err == nil { t.Error("expected error for cancelled context") } } // ============================================================ // Stream Tests // ============================================================ func TestStream(t *testing.T) { chunks := []string{"Hello", "!", " How", " can", " I", " help?"} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/chat" { t.Errorf("unexpected path: %s", r.URL.Path) w.WriteHeader(http.StatusNotFound) return } w.Header().Set("Content-Type", "application/x-ndjson") w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { t.Fatal("expected http.Flusher") } for _, chunk := range chunks { resp := OllamaChatResponse{ Model: "gemma4:e4b", Message: Message{ Role: "assistant", Content: chunk, }, Done: false, } data, _ := json.Marshal(resp) w.Write(append(data, '\n')) flusher.Flush() } // Send done signal doneResp := OllamaChatResponse{ Model: "gemma4:e4b", Message: Message{ Role: "assistant", Content: "", }, Done: true, } data, _ := json.Marshal(doneResp) w.Write(append(data, '\n')) flusher.Flush() })) defer srv.Close() client := NewOllamaClient(WithBaseURL(srv.URL)) var received []string err := client.Stream(context.Background(), []Message{{Role: "user", Content: "Hi"}}, func(chunk string) error { received = append(received, chunk) return nil }) if err != nil { t.Fatalf("Stream failed: %v", err) } if len(received) != len(chunks) { t.Errorf("expected %d chunks, got %d", len(chunks), len(received)) } } func TestStreamHandlerError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/x-ndjson") w.WriteHeader(http.StatusOK) resp := OllamaChatResponse{ Model: "gemma4:e4b", Message: Message{ Role: "assistant", Content: "chunk", }, Done: false, } data, _ := json.Marshal(resp) w.Write(append(data, '\n')) if f, ok := w.(http.Flusher); ok { f.Flush() } })) defer srv.Close() client := NewOllamaClient(WithBaseURL(srv.URL)) err := client.Stream(context.Background(), []Message{{Role: "user", Content: "Hi"}}, func(chunk string) error { return &streamError{msg: "handler error"} }) if err == nil { t.Fatal("expected error from handler") } if !strings.Contains(err.Error(), "handler error") { t.Errorf("expected 'handler error', got: %v", err) } } type streamError struct{ msg string } func (e *streamError) Error() string { return e.msg } // ============================================================ // Embed Tests // ============================================================ func TestEmbed(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/embed" { t.Errorf("unexpected path: %s", r.URL.Path) w.WriteHeader(http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(OllamaEmbedResponse{ Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5}, }) })) defer srv.Close() client := NewOllamaClient(WithBaseURL(srv.URL)) resp, err := client.Embed(context.Background(), "test input") if err != nil { t.Fatalf("Embed failed: %v", err) } if len(resp.Embedding) != 5 { t.Errorf("expected 5 embedding values, got %d", len(resp.Embedding)) } if resp.Embedding[0] != 0.1 { t.Errorf("expected first value 0.1, got %f", resp.Embedding[0]) } } // ============================================================ // Tool Def Builder Tests // ============================================================ func TestBuildToolDefsFromMap(t *testing.T) { tools := []map[string]interface{}{ { "name": "exec", "description": "Execute a shell command", "parameters": map[string]interface{}{ "properties": map[string]interface{}{ "command": map[string]interface{}{ "type": "string", "description": "Command to run", "required": true, }, "timeout": map[string]interface{}{ "type": "number", "description": "Timeout in seconds", "required": false, }, }, }, }, } defs := BuildToolDefsFromMap(tools) if len(defs) != 1 { t.Fatalf("expected 1 tool def, got %d", len(defs)) } if defs[0].Function.Name != "exec" { t.Errorf("expected name 'exec', got %q", defs[0].Function.Name) } if len(defs[0].Function.Parameters.Required) != 1 { t.Errorf("expected 1 required parameter, got %d", len(defs[0].Function.Parameters.Required)) } if defs[0].Function.Parameters.Required[0] != "command" { t.Errorf("expected required 'command', got %q", defs[0].Function.Parameters.Required[0]) } } // ============================================================ // Mock LLM (for use by other tests) // ============================================================ // MockLLM is a configurable mock implementation of the LLM interface for testing. type MockLLM struct { ChatFunc func(ctx context.Context, messages []Message) (*Response, error) StreamFunc func(ctx context.Context, messages []Message, handler StreamHandler) error } func (m *MockLLM) Chat(ctx context.Context, messages []Message) (*Response, error) { if m.ChatFunc != nil { return m.ChatFunc(ctx, messages) } return &Response{Content: "mock response"}, nil } func (m *MockLLM) Stream(ctx context.Context, messages []Message, handler StreamHandler) error { if m.StreamFunc != nil { return m.StreamFunc(ctx, messages, handler) } return handler("mock stream response") }