orca.ai/pkg/embedding/client.go
2026-05-12 00:09:01 +08:00

125 lines
2.5 KiB
Go

package embedding
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type Client struct {
apiKey string
baseURL string
model string
client *http.Client
}
type Config struct {
APIKey string
BaseURL string
Model string
Timeout time.Duration
}
func NewClient(cfg Config) *Client {
if cfg.BaseURL == "" {
cfg.BaseURL = "https://api.siliconflow.cn/v1"
}
if cfg.Model == "" {
cfg.Model = "Pro/BAAI/bge-m3"
}
if cfg.Timeout == 0 {
cfg.Timeout = 30 * time.Second
}
return &Client{
apiKey: cfg.APIKey,
baseURL: cfg.BaseURL,
model: cfg.Model,
client: &http.Client{Timeout: cfg.Timeout},
}
}
type embedRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
type embedResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func (c *Client) Embed(texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, fmt.Errorf("no texts to embed")
}
reqBody, err := json.Marshal(embedRequest{
Model: c.model,
Input: texts,
})
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", c.baseURL+"/embeddings", bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("embedding API returned %d: %s", resp.StatusCode, string(body))
}
var embedResp embedResponse
if err := json.Unmarshal(body, &embedResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if embedResp.Error != nil {
return nil, fmt.Errorf("embedding API error: %s", embedResp.Error.Message)
}
results := make([][]float32, len(texts))
for _, d := range embedResp.Data {
if d.Index < len(results) {
results[d.Index] = d.Embedding
}
}
return results, nil
}
func (c *Client) EmbedSingle(text string) ([]float32, error) {
results, err := c.Embed([]string{text})
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return results[0], nil
}