orca.ai/pkg/session/memory_manager.go
2026-05-12 00:09:01 +08:00

871 lines
20 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package session
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/orca/orca/pkg/embedding"
"github.com/orca/orca/pkg/llm"
)
type MemoryManager struct {
store *SQLiteStore
vectorStore *VectorStore
embedClient *embedding.Client
llmBackend llm.LLM
tokenBudget TokenBudget
ownsStore bool
embedQueue chan embedTask
embedWg sync.WaitGroup
embedCtx context.Context
embedCancel context.CancelFunc
embedCache map[string]*embedCacheEntry
embedCacheMu sync.RWMutex
embedCacheMax int
embedCacheHit int64
embedCacheMiss int64
}
type embedCacheEntry struct {
embedding []float32
lastUsed time.Time
}
type embedTask struct {
msgID int64
content string
}
type TokenBudget struct {
Total int
Working int
ShortTerm int
LongTerm int
}
type MemoryConfig struct {
DBPath string
EmbedConfig embedding.Config
ModelWindow int
}
func NewMemoryManager(cfg MemoryConfig) (*MemoryManager, error) {
store, err := NewSQLiteStore(cfg.DBPath)
if err != nil {
return nil, fmt.Errorf("failed to create store: %w", err)
}
vectorStore, err := NewVectorStore(store.DB())
if err != nil {
store.Close()
return nil, fmt.Errorf("failed to create vector store: %w", err)
}
budget := calculateBudget(cfg.ModelWindow)
ctx, cancel := context.WithCancel(context.Background())
mm := &MemoryManager{
store: store,
vectorStore: vectorStore,
tokenBudget: budget,
ownsStore: true,
embedQueue: make(chan embedTask, 100),
embedCtx: ctx,
embedCancel: cancel,
embedCache: make(map[string]*embedCacheEntry),
embedCacheMax: 500,
}
if cfg.EmbedConfig.APIKey != "" {
mm.embedClient = embedding.NewClient(cfg.EmbedConfig)
mm.startEmbedWorker()
}
return mm, nil
}
func NewMemoryManagerWithStore(cfg MemoryConfig, store *SQLiteStore) (*MemoryManager, error) {
vectorStore, err := NewVectorStore(store.DB())
if err != nil {
return nil, fmt.Errorf("failed to create vector store: %w", err)
}
budget := calculateBudget(cfg.ModelWindow)
ctx, cancel := context.WithCancel(context.Background())
mm := &MemoryManager{
store: store,
vectorStore: vectorStore,
tokenBudget: budget,
ownsStore: false,
embedQueue: make(chan embedTask, 100),
embedCtx: ctx,
embedCancel: cancel,
embedCache: make(map[string]*embedCacheEntry),
embedCacheMax: 500,
}
if cfg.EmbedConfig.APIKey != "" {
mm.embedClient = embedding.NewClient(cfg.EmbedConfig)
mm.startEmbedWorker()
}
return mm, nil
}
func calculateBudget(modelWindow int) TokenBudget {
if modelWindow <= 0 {
modelWindow = 8192
}
total := int(float64(modelWindow) * 0.6)
return TokenBudget{
Total: total,
Working: int(float64(total) * 0.5),
ShortTerm: int(float64(total) * 0.3),
LongTerm: int(float64(total) * 0.2),
}
}
func (mm *MemoryManager) Close() error {
mm.embedCancel()
mm.embedWg.Wait()
if mm.ownsStore {
return mm.store.Close()
}
return nil
}
func (mm *MemoryManager) SaveMessage(sessionID string, msg SessionMessage) error {
if err := mm.store.Save(sessionID, msg); err != nil {
return err
}
if mm.embedClient != nil && len(msg.Content) > 10 &&
(msg.Role == RoleUser || msg.Role == RoleAssistant) {
msgID, err := mm.getLastMessageID(sessionID)
if err == nil {
select {
case mm.embedQueue <- embedTask{msgID: msgID, content: msg.Content}:
default:
}
}
}
return nil
}
func (mm *MemoryManager) getLastMessageID(sessionID string) (int64, error) {
var id int64
err := mm.store.DB().QueryRow(
"SELECT id FROM main_messages WHERE session_id = ? ORDER BY id DESC LIMIT 1",
sessionID,
).Scan(&id)
return id, err
}
func (mm *MemoryManager) startEmbedWorker() {
mm.embedWg.Add(1)
go func() {
defer mm.embedWg.Done()
batch := make([]embedTask, 0, 5)
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
for {
select {
case task := <- mm.embedQueue:
batch = append(batch, task)
if len(batch) >= 5 {
mm.processBatch(batch)
batch = batch[:0]
timer.Reset(5 * time.Second)
}
case <- timer.C:
if len(batch) > 0 {
mm.processBatch(batch)
batch = batch[:0]
}
timer.Reset(5 * time.Second)
case <- mm.embedCtx.Done():
if len(batch) > 0 {
mm.processBatch(batch)
}
return
}
}
}()
}
func (mm *MemoryManager) processBatch(tasks []embedTask) {
texts := make([]string, len(tasks))
for i, t := range tasks {
texts[i] = t.content
}
embeddings, err := mm.embedClient.Embed(texts)
if err != nil {
log.Printf("[memory] Embedding batch failed: %v", err)
return
}
for i, emb := range embeddings {
if err := mm.vectorStore.SaveEmbedding(tasks[i].msgID, emb); err != nil {
log.Printf("[memory] Save embedding failed: %v", err)
}
}
}
func (mm *MemoryManager) GetWorkingMemory(sessionID string) ([]SessionMessage, error) {
rows, err := mm.store.DB().Query(
`SELECT role, content, timestamp, metadata FROM main_messages
WHERE session_id = ?
ORDER BY timestamp DESC`,
sessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var messages []SessionMessage
totalTokens := 0
for rows.Next() {
var msg SessionMessage
var timestampStr string
var metadataStr string
if err := rows.Scan(&msg.Role, &msg.Content, &timestampStr, &metadataStr); err != nil {
continue
}
msg.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
tokens := estimateTokens(msg.Content)
if totalTokens+tokens > mm.tokenBudget.Working && len(messages) > 0 {
break
}
totalTokens += tokens
messages = append(messages, msg)
}
reverseMessages(messages)
return messages, nil
}
func (mm *MemoryManager) getEmbeddingWithCache(query string) ([]float32, error) {
if mm.embedClient == nil {
return nil, fmt.Errorf("no embed client")
}
mm.embedCacheMu.RLock()
if entry, ok := mm.embedCache[query]; ok {
entry.lastUsed = time.Now()
mm.embedCacheHit++
mm.embedCacheMu.RUnlock()
return entry.embedding, nil
}
mm.embedCacheMu.RUnlock()
embedding, err := mm.embedClient.EmbedSingle(query)
if err != nil {
return nil, err
}
mm.embedCacheMu.Lock()
mm.embedCache[query] = &embedCacheEntry{
embedding: embedding,
lastUsed: time.Now(),
}
mm.embedCacheMiss++
if len(mm.embedCache) > mm.embedCacheMax {
mm.evictLRU()
}
mm.embedCacheMu.Unlock()
return embedding, nil
}
func (mm *MemoryManager) evictLRU() {
var oldestKey string
var oldestTime time.Time
first := true
for k, v := range mm.embedCache {
if first || v.lastUsed.Before(oldestTime) {
oldestKey = k
oldestTime = v.lastUsed
first = false
}
}
if oldestKey != "" {
delete(mm.embedCache, oldestKey)
}
}
func (mm *MemoryManager) GetShortTermMemory(sessionID string, query string) ([]string, error) {
if query != "" && mm.embedClient != nil {
embedding, err := mm.getEmbeddingWithCache(query)
if err == nil {
msgIDs, err := mm.vectorStore.SearchSimilarInSession(sessionID, embedding, 3)
if err == nil && len(msgIDs) > 0 {
return mm.loadMemoryContents(msgIDs)
}
}
}
rows, err := mm.store.DB().Query(
`SELECT content FROM short_term_memories
WHERE session_id = ?
ORDER BY updated_at DESC
LIMIT 3`,
sessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var memories []string
for rows.Next() {
var content string
if err := rows.Scan(&content); err != nil {
continue
}
memories = append(memories, content)
}
return memories, rows.Err()
}
func (mm *MemoryManager) GetLongTermMemory(query string) ([]struct {
ID int64
Content string
Weight float64
}, error) {
var results []struct {
ID int64
Content string
Weight float64
}
if query != "" && mm.embedClient != nil {
embedding, err := mm.getEmbeddingWithCache(query)
if err == nil {
vecResults, err := mm.vectorStore.SearchLongTermSimilar(embedding, 5)
if err == nil && len(vecResults) > 0 {
for _, vr := range vecResults {
var content string
var weight float64
err := mm.store.DB().QueryRow(
"SELECT content, weight FROM long_term_memories WHERE id = ? AND archived = 0",
vr.MemoryID,
).Scan(&content, &weight)
if err == nil {
results = append(results, struct {
ID int64
Content string
Weight float64
}{ID: vr.MemoryID, Content: content, Weight: weight})
}
}
}
}
}
if len(results) == 0 {
rows, err := mm.store.DB().Query(
`SELECT id, content, weight FROM long_term_memories
WHERE archived = 0
ORDER BY weight DESC, access_count DESC
LIMIT 2`,
)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var r struct {
ID int64
Content string
Weight float64
}
if err := rows.Scan(&r.ID, &r.Content, &r.Weight); err != nil {
continue
}
results = append(results, r)
}
}
now := time.Now()
for _, r := range results {
mm.store.DB().Exec(
"UPDATE long_term_memories SET access_count = access_count + 1, last_accessed = ? WHERE id = ?",
now, r.ID,
)
}
return results, nil
}
func (mm *MemoryManager) loadMemoryContents(msgIDs []int64) ([]string, error) {
if len(msgIDs) == 0 {
return nil, nil
}
placeholders := make([]string, len(msgIDs))
args := make([]interface{}, len(msgIDs))
for i, id := range msgIDs {
placeholders[i] = "?"
args[i] = id
}
query := fmt.Sprintf(
"SELECT content FROM main_messages WHERE id IN (%s)",
strings.Join(placeholders, ","),
)
rows, err := mm.store.DB().Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var contents []string
for rows.Next() {
var content string
if err := rows.Scan(&content); err != nil {
continue
}
contents = append(contents, content)
}
return contents, rows.Err()
}
func (mm *MemoryManager) AddShortTermMemory(sessionID string, content string) error {
_, err := mm.store.DB().Exec(
`INSERT INTO short_term_memories (session_id, content, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(session_id, content) DO UPDATE SET
source_count = source_count + 1,
updated_at = ?`,
sessionID, content, time.Now(), time.Now(),
)
return err
}
func (mm *MemoryManager) AddLongTermMemory(content string, memoryType string) error {
_, err := mm.store.DB().Exec(
`INSERT INTO long_term_memories (content, memory_type, confidence)
VALUES (?, ?, ?)
ON CONFLICT(content) DO UPDATE SET
access_count = access_count + 1,
confidence = MAX(confidence, ?)`,
content, memoryType, 0.8, 0.8,
)
return err
}
func (mm *MemoryManager) Cleanup() error {
_, err := mm.store.DB().Exec(
`DELETE FROM short_term_memories
WHERE id NOT IN (
SELECT id FROM short_term_memories
ORDER BY updated_at DESC
LIMIT 10
)`,
)
return err
}
func estimateTokens(text string) int {
return len([]rune(text)) / 2
}
func reverseMessages(msgs []SessionMessage) {
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
}
func toInterfaceSlice(strings []string) []interface{} {
result := make([]interface{}, len(strings))
for i, s := range strings {
result[i] = s
}
return result
}
type MemoryContextStats struct {
ShortTermCount int
LongTermCount int
TotalTokens int
}
func (mm *MemoryManager) BuildMemoryContextWithStats(sessionID string, query string) (string, MemoryContextStats) {
var parts []string
stats := MemoryContextStats{}
shortTerm, err := mm.GetShortTermMemory(sessionID, query)
if err == nil && len(shortTerm) > 0 {
parts = append(parts, "## 相关上下文\n"+strings.Join(shortTerm, "\n"))
stats.ShortTermCount = len(shortTerm)
for _, m := range shortTerm {
stats.TotalTokens += estimateTokens(m)
}
}
longTerm, err := mm.GetLongTermMemory(query)
if err == nil && len(longTerm) > 0 {
var contents []string
for _, m := range longTerm {
contents = append(contents, m.Content)
mm.RecordMemoryUsage(m.ID, sessionID, query, true)
}
parts = append(parts, "## 背景知识\n"+strings.Join(contents, "\n"))
stats.LongTermCount = len(longTerm)
for _, m := range longTerm {
stats.TotalTokens += estimateTokens(m.Content)
}
}
if len(parts) == 0 {
return "", stats
}
return "## 记忆信息\n" + strings.Join(parts, "\n\n"), stats
}
func (mm *MemoryManager) BuildMemoryContext(sessionID string, query string) string {
ctx, _ := mm.BuildMemoryContextWithStats(sessionID, query)
return ctx
}
func (mm *MemoryManager) ShouldInjectMemory(sessionID string, query string) bool {
if sessionID == "" {
return false
}
msgCount, err := mm.getSessionMessageCount(sessionID)
if err != nil || msgCount == 0 {
return false
}
if len(query) < 10 {
return false
}
return true
}
func (mm *MemoryManager) getSessionMessageCount(sessionID string) (int, error) {
var count int
err := mm.store.DB().QueryRow(
"SELECT COUNT(*) FROM main_messages WHERE session_id = ?",
sessionID,
).Scan(&count)
return count, err
}
func (mm *MemoryManager) MaintainSessionMemory(sessionID string, userQuery string, assistantResponse string) {
if len(assistantResponse) < 20 {
return
}
summary := fmt.Sprintf("用户问:%s\n回答%s", userQuery, truncateString(assistantResponse, 100))
if err := mm.AddShortTermMemory(sessionID, summary); err != nil {
log.Printf("[memory] Failed to add short-term memory: %v", err)
}
mm.bufferDialogue(sessionID, userQuery, assistantResponse)
if mm.shouldExtract(sessionID) {
mm.triggerExtraction(sessionID)
}
}
func (mm *MemoryManager) shouldExtract(sessionID string) bool {
var count int
err := mm.store.DB().QueryRow(
"SELECT COUNT(*) FROM dialogue_buffer WHERE session_id = ?",
sessionID,
).Scan(&count)
if err != nil {
return false
}
return count >= 5
}
func (mm *MemoryManager) SetLLM(backend llm.LLM) {
mm.llmBackend = backend
}
func (mm *MemoryManager) triggerExtraction(sessionID string) {
dialogues, err := mm.FlushDialogueBuffer(sessionID)
if err != nil || len(dialogues) == 0 {
return
}
if mm.llmBackend == nil {
log.Printf("[memory] No LLM backend configured, skipping extraction for session=%s", sessionID)
return
}
go func() {
facts, err := mm.extractFacts(dialogues)
if err != nil {
log.Printf("[memory] Extraction failed: %v", err)
return
}
for _, fact := range facts {
if fact.Confidence < 0.6 {
continue
}
if err := mm.AddLongTermMemory(fact.Content, fact.Type); err != nil {
log.Printf("[memory] Failed to save long-term memory: %v", err)
}
}
}()
}
func (mm *MemoryManager) bufferDialogue(sessionID string, userQuery string, assistantResponse string) {
_, err := mm.store.DB().Exec(
"INSERT INTO dialogue_buffer (session_id, user_query, assistant_response) VALUES (?, ?, ?)",
sessionID, userQuery, assistantResponse,
)
if err != nil {
log.Printf("[memory] Failed to buffer dialogue: %v", err)
}
}
func (mm *MemoryManager) FlushDialogueBuffer(sessionID string) ([]struct {
UserQuery string
AssistantResponse string
}, error) {
rows, err := mm.store.DB().Query(
"SELECT user_query, assistant_response FROM dialogue_buffer WHERE session_id = ? ORDER BY created_at ASC",
sessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var dialogues []struct {
UserQuery string
AssistantResponse string
}
for rows.Next() {
var d struct {
UserQuery string
AssistantResponse string
}
if err := rows.Scan(&d.UserQuery, &d.AssistantResponse); err != nil {
continue
}
dialogues = append(dialogues, d)
}
if len(dialogues) > 0 {
_, err = mm.store.DB().Exec("DELETE FROM dialogue_buffer WHERE session_id = ?", sessionID)
if err != nil {
log.Printf("[memory] Failed to clear dialogue buffer: %v", err)
}
}
return dialogues, rows.Err()
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func (mm *MemoryManager) RecordMemoryUsage(memoryID int64, sessionID, query string, referenced bool) error {
_, err := mm.store.DB().Exec(
"INSERT INTO memory_usage_log (memory_id, session_id, query, was_referenced) VALUES (?, ?, ?, ?)",
memoryID, sessionID, query, referenced,
)
if err != nil {
return err
}
delta := 0.5
if !referenced {
delta = -0.3
}
_, err = mm.store.DB().Exec(
"UPDATE long_term_memories SET weight = weight + ?, access_count = access_count + 1, last_accessed = ? WHERE id = ?",
delta, time.Now(), memoryID,
)
return err
}
func (mm *MemoryManager) ArchiveLowWeightMemories(threshold float64) (int, error) {
result, err := mm.store.DB().Exec(
"UPDATE long_term_memories SET archived = 1 WHERE weight < ? AND archived = 0",
threshold,
)
if err != nil {
return 0, err
}
count, _ := result.RowsAffected()
return int(count), nil
}
func (mm *MemoryManager) GetCoreMemories(minWeight float64) ([]struct {
ID int64
Content string
Weight float64
}, error) {
rows, err := mm.store.DB().Query(
"SELECT id, content, weight FROM long_term_memories WHERE weight >= ? AND archived = 0 ORDER BY weight DESC",
minWeight,
)
if err != nil {
return nil, err
}
defer rows.Close()
var memories []struct {
ID int64
Content string
Weight float64
}
for rows.Next() {
var m struct {
ID int64
Content string
Weight float64
}
if err := rows.Scan(&m.ID, &m.Content, &m.Weight); err != nil {
continue
}
memories = append(memories, m)
}
return memories, rows.Err()
}
func (mm *MemoryManager) CacheStats() (size int, hits, misses int64) {
mm.embedCacheMu.RLock()
defer mm.embedCacheMu.RUnlock()
return len(mm.embedCache), mm.embedCacheHit, mm.embedCacheMiss
}
const memoryExtractionPrompt = `# Memory Extractor
你是一个专门从对话中提取用户信息的助手。将非结构化的对话转化为结构化的长期记忆。
## 任务
分析给定的对话记录,提取以下类型的信息:
1. **事实 (fact)**:客观信息
- 工作:公司、职位、技术栈、行业
- 技术:擅长语言、框架偏好、架构经验
- 个人:教育背景、所在城市(仅用户明确提及)
2. **偏好 (preference)**:主观倾向
- 回答风格:简洁/详细/代码示例/架构图
- 技术偏好:语言、数据库、部署方式
- 沟通偏好:正式/ casual
3. **项目 (project)**:当前工作
- 项目名称、技术方案、当前阶段、遇到的挑战
## 输出格式
只输出 JSON不要任何解释
` + "```json" + `
{
"facts": [
{
"content": "用户在电商公司担任后端工程师",
"type": "fact",
"confidence": 0.95
},
{
"content": "用户偏好简洁的技术回答",
"type": "preference",
"confidence": 0.85
}
]
}
` + "```" + `
## 规则
- confidence < 0.6 的事实不输出
- 不猜测、不推断,只提取用户明确表达的信息`
type extractedFact struct {
Content string `json:"content"`
Type string `json:"type"`
Confidence float64 `json:"confidence"`
}
type extractionResult struct {
Facts []extractedFact `json:"facts"`
}
func (mm *MemoryManager) extractFacts(dialogues []struct {
UserQuery string
AssistantResponse string
}) ([]extractedFact, error) {
var sb strings.Builder
sb.WriteString(memoryExtractionPrompt)
sb.WriteString("\n\n## 对话记录\n\n")
for i, d := range dialogues {
sb.WriteString(fmt.Sprintf("--- 对话 %d ---\n", i+1))
sb.WriteString(fmt.Sprintf("用户:%s\n", d.UserQuery))
sb.WriteString(fmt.Sprintf("助手:%s\n\n", d.AssistantResponse))
}
messages := []llm.Message{
{Role: "system", Content: "你是一个专门从对话中提取用户信息的助手。"},
{Role: "user", Content: sb.String()},
}
resp, err := mm.llmBackend.Chat(context.Background(), messages)
if err != nil {
return nil, fmt.Errorf("llm chat failed: %w", err)
}
return parseExtractionJSON(resp.Content)
}
func parseExtractionJSON(content string) ([]extractedFact, error) {
content = extractJSONBlock(content)
var result extractionResult
if err := json.Unmarshal([]byte(content), &result); err != nil {
return nil, fmt.Errorf("parse extraction json failed: %w", err)
}
return result.Facts, nil
}
func extractJSONBlock(content string) string {
if idx := strings.Index(content, "```json"); idx != -1 {
start := idx + 7
if end := strings.Index(content[start:], "```"); end != -1 {
return strings.TrimSpace(content[start : start+end])
}
}
if idx := strings.Index(content, "```"); idx != -1 {
start := idx + 3
if end := strings.Index(content[start:], "```"); end != -1 {
return strings.TrimSpace(content[start : start+end])
}
}
return strings.TrimSpace(content)
}