871 lines
20 KiB
Go
871 lines
20 KiB
Go
package session
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/orca/orca/pkg/embedding"
|
||
"github.com/orca/orca/pkg/llm"
|
||
)
|
||
|
||
type MemoryManager struct {
|
||
store *SQLiteStore
|
||
vectorStore *VectorStore
|
||
embedClient *embedding.Client
|
||
llmBackend llm.LLM
|
||
tokenBudget TokenBudget
|
||
ownsStore bool
|
||
|
||
embedQueue chan embedTask
|
||
embedWg sync.WaitGroup
|
||
embedCtx context.Context
|
||
embedCancel context.CancelFunc
|
||
|
||
embedCache map[string]*embedCacheEntry
|
||
embedCacheMu sync.RWMutex
|
||
embedCacheMax int
|
||
embedCacheHit int64
|
||
embedCacheMiss int64
|
||
}
|
||
|
||
type embedCacheEntry struct {
|
||
embedding []float32
|
||
lastUsed time.Time
|
||
}
|
||
|
||
type embedTask struct {
|
||
msgID int64
|
||
content string
|
||
}
|
||
|
||
type TokenBudget struct {
|
||
Total int
|
||
Working int
|
||
ShortTerm int
|
||
LongTerm int
|
||
}
|
||
|
||
type MemoryConfig struct {
|
||
DBPath string
|
||
EmbedConfig embedding.Config
|
||
ModelWindow int
|
||
}
|
||
|
||
func NewMemoryManager(cfg MemoryConfig) (*MemoryManager, error) {
|
||
store, err := NewSQLiteStore(cfg.DBPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create store: %w", err)
|
||
}
|
||
|
||
vectorStore, err := NewVectorStore(store.DB())
|
||
if err != nil {
|
||
store.Close()
|
||
return nil, fmt.Errorf("failed to create vector store: %w", err)
|
||
}
|
||
|
||
budget := calculateBudget(cfg.ModelWindow)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
mm := &MemoryManager{
|
||
store: store,
|
||
vectorStore: vectorStore,
|
||
tokenBudget: budget,
|
||
ownsStore: true,
|
||
embedQueue: make(chan embedTask, 100),
|
||
embedCtx: ctx,
|
||
embedCancel: cancel,
|
||
embedCache: make(map[string]*embedCacheEntry),
|
||
embedCacheMax: 500,
|
||
}
|
||
|
||
if cfg.EmbedConfig.APIKey != "" {
|
||
mm.embedClient = embedding.NewClient(cfg.EmbedConfig)
|
||
mm.startEmbedWorker()
|
||
}
|
||
|
||
return mm, nil
|
||
}
|
||
|
||
func NewMemoryManagerWithStore(cfg MemoryConfig, store *SQLiteStore) (*MemoryManager, error) {
|
||
vectorStore, err := NewVectorStore(store.DB())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create vector store: %w", err)
|
||
}
|
||
|
||
budget := calculateBudget(cfg.ModelWindow)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
mm := &MemoryManager{
|
||
store: store,
|
||
vectorStore: vectorStore,
|
||
tokenBudget: budget,
|
||
ownsStore: false,
|
||
embedQueue: make(chan embedTask, 100),
|
||
embedCtx: ctx,
|
||
embedCancel: cancel,
|
||
embedCache: make(map[string]*embedCacheEntry),
|
||
embedCacheMax: 500,
|
||
}
|
||
|
||
if cfg.EmbedConfig.APIKey != "" {
|
||
mm.embedClient = embedding.NewClient(cfg.EmbedConfig)
|
||
mm.startEmbedWorker()
|
||
}
|
||
|
||
return mm, nil
|
||
}
|
||
|
||
func calculateBudget(modelWindow int) TokenBudget {
|
||
if modelWindow <= 0 {
|
||
modelWindow = 8192
|
||
}
|
||
total := int(float64(modelWindow) * 0.6)
|
||
return TokenBudget{
|
||
Total: total,
|
||
Working: int(float64(total) * 0.5),
|
||
ShortTerm: int(float64(total) * 0.3),
|
||
LongTerm: int(float64(total) * 0.2),
|
||
}
|
||
}
|
||
|
||
func (mm *MemoryManager) Close() error {
|
||
mm.embedCancel()
|
||
mm.embedWg.Wait()
|
||
if mm.ownsStore {
|
||
return mm.store.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (mm *MemoryManager) SaveMessage(sessionID string, msg SessionMessage) error {
|
||
if err := mm.store.Save(sessionID, msg); err != nil {
|
||
return err
|
||
}
|
||
|
||
if mm.embedClient != nil && len(msg.Content) > 10 &&
|
||
(msg.Role == RoleUser || msg.Role == RoleAssistant) {
|
||
msgID, err := mm.getLastMessageID(sessionID)
|
||
if err == nil {
|
||
select {
|
||
case mm.embedQueue <- embedTask{msgID: msgID, content: msg.Content}:
|
||
default:
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (mm *MemoryManager) getLastMessageID(sessionID string) (int64, error) {
|
||
var id int64
|
||
err := mm.store.DB().QueryRow(
|
||
"SELECT id FROM main_messages WHERE session_id = ? ORDER BY id DESC LIMIT 1",
|
||
sessionID,
|
||
).Scan(&id)
|
||
return id, err
|
||
}
|
||
|
||
func (mm *MemoryManager) startEmbedWorker() {
|
||
mm.embedWg.Add(1)
|
||
go func() {
|
||
defer mm.embedWg.Done()
|
||
batch := make([]embedTask, 0, 5)
|
||
timer := time.NewTimer(5 * time.Second)
|
||
defer timer.Stop()
|
||
|
||
for {
|
||
select {
|
||
case task := <- mm.embedQueue:
|
||
batch = append(batch, task)
|
||
if len(batch) >= 5 {
|
||
mm.processBatch(batch)
|
||
batch = batch[:0]
|
||
timer.Reset(5 * time.Second)
|
||
}
|
||
case <- timer.C:
|
||
if len(batch) > 0 {
|
||
mm.processBatch(batch)
|
||
batch = batch[:0]
|
||
}
|
||
timer.Reset(5 * time.Second)
|
||
case <- mm.embedCtx.Done():
|
||
if len(batch) > 0 {
|
||
mm.processBatch(batch)
|
||
}
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (mm *MemoryManager) processBatch(tasks []embedTask) {
|
||
texts := make([]string, len(tasks))
|
||
for i, t := range tasks {
|
||
texts[i] = t.content
|
||
}
|
||
|
||
embeddings, err := mm.embedClient.Embed(texts)
|
||
if err != nil {
|
||
log.Printf("[memory] Embedding batch failed: %v", err)
|
||
return
|
||
}
|
||
|
||
for i, emb := range embeddings {
|
||
if err := mm.vectorStore.SaveEmbedding(tasks[i].msgID, emb); err != nil {
|
||
log.Printf("[memory] Save embedding failed: %v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (mm *MemoryManager) GetWorkingMemory(sessionID string) ([]SessionMessage, error) {
|
||
rows, err := mm.store.DB().Query(
|
||
`SELECT role, content, timestamp, metadata FROM main_messages
|
||
WHERE session_id = ?
|
||
ORDER BY timestamp DESC`,
|
||
sessionID,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var messages []SessionMessage
|
||
totalTokens := 0
|
||
for rows.Next() {
|
||
var msg SessionMessage
|
||
var timestampStr string
|
||
var metadataStr string
|
||
if err := rows.Scan(&msg.Role, &msg.Content, ×tampStr, &metadataStr); err != nil {
|
||
continue
|
||
}
|
||
msg.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
||
|
||
tokens := estimateTokens(msg.Content)
|
||
if totalTokens+tokens > mm.tokenBudget.Working && len(messages) > 0 {
|
||
break
|
||
}
|
||
totalTokens += tokens
|
||
messages = append(messages, msg)
|
||
}
|
||
|
||
reverseMessages(messages)
|
||
return messages, nil
|
||
}
|
||
|
||
func (mm *MemoryManager) getEmbeddingWithCache(query string) ([]float32, error) {
|
||
if mm.embedClient == nil {
|
||
return nil, fmt.Errorf("no embed client")
|
||
}
|
||
|
||
mm.embedCacheMu.RLock()
|
||
if entry, ok := mm.embedCache[query]; ok {
|
||
entry.lastUsed = time.Now()
|
||
mm.embedCacheHit++
|
||
mm.embedCacheMu.RUnlock()
|
||
return entry.embedding, nil
|
||
}
|
||
mm.embedCacheMu.RUnlock()
|
||
|
||
embedding, err := mm.embedClient.EmbedSingle(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
mm.embedCacheMu.Lock()
|
||
mm.embedCache[query] = &embedCacheEntry{
|
||
embedding: embedding,
|
||
lastUsed: time.Now(),
|
||
}
|
||
mm.embedCacheMiss++
|
||
|
||
if len(mm.embedCache) > mm.embedCacheMax {
|
||
mm.evictLRU()
|
||
}
|
||
mm.embedCacheMu.Unlock()
|
||
|
||
return embedding, nil
|
||
}
|
||
|
||
func (mm *MemoryManager) evictLRU() {
|
||
var oldestKey string
|
||
var oldestTime time.Time
|
||
first := true
|
||
for k, v := range mm.embedCache {
|
||
if first || v.lastUsed.Before(oldestTime) {
|
||
oldestKey = k
|
||
oldestTime = v.lastUsed
|
||
first = false
|
||
}
|
||
}
|
||
if oldestKey != "" {
|
||
delete(mm.embedCache, oldestKey)
|
||
}
|
||
}
|
||
|
||
func (mm *MemoryManager) GetShortTermMemory(sessionID string, query string) ([]string, error) {
|
||
if query != "" && mm.embedClient != nil {
|
||
embedding, err := mm.getEmbeddingWithCache(query)
|
||
if err == nil {
|
||
msgIDs, err := mm.vectorStore.SearchSimilarInSession(sessionID, embedding, 3)
|
||
if err == nil && len(msgIDs) > 0 {
|
||
return mm.loadMemoryContents(msgIDs)
|
||
}
|
||
}
|
||
}
|
||
|
||
rows, err := mm.store.DB().Query(
|
||
`SELECT content FROM short_term_memories
|
||
WHERE session_id = ?
|
||
ORDER BY updated_at DESC
|
||
LIMIT 3`,
|
||
sessionID,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var memories []string
|
||
for rows.Next() {
|
||
var content string
|
||
if err := rows.Scan(&content); err != nil {
|
||
continue
|
||
}
|
||
memories = append(memories, content)
|
||
}
|
||
return memories, rows.Err()
|
||
}
|
||
|
||
func (mm *MemoryManager) GetLongTermMemory(query string) ([]struct {
|
||
ID int64
|
||
Content string
|
||
Weight float64
|
||
}, error) {
|
||
var results []struct {
|
||
ID int64
|
||
Content string
|
||
Weight float64
|
||
}
|
||
|
||
if query != "" && mm.embedClient != nil {
|
||
embedding, err := mm.getEmbeddingWithCache(query)
|
||
if err == nil {
|
||
vecResults, err := mm.vectorStore.SearchLongTermSimilar(embedding, 5)
|
||
if err == nil && len(vecResults) > 0 {
|
||
for _, vr := range vecResults {
|
||
var content string
|
||
var weight float64
|
||
err := mm.store.DB().QueryRow(
|
||
"SELECT content, weight FROM long_term_memories WHERE id = ? AND archived = 0",
|
||
vr.MemoryID,
|
||
).Scan(&content, &weight)
|
||
if err == nil {
|
||
results = append(results, struct {
|
||
ID int64
|
||
Content string
|
||
Weight float64
|
||
}{ID: vr.MemoryID, Content: content, Weight: weight})
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(results) == 0 {
|
||
rows, err := mm.store.DB().Query(
|
||
`SELECT id, content, weight FROM long_term_memories
|
||
WHERE archived = 0
|
||
ORDER BY weight DESC, access_count DESC
|
||
LIMIT 2`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
for rows.Next() {
|
||
var r struct {
|
||
ID int64
|
||
Content string
|
||
Weight float64
|
||
}
|
||
if err := rows.Scan(&r.ID, &r.Content, &r.Weight); err != nil {
|
||
continue
|
||
}
|
||
results = append(results, r)
|
||
}
|
||
}
|
||
|
||
now := time.Now()
|
||
for _, r := range results {
|
||
mm.store.DB().Exec(
|
||
"UPDATE long_term_memories SET access_count = access_count + 1, last_accessed = ? WHERE id = ?",
|
||
now, r.ID,
|
||
)
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
func (mm *MemoryManager) loadMemoryContents(msgIDs []int64) ([]string, error) {
|
||
if len(msgIDs) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
placeholders := make([]string, len(msgIDs))
|
||
args := make([]interface{}, len(msgIDs))
|
||
for i, id := range msgIDs {
|
||
placeholders[i] = "?"
|
||
args[i] = id
|
||
}
|
||
|
||
query := fmt.Sprintf(
|
||
"SELECT content FROM main_messages WHERE id IN (%s)",
|
||
strings.Join(placeholders, ","),
|
||
)
|
||
rows, err := mm.store.DB().Query(query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var contents []string
|
||
for rows.Next() {
|
||
var content string
|
||
if err := rows.Scan(&content); err != nil {
|
||
continue
|
||
}
|
||
contents = append(contents, content)
|
||
}
|
||
return contents, rows.Err()
|
||
}
|
||
|
||
func (mm *MemoryManager) AddShortTermMemory(sessionID string, content string) error {
|
||
_, err := mm.store.DB().Exec(
|
||
`INSERT INTO short_term_memories (session_id, content, updated_at)
|
||
VALUES (?, ?, ?)
|
||
ON CONFLICT(session_id, content) DO UPDATE SET
|
||
source_count = source_count + 1,
|
||
updated_at = ?`,
|
||
sessionID, content, time.Now(), time.Now(),
|
||
)
|
||
return err
|
||
}
|
||
|
||
func (mm *MemoryManager) AddLongTermMemory(content string, memoryType string) error {
|
||
_, err := mm.store.DB().Exec(
|
||
`INSERT INTO long_term_memories (content, memory_type, confidence)
|
||
VALUES (?, ?, ?)
|
||
ON CONFLICT(content) DO UPDATE SET
|
||
access_count = access_count + 1,
|
||
confidence = MAX(confidence, ?)`,
|
||
content, memoryType, 0.8, 0.8,
|
||
)
|
||
return err
|
||
}
|
||
|
||
func (mm *MemoryManager) Cleanup() error {
|
||
_, err := mm.store.DB().Exec(
|
||
`DELETE FROM short_term_memories
|
||
WHERE id NOT IN (
|
||
SELECT id FROM short_term_memories
|
||
ORDER BY updated_at DESC
|
||
LIMIT 10
|
||
)`,
|
||
)
|
||
return err
|
||
}
|
||
|
||
func estimateTokens(text string) int {
|
||
return len([]rune(text)) / 2
|
||
}
|
||
|
||
func reverseMessages(msgs []SessionMessage) {
|
||
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
||
msgs[i], msgs[j] = msgs[j], msgs[i]
|
||
}
|
||
}
|
||
|
||
func toInterfaceSlice(strings []string) []interface{} {
|
||
result := make([]interface{}, len(strings))
|
||
for i, s := range strings {
|
||
result[i] = s
|
||
}
|
||
return result
|
||
}
|
||
|
||
type MemoryContextStats struct {
|
||
ShortTermCount int
|
||
LongTermCount int
|
||
TotalTokens int
|
||
}
|
||
|
||
func (mm *MemoryManager) BuildMemoryContextWithStats(sessionID string, query string) (string, MemoryContextStats) {
|
||
var parts []string
|
||
stats := MemoryContextStats{}
|
||
|
||
shortTerm, err := mm.GetShortTermMemory(sessionID, query)
|
||
if err == nil && len(shortTerm) > 0 {
|
||
parts = append(parts, "## 相关上下文\n"+strings.Join(shortTerm, "\n"))
|
||
stats.ShortTermCount = len(shortTerm)
|
||
for _, m := range shortTerm {
|
||
stats.TotalTokens += estimateTokens(m)
|
||
}
|
||
}
|
||
|
||
longTerm, err := mm.GetLongTermMemory(query)
|
||
if err == nil && len(longTerm) > 0 {
|
||
var contents []string
|
||
for _, m := range longTerm {
|
||
contents = append(contents, m.Content)
|
||
mm.RecordMemoryUsage(m.ID, sessionID, query, true)
|
||
}
|
||
parts = append(parts, "## 背景知识\n"+strings.Join(contents, "\n"))
|
||
stats.LongTermCount = len(longTerm)
|
||
for _, m := range longTerm {
|
||
stats.TotalTokens += estimateTokens(m.Content)
|
||
}
|
||
}
|
||
|
||
if len(parts) == 0 {
|
||
return "", stats
|
||
}
|
||
|
||
return "## 记忆信息\n" + strings.Join(parts, "\n\n"), stats
|
||
}
|
||
|
||
func (mm *MemoryManager) BuildMemoryContext(sessionID string, query string) string {
|
||
ctx, _ := mm.BuildMemoryContextWithStats(sessionID, query)
|
||
return ctx
|
||
}
|
||
|
||
func (mm *MemoryManager) ShouldInjectMemory(sessionID string, query string) bool {
|
||
if sessionID == "" {
|
||
return false
|
||
}
|
||
|
||
msgCount, err := mm.getSessionMessageCount(sessionID)
|
||
if err != nil || msgCount == 0 {
|
||
return false
|
||
}
|
||
|
||
if len(query) < 10 {
|
||
return false
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
func (mm *MemoryManager) getSessionMessageCount(sessionID string) (int, error) {
|
||
var count int
|
||
err := mm.store.DB().QueryRow(
|
||
"SELECT COUNT(*) FROM main_messages WHERE session_id = ?",
|
||
sessionID,
|
||
).Scan(&count)
|
||
return count, err
|
||
}
|
||
|
||
func (mm *MemoryManager) MaintainSessionMemory(sessionID string, userQuery string, assistantResponse string) {
|
||
if len(assistantResponse) < 20 {
|
||
return
|
||
}
|
||
|
||
summary := fmt.Sprintf("用户问:%s\n回答:%s", userQuery, truncateString(assistantResponse, 100))
|
||
if err := mm.AddShortTermMemory(sessionID, summary); err != nil {
|
||
log.Printf("[memory] Failed to add short-term memory: %v", err)
|
||
}
|
||
|
||
mm.bufferDialogue(sessionID, userQuery, assistantResponse)
|
||
|
||
if mm.shouldExtract(sessionID) {
|
||
mm.triggerExtraction(sessionID)
|
||
}
|
||
}
|
||
|
||
func (mm *MemoryManager) shouldExtract(sessionID string) bool {
|
||
var count int
|
||
err := mm.store.DB().QueryRow(
|
||
"SELECT COUNT(*) FROM dialogue_buffer WHERE session_id = ?",
|
||
sessionID,
|
||
).Scan(&count)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return count >= 5
|
||
}
|
||
|
||
func (mm *MemoryManager) SetLLM(backend llm.LLM) {
|
||
mm.llmBackend = backend
|
||
}
|
||
|
||
func (mm *MemoryManager) triggerExtraction(sessionID string) {
|
||
dialogues, err := mm.FlushDialogueBuffer(sessionID)
|
||
if err != nil || len(dialogues) == 0 {
|
||
return
|
||
}
|
||
|
||
if mm.llmBackend == nil {
|
||
log.Printf("[memory] No LLM backend configured, skipping extraction for session=%s", sessionID)
|
||
return
|
||
}
|
||
|
||
go func() {
|
||
facts, err := mm.extractFacts(dialogues)
|
||
if err != nil {
|
||
log.Printf("[memory] Extraction failed: %v", err)
|
||
return
|
||
}
|
||
|
||
for _, fact := range facts {
|
||
if fact.Confidence < 0.6 {
|
||
continue
|
||
}
|
||
if err := mm.AddLongTermMemory(fact.Content, fact.Type); err != nil {
|
||
log.Printf("[memory] Failed to save long-term memory: %v", err)
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (mm *MemoryManager) bufferDialogue(sessionID string, userQuery string, assistantResponse string) {
|
||
_, err := mm.store.DB().Exec(
|
||
"INSERT INTO dialogue_buffer (session_id, user_query, assistant_response) VALUES (?, ?, ?)",
|
||
sessionID, userQuery, assistantResponse,
|
||
)
|
||
if err != nil {
|
||
log.Printf("[memory] Failed to buffer dialogue: %v", err)
|
||
}
|
||
}
|
||
|
||
func (mm *MemoryManager) FlushDialogueBuffer(sessionID string) ([]struct {
|
||
UserQuery string
|
||
AssistantResponse string
|
||
}, error) {
|
||
rows, err := mm.store.DB().Query(
|
||
"SELECT user_query, assistant_response FROM dialogue_buffer WHERE session_id = ? ORDER BY created_at ASC",
|
||
sessionID,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var dialogues []struct {
|
||
UserQuery string
|
||
AssistantResponse string
|
||
}
|
||
for rows.Next() {
|
||
var d struct {
|
||
UserQuery string
|
||
AssistantResponse string
|
||
}
|
||
if err := rows.Scan(&d.UserQuery, &d.AssistantResponse); err != nil {
|
||
continue
|
||
}
|
||
dialogues = append(dialogues, d)
|
||
}
|
||
|
||
if len(dialogues) > 0 {
|
||
_, err = mm.store.DB().Exec("DELETE FROM dialogue_buffer WHERE session_id = ?", sessionID)
|
||
if err != nil {
|
||
log.Printf("[memory] Failed to clear dialogue buffer: %v", err)
|
||
}
|
||
}
|
||
|
||
return dialogues, rows.Err()
|
||
}
|
||
|
||
func truncateString(s string, maxLen int) string {
|
||
if len(s) <= maxLen {
|
||
return s
|
||
}
|
||
return s[:maxLen] + "..."
|
||
}
|
||
|
||
func (mm *MemoryManager) RecordMemoryUsage(memoryID int64, sessionID, query string, referenced bool) error {
|
||
_, err := mm.store.DB().Exec(
|
||
"INSERT INTO memory_usage_log (memory_id, session_id, query, was_referenced) VALUES (?, ?, ?, ?)",
|
||
memoryID, sessionID, query, referenced,
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
delta := 0.5
|
||
if !referenced {
|
||
delta = -0.3
|
||
}
|
||
|
||
_, err = mm.store.DB().Exec(
|
||
"UPDATE long_term_memories SET weight = weight + ?, access_count = access_count + 1, last_accessed = ? WHERE id = ?",
|
||
delta, time.Now(), memoryID,
|
||
)
|
||
return err
|
||
}
|
||
|
||
func (mm *MemoryManager) ArchiveLowWeightMemories(threshold float64) (int, error) {
|
||
result, err := mm.store.DB().Exec(
|
||
"UPDATE long_term_memories SET archived = 1 WHERE weight < ? AND archived = 0",
|
||
threshold,
|
||
)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
count, _ := result.RowsAffected()
|
||
return int(count), nil
|
||
}
|
||
|
||
func (mm *MemoryManager) GetCoreMemories(minWeight float64) ([]struct {
|
||
ID int64
|
||
Content string
|
||
Weight float64
|
||
}, error) {
|
||
rows, err := mm.store.DB().Query(
|
||
"SELECT id, content, weight FROM long_term_memories WHERE weight >= ? AND archived = 0 ORDER BY weight DESC",
|
||
minWeight,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var memories []struct {
|
||
ID int64
|
||
Content string
|
||
Weight float64
|
||
}
|
||
for rows.Next() {
|
||
var m struct {
|
||
ID int64
|
||
Content string
|
||
Weight float64
|
||
}
|
||
if err := rows.Scan(&m.ID, &m.Content, &m.Weight); err != nil {
|
||
continue
|
||
}
|
||
memories = append(memories, m)
|
||
}
|
||
return memories, rows.Err()
|
||
}
|
||
|
||
func (mm *MemoryManager) CacheStats() (size int, hits, misses int64) {
|
||
mm.embedCacheMu.RLock()
|
||
defer mm.embedCacheMu.RUnlock()
|
||
return len(mm.embedCache), mm.embedCacheHit, mm.embedCacheMiss
|
||
}
|
||
|
||
const memoryExtractionPrompt = `# Memory Extractor
|
||
|
||
你是一个专门从对话中提取用户信息的助手。将非结构化的对话转化为结构化的长期记忆。
|
||
|
||
## 任务
|
||
|
||
分析给定的对话记录,提取以下类型的信息:
|
||
|
||
1. **事实 (fact)**:客观信息
|
||
- 工作:公司、职位、技术栈、行业
|
||
- 技术:擅长语言、框架偏好、架构经验
|
||
- 个人:教育背景、所在城市(仅用户明确提及)
|
||
|
||
2. **偏好 (preference)**:主观倾向
|
||
- 回答风格:简洁/详细/代码示例/架构图
|
||
- 技术偏好:语言、数据库、部署方式
|
||
- 沟通偏好:正式/ casual
|
||
|
||
3. **项目 (project)**:当前工作
|
||
- 项目名称、技术方案、当前阶段、遇到的挑战
|
||
|
||
## 输出格式
|
||
|
||
只输出 JSON,不要任何解释:
|
||
|
||
` + "```json" + `
|
||
{
|
||
"facts": [
|
||
{
|
||
"content": "用户在电商公司担任后端工程师",
|
||
"type": "fact",
|
||
"confidence": 0.95
|
||
},
|
||
{
|
||
"content": "用户偏好简洁的技术回答",
|
||
"type": "preference",
|
||
"confidence": 0.85
|
||
}
|
||
]
|
||
}
|
||
` + "```" + `
|
||
|
||
## 规则
|
||
|
||
- confidence < 0.6 的事实不输出
|
||
- 不猜测、不推断,只提取用户明确表达的信息`
|
||
|
||
type extractedFact struct {
|
||
Content string `json:"content"`
|
||
Type string `json:"type"`
|
||
Confidence float64 `json:"confidence"`
|
||
}
|
||
|
||
type extractionResult struct {
|
||
Facts []extractedFact `json:"facts"`
|
||
}
|
||
|
||
func (mm *MemoryManager) extractFacts(dialogues []struct {
|
||
UserQuery string
|
||
AssistantResponse string
|
||
}) ([]extractedFact, error) {
|
||
var sb strings.Builder
|
||
sb.WriteString(memoryExtractionPrompt)
|
||
sb.WriteString("\n\n## 对话记录\n\n")
|
||
for i, d := range dialogues {
|
||
sb.WriteString(fmt.Sprintf("--- 对话 %d ---\n", i+1))
|
||
sb.WriteString(fmt.Sprintf("用户:%s\n", d.UserQuery))
|
||
sb.WriteString(fmt.Sprintf("助手:%s\n\n", d.AssistantResponse))
|
||
}
|
||
|
||
messages := []llm.Message{
|
||
{Role: "system", Content: "你是一个专门从对话中提取用户信息的助手。"},
|
||
{Role: "user", Content: sb.String()},
|
||
}
|
||
|
||
resp, err := mm.llmBackend.Chat(context.Background(), messages)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("llm chat failed: %w", err)
|
||
}
|
||
|
||
return parseExtractionJSON(resp.Content)
|
||
}
|
||
|
||
func parseExtractionJSON(content string) ([]extractedFact, error) {
|
||
content = extractJSONBlock(content)
|
||
|
||
var result extractionResult
|
||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||
return nil, fmt.Errorf("parse extraction json failed: %w", err)
|
||
}
|
||
|
||
return result.Facts, nil
|
||
}
|
||
|
||
func extractJSONBlock(content string) string {
|
||
if idx := strings.Index(content, "```json"); idx != -1 {
|
||
start := idx + 7
|
||
if end := strings.Index(content[start:], "```"); end != -1 {
|
||
return strings.TrimSpace(content[start : start+end])
|
||
}
|
||
}
|
||
if idx := strings.Index(content, "```"); idx != -1 {
|
||
start := idx + 3
|
||
if end := strings.Index(content[start:], "```"); end != -1 {
|
||
return strings.TrimSpace(content[start : start+end])
|
||
}
|
||
}
|
||
return strings.TrimSpace(content)
|
||
}
|