package session import ( "context" "encoding/json" "fmt" "log" "strings" "sync" "time" "github.com/orca/orca/pkg/embedding" "github.com/orca/orca/pkg/llm" ) type MemoryManager struct { store *SQLiteStore vectorStore *VectorStore embedClient *embedding.Client llmBackend llm.LLM tokenBudget TokenBudget ownsStore bool embedQueue chan embedTask embedWg sync.WaitGroup embedCtx context.Context embedCancel context.CancelFunc embedCache map[string]*embedCacheEntry embedCacheMu sync.RWMutex embedCacheMax int embedCacheHit int64 embedCacheMiss int64 } type embedCacheEntry struct { embedding []float32 lastUsed time.Time } type embedTask struct { msgID int64 content string } type TokenBudget struct { Total int Working int ShortTerm int LongTerm int } type MemoryConfig struct { DBPath string EmbedConfig embedding.Config ModelWindow int } func NewMemoryManager(cfg MemoryConfig) (*MemoryManager, error) { store, err := NewSQLiteStore(cfg.DBPath) if err != nil { return nil, fmt.Errorf("failed to create store: %w", err) } vectorStore, err := NewVectorStore(store.DB()) if err != nil { store.Close() return nil, fmt.Errorf("failed to create vector store: %w", err) } budget := calculateBudget(cfg.ModelWindow) ctx, cancel := context.WithCancel(context.Background()) mm := &MemoryManager{ store: store, vectorStore: vectorStore, tokenBudget: budget, ownsStore: true, embedQueue: make(chan embedTask, 100), embedCtx: ctx, embedCancel: cancel, embedCache: make(map[string]*embedCacheEntry), embedCacheMax: 500, } if cfg.EmbedConfig.APIKey != "" { mm.embedClient = embedding.NewClient(cfg.EmbedConfig) mm.startEmbedWorker() } return mm, nil } func NewMemoryManagerWithStore(cfg MemoryConfig, store *SQLiteStore) (*MemoryManager, error) { vectorStore, err := NewVectorStore(store.DB()) if err != nil { return nil, fmt.Errorf("failed to create vector store: %w", err) } budget := calculateBudget(cfg.ModelWindow) ctx, cancel := context.WithCancel(context.Background()) mm := &MemoryManager{ store: store, vectorStore: vectorStore, tokenBudget: budget, ownsStore: false, embedQueue: make(chan embedTask, 100), embedCtx: ctx, embedCancel: cancel, embedCache: make(map[string]*embedCacheEntry), embedCacheMax: 500, } if cfg.EmbedConfig.APIKey != "" { mm.embedClient = embedding.NewClient(cfg.EmbedConfig) mm.startEmbedWorker() } return mm, nil } func calculateBudget(modelWindow int) TokenBudget { if modelWindow <= 0 { modelWindow = 8192 } total := int(float64(modelWindow) * 0.6) return TokenBudget{ Total: total, Working: int(float64(total) * 0.5), ShortTerm: int(float64(total) * 0.3), LongTerm: int(float64(total) * 0.2), } } func (mm *MemoryManager) Close() error { mm.embedCancel() mm.embedWg.Wait() if mm.ownsStore { return mm.store.Close() } return nil } func (mm *MemoryManager) SaveMessage(sessionID string, msg SessionMessage) error { if err := mm.store.Save(sessionID, msg); err != nil { return err } if mm.embedClient != nil && len(msg.Content) > 10 && (msg.Role == RoleUser || msg.Role == RoleAssistant) { msgID, err := mm.getLastMessageID(sessionID) if err == nil { select { case mm.embedQueue <- embedTask{msgID: msgID, content: msg.Content}: default: } } } return nil } func (mm *MemoryManager) getLastMessageID(sessionID string) (int64, error) { var id int64 err := mm.store.DB().QueryRow( "SELECT id FROM main_messages WHERE session_id = ? ORDER BY id DESC LIMIT 1", sessionID, ).Scan(&id) return id, err } func (mm *MemoryManager) startEmbedWorker() { mm.embedWg.Add(1) go func() { defer mm.embedWg.Done() batch := make([]embedTask, 0, 5) timer := time.NewTimer(5 * time.Second) defer timer.Stop() for { select { case task := <- mm.embedQueue: batch = append(batch, task) if len(batch) >= 5 { mm.processBatch(batch) batch = batch[:0] timer.Reset(5 * time.Second) } case <- timer.C: if len(batch) > 0 { mm.processBatch(batch) batch = batch[:0] } timer.Reset(5 * time.Second) case <- mm.embedCtx.Done(): if len(batch) > 0 { mm.processBatch(batch) } return } } }() } func (mm *MemoryManager) processBatch(tasks []embedTask) { texts := make([]string, len(tasks)) for i, t := range tasks { texts[i] = t.content } embeddings, err := mm.embedClient.Embed(texts) if err != nil { log.Printf("[memory] Embedding batch failed: %v", err) return } for i, emb := range embeddings { if err := mm.vectorStore.SaveEmbedding(tasks[i].msgID, emb); err != nil { log.Printf("[memory] Save embedding failed: %v", err) } } } func (mm *MemoryManager) GetWorkingMemory(sessionID string) ([]SessionMessage, error) { rows, err := mm.store.DB().Query( `SELECT role, content, timestamp, metadata FROM main_messages WHERE session_id = ? ORDER BY timestamp DESC`, sessionID, ) if err != nil { return nil, err } defer rows.Close() var messages []SessionMessage totalTokens := 0 for rows.Next() { var msg SessionMessage var timestampStr string var metadataStr string if err := rows.Scan(&msg.Role, &msg.Content, ×tampStr, &metadataStr); err != nil { continue } msg.Timestamp, _ = time.Parse(time.RFC3339, timestampStr) tokens := estimateTokens(msg.Content) if totalTokens+tokens > mm.tokenBudget.Working && len(messages) > 0 { break } totalTokens += tokens messages = append(messages, msg) } reverseMessages(messages) return messages, nil } func (mm *MemoryManager) getEmbeddingWithCache(query string) ([]float32, error) { if mm.embedClient == nil { return nil, fmt.Errorf("no embed client") } mm.embedCacheMu.RLock() if entry, ok := mm.embedCache[query]; ok { entry.lastUsed = time.Now() mm.embedCacheHit++ mm.embedCacheMu.RUnlock() return entry.embedding, nil } mm.embedCacheMu.RUnlock() embedding, err := mm.embedClient.EmbedSingle(query) if err != nil { return nil, err } mm.embedCacheMu.Lock() mm.embedCache[query] = &embedCacheEntry{ embedding: embedding, lastUsed: time.Now(), } mm.embedCacheMiss++ if len(mm.embedCache) > mm.embedCacheMax { mm.evictLRU() } mm.embedCacheMu.Unlock() return embedding, nil } func (mm *MemoryManager) evictLRU() { var oldestKey string var oldestTime time.Time first := true for k, v := range mm.embedCache { if first || v.lastUsed.Before(oldestTime) { oldestKey = k oldestTime = v.lastUsed first = false } } if oldestKey != "" { delete(mm.embedCache, oldestKey) } } func (mm *MemoryManager) GetShortTermMemory(sessionID string, query string) ([]string, error) { if query != "" && mm.embedClient != nil { embedding, err := mm.getEmbeddingWithCache(query) if err == nil { msgIDs, err := mm.vectorStore.SearchSimilarInSession(sessionID, embedding, 3) if err == nil && len(msgIDs) > 0 { return mm.loadMemoryContents(msgIDs) } } } rows, err := mm.store.DB().Query( `SELECT content FROM short_term_memories WHERE session_id = ? ORDER BY updated_at DESC LIMIT 3`, sessionID, ) if err != nil { return nil, err } defer rows.Close() var memories []string for rows.Next() { var content string if err := rows.Scan(&content); err != nil { continue } memories = append(memories, content) } return memories, rows.Err() } func (mm *MemoryManager) GetLongTermMemory(query string) ([]struct { ID int64 Content string Weight float64 }, error) { var results []struct { ID int64 Content string Weight float64 } if query != "" && mm.embedClient != nil { embedding, err := mm.getEmbeddingWithCache(query) if err == nil { vecResults, err := mm.vectorStore.SearchLongTermSimilar(embedding, 5) if err == nil && len(vecResults) > 0 { for _, vr := range vecResults { var content string var weight float64 err := mm.store.DB().QueryRow( "SELECT content, weight FROM long_term_memories WHERE id = ? AND archived = 0", vr.MemoryID, ).Scan(&content, &weight) if err == nil { results = append(results, struct { ID int64 Content string Weight float64 }{ID: vr.MemoryID, Content: content, Weight: weight}) } } } } } if len(results) == 0 { rows, err := mm.store.DB().Query( `SELECT id, content, weight FROM long_term_memories WHERE archived = 0 ORDER BY weight DESC, access_count DESC LIMIT 2`, ) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var r struct { ID int64 Content string Weight float64 } if err := rows.Scan(&r.ID, &r.Content, &r.Weight); err != nil { continue } results = append(results, r) } } now := time.Now() for _, r := range results { mm.store.DB().Exec( "UPDATE long_term_memories SET access_count = access_count + 1, last_accessed = ? WHERE id = ?", now, r.ID, ) } return results, nil } func (mm *MemoryManager) loadMemoryContents(msgIDs []int64) ([]string, error) { if len(msgIDs) == 0 { return nil, nil } placeholders := make([]string, len(msgIDs)) args := make([]interface{}, len(msgIDs)) for i, id := range msgIDs { placeholders[i] = "?" args[i] = id } query := fmt.Sprintf( "SELECT content FROM main_messages WHERE id IN (%s)", strings.Join(placeholders, ","), ) rows, err := mm.store.DB().Query(query, args...) if err != nil { return nil, err } defer rows.Close() var contents []string for rows.Next() { var content string if err := rows.Scan(&content); err != nil { continue } contents = append(contents, content) } return contents, rows.Err() } func (mm *MemoryManager) AddShortTermMemory(sessionID string, content string) error { _, err := mm.store.DB().Exec( `INSERT INTO short_term_memories (session_id, content, updated_at) VALUES (?, ?, ?) ON CONFLICT(session_id, content) DO UPDATE SET source_count = source_count + 1, updated_at = ?`, sessionID, content, time.Now(), time.Now(), ) return err } func (mm *MemoryManager) AddLongTermMemory(content string, memoryType string) error { _, err := mm.store.DB().Exec( `INSERT INTO long_term_memories (content, memory_type, confidence) VALUES (?, ?, ?) ON CONFLICT(content) DO UPDATE SET access_count = access_count + 1, confidence = MAX(confidence, ?)`, content, memoryType, 0.8, 0.8, ) return err } func (mm *MemoryManager) Cleanup() error { _, err := mm.store.DB().Exec( `DELETE FROM short_term_memories WHERE id NOT IN ( SELECT id FROM short_term_memories ORDER BY updated_at DESC LIMIT 10 )`, ) return err } func estimateTokens(text string) int { return len([]rune(text)) / 2 } func reverseMessages(msgs []SessionMessage) { for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } } func toInterfaceSlice(strings []string) []interface{} { result := make([]interface{}, len(strings)) for i, s := range strings { result[i] = s } return result } type MemoryContextStats struct { ShortTermCount int LongTermCount int TotalTokens int } func (mm *MemoryManager) BuildMemoryContextWithStats(sessionID string, query string) (string, MemoryContextStats) { var parts []string stats := MemoryContextStats{} shortTerm, err := mm.GetShortTermMemory(sessionID, query) if err == nil && len(shortTerm) > 0 { parts = append(parts, "## 相关上下文\n"+strings.Join(shortTerm, "\n")) stats.ShortTermCount = len(shortTerm) for _, m := range shortTerm { stats.TotalTokens += estimateTokens(m) } } longTerm, err := mm.GetLongTermMemory(query) if err == nil && len(longTerm) > 0 { var contents []string for _, m := range longTerm { contents = append(contents, m.Content) mm.RecordMemoryUsage(m.ID, sessionID, query, true) } parts = append(parts, "## 背景知识\n"+strings.Join(contents, "\n")) stats.LongTermCount = len(longTerm) for _, m := range longTerm { stats.TotalTokens += estimateTokens(m.Content) } } if len(parts) == 0 { return "", stats } return "## 记忆信息\n" + strings.Join(parts, "\n\n"), stats } func (mm *MemoryManager) BuildMemoryContext(sessionID string, query string) string { ctx, _ := mm.BuildMemoryContextWithStats(sessionID, query) return ctx } func (mm *MemoryManager) ShouldInjectMemory(sessionID string, query string) bool { if sessionID == "" { return false } msgCount, err := mm.getSessionMessageCount(sessionID) if err != nil || msgCount == 0 { return false } if len(query) < 10 { return false } return true } func (mm *MemoryManager) getSessionMessageCount(sessionID string) (int, error) { var count int err := mm.store.DB().QueryRow( "SELECT COUNT(*) FROM main_messages WHERE session_id = ?", sessionID, ).Scan(&count) return count, err } func (mm *MemoryManager) MaintainSessionMemory(sessionID string, userQuery string, assistantResponse string) { if len(assistantResponse) < 20 { return } summary := fmt.Sprintf("用户问:%s\n回答:%s", userQuery, truncateString(assistantResponse, 100)) if err := mm.AddShortTermMemory(sessionID, summary); err != nil { log.Printf("[memory] Failed to add short-term memory: %v", err) } mm.bufferDialogue(sessionID, userQuery, assistantResponse) if mm.shouldExtract(sessionID) { mm.triggerExtraction(sessionID) } } func (mm *MemoryManager) shouldExtract(sessionID string) bool { var count int err := mm.store.DB().QueryRow( "SELECT COUNT(*) FROM dialogue_buffer WHERE session_id = ?", sessionID, ).Scan(&count) if err != nil { return false } return count >= 5 } func (mm *MemoryManager) SetLLM(backend llm.LLM) { mm.llmBackend = backend } func (mm *MemoryManager) triggerExtraction(sessionID string) { dialogues, err := mm.FlushDialogueBuffer(sessionID) if err != nil || len(dialogues) == 0 { return } if mm.llmBackend == nil { log.Printf("[memory] No LLM backend configured, skipping extraction for session=%s", sessionID) return } go func() { facts, err := mm.extractFacts(dialogues) if err != nil { log.Printf("[memory] Extraction failed: %v", err) return } for _, fact := range facts { if fact.Confidence < 0.6 { continue } if err := mm.AddLongTermMemory(fact.Content, fact.Type); err != nil { log.Printf("[memory] Failed to save long-term memory: %v", err) } } }() } func (mm *MemoryManager) bufferDialogue(sessionID string, userQuery string, assistantResponse string) { _, err := mm.store.DB().Exec( "INSERT INTO dialogue_buffer (session_id, user_query, assistant_response) VALUES (?, ?, ?)", sessionID, userQuery, assistantResponse, ) if err != nil { log.Printf("[memory] Failed to buffer dialogue: %v", err) } } func (mm *MemoryManager) FlushDialogueBuffer(sessionID string) ([]struct { UserQuery string AssistantResponse string }, error) { rows, err := mm.store.DB().Query( "SELECT user_query, assistant_response FROM dialogue_buffer WHERE session_id = ? ORDER BY created_at ASC", sessionID, ) if err != nil { return nil, err } defer rows.Close() var dialogues []struct { UserQuery string AssistantResponse string } for rows.Next() { var d struct { UserQuery string AssistantResponse string } if err := rows.Scan(&d.UserQuery, &d.AssistantResponse); err != nil { continue } dialogues = append(dialogues, d) } if len(dialogues) > 0 { _, err = mm.store.DB().Exec("DELETE FROM dialogue_buffer WHERE session_id = ?", sessionID) if err != nil { log.Printf("[memory] Failed to clear dialogue buffer: %v", err) } } return dialogues, rows.Err() } func truncateString(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } func (mm *MemoryManager) RecordMemoryUsage(memoryID int64, sessionID, query string, referenced bool) error { _, err := mm.store.DB().Exec( "INSERT INTO memory_usage_log (memory_id, session_id, query, was_referenced) VALUES (?, ?, ?, ?)", memoryID, sessionID, query, referenced, ) if err != nil { return err } delta := 0.5 if !referenced { delta = -0.3 } _, err = mm.store.DB().Exec( "UPDATE long_term_memories SET weight = weight + ?, access_count = access_count + 1, last_accessed = ? WHERE id = ?", delta, time.Now(), memoryID, ) return err } func (mm *MemoryManager) ArchiveLowWeightMemories(threshold float64) (int, error) { result, err := mm.store.DB().Exec( "UPDATE long_term_memories SET archived = 1 WHERE weight < ? AND archived = 0", threshold, ) if err != nil { return 0, err } count, _ := result.RowsAffected() return int(count), nil } func (mm *MemoryManager) GetCoreMemories(minWeight float64) ([]struct { ID int64 Content string Weight float64 }, error) { rows, err := mm.store.DB().Query( "SELECT id, content, weight FROM long_term_memories WHERE weight >= ? AND archived = 0 ORDER BY weight DESC", minWeight, ) if err != nil { return nil, err } defer rows.Close() var memories []struct { ID int64 Content string Weight float64 } for rows.Next() { var m struct { ID int64 Content string Weight float64 } if err := rows.Scan(&m.ID, &m.Content, &m.Weight); err != nil { continue } memories = append(memories, m) } return memories, rows.Err() } func (mm *MemoryManager) CacheStats() (size int, hits, misses int64) { mm.embedCacheMu.RLock() defer mm.embedCacheMu.RUnlock() return len(mm.embedCache), mm.embedCacheHit, mm.embedCacheMiss } const memoryExtractionPrompt = `# Memory Extractor 你是一个专门从对话中提取用户信息的助手。将非结构化的对话转化为结构化的长期记忆。 ## 任务 分析给定的对话记录,提取以下类型的信息: 1. **事实 (fact)**:客观信息 - 工作:公司、职位、技术栈、行业 - 技术:擅长语言、框架偏好、架构经验 - 个人:教育背景、所在城市(仅用户明确提及) 2. **偏好 (preference)**:主观倾向 - 回答风格:简洁/详细/代码示例/架构图 - 技术偏好:语言、数据库、部署方式 - 沟通偏好:正式/ casual 3. **项目 (project)**:当前工作 - 项目名称、技术方案、当前阶段、遇到的挑战 ## 输出格式 只输出 JSON,不要任何解释: ` + "```json" + ` { "facts": [ { "content": "用户在电商公司担任后端工程师", "type": "fact", "confidence": 0.95 }, { "content": "用户偏好简洁的技术回答", "type": "preference", "confidence": 0.85 } ] } ` + "```" + ` ## 规则 - confidence < 0.6 的事实不输出 - 不猜测、不推断,只提取用户明确表达的信息` type extractedFact struct { Content string `json:"content"` Type string `json:"type"` Confidence float64 `json:"confidence"` } type extractionResult struct { Facts []extractedFact `json:"facts"` } func (mm *MemoryManager) extractFacts(dialogues []struct { UserQuery string AssistantResponse string }) ([]extractedFact, error) { var sb strings.Builder sb.WriteString(memoryExtractionPrompt) sb.WriteString("\n\n## 对话记录\n\n") for i, d := range dialogues { sb.WriteString(fmt.Sprintf("--- 对话 %d ---\n", i+1)) sb.WriteString(fmt.Sprintf("用户:%s\n", d.UserQuery)) sb.WriteString(fmt.Sprintf("助手:%s\n\n", d.AssistantResponse)) } messages := []llm.Message{ {Role: "system", Content: "你是一个专门从对话中提取用户信息的助手。"}, {Role: "user", Content: sb.String()}, } resp, err := mm.llmBackend.Chat(context.Background(), messages) if err != nil { return nil, fmt.Errorf("llm chat failed: %w", err) } return parseExtractionJSON(resp.Content) } func parseExtractionJSON(content string) ([]extractedFact, error) { content = extractJSONBlock(content) var result extractionResult if err := json.Unmarshal([]byte(content), &result); err != nil { return nil, fmt.Errorf("parse extraction json failed: %w", err) } return result.Facts, nil } func extractJSONBlock(content string) string { if idx := strings.Index(content, "```json"); idx != -1 { start := idx + 7 if end := strings.Index(content[start:], "```"); end != -1 { return strings.TrimSpace(content[start : start+end]) } } if idx := strings.Index(content, "```"); idx != -1 { start := idx + 3 if end := strings.Index(content[start:], "```"); end != -1 { return strings.TrimSpace(content[start : start+end]) } } return strings.TrimSpace(content) }