orca.ai/pkg/session/vector_store.go
2026-05-12 00:09:01 +08:00

197 lines
4.2 KiB
Go

package session
import (
"database/sql"
"fmt"
"strings"
)
type VectorStore struct {
db *sql.DB
enabled bool
}
func NewVectorStore(db *sql.DB) (*VectorStore, error) {
vs := &VectorStore{db: db}
if err := vs.initSchema(); err != nil {
return &VectorStore{db: db, enabled: false}, nil
}
vs.enabled = true
return vs, nil
}
func (vs *VectorStore) initSchema() error {
_, err := vs.db.Exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS vec_main_messages USING vec0(
msg_id INTEGER PRIMARY KEY,
embedding FLOAT[1024]
)
`)
return err
}
func (vs *VectorStore) SaveEmbedding(msgID int64, embedding []float32) error {
if !vs.enabled {
return nil
}
if len(embedding) != 1024 {
return fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
}
embeddingStr := formatEmbedding(embedding)
_, err := vs.db.Exec(
"INSERT INTO vec_main_messages (msg_id, embedding) VALUES (?, ?)",
msgID, embeddingStr,
)
if err != nil {
return fmt.Errorf("failed to save embedding: %w", err)
}
_, err = vs.db.Exec(
"UPDATE main_messages SET has_embedding = TRUE WHERE id = ?",
msgID,
)
return err
}
func (vs *VectorStore) SaveLongTermEmbedding(memoryID int64, embedding []float32) error {
if !vs.enabled {
return nil
}
if len(embedding) != 1024 {
return fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
}
embeddingStr := formatEmbedding(embedding)
_, err := vs.db.Exec(
"INSERT INTO vec_long_term_memories (memory_id, embedding) VALUES (?, ?)",
memoryID, embeddingStr,
)
if err != nil {
return fmt.Errorf("failed to save long-term embedding: %w", err)
}
return nil
}
func (vs *VectorStore) SearchLongTermSimilar(embedding []float32, limit int) ([]struct {
MemoryID int64
Distance float64
}, error) {
if !vs.enabled {
return nil, nil
}
if len(embedding) != 1024 {
return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
}
embeddingStr := formatEmbedding(embedding)
rows, err := vs.db.Query(
`SELECT memory_id, distance FROM vec_long_term_memories
WHERE embedding MATCH ?
ORDER BY distance
LIMIT ?`,
embeddingStr, limit,
)
if err != nil {
return nil, fmt.Errorf("failed to search long-term vectors: %w", err)
}
defer rows.Close()
var results []struct {
MemoryID int64
Distance float64
}
for rows.Next() {
var r struct {
MemoryID int64
Distance float64
}
if err := rows.Scan(&r.MemoryID, &r.Distance); err != nil {
return nil, err
}
results = append(results, r)
}
return results, rows.Err()
}
func (vs *VectorStore) SearchSimilar(embedding []float32, limit int) ([]int64, error) {
if !vs.enabled {
return []int64{}, nil
}
if len(embedding) != 1024 {
return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
}
embeddingStr := formatEmbedding(embedding)
rows, err := vs.db.Query(
`SELECT msg_id FROM vec_main_messages
WHERE embedding MATCH ?
ORDER BY distance
LIMIT ?`,
embeddingStr, limit,
)
if err != nil {
return nil, fmt.Errorf("failed to search vectors: %w", err)
}
defer rows.Close()
var msgIDs []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
msgIDs = append(msgIDs, id)
}
return msgIDs, rows.Err()
}
func (vs *VectorStore) SearchSimilarInSession(sessionID string, embedding []float32, limit int) ([]int64, error) {
if !vs.enabled {
return []int64{}, nil
}
if len(embedding) != 1024 {
return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
}
embeddingStr := formatEmbedding(embedding)
rows, err := vs.db.Query(
`SELECT v.msg_id FROM vec_main_messages v
JOIN main_messages m ON v.msg_id = m.id
WHERE m.session_id = ? AND v.embedding MATCH ?
ORDER BY distance
LIMIT ?`,
sessionID, embeddingStr, limit,
)
if err != nil {
return nil, fmt.Errorf("failed to search session vectors: %w", err)
}
defer rows.Close()
var msgIDs []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
msgIDs = append(msgIDs, id)
}
return msgIDs, rows.Err()
}
func formatEmbedding(embedding []float32) string {
parts := make([]string, len(embedding))
for i, v := range embedding {
parts[i] = fmt.Sprintf("%f", v)
}
return "[" + strings.Join(parts, ",") + "]"
}