197 lines
4.2 KiB
Go
197 lines
4.2 KiB
Go
package session
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type VectorStore struct {
|
|
db *sql.DB
|
|
enabled bool
|
|
}
|
|
|
|
func NewVectorStore(db *sql.DB) (*VectorStore, error) {
|
|
vs := &VectorStore{db: db}
|
|
if err := vs.initSchema(); err != nil {
|
|
return &VectorStore{db: db, enabled: false}, nil
|
|
}
|
|
vs.enabled = true
|
|
return vs, nil
|
|
}
|
|
|
|
func (vs *VectorStore) initSchema() error {
|
|
_, err := vs.db.Exec(`
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_main_messages USING vec0(
|
|
msg_id INTEGER PRIMARY KEY,
|
|
embedding FLOAT[1024]
|
|
)
|
|
`)
|
|
return err
|
|
}
|
|
|
|
func (vs *VectorStore) SaveEmbedding(msgID int64, embedding []float32) error {
|
|
if !vs.enabled {
|
|
return nil
|
|
}
|
|
|
|
if len(embedding) != 1024 {
|
|
return fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
|
|
}
|
|
|
|
embeddingStr := formatEmbedding(embedding)
|
|
_, err := vs.db.Exec(
|
|
"INSERT INTO vec_main_messages (msg_id, embedding) VALUES (?, ?)",
|
|
msgID, embeddingStr,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save embedding: %w", err)
|
|
}
|
|
|
|
_, err = vs.db.Exec(
|
|
"UPDATE main_messages SET has_embedding = TRUE WHERE id = ?",
|
|
msgID,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (vs *VectorStore) SaveLongTermEmbedding(memoryID int64, embedding []float32) error {
|
|
if !vs.enabled {
|
|
return nil
|
|
}
|
|
|
|
if len(embedding) != 1024 {
|
|
return fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
|
|
}
|
|
|
|
embeddingStr := formatEmbedding(embedding)
|
|
_, err := vs.db.Exec(
|
|
"INSERT INTO vec_long_term_memories (memory_id, embedding) VALUES (?, ?)",
|
|
memoryID, embeddingStr,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save long-term embedding: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (vs *VectorStore) SearchLongTermSimilar(embedding []float32, limit int) ([]struct {
|
|
MemoryID int64
|
|
Distance float64
|
|
}, error) {
|
|
if !vs.enabled {
|
|
return nil, nil
|
|
}
|
|
|
|
if len(embedding) != 1024 {
|
|
return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
|
|
}
|
|
|
|
embeddingStr := formatEmbedding(embedding)
|
|
rows, err := vs.db.Query(
|
|
`SELECT memory_id, distance FROM vec_long_term_memories
|
|
WHERE embedding MATCH ?
|
|
ORDER BY distance
|
|
LIMIT ?`,
|
|
embeddingStr, limit,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search long-term vectors: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var results []struct {
|
|
MemoryID int64
|
|
Distance float64
|
|
}
|
|
for rows.Next() {
|
|
var r struct {
|
|
MemoryID int64
|
|
Distance float64
|
|
}
|
|
if err := rows.Scan(&r.MemoryID, &r.Distance); err != nil {
|
|
return nil, err
|
|
}
|
|
results = append(results, r)
|
|
}
|
|
|
|
return results, rows.Err()
|
|
}
|
|
|
|
func (vs *VectorStore) SearchSimilar(embedding []float32, limit int) ([]int64, error) {
|
|
if !vs.enabled {
|
|
return []int64{}, nil
|
|
}
|
|
|
|
if len(embedding) != 1024 {
|
|
return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
|
|
}
|
|
|
|
embeddingStr := formatEmbedding(embedding)
|
|
rows, err := vs.db.Query(
|
|
`SELECT msg_id FROM vec_main_messages
|
|
WHERE embedding MATCH ?
|
|
ORDER BY distance
|
|
LIMIT ?`,
|
|
embeddingStr, limit,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search vectors: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var msgIDs []int64
|
|
for rows.Next() {
|
|
var id int64
|
|
if err := rows.Scan(&id); err != nil {
|
|
return nil, err
|
|
}
|
|
msgIDs = append(msgIDs, id)
|
|
}
|
|
|
|
return msgIDs, rows.Err()
|
|
}
|
|
|
|
func (vs *VectorStore) SearchSimilarInSession(sessionID string, embedding []float32, limit int) ([]int64, error) {
|
|
if !vs.enabled {
|
|
return []int64{}, nil
|
|
}
|
|
|
|
if len(embedding) != 1024 {
|
|
return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding))
|
|
}
|
|
|
|
embeddingStr := formatEmbedding(embedding)
|
|
rows, err := vs.db.Query(
|
|
`SELECT v.msg_id FROM vec_main_messages v
|
|
JOIN main_messages m ON v.msg_id = m.id
|
|
WHERE m.session_id = ? AND v.embedding MATCH ?
|
|
ORDER BY distance
|
|
LIMIT ?`,
|
|
sessionID, embeddingStr, limit,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search session vectors: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var msgIDs []int64
|
|
for rows.Next() {
|
|
var id int64
|
|
if err := rows.Scan(&id); err != nil {
|
|
return nil, err
|
|
}
|
|
msgIDs = append(msgIDs, id)
|
|
}
|
|
|
|
return msgIDs, rows.Err()
|
|
}
|
|
|
|
func formatEmbedding(embedding []float32) string {
|
|
parts := make([]string, len(embedding))
|
|
for i, v := range embedding {
|
|
parts[i] = fmt.Sprintf("%f", v)
|
|
}
|
|
return "[" + strings.Join(parts, ",") + "]"
|
|
}
|