package session import ( "database/sql" "fmt" "strings" ) type VectorStore struct { db *sql.DB enabled bool } func NewVectorStore(db *sql.DB) (*VectorStore, error) { vs := &VectorStore{db: db} if err := vs.initSchema(); err != nil { return &VectorStore{db: db, enabled: false}, nil } vs.enabled = true return vs, nil } func (vs *VectorStore) initSchema() error { _, err := vs.db.Exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vec_main_messages USING vec0( msg_id INTEGER PRIMARY KEY, embedding FLOAT[1024] ) `) return err } func (vs *VectorStore) SaveEmbedding(msgID int64, embedding []float32) error { if !vs.enabled { return nil } if len(embedding) != 1024 { return fmt.Errorf("expected 1024 dimensions, got %d", len(embedding)) } embeddingStr := formatEmbedding(embedding) _, err := vs.db.Exec( "INSERT INTO vec_main_messages (msg_id, embedding) VALUES (?, ?)", msgID, embeddingStr, ) if err != nil { return fmt.Errorf("failed to save embedding: %w", err) } _, err = vs.db.Exec( "UPDATE main_messages SET has_embedding = TRUE WHERE id = ?", msgID, ) return err } func (vs *VectorStore) SaveLongTermEmbedding(memoryID int64, embedding []float32) error { if !vs.enabled { return nil } if len(embedding) != 1024 { return fmt.Errorf("expected 1024 dimensions, got %d", len(embedding)) } embeddingStr := formatEmbedding(embedding) _, err := vs.db.Exec( "INSERT INTO vec_long_term_memories (memory_id, embedding) VALUES (?, ?)", memoryID, embeddingStr, ) if err != nil { return fmt.Errorf("failed to save long-term embedding: %w", err) } return nil } func (vs *VectorStore) SearchLongTermSimilar(embedding []float32, limit int) ([]struct { MemoryID int64 Distance float64 }, error) { if !vs.enabled { return nil, nil } if len(embedding) != 1024 { return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding)) } embeddingStr := formatEmbedding(embedding) rows, err := vs.db.Query( `SELECT memory_id, distance FROM vec_long_term_memories WHERE embedding MATCH ? ORDER BY distance LIMIT ?`, embeddingStr, limit, ) if err != nil { return nil, fmt.Errorf("failed to search long-term vectors: %w", err) } defer rows.Close() var results []struct { MemoryID int64 Distance float64 } for rows.Next() { var r struct { MemoryID int64 Distance float64 } if err := rows.Scan(&r.MemoryID, &r.Distance); err != nil { return nil, err } results = append(results, r) } return results, rows.Err() } func (vs *VectorStore) SearchSimilar(embedding []float32, limit int) ([]int64, error) { if !vs.enabled { return []int64{}, nil } if len(embedding) != 1024 { return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding)) } embeddingStr := formatEmbedding(embedding) rows, err := vs.db.Query( `SELECT msg_id FROM vec_main_messages WHERE embedding MATCH ? ORDER BY distance LIMIT ?`, embeddingStr, limit, ) if err != nil { return nil, fmt.Errorf("failed to search vectors: %w", err) } defer rows.Close() var msgIDs []int64 for rows.Next() { var id int64 if err := rows.Scan(&id); err != nil { return nil, err } msgIDs = append(msgIDs, id) } return msgIDs, rows.Err() } func (vs *VectorStore) SearchSimilarInSession(sessionID string, embedding []float32, limit int) ([]int64, error) { if !vs.enabled { return []int64{}, nil } if len(embedding) != 1024 { return nil, fmt.Errorf("expected 1024 dimensions, got %d", len(embedding)) } embeddingStr := formatEmbedding(embedding) rows, err := vs.db.Query( `SELECT v.msg_id FROM vec_main_messages v JOIN main_messages m ON v.msg_id = m.id WHERE m.session_id = ? AND v.embedding MATCH ? ORDER BY distance LIMIT ?`, sessionID, embeddingStr, limit, ) if err != nil { return nil, fmt.Errorf("failed to search session vectors: %w", err) } defer rows.Close() var msgIDs []int64 for rows.Next() { var id int64 if err := rows.Scan(&id); err != nil { return nil, err } msgIDs = append(msgIDs, id) } return msgIDs, rows.Err() } func formatEmbedding(embedding []float32) string { parts := make([]string, len(embedding)) for i, v := range embedding { parts[i] = fmt.Sprintf("%f", v) } return "[" + strings.Join(parts, ",") + "]" }