From 05c8f46b2367e78ebda2bc92d08a2e059dac9a51 Mon Sep 17 00:00:00 2001 From: Ashish-dwi99 Date: Thu, 12 Feb 2026 02:08:12 +0530 Subject: [PATCH 1/8] fixed db issue --- engram/api/app.py | 5 ++++- engram/configs/base.py | 4 ++-- engram/mcp_server.py | 43 +++++++++++++++++++++++++++++------------- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/engram/api/app.py b/engram/api/app.py index 2d01586..f1260b5 100644 --- a/engram/api/app.py +++ b/engram/api/app.py @@ -110,7 +110,10 @@ def get_memory() -> Memory: if _memory is None: with _memory_lock: if _memory is None: - _memory = Memory() + # Re-use the MCP server's env-var-aware factory so the API + # honours ENGRAM_VECTOR_PROVIDER=sqlite_vec too. + from engram.mcp_server import get_memory_instance + _memory = get_memory_instance() return _memory diff --git a/engram/configs/base.py b/engram/configs/base.py index 56e5729..e67aa6f 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -12,10 +12,10 @@ class VectorStoreConfig(BaseModel): - provider: str = Field(default="qdrant") + provider: str = Field(default="sqlite_vec") config: Dict[str, Any] = Field( default_factory=lambda: { - "path": os.path.join(os.path.expanduser("~"), ".engram", "qdrant"), + "db_path": os.path.join(os.path.expanduser("~"), ".engram", "vectors.db"), "collection_name": "fadem_memories", } ) diff --git a/engram/mcp_server.py b/engram/mcp_server.py index 4b83c85..94d5f29 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -120,19 +120,36 @@ def get_memory_instance() -> Memory: "No API key found. Set GOOGLE_API_KEY, GEMINI_API_KEY, or OPENAI_API_KEY environment variable." ) - # Configure vector store - qdrant_path = os.environ.get( - "FADEM_QDRANT_PATH", - os.path.join(os.path.expanduser("~"), ".engram", "qdrant") - ) - vector_store_config = VectorStoreConfig( - provider="qdrant", - config={ - "path": qdrant_path, - "collection_name": os.environ.get("FADEM_COLLECTION", "fadem_memories"), - "embedding_model_dims": embedding_dims, - } - ) + # Configure vector store — honour ENGRAM_VECTOR_PROVIDER (default: sqlite_vec) + vector_provider = os.environ.get("ENGRAM_VECTOR_PROVIDER", "sqlite_vec") + collection = os.environ.get("FADEM_COLLECTION", "fadem_memories") + + if vector_provider == "sqlite_vec": + vec_db = os.environ.get( + "ENGRAM_SQLITE_VEC_PATH", + os.path.join(os.path.expanduser("~"), ".engram", "vectors.db"), + ) + vector_store_config = VectorStoreConfig( + provider="sqlite_vec", + config={ + "db_path": vec_db, + "collection_name": collection, + "embedding_model_dims": embedding_dims, + }, + ) + else: + qdrant_path = os.environ.get( + "FADEM_QDRANT_PATH", + os.path.join(os.path.expanduser("~"), ".engram", "qdrant"), + ) + vector_store_config = VectorStoreConfig( + provider=vector_provider, + config={ + "path": qdrant_path, + "collection_name": collection, + "embedding_model_dims": embedding_dims, + }, + ) # Configure history database history_db_path = os.environ.get( From ad4a1a14e6c1981ed07dee854ef3bf975f9a5bad Mon Sep 17 00:00:00 2001 From: Ashish-dwi99 Date: Sat, 14 Feb 2026 16:49:10 +0530 Subject: [PATCH 2/8] minor fix --- engram-bridge/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engram-bridge/pyproject.toml b/engram-bridge/pyproject.toml index b486b9f..fe2cb18 100644 --- a/engram-bridge/pyproject.toml +++ b/engram-bridge/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.10" license = {text = "MIT"} dependencies = [ "engram-bus>=0.1.0", - "engram>=0.4.0", + "engram-memory>=0.4.0", ] [project.optional-dependencies] From dc44beb625a553ba6611beb33185dcb637f0d97b Mon Sep 17 00:00:00 2001 From: Ashish-dwi99 Date: Tue, 17 Feb 2026 18:37:14 +0530 Subject: [PATCH 3/8] =?UTF-8?q?feat:=20Phase=204=20=E2=80=94=20three-tier?= =?UTF-8?q?=20memory=20architecture=20(CoreMemory=20=E2=86=92=20SmartMemor?= =?UTF-8?q?y=20=E2=86=92=20FullMemory)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces CoreMemory (zero-config, no LLM), SmartMemory (echo/categories/graph), and FullMemory (scenes/profiles/tasks) with CoreSQLiteManager, content-hash dedup, query cache, presets system, benchmark command, and 82 new/updated tests. Co-Authored-By: Claude Opus 4.6 --- engram/__init__.py | 33 +- engram/cli.py | 50 + engram/cli_mcp.py | 8 +- engram/cli_setup.py | 136 +-- engram/configs/base.py | 20 + engram/configs/presets.py | 144 +++ engram/core/decay.py | 7 +- engram/core/echo.py | 9 +- engram/core/retrieval.py | 24 +- engram/core/traces.py | 34 +- engram/db/sqlite.py | 531 ++++++++- engram/db/sqlite_backup.py | 2070 +++++++++++++++++++++++++++++++++ engram/embeddings/gemini.py | 92 +- engram/embeddings/simple.py | 52 +- engram/llms/gemini.py | 56 +- engram/mcp_server.py | 1401 ++++------------------ engram/memory/__init__.py | 13 +- engram/memory/core.py | 444 +++++++ engram/memory/main.py | 163 ++- engram/memory/smart.py | 344 ++++++ engram/utils/factory.py | 53 +- engram/utils/math.py | 40 +- pyproject.toml | 63 +- tests/test_accel.py | 26 +- tests/test_accel_benchmark.py | 13 +- tests/test_core_memory.py | 200 ++++ tests/test_dedup.py | 74 ++ tests/test_mcp_tools_slim.py | 29 + tests/test_presets.py | 48 + tests/test_query_cache.py | 42 + tests/test_smart_memory.py | 144 +++ 31 files changed, 4686 insertions(+), 1677 deletions(-) create mode 100644 engram/configs/presets.py create mode 100644 engram/db/sqlite_backup.py create mode 100644 engram/memory/core.py create mode 100644 engram/memory/smart.py create mode 100644 tests/test_core_memory.py create mode 100644 tests/test_dedup.py create mode 100644 tests/test_mcp_tools_slim.py create mode 100644 tests/test_presets.py create mode 100644 tests/test_query_cache.py create mode 100644 tests/test_smart_memory.py diff --git a/engram/__init__.py b/engram/__init__.py index d55a6ee..5aea6bc 100644 --- a/engram/__init__.py +++ b/engram/__init__.py @@ -4,26 +4,39 @@ - EchoMem: Multi-modal encoding for stronger retention - CategoryMem: Dynamic hierarchical category organization -Quick Start: - from engram import Engram +Quick Start (zero-config, no API key): + from engram import Memory + m = Memory() + m.add("User prefers Python") + results = m.search("programming preferences") - memory = Engram() - memory.add("User prefers Python", user_id="u123") - results = memory.search("programming preferences", user_id="u123") +Tiered Memory Classes: + CoreMemory — lightweight: add/search/delete + decay (no LLM) + SmartMemory — + echo encoding, categories, knowledge graph (needs LLM) + FullMemory — + scenes, profiles, tasks, projects (everything) + Memory — alias for CoreMemory (lightest default) """ +from engram.memory.core import CoreMemory +from engram.memory.smart import SmartMemory +from engram.memory.main import FullMemory from engram.simple import Engram -from engram.memory.main import Memory from engram.core.category import CategoryProcessor, Category, CategoryType, CategoryMatch from engram.core.echo import EchoProcessor, EchoDepth, EchoResult from engram.configs.base import MemoryConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig, ScopeConfig -__version__ = "0.5.0" +# Default: CoreMemory (lightest, zero-config) +Memory = CoreMemory + +__version__ = "0.6.0" __all__ = [ - # Simplified interface (recommended) - "Engram", - # Full interface + # Tiered memory classes + "CoreMemory", + "SmartMemory", + "FullMemory", "Memory", + # Simplified interface + "Engram", # CategoryMem "CategoryProcessor", "Category", diff --git a/engram/cli.py b/engram/cli.py index 8c576c2..06a2bc2 100644 --- a/engram/cli.py +++ b/engram/cli.py @@ -257,6 +257,56 @@ def cmd_uninstall(args: argparse.Namespace) -> None: print("Cancelled.") +def cmd_benchmark(args: argparse.Namespace) -> None: + """Run performance benchmarks.""" + import time + from engram import Memory + + print("=" * 60) + print(" engram benchmark") + print("=" * 60) + + # Cold start + print("\n[1/4] Cold start time...") + start = time.perf_counter() + m = Memory() + cold_start = time.perf_counter() - start + print(f" Cold start: {cold_start*1000:.1f} ms") + + # Add 100 memories + print("\n[2/4] Add 100 memories...") + start = time.perf_counter() + for i in range(100): + m.add(f"Test memory {i}: The quick brown fox jumps over the lazy dog.") + add_time = time.perf_counter() - start + print(f" Added 100 memories in {add_time*1000:.1f} ms ({add_time/100*1000:.2f} ms/mem)") + + # Search (cached) + print("\n[3/4] Search (cached embedding)...") + start = time.perf_counter() + for _ in range(10): + m.search("quick fox") + search_cached = time.perf_counter() - start + print(f" 10 searches (cached): {search_cached*1000:.1f} ms ({search_cached/10*1000:.2f} ms/search)") + + # Decay cycle + print("\n[4/4] Decay cycle...") + start = time.perf_counter() + m.apply_decay() + decay_time = time.perf_counter() - start + print(f" Decay cycle: {decay_time*1000:.1f} ms") + + # Summary table + print("\n" + "=" * 60) + print(" Results") + print("=" * 60) + print(f" Cold start: {cold_start*1000:7.1f} ms") + print(f" Add (100 mems): {add_time*1000:7.1f} ms ({add_time/100*1000:.2f} ms/mem)") + print(f" Search (cached): {search_cached/10*1000:7.2f} ms/search") + print(f" Decay cycle: {decay_time*1000:7.1f} ms") + print("=" * 60) + + # --------------------------------------------------------------------------- # Argument parser # --------------------------------------------------------------------------- diff --git a/engram/cli_mcp.py b/engram/cli_mcp.py index d5b10aa..3de92a4 100644 --- a/engram/cli_mcp.py +++ b/engram/cli_mcp.py @@ -66,13 +66,7 @@ def _read_toml_mcp_servers(path: str) -> Dict[str, Any]: """Read MCP servers from a TOML config file (for Codex).""" if not os.path.exists(path): return {} - try: - import tomllib - except ImportError: - try: - import tomli as tomllib - except ImportError: - return {} + import tomllib with open(path, "rb") as f: data = tomllib.load(f) return data.get("mcp_servers", data.get("mcpServers", {})) diff --git a/engram/cli_setup.py b/engram/cli_setup.py index 7ff2b97..addb4fb 100644 --- a/engram/cli_setup.py +++ b/engram/cli_setup.py @@ -1,6 +1,6 @@ -"""Interactive setup wizard for `engram setup`.""" +"""Auto-setup for `engram setup`. No prompts — detects environment and configures automatically.""" -import getpass +import logging import os import sys @@ -12,112 +12,49 @@ save_config, ) from engram.cli_mcp import configure_mcp_servers, detect_agents +from engram.utils.factory import _detect_provider - -PACKAGES = [ - ("engram-memory", "core memory layer with decay, encoding, scenes"), - ("engram-bus", "cross-agent coordination and handoff"), -] - -PROVIDERS = [ - ("gemini", "Google AI (recommended, free tier)"), - ("openai", "GPT models"), - ("nvidia", "Llama / Kimi, cloud hosted"), - ("ollama", "Local models, no API key needed"), -] - - -def _prompt_choice(label: str, options: list, default: int = 1) -> int: - """Prompt user to pick from numbered options. Returns 1-based index.""" - print(f"\n{label}") - for i, (name, desc) in enumerate(options, 1): - print(f" {i}. {name:16s} — {desc}") - while True: - raw = input(f"Enter number [{default}]: ").strip() - if not raw: - return default - try: - n = int(raw) - if 1 <= n <= len(options): - return n - except ValueError: - pass - print(f" Please enter 1-{len(options)}") - - -def _prompt_multi(label: str, options: list, default: str = "1") -> list: - """Prompt user to pick multiple (comma-separated). Returns list of 1-based indices.""" - print(f"\n{label}") - for i, (name, desc) in enumerate(options, 1): - print(f" {i}. {name:16s} — {desc}") - while True: - raw = input(f"Enter numbers [{default}]: ").strip() - if not raw: - raw = default - try: - nums = [int(x.strip()) for x in raw.split(",")] - if all(1 <= n <= len(options) for n in nums): - return nums - except ValueError: - pass - print(f" Enter comma-separated numbers, e.g. 1,2") - - -def _prompt_api_key(provider: str) -> str: - """Prompt for API key.""" - defaults = PROVIDER_DEFAULTS[provider] - env_var = defaults["env_var"] - existing = os.environ.get(env_var, "") - for alt in defaults.get("alt_env_vars", []): - existing = existing or os.environ.get(alt, "") - - if existing: - masked = existing[:4] + "..." + existing[-4:] if len(existing) > 8 else "****" - print(f"\n Found {env_var}={masked} in environment.") - use = input(" Use this key? [Y/n]: ").strip().lower() - if use in ("", "y", "yes"): - return existing - - print(f"\n Enter your {env_var} (input hidden):") - key = getpass.getpass(f" {env_var}: ") - if key: - print(f"\n To persist, add to your shell profile:") - print(f" export {env_var}={key[:4]}...{key[-4:]}") - os.environ[env_var] = key - return key +logger = logging.getLogger(__name__) def run_setup() -> None: - """Run the interactive setup wizard.""" + """Auto-detect environment and configure. No prompts.""" print("=" * 50) - print(" engram setup") + print(" engram setup (auto-detect)") print("=" * 50) config = load_config() - # 1. Package selection - pkg_indices = _prompt_multi("Which packages?", PACKAGES, default="1") - config["packages"] = [PACKAGES[i - 1][0] for i in pkg_indices] - - # 2. Provider selection - provider_idx = _prompt_choice("Which LLM provider?", PROVIDERS, default=1) - provider = PROVIDERS[provider_idx - 1][0] - config["provider"] = provider - - # 3. API key (skip for ollama) - if provider != "ollama": - key = _prompt_api_key(provider) - if not key: - print("\n Warning: No API key set. Memory operations will fail without it.") - else: - print("\n Ollama selected — no API key needed.") + # Auto-detect provider + embedder_provider, llm_provider = _detect_provider() + config["provider"] = embedder_provider + config["auto_configured"] = True + + if embedder_provider in ("gemini", "openai"): + defaults = PROVIDER_DEFAULTS.get(embedder_provider, {}) + env_var = defaults.get("env_var", f"{embedder_provider.upper()}_API_KEY") + key = os.environ.get(env_var, "") + for alt in defaults.get("alt_env_vars", []): + key = key or os.environ.get(alt, "") + if key: + masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "****" + print(f" Provider detected: {embedder_provider}") + print(f" API key found: {env_var}={masked}") + else: + print(f" Provider detected: {embedder_provider}") + print(f" ! No API key found — set {env_var} for full functionality") + elif embedder_provider == "ollama": + print(" Provider detected: ollama (local)") print(" Make sure Ollama is running: ollama serve") + else: + print(" Provider detected: simple (hash-based embedder)") + print(" No API key required. In-memory vector store for zero-config.") # Save config save_config(config) print(f"\n Config saved to {os.path.join(CONFIG_DIR, 'config.json')}") - # 4. Auto-configure MCP servers + # Auto-configure MCP servers agents = detect_agents() if agents: print(f"\n Detected agents: {', '.join(agents)}") @@ -126,14 +63,13 @@ def run_setup() -> None: for agent, status in results.items(): print(f" {agent}: {status}") else: - print("\n No agents detected. MCP will be configured when you install one.") + print("\n No agents detected. MCP will configure when you install one.") - # Done print("\n" + "=" * 50) - print(" Setup complete!") + print(" Setup complete!") print() - print(" Try:") - print(' engram add "User prefers dark mode"') - print(' engram search "preferences"') - print(" engram status") + print(" Try:") + print(' engram add "User prefers dark mode"') + print(' engram search "preferences"') + print(' engram status') print("=" * 50) diff --git a/engram/configs/base.py b/engram/configs/base.py index 2c7ef72..986f109 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -449,3 +449,23 @@ def _valid_dims(cls, v: int) -> int: if v < 1 or v > 65536: raise ValueError(f"embedding_model_dims must be 1-65536, got {v}") return v + + # ---- Preset factory methods ---- + + @classmethod + def minimal(cls) -> "MemoryConfig": + """Zero-config: hash embedder, in-memory vector store, basic decay. No API key.""" + from engram.configs.presets import minimal_config + return minimal_config() + + @classmethod + def smart(cls) -> "MemoryConfig": + """Auto-detect best available provider + echo + categories.""" + from engram.configs.presets import smart_config + return smart_config() + + @classmethod + def full(cls) -> "MemoryConfig": + """Everything: scenes, profiles, graph, tasks.""" + from engram.configs.presets import full_config + return full_config() diff --git a/engram/configs/presets.py b/engram/configs/presets.py new file mode 100644 index 0000000..1735fce --- /dev/null +++ b/engram/configs/presets.py @@ -0,0 +1,144 @@ +"""Preset factory methods for MemoryConfig. + +These provide ready-made configurations at different complexity levels: +- minimal: Zero-config, no API key needed (hash embedder, in-memory vectors) +- smart: Auto-detected provider + echo + categories + graph +- full: Everything including scenes, profiles, tasks +""" + +import os +import tempfile + + +def minimal_config(): + """Zero-config: hash embedder, in-memory vector store, basic decay. No API key.""" + from engram.configs.base import ( + CategoryMemConfig, + EchoMemConfig, + EmbedderConfig, + FadeMemConfig, + KnowledgeGraphConfig, + LLMConfig, + MemoryConfig, + SceneConfig, + ProfileConfig, + VectorStoreConfig, + ) + + data_dir = os.environ.get("ENGRAM_DATA_DIR", os.path.join(os.path.expanduser("~"), ".engram")) + os.makedirs(data_dir, exist_ok=True) + + return MemoryConfig( + embedder=EmbedderConfig( + provider="simple", + config={"embedding_dims": 384}, + ), + llm=LLMConfig(provider="mock", config={}), + vector_store=VectorStoreConfig( + provider="memory", + config={ + "collection_name": "engram_memories", + "embedding_model_dims": 384, + }, + ), + history_db_path=os.path.join(data_dir, "history.db"), + collection_name="engram_memories", + embedding_model_dims=384, + engram=FadeMemConfig(enable_forgetting=True), + echo=EchoMemConfig(enable_echo=False), + category=CategoryMemConfig(enable_categories=False), + graph=KnowledgeGraphConfig(enable_graph=False), + scene=SceneConfig(enable_scenes=False), + profile=ProfileConfig(enable_profiles=False), + ) + + +def smart_config(): + """Auto-detect best available provider + echo + categories. Needs API key or Ollama.""" + from engram.configs.base import ( + CategoryMemConfig, + EchoMemConfig, + EmbedderConfig, + FadeMemConfig, + KnowledgeGraphConfig, + LLMConfig, + MemoryConfig, + SceneConfig, + ProfileConfig, + VectorStoreConfig, + ) + from engram.utils.factory import _detect_provider + + embedder_provider, llm_provider = _detect_provider() + data_dir = os.environ.get("ENGRAM_DATA_DIR", os.path.join(os.path.expanduser("~"), ".engram")) + os.makedirs(data_dir, exist_ok=True) + + if embedder_provider == "simple": + dims = 384 + embedder_config = {"embedding_dims": dims} + elif embedder_provider == "gemini": + dims = 3072 + embedder_config = {"model": "gemini-embedding-001"} + elif embedder_provider == "openai": + dims = 1536 + embedder_config = {"model": "text-embedding-3-small"} + elif embedder_provider == "ollama": + dims = 768 + embedder_config = {} + else: + dims = 384 + embedder_config = {"embedding_dims": 384} + + # Use sqlite_vec for persistent storage when a real provider is available + use_sqlite_vec = embedder_provider != "simple" + if use_sqlite_vec: + vs = VectorStoreConfig( + provider="sqlite_vec", + config={ + "path": os.path.join(data_dir, "sqlite_vec.db"), + "collection_name": "engram_memories", + "embedding_model_dims": dims, + }, + ) + else: + vs = VectorStoreConfig( + provider="memory", + config={ + "collection_name": "engram_memories", + "embedding_model_dims": dims, + }, + ) + + # Echo/category need LLM — disable if using mock + has_llm = llm_provider != "mock" + + return MemoryConfig( + embedder=EmbedderConfig(provider=embedder_provider, config=embedder_config), + llm=LLMConfig(provider=llm_provider, config={}), + vector_store=vs, + history_db_path=os.path.join(data_dir, "history.db"), + collection_name="engram_memories", + embedding_model_dims=dims, + engram=FadeMemConfig(enable_forgetting=True), + echo=EchoMemConfig(enable_echo=has_llm), + category=CategoryMemConfig(enable_categories=has_llm), + graph=KnowledgeGraphConfig(enable_graph=True, use_llm_extraction=False), + scene=SceneConfig(enable_scenes=False), + profile=ProfileConfig(enable_profiles=False), + ) + + +def full_config(): + """Everything: scenes, profiles, graph, tasks. Needs API key or Ollama.""" + from engram.configs.base import ( + SceneConfig, + ProfileConfig, + ) + + config = smart_config() + config.scene = SceneConfig(enable_scenes=True) + config.profile = ProfileConfig(enable_profiles=True) + config.echo.enable_echo = True + config.category.enable_categories = True + config.graph.enable_graph = True + return config diff --git a/engram/core/decay.py b/engram/core/decay.py index f4e92e0..ee1be40 100644 --- a/engram/core/decay.py +++ b/engram/core/decay.py @@ -10,12 +10,7 @@ if TYPE_CHECKING: from engram.configs.base import FadeMemConfig -try: - from engram_accel import calculate_decayed_strength as _rs_decay -except ImportError: - def _rs_decay(strength, elapsed_days, decay_rate, access_count, dampening_factor): - dampening = 1.0 + dampening_factor * math.log1p(access_count) - return strength * math.exp(-decay_rate * elapsed_days / dampening) +from engram_accel import calculate_decayed_strength as _rs_decay def calculate_decayed_strength( diff --git a/engram/core/echo.py b/engram/core/echo.py index a2ada39..961914e 100644 --- a/engram/core/echo.py +++ b/engram/core/echo.py @@ -16,11 +16,7 @@ from functools import wraps from typing import Any, Callable, Dict, List, Optional, TypeVar -from pydantic import BaseModel, Field, ValidationError, field_validator -try: - from pydantic import ConfigDict -except ImportError: # pragma: no cover - fallback for older pydantic - ConfigDict = None +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from engram.utils.prompts import BATCH_ECHO_PROCESSING_PROMPT, ECHO_PROCESSING_PROMPT @@ -60,8 +56,7 @@ class EchoDepth(str, Enum): class EchoOutput(BaseModel): """Structured output from LLM for echo processing.""" - if ConfigDict: - model_config = ConfigDict(extra="ignore") + model_config = ConfigDict(extra="ignore") paraphrases: List[str] = Field(description="3-5 diverse rephrasings of the memory.") keywords: List[str] = Field(description="Core concepts and entities.") diff --git a/engram/core/retrieval.py b/engram/core/retrieval.py index da7b204..7e83ecd 100644 --- a/engram/core/retrieval.py +++ b/engram/core/retrieval.py @@ -6,15 +6,7 @@ import math from typing import Dict, List, Any, Optional, Set -try: - from engram_accel import tokenize as _rs_tokenize, bm25_score_batch as _rs_bm25_batch -except ImportError: - import re as _re - - def _rs_tokenize(text): - return _re.findall(r'\w+', text.lower()) - - _rs_bm25_batch = None +from engram_accel import tokenize as _rs_tokenize, bm25_score_batch as _rs_bm25_batch def composite_score(similarity: float, strength: float) -> float: @@ -72,18 +64,8 @@ def bm25_score_batch( k1: float = 1.5, b: float = 0.75, ) -> List[float]: - """Batch BM25 scoring for N documents (Rust-accelerated with Python fallback).""" - if _rs_bm25_batch is not None: - return _rs_bm25_batch(query_terms, documents, total_docs, avg_doc_len, k1, b) - query_set = set(query_terms) - return [ - calculate_bm25_score( - query_set, doc, - {t: sum(1 for d in documents if t in d) for t in query_set}, - total_docs, avg_doc_len, k1, b, - ) - for doc in documents - ] + """Batch BM25 scoring for N documents (Rust-accelerated).""" + return _rs_bm25_batch(query_terms, documents, total_docs, avg_doc_len, k1, b) def calculate_keyword_score( diff --git a/engram/core/traces.py b/engram/core/traces.py index 3209063..7fd200c 100644 --- a/engram/core/traces.py +++ b/engram/core/traces.py @@ -15,10 +15,7 @@ if TYPE_CHECKING: from engram.configs.base import DistillationConfig -try: - from engram_accel import decay_traces_batch as _rs_decay_traces_batch -except ImportError: - _rs_decay_traces_batch = None +from engram_accel import decay_traces_batch as _rs_decay_traces_batch def initialize_traces( @@ -81,26 +78,15 @@ def decay_traces_batch( access_counts: List[int], config: "DistillationConfig", ) -> List[Tuple[float, float, float]]: - """Batch version of decay_traces (Rust-accelerated with Python fallback).""" - if _rs_decay_traces_batch is not None: - return _rs_decay_traces_batch( - traces, - elapsed_days, - [int(a) for a in access_counts], - config.s_fast_decay_rate, - config.s_mid_decay_rate, - config.s_slow_decay_rate, - ) - # Python fallback - results = [] - for (sf, sm, ss), ed, ac in zip(traces, elapsed_days, access_counts): - dampening = 1.0 + 0.5 * math.log1p(ac) - results.append(( - max(0.0, min(1.0, sf * math.exp(-config.s_fast_decay_rate * ed / dampening))), - max(0.0, min(1.0, sm * math.exp(-config.s_mid_decay_rate * ed / dampening))), - max(0.0, min(1.0, ss * math.exp(-config.s_slow_decay_rate * ed / dampening))), - )) - return results + """Batch version of decay_traces (Rust-accelerated).""" + return _rs_decay_traces_batch( + traces, + elapsed_days, + [int(a) for a in access_counts], + config.s_fast_decay_rate, + config.s_mid_decay_rate, + config.s_slow_decay_rate, + ) def cascade_traces( diff --git a/engram/db/sqlite.py b/engram/db/sqlite.py index 4a91b1a..8c0bc31 100644 --- a/engram/db/sqlite.py +++ b/engram/db/sqlite.py @@ -18,7 +18,7 @@ "decay_lambda", "status", "importance", "sensitivity", "namespace", "access_count", "last_accessed", "immutable", "expiration_date", "scene_id", "user_id", "agent_id", "run_id", "app_id", - "memory_type", "s_fast", "s_mid", "s_slow", + "memory_type", "s_fast", "s_mid", "s_slow", "content_hash", }) VALID_SCENE_COLUMNS = frozenset({ @@ -44,7 +44,501 @@ def _utcnow_iso() -> str: return _utcnow().isoformat() -class SQLiteManager: +class _SQLiteBase: + """Base class for SQLite managers with common functionality.""" + + def __init__(self, db_path: str): + self.db_path = db_path + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + # Phase 1: Persistent connection with WAL mode. + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=5000") + self._conn.execute("PRAGMA synchronous=FULL") + self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache + self._conn.execute("PRAGMA temp_store=MEMORY") + self._conn.row_factory = sqlite3.Row + self._lock = threading.RLock() + + def close(self) -> None: + """Close the persistent connection for clean shutdown.""" + with self._lock: + if self._conn: + try: + self._conn.close() + except Exception: + pass + self._conn = None # type: ignore[assignment] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(db_path={self.db_path!r})" + + @contextmanager + def _get_connection(self): + """Yield the persistent connection under the thread lock.""" + with self._lock: + try: + yield self._conn + self._conn.commit() + except Exception: + self._conn.rollback() + raise + + def _is_migration_applied(self, conn: sqlite3.Connection, version: str) -> bool: + row = conn.execute( + "SELECT 1 FROM schema_migrations WHERE version = ?", + (version,), + ).fetchone() + return row is not None + + # Phase 5: Allowed table names for ALTER TABLE to prevent SQL injection. + _ALLOWED_TABLES = frozenset({ + "memories", "scenes", "profiles", "categories", + }) + + def _migrate_add_column_conn( + self, + conn: sqlite3.Connection, + table: str, + column: str, + col_type: str, + ) -> None: + """Add a column using an existing connection, if missing.""" + if table not in self._ALLOWED_TABLES: + raise ValueError(f"Invalid table for migration: {table!r}") + # Validate column name: must be alphanumeric/underscore only. + if not column.replace("_", "").isalnum(): + raise ValueError(f"Invalid column name: {column!r}") + try: + conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {col_type}") + except sqlite3.OperationalError: + pass + + @staticmethod + def _parse_json_value(value: Any, default: Any) -> Any: + if value is None: + return default + if isinstance(value, (dict, list)): + return value + try: + return json.loads(value) + except Exception: + return default + + +class CoreSQLiteManager(_SQLiteBase): + """Minimal SQLite manager for CoreMemory - only essential tables. + + Tables created: + - memories: core memory storage with content_hash for deduplication + - memory_history: audit trail for memory operations + - decay_log: decay cycle metrics + - schema_migrations: migration tracking + """ + + def __init__(self, db_path: str): + super().__init__(db_path) + self._init_db() + + def _init_db(self) -> None: + """Initialize minimal schema for CoreMemory.""" + with self._get_connection() as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TEXT DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + memory TEXT NOT NULL, + user_id TEXT, + agent_id TEXT, + run_id TEXT, + app_id TEXT, + metadata TEXT DEFAULT '{}', + categories TEXT DEFAULT '[]', + immutable INTEGER DEFAULT 0, + expiration_date TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP, + layer TEXT DEFAULT 'sml' CHECK (layer IN ('sml', 'lml')), + strength REAL DEFAULT 1.0, + access_count INTEGER DEFAULT 0, + last_accessed TEXT DEFAULT CURRENT_TIMESTAMP, + embedding TEXT, + related_memories TEXT DEFAULT '[]', + source_memories TEXT DEFAULT '[]', + tombstone INTEGER DEFAULT 0, + content_hash TEXT + ); + + CREATE INDEX IF NOT EXISTS idx_user_layer ON memories(user_id, layer); + CREATE INDEX IF NOT EXISTS idx_strength ON memories(strength DESC); + CREATE INDEX IF NOT EXISTS idx_tombstone ON memories(tombstone); + + CREATE TABLE IF NOT EXISTS memory_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + event TEXT NOT NULL, + old_value TEXT, + new_value TEXT, + old_strength REAL, + new_strength REAL, + old_layer TEXT, + new_layer TEXT, + timestamp TEXT DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS decay_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_at TEXT DEFAULT CURRENT_TIMESTAMP, + memories_decayed INTEGER, + memories_forgotten INTEGER, + memories_promoted INTEGER, + storage_before_mb REAL, + storage_after_mb REAL + ); + """ + ) + # Migrate content_hash column + index for pre-existing DBs + self._ensure_content_hash_column(conn) + + def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: + """Add content_hash column + index for SHA-256 dedup (idempotent).""" + if self._is_migration_applied(conn, "v2_content_hash"): + return + self._migrate_add_column_conn(conn, "memories", "content_hash", "TEXT") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id)" + ) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_content_hash')" + ) + + # Core memory operations + def add_memory(self, memory_data: Dict[str, Any]) -> str: + memory_id = memory_data.get("id", str(uuid.uuid4())) + now = _utcnow_iso() + metadata = memory_data.get("metadata", {}) or {} + + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO memories ( + id, memory, user_id, agent_id, run_id, app_id, + metadata, categories, immutable, expiration_date, + created_at, updated_at, layer, strength, access_count, + last_accessed, embedding, related_memories, source_memories, tombstone, + content_hash + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + memory_data.get("memory", ""), + memory_data.get("user_id"), + memory_data.get("agent_id"), + memory_data.get("run_id"), + memory_data.get("app_id"), + json.dumps(memory_data.get("metadata", {})), + json.dumps(memory_data.get("categories", [])), + 1 if memory_data.get("immutable", False) else 0, + memory_data.get("expiration_date"), + memory_data.get("created_at", now), + memory_data.get("updated_at", now), + memory_data.get("layer", "sml"), + memory_data.get("strength", 1.0), + memory_data.get("access_count", 0), + memory_data.get("last_accessed", now), + json.dumps(memory_data.get("embedding", [])), + json.dumps(memory_data.get("related_memories", [])), + json.dumps(memory_data.get("source_memories", [])), + 1 if memory_data.get("tombstone", False) else 0, + memory_data.get("content_hash"), + ), + ) + # Log the add event + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None), + ) + return memory_id + + def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Optional[Dict[str, Any]]: + query = "SELECT * FROM memories WHERE id = ?" + params = [memory_id] + if not include_tombstoned: + query += " AND tombstone = 0" + + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._row_to_dict(row) + return None + + def get_memory_by_content_hash( + self, content_hash: str, user_id: str = "default" + ) -> Optional[Dict[str, Any]]: + """Find an existing memory by content hash (for deduplication).""" + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE content_hash = ? AND user_id = ? AND tombstone = 0 LIMIT 1", + (content_hash, user_id), + ).fetchone() + if row: + return self._row_to_dict(row) + return None + + def get_all_memories( + self, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + app_id: Optional[str] = None, + layer: Optional[str] = None, + namespace: Optional[str] = None, + min_strength: float = 0.0, + include_tombstoned: bool = False, + limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM memories WHERE strength >= ?" + params: List[Any] = [min_strength] + + if not include_tombstoned: + query += " AND tombstone = 0" + if user_id: + query += " AND user_id = ?" + params.append(user_id) + if agent_id: + query += " AND agent_id = ?" + params.append(agent_id) + if run_id: + query += " AND run_id = ?" + params.append(run_id) + if app_id: + query += " AND app_id = ?" + params.append(app_id) + if layer: + query += " AND layer = ?" + params.append(layer) + + query += " ORDER BY strength DESC" + + if limit is not None and limit > 0: + query += " LIMIT ?" + params.append(limit) + + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_dict(row) for row in rows] + + def update_memory(self, memory_id: str, updates: Dict[str, Any]) -> bool: + set_clauses = [] + params: List[Any] = [] + for key, value in updates.items(): + if key not in VALID_MEMORY_COLUMNS: + raise ValueError(f"Invalid memory column: {key!r}") + if key in {"metadata", "categories", "embedding", "related_memories", "source_memories"}: + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + + set_clauses.append("updated_at = ?") + params.append(_utcnow_iso()) + params.append(memory_id) + + with self._get_connection() as conn: + old_row = conn.execute( + "SELECT memory, strength, layer FROM memories WHERE id = ?", + (memory_id,), + ).fetchone() + if not old_row: + return False + + conn.execute( + f"UPDATE memories SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + + # Log the update event + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + "UPDATE", + old_row["memory"], + updates.get("memory"), + old_row["strength"], + updates.get("strength"), + old_row["layer"], + updates.get("layer"), + ), + ) + return True + + def delete_memory(self, memory_id: str, use_tombstone: bool = True) -> bool: + if use_tombstone: + return self.update_memory(memory_id, {"tombstone": 1}) + with self._get_connection() as conn: + conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) + self._log_event(memory_id, "DELETE") + return True + + def increment_access(self, memory_id: str) -> None: + now = _utcnow_iso() + with self._get_connection() as conn: + conn.execute( + """ + UPDATE memories + SET access_count = access_count + 1, last_accessed = ? + WHERE id = ? + """, + (now, memory_id), + ) + + def increment_access_bulk(self, memory_ids: List[str]) -> None: + """Increment access count for multiple memories in a single transaction.""" + if not memory_ids: + return + now = _utcnow_iso() + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + conn.execute( + f""" + UPDATE memories + SET access_count = access_count + 1, last_accessed = ? + WHERE id IN ({placeholders}) + """, + [now] + list(memory_ids), + ) + + def get_memories_bulk( + self, memory_ids: List[str], include_tombstoned: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Fetch multiple memories by ID in a single query.""" + if not memory_ids: + return {} + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + query = f"SELECT * FROM memories WHERE id IN ({placeholders})" + if not include_tombstoned: + query += " AND tombstone = 0" + rows = conn.execute(query, memory_ids).fetchall() + return {row["id"]: self._row_to_dict(row) for row in rows} + + def update_strength_bulk(self, updates: Dict[str, float]) -> None: + """Batch-update strength for multiple memories.""" + if not updates: + return + now = _utcnow_iso() + with self._get_connection() as conn: + conn.executemany( + "UPDATE memories SET strength = ?, updated_at = ? WHERE id = ?", + [(strength, now, memory_id) for memory_id, strength in updates.items()], + ) + + _MEMORY_JSON_FIELDS = ("metadata", "categories", "related_memories", "source_memories") + + def _row_to_dict(self, row: sqlite3.Row, *, skip_embedding: bool = False) -> Dict[str, Any]: + data = dict(row) + for key in self._MEMORY_JSON_FIELDS: + if key in data and data[key]: + data[key] = json.loads(data[key]) + # Embedding is the largest JSON field (~30-50KB for 3072-dim vectors). + if skip_embedding: + data.pop("embedding", None) + elif "embedding" in data and data["embedding"]: + data["embedding"] = json.loads(data["embedding"]) + data["immutable"] = bool(data.get("immutable", 0)) + data["tombstone"] = bool(data.get("tombstone", 0)) + return data + + def _log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + event, + kwargs.get("old_value"), + kwargs.get("new_value"), + kwargs.get("old_strength"), + kwargs.get("new_strength"), + kwargs.get("old_layer"), + kwargs.get("new_layer"), + ), + ) + + def log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: + """Public wrapper for logging custom events like DECAY or FUSE.""" + self._log_event(memory_id, event, **kwargs) + + def get_history(self, memory_id: str) -> List[Dict[str, Any]]: + with self._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM memory_history WHERE memory_id = ? ORDER BY timestamp DESC", + (memory_id,), + ).fetchall() + return [dict(row) for row in rows] + + # Alias for CoreMemory compatibility + get_memory_history = get_history + + def log_decay( + self, + decayed: int, + forgotten: int, + promoted: int, + storage_before_mb: Optional[float] = None, + storage_after_mb: Optional[float] = None, + ) -> None: + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO decay_log (memories_decayed, memories_forgotten, memories_promoted, storage_before_mb, storage_after_mb) + VALUES (?, ?, ?, ?, ?) + """, + (decayed, forgotten, promoted, storage_before_mb, storage_after_mb), + ) + + def purge_tombstoned(self) -> int: + """Permanently delete all tombstoned memories.""" + with self._get_connection() as conn: + rows = conn.execute( + "SELECT id, user_id, memory FROM memories WHERE tombstone = 1" + ).fetchall() + count = len(rows) + if count > 0: + for row in rows: + self._log_event(row["id"], "PURGE", old_value=row["memory"]) + conn.execute("DELETE FROM memories WHERE tombstone = 1") + return count + + +# Backward compatibility alias +SQLiteManager = CoreSQLiteManager + + +class FullSQLiteManager(CoreSQLiteManager): def __init__(self, db_path: str): self.db_path = db_path db_dir = os.path.dirname(db_path) @@ -384,6 +878,21 @@ def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: # CLS Distillation Memory columns (idempotent). self._ensure_cls_columns(conn) + # Content-hash dedup column (idempotent). + self._ensure_content_hash_column(conn) + + def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: + """Add content_hash column + index for SHA-256 dedup.""" + if self._is_migration_applied(conn, "v2_content_hash"): + return + self._migrate_add_column_conn(conn, "memories", "content_hash", "TEXT") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id)" + ) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_content_hash')" + ) + def _ensure_cls_columns(self, conn: sqlite3.Connection) -> None: """Add CLS Distillation Memory columns to memories table (idempotent).""" if self._is_migration_applied(conn, "v2_cls_columns_complete"): @@ -453,8 +962,8 @@ def add_memory(self, memory_data: Dict[str, Any]) -> str: last_accessed, embedding, related_memories, source_memories, tombstone, confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, status, importance, sensitivity, - memory_type, s_fast, s_mid, s_slow - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + memory_type, s_fast, s_mid, s_slow, content_hash + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( memory_id, @@ -490,6 +999,7 @@ def add_memory(self, memory_data: Dict[str, Any]) -> str: memory_data.get("s_fast"), memory_data.get("s_mid"), memory_data.get("s_slow"), + memory_data.get("content_hash"), ), ) @@ -602,6 +1112,19 @@ def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Option return self._row_to_dict(row) return None + def get_memory_by_content_hash( + self, content_hash: str, user_id: str = "default" + ) -> Optional[Dict[str, Any]]: + """Find an existing memory by content hash (for deduplication).""" + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE content_hash = ? AND user_id = ? AND tombstone = 0 LIMIT 1", + (content_hash, user_id), + ).fetchone() + if row: + return self._row_to_dict(row) + return None + def get_memory_by_source_event( self, *, diff --git a/engram/db/sqlite_backup.py b/engram/db/sqlite_backup.py new file mode 100644 index 0000000..27224d8 --- /dev/null +++ b/engram/db/sqlite_backup.py @@ -0,0 +1,2070 @@ +import json +import logging +import os +import sqlite3 +import threading +import uuid +from contextlib import contextmanager +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# Phase 5: Allowed column names for dynamic UPDATE queries to prevent SQL injection. +VALID_MEMORY_COLUMNS = frozenset({ + "memory", "metadata", "categories", "embedding", "strength", + "layer", "tombstone", "updated_at", "related_memories", "source_memories", + "confidentiality_scope", "source_type", "source_app", "source_event_id", + "decay_lambda", "status", "importance", "sensitivity", "namespace", + "access_count", "last_accessed", "immutable", "expiration_date", + "scene_id", "user_id", "agent_id", "run_id", "app_id", + "memory_type", "s_fast", "s_mid", "s_slow", +}) + +VALID_SCENE_COLUMNS = frozenset({ + "title", "summary", "topic", "location", "participants", "memory_ids", + "start_time", "end_time", "embedding", "strength", "access_count", + "tombstone", "layer", "scene_strength", "topic_embedding_ref", "namespace", +}) + +VALID_PROFILE_COLUMNS = frozenset({ + "name", "profile_type", "narrative", "facts", "preferences", + "relationships", "sentiment", "theory_of_mind", "aliases", + "embedding", "strength", "updated_at", "role_bias", "profile_summary", +}) + + +def _utcnow() -> datetime: + """Return current UTC datetime (timezone-aware).""" + return datetime.now(timezone.utc) + + +def _utcnow_iso() -> str: + """Return current UTC time as ISO string.""" + return _utcnow().isoformat() + + +class _SQLiteBase: + """Base class for SQLite managers with common functionality.""" + + def __init__(self, db_path: str): + self.db_path = db_path + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + # Phase 1: Persistent connection with WAL mode. + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=5000") + self._conn.execute("PRAGMA synchronous=FULL") + self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache + self._conn.execute("PRAGMA temp_store=MEMORY") + self._conn.row_factory = sqlite3.Row + self._lock = threading.RLock() + + def close(self) -> None: + """Close the persistent connection for clean shutdown.""" + with self._lock: + if self._conn: + try: + self._conn.close() + except Exception: + pass + self._conn = None # type: ignore[assignment] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(db_path={self.db_path!r})" + + @contextmanager + def _get_connection(self): + """Yield the persistent connection under the thread lock.""" + with self._lock: + try: + yield self._conn + self._conn.commit() + except Exception: + self._conn.rollback() + raise + + def _is_migration_applied(self, conn: sqlite3.Connection, version: str) -> bool: + row = conn.execute( + "SELECT 1 FROM schema_migrations WHERE version = ?", + (version,), + ).fetchone() + return row is not None + + # Phase 5: Allowed table names for ALTER TABLE to prevent SQL injection. + _ALLOWED_TABLES = frozenset({ + "memories", "scenes", "profiles", "categories", + }) + + def _migrate_add_column_conn( + self, + conn: sqlite3.Connection, + table: str, + column: str, + col_type: str, + ) -> None: + """Add a column using an existing connection, if missing.""" + if table not in self._ALLOWED_TABLES: + raise ValueError(f"Invalid table for migration: {table!r}") + # Validate column name: must be alphanumeric/underscore only. + if not column.replace("_", "").isalnum(): + raise ValueError(f"Invalid column name: {column!r}") + try: + conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {col_type}") + except sqlite3.OperationalError: + pass + + @staticmethod + def _parse_json_value(value: Any, default: Any) -> Any: + if value is None: + return default + if isinstance(value, (dict, list)): + return value + try: + return json.loads(value) + except Exception: + return default + + +class CoreSQLiteManager(_SQLiteBase): + """Minimal SQLite manager for CoreMemory - only essential tables. + + Tables created: + - memories: core memory storage with content_hash for deduplication + - memory_history: audit trail for memory operations + - decay_log: decay cycle metrics + - schema_migrations: migration tracking + """ + + def __init__(self, db_path: str): + super().__init__(db_path) + self._init_db() + + def _init_db(self) -> None: + """Initialize minimal schema for CoreMemory.""" + with self._get_connection() as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + memory TEXT NOT NULL, + user_id TEXT, + agent_id TEXT, + run_id TEXT, + app_id TEXT, + metadata TEXT DEFAULT '{}', + categories TEXT DEFAULT '[]', + immutable INTEGER DEFAULT 0, + expiration_date TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP, + layer TEXT DEFAULT 'sml' CHECK (layer IN ('sml', 'lml')), + strength REAL DEFAULT 1.0, + access_count INTEGER DEFAULT 0, + last_accessed TEXT DEFAULT CURRENT_TIMESTAMP, + embedding TEXT, + related_memories TEXT DEFAULT '[]', + source_memories TEXT DEFAULT '[]', + tombstone INTEGER DEFAULT 0, + content_hash TEXT + ); + + CREATE INDEX IF NOT EXISTS idx_user_layer ON memories(user_id, layer); + CREATE INDEX IF NOT EXISTS idx_strength ON memories(strength DESC); + CREATE INDEX IF NOT EXISTS idx_tombstone ON memories(tombstone); + CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id); + + CREATE TABLE IF NOT EXISTS memory_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + event TEXT NOT NULL, + old_value TEXT, + new_value TEXT, + old_strength REAL, + new_strength REAL, + old_layer TEXT, + new_layer TEXT, + timestamp TEXT DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS decay_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_at TEXT DEFAULT CURRENT_TIMESTAMP, + memories_decayed INTEGER, + memories_forgotten INTEGER, + memories_promoted INTEGER, + storage_before_mb REAL, + storage_after_mb REAL + ); + + CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TEXT DEFAULT CURRENT_TIMESTAMP + ); + """ + ) + # Apply content_hash column migration if needed + self._ensure_content_hash_column(conn) + + def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: + """Add content_hash column + index for SHA-256 dedup (idempotent).""" + if self._is_migration_applied(conn, "v2_content_hash"): + return + self._migrate_add_column_conn(conn, "memories", "content_hash", "TEXT") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id)" + ) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_content_hash')" + ) + + # Core memory operations + def add_memory(self, memory_data: Dict[str, Any]) -> str: + memory_id = memory_data.get("id", str(uuid.uuid4())) + now = _utcnow_iso() + metadata = memory_data.get("metadata", {}) or {} + + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO memories ( + id, memory, user_id, agent_id, run_id, app_id, + metadata, categories, immutable, expiration_date, + created_at, updated_at, layer, strength, access_count, + last_accessed, embedding, related_memories, source_memories, tombstone, + content_hash + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + memory_data.get("memory", ""), + memory_data.get("user_id"), + memory_data.get("agent_id"), + memory_data.get("run_id"), + memory_data.get("app_id"), + json.dumps(memory_data.get("metadata", {})), + json.dumps(memory_data.get("categories", [])), + 1 if memory_data.get("immutable", False) else 0, + memory_data.get("expiration_date"), + memory_data.get("created_at", now), + memory_data.get("updated_at", now), + memory_data.get("layer", "sml"), + memory_data.get("strength", 1.0), + memory_data.get("access_count", 0), + memory_data.get("last_accessed", now), + json.dumps(memory_data.get("embedding", [])), + json.dumps(memory_data.get("related_memories", [])), + json.dumps(memory_data.get("source_memories", [])), + 1 if memory_data.get("tombstone", False) else 0, + memory_data.get("content_hash"), + ), + ) + # Log the add event + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None), + ) + return memory_id + + def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Optional[Dict[str, Any]]: + query = "SELECT * FROM memories WHERE id = ?" + params = [memory_id] + if not include_tombstoned: + query += " AND tombstone = 0" + + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._row_to_dict(row) + return None + + def get_memory_by_content_hash( + self, content_hash: str, user_id: str = "default" + ) -> Optional[Dict[str, Any]]: + """Find an existing memory by content hash (for deduplication).""" + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE content_hash = ? AND user_id = ? AND tombstone = 0 LIMIT 1", + (content_hash, user_id), + ).fetchone() + if row: + return self._row_to_dict(row) + return None + + def get_all_memories( + self, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + app_id: Optional[str] = None, + layer: Optional[str] = None, + namespace: Optional[str] = None, + min_strength: float = 0.0, + include_tombstoned: bool = False, + limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM memories WHERE strength >= ?" + params: List[Any] = [min_strength] + + if not include_tombstoned: + query += " AND tombstone = 0" + if user_id: + query += " AND user_id = ?" + params.append(user_id) + if agent_id: + query += " AND agent_id = ?" + params.append(agent_id) + if run_id: + query += " AND run_id = ?" + params.append(run_id) + if app_id: + query += " AND app_id = ?" + params.append(app_id) + if layer: + query += " AND layer = ?" + params.append(layer) + + query += " ORDER BY strength DESC" + + if limit is not None and limit > 0: + query += " LIMIT ?" + params.append(limit) + + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_dict(row) for row in rows] + + def update_memory(self, memory_id: str, updates: Dict[str, Any]) -> bool: + set_clauses = [] + params: List[Any] = [] + for key, value in updates.items(): + if key not in VALID_MEMORY_COLUMNS: + raise ValueError(f"Invalid memory column: {key!r}") + if key in {"metadata", "categories", "embedding", "related_memories", "source_memories"}: + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + + set_clauses.append("updated_at = ?") + params.append(_utcnow_iso()) + params.append(memory_id) + + with self._get_connection() as conn: + old_row = conn.execute( + "SELECT memory, strength, layer FROM memories WHERE id = ?", + (memory_id,), + ).fetchone() + if not old_row: + return False + + conn.execute( + f"UPDATE memories SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + + # Log the update event + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + "UPDATE", + old_row["memory"], + updates.get("memory"), + old_row["strength"], + updates.get("strength"), + old_row["layer"], + updates.get("layer"), + ), + ) + return True + + def delete_memory(self, memory_id: str, use_tombstone: bool = True) -> bool: + if use_tombstone: + return self.update_memory(memory_id, {"tombstone": 1}) + with self._get_connection() as conn: + conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) + self._log_event(memory_id, "DELETE") + return True + + def increment_access(self, memory_id: str) -> None: + now = _utcnow_iso() + with self._get_connection() as conn: + conn.execute( + """ + UPDATE memories + SET access_count = access_count + 1, last_accessed = ? + WHERE id = ? + """, + (now, memory_id), + ) + + def increment_access_bulk(self, memory_ids: List[str]) -> None: + """Increment access count for multiple memories in a single transaction.""" + if not memory_ids: + return + now = _utcnow_iso() + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + conn.execute( + f""" + UPDATE memories + SET access_count = access_count + 1, last_accessed = ? + WHERE id IN ({placeholders}) + """, + [now] + list(memory_ids), + ) + + def get_memories_bulk( + self, memory_ids: List[str], include_tombstoned: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Fetch multiple memories by ID in a single query.""" + if not memory_ids: + return {} + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + query = f"SELECT * FROM memories WHERE id IN ({placeholders})" + if not include_tombstoned: + query += " AND tombstone = 0" + rows = conn.execute(query, memory_ids).fetchall() + return {row["id"]: self._row_to_dict(row) for row in rows} + + def update_strength_bulk(self, updates: Dict[str, float]) -> None: + """Batch-update strength for multiple memories.""" + if not updates: + return + now = _utcnow_iso() + with self._get_connection() as conn: + conn.executemany( + "UPDATE memories SET strength = ?, updated_at = ? WHERE id = ?", + [(strength, now, memory_id) for memory_id, strength in updates.items()], + ) + + _MEMORY_JSON_FIELDS = ("metadata", "categories", "related_memories", "source_memories") + + def _row_to_dict(self, row: sqlite3.Row, *, skip_embedding: bool = False) -> Dict[str, Any]: + data = dict(row) + for key in self._MEMORY_JSON_FIELDS: + if key in data and data[key]: + data[key] = json.loads(data[key]) + # Embedding is the largest JSON field (~30-50KB for 3072-dim vectors). + if skip_embedding: + data.pop("embedding", None) + elif "embedding" in data and data["embedding"]: + data["embedding"] = json.loads(data["embedding"]) + data["immutable"] = bool(data.get("immutable", 0)) + data["tombstone"] = bool(data.get("tombstone", 0)) + return data + + def _log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + event, + kwargs.get("old_value"), + kwargs.get("new_value"), + kwargs.get("old_strength"), + kwargs.get("new_strength"), + kwargs.get("old_layer"), + kwargs.get("new_layer"), + ), + ) + + def log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: + """Public wrapper for logging custom events like DECAY or FUSE.""" + self._log_event(memory_id, event, **kwargs) + + def get_history(self, memory_id: str) -> List[Dict[str, Any]]: + with self._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM memory_history WHERE memory_id = ? ORDER BY timestamp DESC", + (memory_id,), + ).fetchall() + return [dict(row) for row in rows] + + def log_decay( + self, + decayed: int, + forgotten: int, + promoted: int, + storage_before_mb: Optional[float] = None, + storage_after_mb: Optional[float] = None, + ) -> None: + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO decay_log (memories_decayed, memories_forgotten, memories_promoted, storage_before_mb, storage_after_mb) + VALUES (?, ?, ?, ?, ?) + """, + (decayed, forgotten, promoted, storage_before_mb, storage_after_mb), + ) + + def purge_tombstoned(self) -> int: + """Permanently delete all tombstoned memories.""" + with self._get_connection() as conn: + rows = conn.execute( + "SELECT id, user_id, memory FROM memories WHERE tombstone = 1" + ).fetchall() + count = len(rows) + if count > 0: + for row in rows: + self._log_event(row["id"], "PURGE", old_value=row["memory"]) + conn.execute("DELETE FROM memories WHERE tombstone = 1") + return count + + +# Backward compatibility alias +SQLiteManager = CoreSQLiteManager + + +class FullSQLiteManager(CoreSQLiteManager): + def __init__(self, db_path: str): + self.db_path = db_path + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + # Phase 1: Persistent connection with WAL mode. + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=5000") + self._conn.execute("PRAGMA synchronous=FULL") + self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache + self._conn.execute("PRAGMA temp_store=MEMORY") + self._conn.row_factory = sqlite3.Row + self._lock = threading.RLock() + self._init_db() + + def close(self) -> None: + """Close the persistent connection for clean shutdown.""" + with self._lock: + if self._conn: + try: + self._conn.close() + except Exception: + pass + self._conn = None # type: ignore[assignment] + + def __repr__(self) -> str: + return f"SQLiteManager(db_path={self.db_path!r})" + + def _init_db(self) -> None: + with self._get_connection() as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + memory TEXT NOT NULL, + user_id TEXT, + agent_id TEXT, + run_id TEXT, + app_id TEXT, + metadata TEXT DEFAULT '{}', + categories TEXT DEFAULT '[]', + immutable INTEGER DEFAULT 0, + expiration_date TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP, + layer TEXT DEFAULT 'sml' CHECK (layer IN ('sml', 'lml')), + strength REAL DEFAULT 1.0, + access_count INTEGER DEFAULT 0, + last_accessed TEXT DEFAULT CURRENT_TIMESTAMP, + embedding TEXT, + related_memories TEXT DEFAULT '[]', + source_memories TEXT DEFAULT '[]', + tombstone INTEGER DEFAULT 0 + ); + + CREATE INDEX IF NOT EXISTS idx_user_layer ON memories(user_id, layer); + CREATE INDEX IF NOT EXISTS idx_strength ON memories(strength DESC); + CREATE INDEX IF NOT EXISTS idx_tombstone ON memories(tombstone); + + CREATE TABLE IF NOT EXISTS memory_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + event TEXT NOT NULL, + old_value TEXT, + new_value TEXT, + old_strength REAL, + new_strength REAL, + old_layer TEXT, + new_layer TEXT, + timestamp TEXT DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS decay_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_at TEXT DEFAULT CURRENT_TIMESTAMP, + memories_decayed INTEGER, + memories_forgotten INTEGER, + memories_promoted INTEGER, + storage_before_mb REAL, + storage_after_mb REAL + ); + + -- CategoryMem tables + CREATE TABLE IF NOT EXISTS categories ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + category_type TEXT DEFAULT 'dynamic', + parent_id TEXT, + children_ids TEXT DEFAULT '[]', + memory_count INTEGER DEFAULT 0, + total_strength REAL DEFAULT 0.0, + access_count INTEGER DEFAULT 0, + last_accessed TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + embedding TEXT, + keywords TEXT DEFAULT '[]', + summary TEXT, + summary_updated_at TEXT, + related_ids TEXT DEFAULT '[]', + strength REAL DEFAULT 1.0, + FOREIGN KEY (parent_id) REFERENCES categories(id) + ); + + CREATE INDEX IF NOT EXISTS idx_category_type ON categories(category_type); + CREATE INDEX IF NOT EXISTS idx_category_parent ON categories(parent_id); + CREATE INDEX IF NOT EXISTS idx_category_strength ON categories(strength DESC); + + -- Episodic scenes + CREATE TABLE IF NOT EXISTS scenes ( + id TEXT PRIMARY KEY, + user_id TEXT, + title TEXT, + summary TEXT, + topic TEXT, + location TEXT, + participants TEXT DEFAULT '[]', + memory_ids TEXT DEFAULT '[]', + start_time TEXT, + end_time TEXT, + embedding TEXT, + strength REAL DEFAULT 1.0, + access_count INTEGER DEFAULT 0, + tombstone INTEGER DEFAULT 0 + ); + + CREATE INDEX IF NOT EXISTS idx_scene_user ON scenes(user_id); + CREATE INDEX IF NOT EXISTS idx_scene_start ON scenes(start_time DESC); + + -- Scene-Memory junction + CREATE TABLE IF NOT EXISTS scene_memories ( + scene_id TEXT NOT NULL, + memory_id TEXT NOT NULL, + position INTEGER DEFAULT 0, + PRIMARY KEY (scene_id, memory_id), + FOREIGN KEY (scene_id) REFERENCES scenes(id), + FOREIGN KEY (memory_id) REFERENCES memories(id) + ); + + -- Character profiles + CREATE TABLE IF NOT EXISTS profiles ( + id TEXT PRIMARY KEY, + user_id TEXT, + name TEXT NOT NULL, + profile_type TEXT DEFAULT 'contact' CHECK (profile_type IN ('self', 'contact', 'entity')), + narrative TEXT, + facts TEXT DEFAULT '[]', + preferences TEXT DEFAULT '[]', + relationships TEXT DEFAULT '[]', + sentiment TEXT, + theory_of_mind TEXT DEFAULT '{}', + aliases TEXT DEFAULT '[]', + embedding TEXT, + strength REAL DEFAULT 1.0, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_profile_user ON profiles(user_id); + CREATE INDEX IF NOT EXISTS idx_profile_name ON profiles(name); + CREATE INDEX IF NOT EXISTS idx_profile_type ON profiles(profile_type); + + -- Profile-Memory junction + CREATE TABLE IF NOT EXISTS profile_memories ( + profile_id TEXT NOT NULL, + memory_id TEXT NOT NULL, + role TEXT DEFAULT 'mentioned' CHECK (role IN ('subject', 'mentioned', 'about')), + PRIMARY KEY (profile_id, memory_id), + FOREIGN KEY (profile_id) REFERENCES profiles(id), + FOREIGN KEY (memory_id) REFERENCES memories(id) + ); + """ + ) + # Legacy migration: add scene_id column to memories if missing. + self._migrate_add_column_conn(conn, "memories", "scene_id", "TEXT") + # v2 schema + idempotent migrations. + self._ensure_v2_schema(conn) + + @contextmanager + def _get_connection(self): + """Yield the persistent connection under the thread lock.""" + with self._lock: + try: + yield self._conn + self._conn.commit() + except Exception: + self._conn.rollback() + raise + + def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: + """Create and migrate Engram v2 schema in-place (idempotent).""" + conn.execute( + """ + CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + migrations: Dict[str, str] = { + "v2_013": """ + CREATE TABLE IF NOT EXISTS distillation_provenance ( + id TEXT PRIMARY KEY, + semantic_memory_id TEXT NOT NULL, + episodic_memory_id TEXT NOT NULL, + distillation_run_id TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_distill_prov_semantic ON distillation_provenance(semantic_memory_id); + CREATE INDEX IF NOT EXISTS idx_distill_prov_episodic ON distillation_provenance(episodic_memory_id); + CREATE INDEX IF NOT EXISTS idx_distill_prov_run ON distillation_provenance(distillation_run_id); + + CREATE TABLE IF NOT EXISTS distillation_log ( + id TEXT PRIMARY KEY, + run_at TEXT DEFAULT CURRENT_TIMESTAMP, + user_id TEXT, + episodes_sampled INTEGER DEFAULT 0, + semantic_created INTEGER DEFAULT 0, + semantic_deduplicated INTEGER DEFAULT 0, + errors INTEGER DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_distill_log_user ON distillation_log(user_id, run_at DESC); + """, + } + + for version, ddl in migrations.items(): + if not self._is_migration_applied(conn, version): + conn.executescript(ddl) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES (?)", + (version,), + ) + + # Phase 3: Skip column migrations + backfills if already complete. + if self._is_migration_applied(conn, "v2_columns_complete"): + # CLS Distillation Memory columns (idempotent). + self._ensure_cls_columns(conn) + return + + # v2 columns on existing canonical tables. + self._migrate_add_column_conn(conn, "memories", "confidentiality_scope", "TEXT DEFAULT 'work'") + self._migrate_add_column_conn(conn, "memories", "source_type", "TEXT") + self._migrate_add_column_conn(conn, "memories", "source_app", "TEXT") + self._migrate_add_column_conn(conn, "memories", "source_event_id", "TEXT") + self._migrate_add_column_conn(conn, "memories", "decay_lambda", "REAL DEFAULT 0.12") + self._migrate_add_column_conn(conn, "memories", "status", "TEXT DEFAULT 'active'") + self._migrate_add_column_conn(conn, "memories", "importance", "REAL DEFAULT 0.5") + self._migrate_add_column_conn(conn, "memories", "sensitivity", "TEXT DEFAULT 'normal'") + self._migrate_add_column_conn(conn, "memories", "namespace", "TEXT DEFAULT 'default'") + + self._migrate_add_column_conn(conn, "scenes", "layer", "TEXT DEFAULT 'sml'") + self._migrate_add_column_conn(conn, "scenes", "scene_strength", "REAL DEFAULT 1.0") + self._migrate_add_column_conn(conn, "scenes", "topic_embedding_ref", "TEXT") + self._migrate_add_column_conn(conn, "scenes", "namespace", "TEXT DEFAULT 'default'") + + self._migrate_add_column_conn(conn, "profiles", "role_bias", "TEXT") + self._migrate_add_column_conn(conn, "profiles", "profile_summary", "TEXT") + + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memories_user_source_event + ON memories(user_id, source_event_id, namespace, created_at DESC) + """ + ) + + # Backfills. + conn.execute( + """ + UPDATE memories + SET confidentiality_scope = 'work' + WHERE confidentiality_scope IS NULL OR confidentiality_scope = '' + """ + ) + conn.execute( + """ + UPDATE memories + SET status = 'active' + WHERE status IS NULL OR status = '' + """ + ) + conn.execute( + """ + UPDATE memories + SET namespace = 'default' + WHERE namespace IS NULL OR namespace = '' + """ + ) + conn.execute( + """ + UPDATE scenes + SET namespace = 'default' + WHERE namespace IS NULL OR namespace = '' + """ + ) + conn.execute( + """ + UPDATE memories + SET decay_lambda = 0.12 + WHERE decay_lambda IS NULL + """ + ) + conn.execute( + """ + UPDATE memories + SET importance = COALESCE( + CASE + WHEN json_extract(metadata, '$.importance') IS NOT NULL + THEN json_extract(metadata, '$.importance') + ELSE importance + END, + 0.5 + ) + """ + ) + conn.execute( + """ + UPDATE memories + SET sensitivity = CASE + WHEN lower(memory) LIKE '%password%' OR lower(memory) LIKE '%api key%' OR lower(memory) LIKE '%token%' + THEN 'secret' + WHEN lower(memory) LIKE '%health%' OR lower(memory) LIKE '%medical%' + THEN 'sensitive' + WHEN lower(memory) LIKE '%bank%' OR lower(memory) LIKE '%salary%' OR lower(memory) LIKE '%credit card%' + THEN 'sensitive' + ELSE COALESCE(NULLIF(sensitivity, ''), 'normal') + END + """ + ) + + # Phase 3: Mark column migrations + backfills as complete. + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_columns_complete')" + ) + + # CLS Distillation Memory columns (idempotent). + self._ensure_cls_columns(conn) + + # Content-hash dedup column (idempotent). + self._ensure_content_hash_column(conn) + + def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: + """Add content_hash column + index for SHA-256 dedup.""" + if self._is_migration_applied(conn, "v2_content_hash"): + return + self._migrate_add_column_conn(conn, "memories", "content_hash", "TEXT") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id)" + ) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_content_hash')" + ) + + def _ensure_cls_columns(self, conn: sqlite3.Connection) -> None: + """Add CLS Distillation Memory columns to memories table (idempotent).""" + if self._is_migration_applied(conn, "v2_cls_columns_complete"): + return + + self._migrate_add_column_conn(conn, "memories", "memory_type", "TEXT DEFAULT 'semantic'") + self._migrate_add_column_conn(conn, "memories", "s_fast", "REAL DEFAULT NULL") + self._migrate_add_column_conn(conn, "memories", "s_mid", "REAL DEFAULT NULL") + self._migrate_add_column_conn(conn, "memories", "s_slow", "REAL DEFAULT NULL") + + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_memories_memory_type ON memories(memory_type, user_id)" + ) + + # Backfill: set memory_type to 'semantic' for existing memories. + conn.execute( + "UPDATE memories SET memory_type = 'semantic' WHERE memory_type IS NULL" + ) + + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_cls_columns_complete')" + ) + + def _is_migration_applied(self, conn: sqlite3.Connection, version: str) -> bool: + row = conn.execute( + "SELECT 1 FROM schema_migrations WHERE version = ?", + (version,), + ).fetchone() + return row is not None + + # Phase 5: Allowed table names for ALTER TABLE to prevent SQL injection. + _ALLOWED_TABLES = frozenset({ + "memories", "scenes", "profiles", "categories", + }) + + def _migrate_add_column_conn( + self, + conn: sqlite3.Connection, + table: str, + column: str, + col_type: str, + ) -> None: + """Add a column using an existing connection, if missing.""" + if table not in self._ALLOWED_TABLES: + raise ValueError(f"Invalid table for migration: {table!r}") + # Validate column name: must be alphanumeric/underscore only. + if not column.replace("_", "").isalnum(): + raise ValueError(f"Invalid column name: {column!r}") + try: + conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {col_type}") + except sqlite3.OperationalError: + pass + + def add_memory(self, memory_data: Dict[str, Any]) -> str: + memory_id = memory_data.get("id", str(uuid.uuid4())) + now = _utcnow_iso() + metadata = memory_data.get("metadata", {}) or {} + source_app = memory_data.get("source_app") or memory_data.get("app_id") or metadata.get("source_app") + + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO memories ( + id, memory, user_id, agent_id, run_id, app_id, + metadata, categories, immutable, expiration_date, + created_at, updated_at, layer, strength, access_count, + last_accessed, embedding, related_memories, source_memories, tombstone, + confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, + status, importance, sensitivity, + memory_type, s_fast, s_mid, s_slow, content_hash + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + memory_data.get("memory", ""), + memory_data.get("user_id"), + memory_data.get("agent_id"), + memory_data.get("run_id"), + memory_data.get("app_id"), + json.dumps(memory_data.get("metadata", {})), + json.dumps(memory_data.get("categories", [])), + 1 if memory_data.get("immutable", False) else 0, + memory_data.get("expiration_date"), + memory_data.get("created_at", now), + memory_data.get("updated_at", now), + memory_data.get("layer", "sml"), + memory_data.get("strength", 1.0), + memory_data.get("access_count", 0), + memory_data.get("last_accessed", now), + json.dumps(memory_data.get("embedding", [])), + json.dumps(memory_data.get("related_memories", [])), + json.dumps(memory_data.get("source_memories", [])), + 1 if memory_data.get("tombstone", False) else 0, + memory_data.get("confidentiality_scope", "work"), + memory_data.get("namespace", metadata.get("namespace", "default")), + memory_data.get("source_type") or metadata.get("source_type") or "mcp", + source_app, + memory_data.get("source_event_id") or metadata.get("source_event_id"), + memory_data.get("decay_lambda", 0.12), + memory_data.get("status", "active"), + memory_data.get("importance", metadata.get("importance", 0.5)), + memory_data.get("sensitivity", metadata.get("sensitivity", "normal")), + memory_data.get("memory_type", "semantic"), + memory_data.get("s_fast"), + memory_data.get("s_mid"), + memory_data.get("s_slow"), + memory_data.get("content_hash"), + ), + ) + + # Log within the same transaction -- atomic with the insert. + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None), + ) + + return memory_id + + def add_memories_batch(self, memories: List[Dict[str, Any]]) -> List[str]: + """Insert multiple memories in a single transaction (atomic). + + Returns list of memory IDs in the same order as input. + """ + if not memories: + return [] + now = _utcnow_iso() + ids: List[str] = [] + insert_rows = [] + history_rows = [] + + for memory_data in memories: + memory_id = memory_data.get("id", str(uuid.uuid4())) + ids.append(memory_id) + metadata = memory_data.get("metadata", {}) or {} + source_app = memory_data.get("source_app") or memory_data.get("app_id") or metadata.get("source_app") + + insert_rows.append(( + memory_id, + memory_data.get("memory", ""), + memory_data.get("user_id"), + memory_data.get("agent_id"), + memory_data.get("run_id"), + memory_data.get("app_id"), + json.dumps(memory_data.get("metadata", {})), + json.dumps(memory_data.get("categories", [])), + 1 if memory_data.get("immutable", False) else 0, + memory_data.get("expiration_date"), + memory_data.get("created_at", now), + memory_data.get("updated_at", now), + memory_data.get("layer", "sml"), + memory_data.get("strength", 1.0), + memory_data.get("access_count", 0), + memory_data.get("last_accessed", now), + json.dumps(memory_data.get("embedding", [])), + json.dumps(memory_data.get("related_memories", [])), + json.dumps(memory_data.get("source_memories", [])), + 1 if memory_data.get("tombstone", False) else 0, + memory_data.get("confidentiality_scope", "work"), + memory_data.get("namespace", metadata.get("namespace", "default")), + memory_data.get("source_type") or metadata.get("source_type") or "mcp", + source_app, + memory_data.get("source_event_id") or metadata.get("source_event_id"), + memory_data.get("decay_lambda", 0.12), + memory_data.get("status", "active"), + memory_data.get("importance", metadata.get("importance", 0.5)), + memory_data.get("sensitivity", metadata.get("sensitivity", "normal")), + memory_data.get("memory_type", "semantic"), + memory_data.get("s_fast"), + memory_data.get("s_mid"), + memory_data.get("s_slow"), + )) + history_rows.append(( + memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None, + )) + + with self._get_connection() as conn: + conn.executemany( + """ + INSERT INTO memories ( + id, memory, user_id, agent_id, run_id, app_id, + metadata, categories, immutable, expiration_date, + created_at, updated_at, layer, strength, access_count, + last_accessed, embedding, related_memories, source_memories, tombstone, + confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, + status, importance, sensitivity, + memory_type, s_fast, s_mid, s_slow + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + insert_rows, + ) + conn.executemany( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + history_rows, + ) + + return ids + + def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Optional[Dict[str, Any]]: + query = "SELECT * FROM memories WHERE id = ?" + params = [memory_id] + if not include_tombstoned: + query += " AND tombstone = 0" + + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._row_to_dict(row) + return None + + def get_memory_by_content_hash( + self, content_hash: str, user_id: str = "default" + ) -> Optional[Dict[str, Any]]: + """Find an existing memory by content hash (for deduplication).""" + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE content_hash = ? AND user_id = ? AND tombstone = 0 LIMIT 1", + (content_hash, user_id), + ).fetchone() + if row: + return self._row_to_dict(row) + return None + + def get_memory_by_source_event( + self, + *, + user_id: str, + source_event_id: str, + namespace: Optional[str] = None, + source_app: Optional[str] = None, + include_tombstoned: bool = False, + ) -> Optional[Dict[str, Any]]: + normalized_event = str(source_event_id or "").strip() + if not normalized_event: + return None + query = """ + SELECT * + FROM memories + WHERE user_id = ? + AND source_event_id = ? + """ + params: List[Any] = [user_id, normalized_event] + if namespace: + query += " AND namespace = ?" + params.append(namespace) + if source_app: + query += " AND source_app = ?" + params.append(source_app) + if not include_tombstoned: + query += " AND tombstone = 0" + query += " ORDER BY created_at DESC LIMIT 1" + + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._row_to_dict(row) + return None + + def get_all_memories( + self, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + app_id: Optional[str] = None, + layer: Optional[str] = None, + namespace: Optional[str] = None, + memory_type: Optional[str] = None, + min_strength: float = 0.0, + include_tombstoned: bool = False, + created_after: Optional[str] = None, + created_before: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM memories WHERE strength >= ?" + params: List[Any] = [min_strength] + + if not include_tombstoned: + query += " AND tombstone = 0" + if memory_type: + query += " AND memory_type = ?" + params.append(memory_type) + if user_id: + query += " AND user_id = ?" + params.append(user_id) + if agent_id: + query += " AND agent_id = ?" + params.append(agent_id) + if run_id: + query += " AND run_id = ?" + params.append(run_id) + if app_id: + query += " AND app_id = ?" + params.append(app_id) + if layer: + query += " AND layer = ?" + params.append(layer) + if namespace: + query += " AND namespace = ?" + params.append(namespace) + if created_after: + query += " AND created_at >= ?" + params.append(created_after) + if created_before: + query += " AND created_at <= ?" + params.append(created_before) + + query += " ORDER BY strength DESC" + + # Apply SQL-level LIMIT to avoid fetching unbounded rows into memory. + if limit is not None and limit > 0: + query += " LIMIT ?" + params.append(limit) + + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_dict(row) for row in rows] + + def update_memory(self, memory_id: str, updates: Dict[str, Any]) -> bool: + set_clauses = [] + params: List[Any] = [] + for key, value in updates.items(): + if key not in VALID_MEMORY_COLUMNS: + raise ValueError(f"Invalid memory column: {key!r}") + if key in {"metadata", "categories", "embedding", "related_memories", "source_memories"}: + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + + set_clauses.append("updated_at = ?") + params.append(_utcnow_iso()) + params.append(memory_id) + + with self._get_connection() as conn: + # Read old values and update in a single transaction. + old_row = conn.execute( + "SELECT memory, strength, layer FROM memories WHERE id = ?", + (memory_id,), + ).fetchone() + if not old_row: + return False + + conn.execute( + f"UPDATE memories SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + + # Log within the same transaction. + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + "UPDATE", + old_row["memory"], + updates.get("memory"), + old_row["strength"], + updates.get("strength"), + old_row["layer"], + updates.get("layer"), + ), + ) + return True + + def delete_memory(self, memory_id: str, use_tombstone: bool = True) -> bool: + if use_tombstone: + return self.update_memory(memory_id, {"tombstone": 1}) + with self._get_connection() as conn: + conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) + self._log_event(memory_id, "DELETE") + return True + + def increment_access(self, memory_id: str) -> None: + now = _utcnow_iso() + with self._get_connection() as conn: + conn.execute( + """ + UPDATE memories + SET access_count = access_count + 1, last_accessed = ? + WHERE id = ? + """, + (now, memory_id), + ) + + # Phase 2: Batch operations to eliminate N+1 queries in search. + + def get_memories_bulk(self, memory_ids: List[str], include_tombstoned: bool = False) -> Dict[str, Dict[str, Any]]: + """Fetch multiple memories by ID in a single query. Returns {id: memory_dict}.""" + if not memory_ids: + return {} + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + query = f"SELECT * FROM memories WHERE id IN ({placeholders})" + if not include_tombstoned: + query += " AND tombstone = 0" + rows = conn.execute(query, memory_ids).fetchall() + return {row["id"]: self._row_to_dict(row) for row in rows} + + def increment_access_bulk(self, memory_ids: List[str]) -> None: + """Increment access count for multiple memories in a single transaction.""" + if not memory_ids: + return + now = _utcnow_iso() + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + conn.execute( + f""" + UPDATE memories + SET access_count = access_count + 1, last_accessed = ? + WHERE id IN ({placeholders}) + """, + [now] + list(memory_ids), + ) + + def update_strength_bulk(self, updates: Dict[str, float]) -> None: + """Batch-update strength for multiple memories. updates = {memory_id: new_strength}.""" + if not updates: + return + now = _utcnow_iso() + with self._get_connection() as conn: + conn.executemany( + "UPDATE memories SET strength = ?, updated_at = ? WHERE id = ?", + [(strength, now, memory_id) for memory_id, strength in updates.items()], + ) + + _MEMORY_JSON_FIELDS = ("metadata", "categories", "related_memories", "source_memories") + + def _row_to_dict(self, row: sqlite3.Row, *, skip_embedding: bool = False) -> Dict[str, Any]: + data = dict(row) + for key in self._MEMORY_JSON_FIELDS: + if key in data and data[key]: + data[key] = json.loads(data[key]) + # Embedding is the largest JSON field (~30-50KB for 3072-dim vectors). + # Skip deserialization when the caller doesn't need it. + if skip_embedding: + data.pop("embedding", None) + elif "embedding" in data and data["embedding"]: + data["embedding"] = json.loads(data["embedding"]) + data["immutable"] = bool(data.get("immutable", 0)) + data["tombstone"] = bool(data.get("tombstone", 0)) + return data + + def _log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + event, + kwargs.get("old_value"), + kwargs.get("new_value"), + kwargs.get("old_strength"), + kwargs.get("new_strength"), + kwargs.get("old_layer"), + kwargs.get("new_layer"), + ), + ) + + def log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: + """Public wrapper for logging custom events like DECAY or FUSE.""" + self._log_event(memory_id, event, **kwargs) + + def get_history(self, memory_id: str) -> List[Dict[str, Any]]: + with self._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM memory_history WHERE memory_id = ? ORDER BY timestamp DESC", + (memory_id,), + ).fetchall() + return [dict(row) for row in rows] + + def log_decay(self, decayed: int, forgotten: int, promoted: int, storage_before_mb: Optional[float] = None, storage_after_mb: Optional[float] = None) -> None: + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO decay_log (memories_decayed, memories_forgotten, memories_promoted, storage_before_mb, storage_after_mb) + VALUES (?, ?, ?, ?, ?) + """, + (decayed, forgotten, promoted, storage_before_mb, storage_after_mb), + ) + + def purge_tombstoned(self) -> int: + """Permanently delete all tombstoned memories. This is IRREVERSIBLE.""" + with self._get_connection() as conn: + # Log what will be purged before deletion for audit trail. + rows = conn.execute( + "SELECT id, user_id, memory FROM memories WHERE tombstone = 1" + ).fetchall() + count = len(rows) + if count > 0: + ids = [row["id"] for row in rows] + logger.warning( + "purge_tombstoned: permanently deleting %d memories: %s", + count, + ids, + ) + for row in rows: + conn.execute( + """INSERT INTO memory_history (memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer) + VALUES (?, ?, ?, NULL, NULL, NULL, NULL, NULL)""", + (row["id"], "PURGE", row["memory"]), + ) + conn.execute("DELETE FROM memories WHERE tombstone = 1") + return count + + # CLS Distillation Memory helpers + + def get_episodic_memories( + self, + user_id: str, + *, + scene_id: Optional[str] = None, + created_after: Optional[str] = None, + created_before: Optional[str] = None, + limit: int = 100, + namespace: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Fetch episodic-type memories for a user, optionally filtered by scene/time.""" + query = "SELECT * FROM memories WHERE user_id = ? AND memory_type = 'episodic' AND tombstone = 0" + params: List[Any] = [user_id] + if scene_id: + query += " AND scene_id = ?" + params.append(scene_id) + if created_after: + query += " AND created_at >= ?" + params.append(created_after) + if created_before: + query += " AND created_at <= ?" + params.append(created_before) + if namespace: + query += " AND namespace = ?" + params.append(namespace) + query += " ORDER BY created_at DESC LIMIT ?" + params.append(limit) + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_dict(row) for row in rows] + + def add_distillation_provenance( + self, + semantic_memory_id: str, + episodic_memory_ids: List[str], + run_id: str, + ) -> None: + """Record which episodic memories contributed to a distilled semantic memory.""" + with self._get_connection() as conn: + for ep_id in episodic_memory_ids: + conn.execute( + """ + INSERT INTO distillation_provenance (id, semantic_memory_id, episodic_memory_id, distillation_run_id) + VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), semantic_memory_id, ep_id, run_id), + ) + + def log_distillation_run( + self, + user_id: str, + episodes_sampled: int, + semantic_created: int, + semantic_deduplicated: int = 0, + errors: int = 0, + ) -> str: + """Log a distillation run and return the run ID.""" + run_id = str(uuid.uuid4()) + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO distillation_log (id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors) + VALUES (?, ?, ?, ?, ?, ?) + """, + (run_id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors), + ) + return run_id + + def get_memory_count_by_namespace(self, user_id: str) -> Dict[str, int]: + """Return {namespace: count} for active memories of a user.""" + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT COALESCE(namespace, 'default') AS ns, COUNT(*) AS cnt + FROM memories + WHERE user_id = ? AND tombstone = 0 + GROUP BY ns + """, + (user_id,), + ).fetchall() + return {row["ns"]: row["cnt"] for row in rows} + + def update_multi_trace( + self, + memory_id: str, + s_fast: float, + s_mid: float, + s_slow: float, + effective_strength: float, + ) -> bool: + """Update multi-trace columns and effective strength for a memory.""" + return self.update_memory(memory_id, { + "s_fast": s_fast, + "s_mid": s_mid, + "s_slow": s_slow, + "strength": effective_strength, + }) + + # CategoryMem methods + def save_category(self, category_data: Dict[str, Any]) -> str: + """Save or update a category.""" + category_id = category_data.get("id") + if not category_id: + return "" + + with self._get_connection() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO categories ( + id, name, description, category_type, parent_id, + children_ids, memory_count, total_strength, access_count, + last_accessed, created_at, embedding, keywords, + summary, summary_updated_at, related_ids, strength + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + category_id, + category_data.get("name", ""), + category_data.get("description", ""), + category_data.get("category_type", "dynamic"), + category_data.get("parent_id"), + json.dumps(category_data.get("children_ids", [])), + category_data.get("memory_count", 0), + category_data.get("total_strength", 0.0), + category_data.get("access_count", 0), + category_data.get("last_accessed"), + category_data.get("created_at"), + json.dumps(category_data.get("embedding")) if category_data.get("embedding") else None, + json.dumps(category_data.get("keywords", [])), + category_data.get("summary"), + category_data.get("summary_updated_at"), + json.dumps(category_data.get("related_ids", [])), + category_data.get("strength", 1.0), + ), + ) + return category_id + + def get_category(self, category_id: str) -> Optional[Dict[str, Any]]: + """Get a category by ID.""" + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM categories WHERE id = ?", + (category_id,) + ).fetchone() + if row: + return self._category_row_to_dict(row) + return None + + def get_all_categories(self) -> List[Dict[str, Any]]: + """Get all categories.""" + with self._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM categories ORDER BY strength DESC" + ).fetchall() + return [self._category_row_to_dict(row) for row in rows] + + def delete_category(self, category_id: str) -> bool: + """Delete a category.""" + with self._get_connection() as conn: + conn.execute("DELETE FROM categories WHERE id = ?", (category_id,)) + return True + + def save_all_categories(self, categories: List[Dict[str, Any]]) -> int: + """Save multiple categories in a single transaction for performance.""" + if not categories: + return 0 + rows = [] + for cat in categories: + cat_id = cat.get("id") + if not cat_id: + continue + rows.append(( + cat_id, + cat.get("name", ""), + cat.get("description", ""), + cat.get("category_type", "dynamic"), + cat.get("parent_id"), + json.dumps(cat.get("children_ids", [])), + cat.get("memory_count", 0), + cat.get("total_strength", 0.0), + cat.get("access_count", 0), + cat.get("last_accessed"), + cat.get("created_at"), + json.dumps(cat.get("embedding")) if cat.get("embedding") else None, + json.dumps(cat.get("keywords", [])), + cat.get("summary"), + cat.get("summary_updated_at"), + json.dumps(cat.get("related_ids", [])), + cat.get("strength", 1.0), + )) + if not rows: + return 0 + with self._get_connection() as conn: + conn.executemany( + """ + INSERT OR REPLACE INTO categories ( + id, name, description, category_type, parent_id, + children_ids, memory_count, total_strength, access_count, + last_accessed, created_at, embedding, keywords, + summary, summary_updated_at, related_ids, strength + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + return len(rows) + + def _category_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + """Convert a category row to dict.""" + data = dict(row) + for key in ["children_ids", "keywords", "related_ids"]: + if key in data and data[key]: + data[key] = json.loads(data[key]) + else: + data[key] = [] + if data.get("embedding"): + data["embedding"] = json.loads(data["embedding"]) + return data + + def _migrate_add_column(self, table: str, column: str, col_type: str) -> None: + """Add a column to an existing table if it doesn't already exist.""" + with self._get_connection() as conn: + self._migrate_add_column_conn(conn, table, column, col_type) + + # ========================================================================= + # Scene methods + # ========================================================================= + + def add_scene(self, scene_data: Dict[str, Any]) -> str: + scene_id = scene_data.get("id", str(uuid.uuid4())) + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO scenes ( + id, user_id, title, summary, topic, location, + participants, memory_ids, start_time, end_time, + embedding, strength, access_count, tombstone, + layer, scene_strength, topic_embedding_ref, namespace + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + scene_id, + scene_data.get("user_id"), + scene_data.get("title"), + scene_data.get("summary"), + scene_data.get("topic"), + scene_data.get("location"), + json.dumps(scene_data.get("participants", [])), + json.dumps(scene_data.get("memory_ids", [])), + scene_data.get("start_time"), + scene_data.get("end_time"), + json.dumps(scene_data.get("embedding")) if scene_data.get("embedding") else None, + scene_data.get("strength", 1.0), + scene_data.get("access_count", 0), + 1 if scene_data.get("tombstone", False) else 0, + scene_data.get("layer", "sml"), + scene_data.get("scene_strength", scene_data.get("strength", 1.0)), + scene_data.get("topic_embedding_ref"), + scene_data.get("namespace", "default"), + ), + ) + return scene_id + + def get_scene(self, scene_id: str) -> Optional[Dict[str, Any]]: + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM scenes WHERE id = ? AND tombstone = 0", (scene_id,) + ).fetchone() + if row: + return self._scene_row_to_dict(row) + return None + + def update_scene(self, scene_id: str, updates: Dict[str, Any]) -> bool: + set_clauses = [] + params: List[Any] = [] + for key, value in updates.items(): + if key not in VALID_SCENE_COLUMNS: + raise ValueError(f"Invalid scene column: {key!r}") + if key in {"participants", "memory_ids", "embedding"}: + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + if not set_clauses: + return False + params.append(scene_id) + with self._get_connection() as conn: + conn.execute( + f"UPDATE scenes SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + return True + + def get_open_scene(self, user_id: str) -> Optional[Dict[str, Any]]: + """Get the most recent scene without an end_time for a user.""" + with self._get_connection() as conn: + row = conn.execute( + """ + SELECT * FROM scenes + WHERE user_id = ? AND end_time IS NULL AND tombstone = 0 + ORDER BY start_time DESC LIMIT 1 + """, + (user_id,), + ).fetchone() + if row: + return self._scene_row_to_dict(row) + return None + + def get_scenes( + self, + user_id: Optional[str] = None, + topic: Optional[str] = None, + start_after: Optional[str] = None, + start_before: Optional[str] = None, + namespace: Optional[str] = None, + limit: int = 50, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM scenes WHERE tombstone = 0" + params: List[Any] = [] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + if topic: + query += " AND topic LIKE ?" + params.append(f"%{topic}%") + if start_after: + query += " AND start_time >= ?" + params.append(start_after) + if start_before: + query += " AND start_time <= ?" + params.append(start_before) + if namespace: + query += " AND namespace = ?" + params.append(namespace) + query += " ORDER BY start_time DESC LIMIT ?" + params.append(limit) + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._scene_row_to_dict(row) for row in rows] + + def add_scene_memory(self, scene_id: str, memory_id: str, position: int = 0) -> None: + with self._get_connection() as conn: + conn.execute( + "INSERT OR IGNORE INTO scene_memories (scene_id, memory_id, position) VALUES (?, ?, ?)", + (scene_id, memory_id, position), + ) + + def get_scene_memories(self, scene_id: str) -> List[Dict[str, Any]]: + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT m.* FROM memories m + JOIN scene_memories sm ON m.id = sm.memory_id + WHERE sm.scene_id = ? AND m.tombstone = 0 + ORDER BY sm.position + """, + (scene_id,), + ).fetchall() + return [self._row_to_dict(row) for row in rows] + + def _scene_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + for key in ["participants", "memory_ids"]: + if key in data and data[key]: + data[key] = json.loads(data[key]) + else: + data[key] = [] + if data.get("embedding"): + data["embedding"] = json.loads(data["embedding"]) + data["tombstone"] = bool(data.get("tombstone", 0)) + return data + + # ========================================================================= + # Profile methods + # ========================================================================= + + def add_profile(self, profile_data: Dict[str, Any]) -> str: + profile_id = profile_data.get("id", str(uuid.uuid4())) + now = _utcnow_iso() + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO profiles ( + id, user_id, name, profile_type, narrative, + facts, preferences, relationships, sentiment, + theory_of_mind, aliases, embedding, strength, + created_at, updated_at, role_bias, profile_summary + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + profile_id, + profile_data.get("user_id"), + profile_data.get("name", ""), + profile_data.get("profile_type", "contact"), + profile_data.get("narrative"), + json.dumps(profile_data.get("facts", [])), + json.dumps(profile_data.get("preferences", [])), + json.dumps(profile_data.get("relationships", [])), + profile_data.get("sentiment"), + json.dumps(profile_data.get("theory_of_mind", {})), + json.dumps(profile_data.get("aliases", [])), + json.dumps(profile_data.get("embedding")) if profile_data.get("embedding") else None, + profile_data.get("strength", 1.0), + profile_data.get("created_at", now), + profile_data.get("updated_at", now), + profile_data.get("role_bias"), + profile_data.get("profile_summary"), + ), + ) + return profile_id + + def get_profile(self, profile_id: str) -> Optional[Dict[str, Any]]: + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM profiles WHERE id = ?", (profile_id,) + ).fetchone() + if row: + return self._profile_row_to_dict(row) + return None + + def update_profile(self, profile_id: str, updates: Dict[str, Any]) -> bool: + set_clauses = [] + params: List[Any] = [] + for key, value in updates.items(): + if key not in VALID_PROFILE_COLUMNS: + raise ValueError(f"Invalid profile column: {key!r}") + if key in {"facts", "preferences", "relationships", "aliases", "theory_of_mind", "embedding"}: + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + set_clauses.append("updated_at = ?") + params.append(_utcnow_iso()) + params.append(profile_id) + with self._get_connection() as conn: + conn.execute( + f"UPDATE profiles SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + return True + + def get_all_profiles(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]: + query = "SELECT * FROM profiles" + params: List[Any] = [] + if user_id: + query += " WHERE user_id = ?" + params.append(user_id) + query += " ORDER BY strength DESC" + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._profile_row_to_dict(row) for row in rows] + + def get_profile_by_name(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Find a profile by exact name match, then fall back to alias scan.""" + # Fast path: exact name match via indexed column. + query = "SELECT * FROM profiles WHERE lower(name) = ?" + params: List[Any] = [name.lower()] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + query += " LIMIT 1" + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._profile_row_to_dict(row) + # Slow path: alias scan (aliases stored as JSON, can't index). + alias_query = "SELECT * FROM profiles WHERE aliases LIKE ?" + alias_params: List[Any] = [f'%"{name}"%'] + if user_id: + alias_query += " AND user_id = ?" + alias_params.append(user_id) + alias_query += " LIMIT 1" + row = conn.execute(alias_query, alias_params).fetchone() + if row: + result = self._profile_row_to_dict(row) + # Verify case-insensitive alias match. + if name.lower() in [a.lower() for a in result.get("aliases", [])]: + return result + return None + + def find_profile_by_substring(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Find a profile where the name contains the query as a substring (case-insensitive).""" + query = "SELECT * FROM profiles WHERE lower(name) LIKE ?" + params: List[Any] = [f"%{name.lower()}%"] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + query += " ORDER BY strength DESC LIMIT 1" + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._profile_row_to_dict(row) + return None + + def add_profile_memory(self, profile_id: str, memory_id: str, role: str = "mentioned") -> None: + with self._get_connection() as conn: + conn.execute( + "INSERT OR IGNORE INTO profile_memories (profile_id, memory_id, role) VALUES (?, ?, ?)", + (profile_id, memory_id, role), + ) + + def get_profile_memories(self, profile_id: str) -> List[Dict[str, Any]]: + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT m.*, pm.role AS profile_role FROM memories m + JOIN profile_memories pm ON m.id = pm.memory_id + WHERE pm.profile_id = ? AND m.tombstone = 0 + ORDER BY m.created_at DESC + """, + (profile_id,), + ).fetchall() + return [self._row_to_dict(row) for row in rows] + + def _profile_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + for key in ["facts", "preferences", "relationships", "aliases"]: + if key in data and data[key]: + data[key] = json.loads(data[key]) + else: + data[key] = [] + if data.get("theory_of_mind"): + data["theory_of_mind"] = json.loads(data["theory_of_mind"]) + else: + data["theory_of_mind"] = {} + if data.get("embedding"): + data["embedding"] = json.loads(data["embedding"]) + return data + + def get_memories_by_category( + self, + category_id: str, + limit: int = 100, + min_strength: float = 0.0, + ) -> List[Dict[str, Any]]: + """Get memories belonging to a specific category.""" + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT * FROM memories + WHERE categories LIKE ? AND strength >= ? AND tombstone = 0 + ORDER BY strength DESC + LIMIT ? + """, + (f'%"{category_id}"%', min_strength, limit), + ).fetchall() + return [self._row_to_dict(row) for row in rows] + + # ========================================================================= + # User ID listing + # ========================================================================= + + def list_user_ids(self) -> List[str]: + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT DISTINCT user_id FROM memories + WHERE user_id IS NOT NULL AND user_id != '' + ORDER BY user_id + """ + ).fetchall() + return [str(row["user_id"]) for row in rows if row["user_id"]] + + # ========================================================================= + # Dashboard / Visualization methods + # ========================================================================= + + def get_constellation_data(self, user_id: Optional[str] = None, limit: int = 200) -> Dict[str, Any]: + """Build graph data for the constellation visualizer.""" + with self._get_connection() as conn: + # Nodes: memories + mem_query = "SELECT id, memory, strength, layer, categories, created_at FROM memories WHERE tombstone = 0" + params: List[Any] = [] + if user_id: + mem_query += " AND user_id = ?" + params.append(user_id) + mem_query += " ORDER BY strength DESC LIMIT ?" + params.append(limit) + mem_rows = conn.execute(mem_query, params).fetchall() + + nodes = [] + node_ids = set() + for row in mem_rows: + cats = row["categories"] + if cats: + try: + cats = json.loads(cats) + except Exception: + cats = [] + else: + cats = [] + nodes.append({ + "id": row["id"], + "memory": (row["memory"] or "")[:120], + "strength": row["strength"], + "layer": row["layer"], + "categories": cats, + "created_at": row["created_at"], + }) + node_ids.add(row["id"]) + + # Edges from scene_memories (memories sharing a scene) + edges: List[Dict[str, Any]] = [] + if node_ids: + placeholders = ",".join("?" for _ in node_ids) + scene_rows = conn.execute( + f""" + SELECT a.memory_id AS source, b.memory_id AS target, a.scene_id + FROM scene_memories a + JOIN scene_memories b ON a.scene_id = b.scene_id AND a.memory_id < b.memory_id + WHERE a.memory_id IN ({placeholders}) AND b.memory_id IN ({placeholders}) + """, + list(node_ids) + list(node_ids), + ).fetchall() + for row in scene_rows: + edges.append({"source": row["source"], "target": row["target"], "type": "scene"}) + + # Edges from profile_memories (memories sharing a profile) + profile_rows = conn.execute( + f""" + SELECT a.memory_id AS source, b.memory_id AS target, a.profile_id + FROM profile_memories a + JOIN profile_memories b ON a.profile_id = b.profile_id AND a.memory_id < b.memory_id + WHERE a.memory_id IN ({placeholders}) AND b.memory_id IN ({placeholders}) + """, + list(node_ids) + list(node_ids), + ).fetchall() + for row in profile_rows: + edges.append({"source": row["source"], "target": row["target"], "type": "profile"}) + + return {"nodes": nodes, "edges": edges} + + def get_decay_log_entries(self, limit: int = 20) -> List[Dict[str, Any]]: + """Return recent decay log entries for the dashboard sparkline.""" + with self._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM decay_log ORDER BY run_at DESC LIMIT ?", + (limit,), + ).fetchall() + return [dict(row) for row in rows] + + # ========================================================================= + # Utilities + # ========================================================================= + + @staticmethod + def _parse_json_value(value: Any, default: Any) -> Any: + if value is None: + return default + if isinstance(value, (dict, list)): + return value + try: + return json.loads(value) + except Exception: + return default diff --git a/engram/embeddings/gemini.py b/engram/embeddings/gemini.py index e4832e7..8cf5657 100644 --- a/engram/embeddings/gemini.py +++ b/engram/embeddings/gemini.py @@ -2,6 +2,8 @@ import os from typing import List, Optional +from google import genai + from engram.embeddings.base import BaseEmbedder logger = logging.getLogger(__name__) @@ -15,48 +17,15 @@ def __init__(self, config: Optional[dict] = None): raise ValueError("Gemini API key not provided. Set GEMINI_API_KEY or pass api_key in config.") self.model = self.config.get("model", "gemini-embedding-001") - - self._client_type = None - self._client = None - self._genai = None - - try: - import google.generativeai as genai - - genai.configure(api_key=self.api_key) - self._client_type = "generativeai" - self._genai = genai - except ImportError: - try: - from google import genai - - self._client_type = "genai" - self._client = genai.Client(api_key=self.api_key) - except Exception as exc: - raise ImportError( - "Install google-generativeai or google-genai to use GeminiEmbedder" - ) from exc + self._client = genai.Client(api_key=self.api_key) def embed(self, text: str, memory_action: Optional[str] = None) -> List[float]: try: - if self._client_type == "generativeai": - response = self._genai.embed_content( - model=self.model, - content=text, - ) - embedding = response.get("embedding") if isinstance(response, dict) else getattr(response, "embedding", None) - if not embedding: - raise RuntimeError(f"Gemini embedding returned empty result (model={self.model})") - return embedding - - if self._client_type == "genai": - response = self._client.models.embed_content( - model=self.model, - contents=text, - ) - return _extract_embedding_from_response(response) - - raise RuntimeError("Gemini embedder not initialized") + response = self._client.models.embed_content( + model=self.model, + contents=text, + ) + return _extract_embedding_from_response(response) except RuntimeError: raise except Exception as exc: @@ -75,36 +44,21 @@ def embed_batch( return [self.embed(texts[0], memory_action=memory_action)] try: - if self._client_type == "generativeai": - # google-generativeai supports batch via embed_content with list - response = self._genai.embed_content( - model=self.model, - content=texts, - ) - embedding = response.get("embedding") if isinstance(response, dict) else getattr(response, "embedding", None) - if embedding and isinstance(embedding, list) and len(embedding) == len(texts): - # Check if it's a list of lists (batch) vs single list - if embedding and isinstance(embedding[0], list): - return embedding - # Fallback to sequential - return [self.embed(t, memory_action=memory_action) for t in texts] - - if self._client_type == "genai": - response = self._client.models.embed_content( - model=self.model, - contents=texts, - ) - embeddings = getattr(response, "embeddings", None) - if embeddings and isinstance(embeddings, list): - results = [] - for emb in embeddings: - vector = getattr(emb, "values", None) or getattr(emb, "embedding", None) - if vector: - results.append(vector) - if len(results) == len(texts): - return results - # Fallback to sequential - return [self.embed(t, memory_action=memory_action) for t in texts] + response = self._client.models.embed_content( + model=self.model, + contents=texts, + ) + embeddings = getattr(response, "embeddings", None) + if embeddings and isinstance(embeddings, list): + results = [] + for emb in embeddings: + vector = getattr(emb, "values", None) or getattr(emb, "embedding", None) + if vector: + results.append(vector) + if len(results) == len(texts): + return results + # Fallback to sequential + return [self.embed(t, memory_action=memory_action) for t in texts] except Exception as exc: logger.warning( diff --git a/engram/embeddings/simple.py b/engram/embeddings/simple.py index a55fecc..c273e03 100644 --- a/engram/embeddings/simple.py +++ b/engram/embeddings/simple.py @@ -1,27 +1,65 @@ +"""Production-grade hash-based embedder. Zero external deps, no API key. + +Deterministic: same text → same vector (SHA-256 → float array projection). +Fixed 384 dimensions (small, fast). Suitable for offline use and testing. +""" + import hashlib import math +import struct from typing import List, Optional from engram.embeddings.base import BaseEmbedder +_DEFAULT_DIMS = 384 + class SimpleEmbedder(BaseEmbedder): def __init__(self, config: Optional[dict] = None): super().__init__(config) - self.dims = int(self.config.get("embedding_dims", 1536)) + self.dims = int(self.config.get("embedding_dims", _DEFAULT_DIMS)) def embed(self, text: str, memory_action: Optional[str] = None) -> List[float]: - tokens = [t for t in text.lower().split() if t] - if not tokens: + """Deterministic embedding: SHA-256 hash → float vector projection. + + Uses multiple hash rounds to fill the vector space, producing + a normalized unit vector that is stable across runs. + """ + normalized = text.strip().lower() + if not normalized: return [0.0] * self.dims vector = [0.0] * self.dims - for token in tokens: - digest = hashlib.sha256(token.encode()).hexdigest() - idx = int(digest, 16) % self.dims - vector[idx] += 1.0 + # Use sliding window of 3-grams + whole words for richer signal + tokens = normalized.split() + fragments = list(tokens) + # Add bigrams for phrase sensitivity + for i in range(len(tokens) - 1): + fragments.append(f"{tokens[i]} {tokens[i + 1]}") + # Add character 3-grams for typo tolerance + for i in range(max(0, len(normalized) - 2)): + fragments.append(normalized[i:i + 3]) + + for fragment in fragments: + # SHA-256 gives 32 bytes = 8 floats via struct unpacking + digest = hashlib.sha256(fragment.encode("utf-8")).digest() + # Use the digest bytes to seed multiple positions + for offset in range(0, 32, 4): + idx_bytes = digest[offset:offset + 4] + idx = int.from_bytes(idx_bytes, "big") % self.dims + # Use sign from another part of the hash + sign_bit = digest[(offset + 2) % 32] & 1 + weight = 1.0 if sign_bit else -1.0 + vector[idx] += weight + + # L2 normalize to unit vector norm = math.sqrt(sum(x * x for x in vector)) if norm > 0: vector = [x / norm for x in vector] return vector + + def embed_batch( + self, texts: List[str], memory_action: Optional[str] = None + ) -> List[List[float]]: + return [self.embed(t, memory_action=memory_action) for t in texts] diff --git a/engram/llms/gemini.py b/engram/llms/gemini.py index 5a67fd2..2d038d5 100644 --- a/engram/llms/gemini.py +++ b/engram/llms/gemini.py @@ -2,6 +2,8 @@ import os from typing import Optional +from google import genai + from engram.llms.base import BaseLLM logger = logging.getLogger(__name__) @@ -17,53 +19,19 @@ def __init__(self, config: Optional[dict] = None): self.model = self.config.get("model", "gemini-2.0-flash") self.temperature = self.config.get("temperature", 0.1) self.max_tokens = self.config.get("max_tokens", 1024) - - self._client_type = None - self._model = None - self._client = None - - try: - import google.generativeai as genai - - genai.configure(api_key=self.api_key) - self._client_type = "generativeai" - self._genai = genai - self._model = genai.GenerativeModel(self.model) - except ImportError: - try: - from google import genai - - self._client_type = "genai" - self._client = genai.Client(api_key=self.api_key) - except Exception as exc: - raise ImportError( - "Install google-generativeai or google-genai to use GeminiLLM" - ) from exc + self._client = genai.Client(api_key=self.api_key) def generate(self, prompt: str) -> str: try: - if self._client_type == "generativeai": - response = self._model.generate_content( - prompt, - generation_config={ - "temperature": self.temperature, - "max_output_tokens": self.max_tokens, - }, - ) - return getattr(response, "text", "") or "" - - if self._client_type == "genai": - response = self._client.models.generate_content( - model=self.model, - contents=prompt, - config={ - "temperature": self.temperature, - "max_output_tokens": self.max_tokens, - }, - ) - return _extract_text_from_response(response) - - raise RuntimeError("Gemini LLM client not initialized") + response = self._client.models.generate_content( + model=self.model, + contents=prompt, + config={ + "temperature": self.temperature, + "max_output_tokens": self.max_tokens, + }, + ) + return _extract_text_from_response(response) except RuntimeError: raise except Exception as exc: diff --git a/engram/mcp_server.py b/engram/mcp_server.py index 38085fc..cbcb2f3 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -1,21 +1,26 @@ +"""Engram MCP Server — 8 tools, minimal boilerplate. + +Tools: +1. remember — Quick-save (content → memory, infer=False) +2. search_memory — Semantic search +3. get_memory — Fetch by ID +4. get_all_memories — List with filters +5. engram_context — Session-start digest (top memories) +6. get_last_session — Handoff: load prior session +7. save_session_digest — Handoff: save current session +8. get_memory_stats — Quick health check """ -engram MCP Server for Claude Code integration. -This server exposes engram's core memory capabilities as MCP tools. -Governance, handoff, and active memory tools live in engram-enterprise. -""" - -import importlib import json import logging import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import Tool, TextContent -from engram.memory.main import Memory +from engram.memory.main import FullMemory, Memory from engram.configs.base import ( MemoryConfig, VectorStoreConfig, @@ -28,25 +33,19 @@ def _get_embedding_dims_for_model(model: str, provider: str) -> int: - """Get the embedding dimensions for a given model.""" EMBEDDING_DIMS = { "models/text-embedding-005": 768, "text-embedding-005": 768, - "models/text-embedding-004": 768, - "text-embedding-004": 768, "gemini-embedding-001": 3072, "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, } - env_dims = os.environ.get("FADEM_EMBEDDING_DIMS") if env_dims: return int(env_dims) - if model in EMBEDDING_DIMS: return EMBEDDING_DIMS[model] - if provider == "gemini": return 3072 elif provider == "openai": @@ -55,71 +54,73 @@ def _get_embedding_dims_for_model(model: str, provider: str) -> int: def get_memory_instance() -> Memory: - """Create and return a configured Memory instance.""" + """Create and return a configured Memory instance (FullMemory for MCP).""" gemini_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") openai_key = os.environ.get("OPENAI_API_KEY") if gemini_key: embedder_model = os.environ.get("FADEM_EMBEDDER_MODEL", "gemini-embedding-001") embedding_dims = _get_embedding_dims_for_model(embedder_model, "gemini") - llm_config = LLMConfig( provider="gemini", config={ "model": os.environ.get("FADEM_LLM_MODEL", "gemini-2.0-flash"), - "temperature": 0.1, - "max_tokens": 1024, - "api_key": gemini_key, + "temperature": 0.1, "max_tokens": 1024, "api_key": gemini_key, } ) embedder_config = EmbedderConfig( provider="gemini", - config={ - "model": embedder_model, - "api_key": gemini_key, - } + config={"model": embedder_model, "api_key": gemini_key}, ) elif openai_key: embedder_model = os.environ.get("FADEM_EMBEDDER_MODEL", "text-embedding-3-small") embedding_dims = _get_embedding_dims_for_model(embedder_model, "openai") - llm_config = LLMConfig( provider="openai", config={ "model": os.environ.get("FADEM_LLM_MODEL", "gpt-4o-mini"), - "temperature": 0.1, - "max_tokens": 1024, - "api_key": openai_key, + "temperature": 0.1, "max_tokens": 1024, "api_key": openai_key, } ) embedder_config = EmbedderConfig( provider="openai", - config={ - "model": embedder_model, - "api_key": openai_key, - } + config={"model": embedder_model, "api_key": openai_key}, ) else: - raise RuntimeError( - "No API key found. Set GOOGLE_API_KEY, GEMINI_API_KEY, or OPENAI_API_KEY environment variable." + # Zero-config: SimpleEmbedder + MockLLM + embedding_dims = 384 + llm_config = LLMConfig(provider="mock", config={}) + embedder_config = EmbedderConfig( + provider="simple", config={"embedding_dims": 384}, ) vec_db_path = os.environ.get( "FADEM_VEC_DB_PATH", - os.path.join(os.path.expanduser("~"), ".engram", "sqlite_vec.db") - ) - vector_store_config = VectorStoreConfig( - provider="sqlite_vec", - config={ - "path": vec_db_path, - "collection_name": os.environ.get("FADEM_COLLECTION", "fadem_memories"), - "embedding_model_dims": embedding_dims, - } + os.path.join(os.path.expanduser("~"), ".engram", "sqlite_vec.db"), ) + # Use in-memory vector store for simple embedder (dims mismatch with sqlite_vec) + if embedder_config.provider == "simple": + vector_store_config = VectorStoreConfig( + provider="memory", + config={ + "collection_name": os.environ.get("FADEM_COLLECTION", "fadem_memories"), + "embedding_model_dims": embedding_dims, + }, + ) + else: + vector_store_config = VectorStoreConfig( + provider="sqlite_vec", + config={ + "path": vec_db_path, + "collection_name": os.environ.get("FADEM_COLLECTION", "fadem_memories"), + "embedding_model_dims": embedding_dims, + }, + ) + history_db_path = os.environ.get( "FADEM_HISTORY_DB", - os.path.join(os.path.expanduser("~"), ".engram", "history.db") + os.path.join(os.path.expanduser("~"), ".engram", "history.db"), ) fadem_config = FadeMemConfig( @@ -137,709 +138,161 @@ def get_memory_instance() -> Memory: engram=fadem_config, ) - return Memory(config) + return FullMemory(config) -# Global memory instance (lazy initialized) +# Global memory instance (lazy) _memory: Optional[Memory] = None def get_memory() -> Memory: - """Get or create the global memory instance.""" global _memory if _memory is None: _memory = get_memory_instance() return _memory -# Create the MCP server -server = Server("engram-memory") - -# Cached tool list -_tools_cache: Optional[List[Tool]] = None - +# ── MCP Server ── -# ── Power Package Auto-Discovery ── - -_POWER_PACKAGES = [ - ("engram_router", "engram_router.mcp_tools"), - ("engram_identity", "engram_identity.mcp_tools"), - ("engram_heartbeat", "engram_heartbeat.mcp_tools"), - ("engram_policy", "engram_policy.mcp_tools"), - ("engram_skills", "engram_skills.mcp_tools"), - ("engram_spawn", "engram_spawn.mcp_tools"), - ("engram_resilience", "engram_resilience.mcp_tools"), - ("engram_metamemory", "engram_metamemory.mcp_tools"), - ("engram_prospective", "engram_prospective.mcp_tools"), - ("engram_procedural", "engram_procedural.mcp_tools"), - ("engram_reconsolidation", "engram_reconsolidation.mcp_tools"), - ("engram_failure", "engram_failure.mcp_tools"), - ("engram_working", "engram_working.mcp_tools"), - ("engram_warroom", "engram_warroom.mcp_tools"), -] +server = Server("engram-memory") -_POWER_HANDLER_MAP = [ - ("_router_tools", "_router_handlers", "dict"), - ("_identity_tools", "_identity_handlers", "fn"), - ("_heartbeat_tools", "_heartbeat_handler", "fn"), - ("_policy_tools", "_policy_handler", "fn"), - ("_skills_tools", "_skills_handler", "fn"), - ("_spawn_tools", "_spawn_handler", "fn"), - ("_resilience_tools", "_resilience_handler", "fn"), - ("_metamemory_tools", "_metamemory_handler", "fn"), - ("_prospective_tools", "_prospective_handler", "fn"), - ("_procedural_tools", "_procedural_handler", "fn"), - ("_reconsolidation_tools", "_reconsolidation_handler", "fn"), - ("_failure_tools", "_failure_handler", "fn"), - ("_working_tools", "_working_handler", "fn"), - ("_warroom_tools", "_warroom_handler", "fn"), +# Tool definitions — 8 tools total +TOOLS = [ + Tool( + name="remember", + description="Quick-save a fact or preference to memory. Creates a staging proposal commit with source_app='claude-code' and infer=False by default.", + inputSchema={ + "type": "object", + "properties": { + "content": {"type": "string", "description": "The fact or preference to remember"}, + "categories": {"type": "array", "items": {"type": "string"}, "description": "Optional categories to tag this memory with (e.g., ['preferences', 'coding'])"}, + }, + "required": ["content"], + }, + ), + Tool( + name="search_memory", + description="Search engram for relevant memories by semantic query. The UserPromptSubmit hook handles background search automatically — call this tool only for explicit user recall requests such as 'what did we discuss about X?' or 'recall my preference for Y'.", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "The search query - what you're trying to remember"}, + "user_id": {"type": "string", "description": "User identifier (default: 'default')"}, + "agent_id": {"type": "string", "description": "Agent identifier to scope search to (optional)"}, + "limit": {"type": "integer", "description": "Maximum number of results to return (default: 10)"}, + "categories": {"type": "array", "items": {"type": "string"}, "description": "Filter results by categories"}, + }, + "required": ["query"], + }, + ), + Tool( + name="get_memory", + description="Retrieve a single memory by its ID. Use this only when you already have a memory_id from a prior search or listing. Do not use for discovery — use search_memory instead.", + inputSchema={ + "type": "object", + "properties": { + "memory_id": {"type": "string", "description": "The ID of the memory to retrieve"}, + }, + "required": ["memory_id"], + }, + ), + Tool( + name="get_all_memories", + description="Get all stored memories for a user — use for inventory, audit, or when the user wants a full listing. Not for finding specific information; use search_memory for that.", + inputSchema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User identifier (default: 'default')"}, + "agent_id": {"type": "string", "description": "Agent identifier (optional)"}, + "limit": {"type": "integer", "description": "Maximum number of memories to return (default: 50)"}, + "layer": {"type": "string", "enum": ["sml", "lml"], "description": "Filter by memory layer: 'sml' (short-term) or 'lml' (long-term)"}, + }, + }, + ), + Tool( + name="engram_context", + description="Session-start digest. Call once at the beginning of a new conversation to load context from prior sessions. Returns top memories sorted by strength with long-term memories (LML) first.", + inputSchema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User identifier to load context for (default: 'default')"}, + "limit": {"type": "integer", "description": "Maximum number of memories to return in the digest (default: 15)"}, + }, + }, + ), + Tool( + name="get_last_session", + description="Get the most recent session digest to continue where the last agent left off. Returns full handoff context including linked memories.", + inputSchema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User identifier (default: 'default')"}, + "requester_agent_id": {"type": "string", "description": "Agent identity performing this read."}, + "repo": {"type": "string", "description": "Filter by repository/project path"}, + "agent_id": {"type": "string", "description": "Filter by source agent identifier"}, + "fallback_log_recovery": {"type": "boolean", "default": True, "description": "When true and no DB session found, attempt to reconstruct context from Claude Code conversation logs. Default: true."}, + }, + }, + ), + Tool( + name="save_session_digest", + description="Save a session digest before ending or when interrupted. Enables cross-agent handoff so another agent can continue where you left off.", + inputSchema={ + "type": "object", + "properties": { + "task_summary": {"type": "string", "description": "What was the agent doing — the main task being worked on"}, + "repo": {"type": "string", "description": "Repository or project path for scoping"}, + "status": {"type": "string", "enum": ["active", "paused", "completed", "abandoned"], "description": "Session status (default: 'paused')"}, + "decisions_made": {"type": "array", "items": {"type": "string"}, "description": "Key decisions made during the session"}, + "files_touched": {"type": "array", "items": {"type": "string"}, "description": "File paths modified during the session"}, + "todos_remaining": {"type": "array", "items": {"type": "string"}, "description": "Remaining work items for the next agent"}, + "blockers": {"type": "array", "items": {"type": "string"}, "description": "Known blockers for the receiving agent"}, + "key_commands": {"type": "array", "items": {"type": "string"}, "description": "Important commands run during the session"}, + "test_results": {"type": "array", "items": {"type": "string"}, "description": "Recent test outcomes"}, + "agent_id": {"type": "string", "description": "Identifier of the agent saving the digest (default: 'claude-code')"}, + "requester_agent_id": {"type": "string", "description": "Agent identity performing this write (defaults to agent_id)."}, + }, + "required": ["task_summary"], + }, + ), + Tool( + name="get_memory_stats", + description="Get statistics about the memory store including counts and layer distribution. Call when the user asks about memory health, wants an overview of what's stored, or runs /engram:status.", + inputSchema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User identifier to get stats for (default: all users)"}, + "agent_id": {"type": "string", "description": "Agent identifier to scope stats to (optional)"}, + }, + }, + ), ] -_power_tool_handlers: Dict[str, Callable] = {} -_power_discovered = False - - -def _discover_power_tools(srv: Server, memory: "Memory") -> None: - """Auto-discover and register MCP tools from installed power packages.""" - global _power_discovered - if _power_discovered: - return - _power_discovered = True - - for pkg_name, module_path in _POWER_PACKAGES: - try: - mod = importlib.import_module(module_path) - mod.register_tools(srv, memory) - logger.info("Loaded MCP tools from %s", pkg_name) - except ImportError: - pass # package not installed — skip - except Exception as e: - logger.warning("Failed to load MCP tools from %s: %s", pkg_name, e) - - # Consolidate all handlers into a single dispatch dict - for tools_attr, handler_attr, handler_type in _POWER_HANDLER_MAP: - tool_defs = getattr(srv, tools_attr, None) - handler = getattr(srv, handler_attr, None) - if not tool_defs or not handler: - continue - if handler_type == "dict": - for name in tool_defs: - if name in handler: - _power_tool_handlers[name] = handler[name] - else: - for name in tool_defs: - _power_tool_handlers[name] = (lambda h, n: lambda args: h(n, args))(handler, name) - - -def _get_power_tool_defs() -> Dict[str, dict]: - """Collect all power tool definitions from server attributes.""" - defs = {} - for tools_attr, _, _ in _POWER_HANDLER_MAP: - tool_dict = getattr(server, tools_attr, None) - if tool_dict: - defs.update(tool_dict) - return defs - - -@server.list_tools() -async def list_tools() -> List[Tool]: - """List available engram tools.""" - global _tools_cache - if _tools_cache is not None: - return list(_tools_cache) - - # Auto-discover power packages - try: - memory = get_memory() - _discover_power_tools(server, memory) - except Exception as e: - logger.warning("Power tool discovery failed: %s", e) - - tools = [ - Tool( - name="add_memory", - description="Store a memory. Extracts key information from the content and saves it with semantic embedding for later retrieval.", - inputSchema={ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The memory content to store." - }, - "user_id": { - "type": "string", - "description": "User identifier (default: 'default')" - }, - "agent_id": { - "type": "string", - "description": "Agent identifier (optional)" - }, - "categories": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional categories to tag this memory with" - }, - "metadata": { - "type": "object", - "description": "Optional metadata to attach" - }, - "scope": { - "type": "string", - "description": "Confidentiality scope: work|personal|finance|health|private" - }, - }, - "required": ["content"] - } - ), - Tool( - name="search_memory", - description="Search for relevant memories by semantic query.", - inputSchema={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query" - }, - "user_id": { - "type": "string", - "description": "User identifier (default: 'default')" - }, - "agent_id": { - "type": "string", - "description": "Agent identifier (optional)" - }, - "limit": { - "type": "integer", - "description": "Maximum results (default: 10)" - }, - "categories": { - "type": "array", - "items": {"type": "string"}, - "description": "Filter by categories" - }, - }, - "required": ["query"] - } - ), - Tool( - name="get_all_memories", - description="Get all stored memories for a user.", - inputSchema={ - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "User identifier (default: 'default')" - }, - "agent_id": { - "type": "string", - "description": "Agent identifier (optional)" - }, - "limit": { - "type": "integer", - "description": "Maximum memories to return (default: 50)" - }, - "layer": { - "type": "string", - "enum": ["sml", "lml"], - "description": "Filter by memory layer" - } - } - } - ), - Tool( - name="get_memory", - description="Retrieve a single memory by its ID.", - inputSchema={ - "type": "object", - "properties": { - "memory_id": { - "type": "string", - "description": "The ID of the memory to retrieve" - } - }, - "required": ["memory_id"] - } - ), - Tool( - name="update_memory", - description="Update an existing memory's content.", - inputSchema={ - "type": "object", - "properties": { - "memory_id": { - "type": "string", - "description": "The ID of the memory to update" - }, - "content": { - "type": "string", - "description": "The new content" - } - }, - "required": ["memory_id", "content"] - } - ), - Tool( - name="delete_memory", - description="Permanently delete a memory by its ID.", - inputSchema={ - "type": "object", - "properties": { - "memory_id": { - "type": "string", - "description": "The ID of the memory to delete" - } - }, - "required": ["memory_id"] - } - ), - Tool( - name="get_memory_stats", - description="Get statistics about the memory store.", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "agent_id": {"type": "string"}, - } - } - ), - Tool( - name="apply_memory_decay", - description="Apply the memory-decay algorithm to simulate natural forgetting.", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "agent_id": {"type": "string"}, - } - } - ), - Tool( - name="engram_context", - description="Session-start digest. Returns top memories sorted by strength with LML first.", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "limit": {"type": "integer"}, - } - } - ), - Tool( - name="remember", - description="Quick-save a fact or preference to memory.", - inputSchema={ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The fact or preference to remember" - }, - "categories": { - "type": "array", - "items": {"type": "string"}, - } - }, - "required": ["content"] - } - ), - # ---- Episodic Scene tools ---- - Tool( - name="get_scene", - description="Get a specific episodic scene by ID.", - inputSchema={ - "type": "object", - "properties": { - "scene_id": {"type": "string"}, - "user_id": {"type": "string"}, - }, - "required": ["scene_id"] - } - ), - Tool( - name="list_scenes", - description="List episodic scenes chronologically.", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "topic": {"type": "string"}, - "start_after": {"type": "string"}, - "start_before": {"type": "string"}, - "limit": {"type": "integer"}, - } - } - ), - Tool( - name="search_scenes", - description="Semantic search over episodic scene summaries.", - inputSchema={ - "type": "object", - "properties": { - "query": {"type": "string"}, - "user_id": {"type": "string"}, - "limit": {"type": "integer"}, - }, - "required": ["query"] - } - ), - # ---- Character Profile tools ---- - Tool( - name="get_profile", - description="Get a character profile by ID.", - inputSchema={ - "type": "object", - "properties": { - "profile_id": {"type": "string"}, - }, - "required": ["profile_id"] - } - ), - Tool( - name="list_profiles", - description="List all character profiles for a user.", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - } - } - ), - Tool( - name="search_profiles", - description="Search character profiles by name or description.", - inputSchema={ - "type": "object", - "properties": { - "query": {"type": "string"}, - "user_id": {"type": "string"}, - "limit": {"type": "integer"}, - }, - "required": ["query"] - } - ), - # ---- Handoff / Session Continuity tools ---- - Tool( - name="get_last_session", - description="Load previous session context for continuity. Returns task summary, decisions, files touched, and remaining TODOs from the last session. Falls back to parsing Claude Code conversation logs if no stored session exists.", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string", "description": "User identifier (default: 'default')"}, - "requester_agent_id": {"type": "string", "description": "ID of the agent requesting the session"}, - "repo": {"type": "string", "description": "Absolute path to the repo root (enables log fallback)"}, - "agent_id": {"type": "string", "description": "Source agent whose session to load (default: 'mcp-server')"}, - "fallback_log_recovery": {"type": "boolean", "default": True, "description": "Fall back to JSONL log parsing if no bus session exists"}, - }, - } - ), - Tool( - name="save_session_digest", - description="Save session context for the next agent. Call on milestones and before pausing/ending.", - inputSchema={ - "type": "object", - "properties": { - "task_summary": {"type": "string", "description": "Summary of the current task"}, - "repo": {"type": "string", "description": "Absolute path to the repo root"}, - "status": {"type": "string", "enum": ["active", "paused", "completed"], "description": "Session status"}, - "decisions_made": {"type": "array", "items": {"type": "string"}, "description": "Key decisions made during the session"}, - "files_touched": {"type": "array", "items": {"type": "string"}, "description": "Files read, edited, or created"}, - "todos_remaining": {"type": "array", "items": {"type": "string"}, "description": "Outstanding work items"}, - "blockers": {"type": "array", "items": {"type": "string"}, "description": "Known blockers or issues"}, - "key_commands": {"type": "array", "items": {"type": "string"}, "description": "Important commands run"}, - "test_results": {"type": "string", "description": "Summary of test outcomes"}, - "agent_id": {"type": "string", "description": "Agent saving the session (default: 'claude-code')"}, - "requester_agent_id": {"type": "string", "description": "ID of the agent making the request"}, - }, - "required": ["task_summary"], - } - ), - # ---- Task tools ---- - Tool( - name="create_task", - description="Create a task (with dedup — returns existing if title matches an active task).", - inputSchema={ - "type": "object", - "properties": { - "title": {"type": "string", "description": "Task title (used for dedup)"}, - "description": {"type": "string", "description": "Detailed description"}, - "priority": {"type": "string", "enum": ["low", "normal", "high", "urgent"]}, - "status": {"type": "string", "enum": ["inbox", "assigned", "active", "review", "blocked"]}, - "assigned_agent": {"type": "string", "description": "Agent to assign"}, - "due_date": {"type": "string", "description": "ISO date string"}, - "tags": {"type": "array", "items": {"type": "string"}}, - "user_id": {"type": "string"}, - "metadata": {"type": "object", "description": "Arbitrary user-defined attributes"}, - }, - "required": ["title"], - } - ), - Tool( - name="list_tasks", - description="List tasks with optional filters (status, priority, assignee).", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "status": {"type": "string", "enum": ["inbox", "assigned", "active", "review", "blocked", "done", "archived"]}, - "priority": {"type": "string", "enum": ["low", "normal", "high", "urgent"]}, - "assigned_agent": {"type": "string"}, - "limit": {"type": "integer"}, - }, - } - ), - Tool( - name="get_task", - description="Get full task details by memory ID.", - inputSchema={ - "type": "object", - "properties": { - "task_id": {"type": "string", "description": "Task memory ID"}, - }, - "required": ["task_id"], - } - ), - Tool( - name="update_task", - description="Update task fields (status, priority, assignee, title, description, tags, due_date, or custom metadata).", - inputSchema={ - "type": "object", - "properties": { - "task_id": {"type": "string", "description": "Task memory ID"}, - "status": {"type": "string", "enum": ["inbox", "assigned", "active", "review", "blocked", "done", "archived"]}, - "priority": {"type": "string", "enum": ["low", "normal", "high", "urgent"]}, - "assigned_agent": {"type": "string"}, - "title": {"type": "string"}, - "description": {"type": "string"}, - "due_date": {"type": "string"}, - "tags": {"type": "array", "items": {"type": "string"}}, - }, - "required": ["task_id"], - } - ), - Tool( - name="complete_task", - description="Mark a task as done (shorthand for update_task with status=done).", - inputSchema={ - "type": "object", - "properties": { - "task_id": {"type": "string", "description": "Task memory ID"}, - }, - "required": ["task_id"], - } - ), - Tool( - name="add_task_comment", - description="Add a comment to a task.", - inputSchema={ - "type": "object", - "properties": { - "task_id": {"type": "string", "description": "Task memory ID"}, - "text": {"type": "string", "description": "Comment text"}, - "agent": {"type": "string", "description": "Agent adding the comment"}, - }, - "required": ["task_id", "text"], - } - ), - Tool( - name="search_tasks", - description="Semantic search over tasks.", - inputSchema={ - "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"}, - "user_id": {"type": "string"}, - "limit": {"type": "integer"}, - }, - "required": ["query"], - } - ), - Tool( - name="get_pending_tasks", - description="Get actionable tasks (not done/archived).", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "assigned_agent": {"type": "string"}, - }, - } - ), - # ---- Salience tools ---- - Tool( - name="tag_salience", - description="Compute and tag a memory's emotional salience (valence + arousal).", - inputSchema={ - "type": "object", - "properties": { - "memory_id": {"type": "string", "description": "Memory ID to tag"}, - "use_llm": {"type": "boolean", "description": "Use LLM for more accurate scoring", "default": False}, - }, - "required": ["memory_id"], - } - ), - Tool( - name="search_by_salience", - description="Find high-salience (emotionally significant) memories.", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "min_salience": {"type": "number", "description": "Minimum salience score (0-1)", "default": 0.3}, - "limit": {"type": "integer", "default": 20}, - }, - } - ), - Tool( - name="get_salience_stats", - description="Get statistics on salience tagging across memories.", - inputSchema={ - "type": "object", - "properties": {"user_id": {"type": "string"}}, - } - ), - # ---- Causal tools ---- - Tool( - name="add_causal_link", - description="Add a causal relationship between two memories (caused_by, led_to, prevents, enables, requires).", - inputSchema={ - "type": "object", - "properties": { - "source_id": {"type": "string", "description": "Source memory ID"}, - "target_id": {"type": "string", "description": "Target memory ID"}, - "relation_type": { - "type": "string", - "enum": ["caused_by", "led_to", "prevents", "enables", "requires"], - "description": "Type of causal relationship", - }, - }, - "required": ["source_id", "target_id", "relation_type"], - } - ), - Tool( - name="get_causal_chain", - description="Traverse causal links from a memory (what caused it, or what it caused).", - inputSchema={ - "type": "object", - "properties": { - "memory_id": {"type": "string", "description": "Starting memory ID"}, - "direction": { - "type": "string", - "enum": ["backward", "forward"], - "description": "backward=what caused this, forward=what this caused", - "default": "backward", - }, - "depth": {"type": "integer", "description": "Max traversal depth", "default": 5}, - }, - "required": ["memory_id"], - } - ), - Tool( - name="query_causes", - description="Get both causes and effects for a memory.", - inputSchema={ - "type": "object", - "properties": { - "memory_id": {"type": "string", "description": "Memory ID to query"}, - "depth": {"type": "integer", "default": 3}, - }, - "required": ["memory_id"], - } - ), - # ---- AGI Loop tools ---- - Tool( - name="get_agi_status", - description="Get the status of all AGI cognitive subsystems.", - inputSchema={ - "type": "object", - "properties": {"user_id": {"type": "string"}}, - } - ), - Tool( - name="run_agi_cycle", - description="Run one iteration of the full AGI cognitive cycle (consolidate, decay, reconsolidate, etc).", - inputSchema={ - "type": "object", - "properties": { - "user_id": {"type": "string"}, - "context": {"type": "string", "description": "Current context for reconsolidation"}, - }, - } - ), - Tool( - name="get_system_health", - description="Report health status across all cognitive subsystems.", - inputSchema={ - "type": "object", - "properties": {"user_id": {"type": "string"}}, - } - ), - ] - # Append tools from installed power packages - for name, defn in _get_power_tool_defs().items(): - tools.append(Tool( - name=name, - description=defn["description"], - inputSchema=defn["inputSchema"], - )) - - _tools_cache = tools - return list(tools) - - -# Tool handler registry -_TOOL_HANDLERS: Dict[str, Callable] = {} - - -def _tool_handler(name: str): - def decorator(fn): - _TOOL_HANDLERS[name] = fn - return fn - return decorator +# ── Tool Handlers ── -@_tool_handler("add_memory") -def _handle_add_memory(memory: "Memory", arguments: Dict[str, Any]) -> Any: - content = arguments.get("content", "") - user_id = arguments.get("user_id", "default") +def _handle_remember(memory, args): return memory.add( - messages=content, - user_id=user_id, - agent_id=arguments.get("agent_id"), - categories=arguments.get("categories"), - metadata=arguments.get("metadata"), - scope=arguments.get("scope", "work"), - source_app="mcp", - infer=False, - ) - - -@_tool_handler("remember") -def _handle_remember(memory: "Memory", arguments: Dict[str, Any]) -> Any: - return memory.add( - messages=arguments.get("content", ""), + messages=args.get("content", ""), user_id="default", agent_id="claude-code", - categories=arguments.get("categories"), + categories=args.get("categories"), source_app="claude-code", infer=False, ) -@_tool_handler("search_memory") -def _handle_search_memory(memory: "Memory", arguments: Dict[str, Any]) -> Any: +def _handle_search_memory(memory, args): try: - limit = max(1, min(1000, int(arguments.get("limit", 10)))) + limit = max(1, min(1000, int(args.get("limit", 10)))) except (ValueError, TypeError): limit = 10 result = memory.search( - query=arguments.get("query", ""), - user_id=arguments.get("user_id", "default"), - agent_id=arguments.get("agent_id"), + query=args.get("query", ""), + user_id=args.get("user_id", "default"), + agent_id=args.get("agent_id"), limit=limit, - categories=arguments.get("categories"), + categories=args.get("categories"), ) if "results" in result: result["results"] = [ @@ -855,17 +308,31 @@ def _handle_search_memory(memory: "Memory", arguments: Dict[str, Any]) -> Any: return result -@_tool_handler("get_all_memories") -def _handle_get_all_memories(memory: "Memory", arguments: Dict[str, Any]) -> Any: +def _handle_get_memory(memory, args): + result = memory.get(args.get("memory_id", "")) + if result: + return { + "id": result["id"], + "memory": result["memory"], + "layer": result.get("layer", "sml"), + "strength": round(result.get("strength", 1.0), 3), + "categories": result.get("categories", []), + "created_at": result.get("created_at"), + "access_count": result.get("access_count", 0), + } + return {"error": "Memory not found"} + + +def _handle_get_all_memories(memory, args): try: - limit = max(1, min(1000, int(arguments.get("limit", 50)))) + limit = max(1, min(1000, int(args.get("limit", 50)))) except (ValueError, TypeError): limit = 50 result = memory.get_all( - user_id=arguments.get("user_id", "default"), - agent_id=arguments.get("agent_id"), + user_id=args.get("user_id", "default"), + agent_id=args.get("agent_id"), limit=limit, - layer=arguments.get("layer"), + layer=args.get("layer"), ) if "results" in result: result["results"] = [ @@ -881,53 +348,10 @@ def _handle_get_all_memories(memory: "Memory", arguments: Dict[str, Any]) -> Any return result -@_tool_handler("get_memory") -def _handle_get_memory(memory: "Memory", arguments: Dict[str, Any]) -> Any: - result = memory.get(arguments.get("memory_id", "")) - if result: - return { - "id": result["id"], - "memory": result["memory"], - "layer": result.get("layer", "sml"), - "strength": round(result.get("strength", 1.0), 3), - "categories": result.get("categories", []), - "created_at": result.get("created_at"), - "access_count": result.get("access_count", 0), - } - return {"error": "Memory not found"} - - -@_tool_handler("update_memory") -def _handle_update_memory(memory: "Memory", arguments: Dict[str, Any]) -> Any: - return memory.update(arguments.get("memory_id", ""), arguments.get("content", "")) - - -@_tool_handler("delete_memory") -def _handle_delete_memory(memory: "Memory", arguments: Dict[str, Any]) -> Any: - return memory.delete(arguments.get("memory_id", "")) - - -@_tool_handler("get_memory_stats") -def _handle_get_memory_stats(memory: "Memory", arguments: Dict[str, Any]) -> Any: - return memory.get_stats( - user_id=arguments.get("user_id"), - agent_id=arguments.get("agent_id"), - ) - - -@_tool_handler("apply_memory_decay") -def _handle_apply_memory_decay(memory: "Memory", arguments: Dict[str, Any]) -> Any: - user_id = arguments.get("user_id") - agent_id = arguments.get("agent_id") - scope = {"user_id": user_id, "agent_id": agent_id} if user_id or agent_id else None - return memory.apply_decay(scope=scope) - - -@_tool_handler("engram_context") -def _handle_engram_context(memory: "Memory", arguments: Dict[str, Any]) -> Any: - user_id = arguments.get("user_id", "default") +def _handle_engram_context(memory, args): + user_id = args.get("user_id", "default") try: - limit = max(1, min(100, int(arguments.get("limit", 15)))) + limit = max(1, min(100, int(args.get("limit", 15)))) except (ValueError, TypeError): limit = 15 all_result = memory.get_all(user_id=user_id, limit=limit * 3) @@ -935,7 +359,7 @@ def _handle_engram_context(memory: "Memory", arguments: Dict[str, Any]) -> Any: layer_order = {"lml": 0, "sml": 1} all_memories.sort(key=lambda m: ( layer_order.get(m.get("layer", "sml"), 1), - -float(m.get("strength", 1.0)) + -float(m.get("strength", 1.0)), )) digest = [ { @@ -947,445 +371,80 @@ def _handle_engram_context(memory: "Memory", arguments: Dict[str, Any]) -> Any: } for m in all_memories[:limit] ] - # Surface pending tasks in context digest - try: - tm = _get_task_manager(memory) - pending = tm.get_pending_tasks(user_id=user_id) - pending_summary = [ - {"id": t["id"], "title": t["title"], "status": t["status"], "priority": t["priority"]} - for t in pending[:5] - ] - except Exception: - pending = [] - pending_summary = [] - return { "digest": digest, "total_in_store": len(all_memories), "returned": len(digest), - "pending_tasks": pending_summary, - "pending_task_count": len(pending), - } - - -@_tool_handler("get_scene") -def _handle_get_scene(memory: "Memory", arguments: Dict[str, Any]) -> Any: - scene_id = arguments.get("scene_id", "") - scene = memory.db.get_scene(scene_id) - return scene if scene else {"error": "Scene not found"} - - -@_tool_handler("list_scenes") -def _handle_list_scenes(memory: "Memory", arguments: Dict[str, Any]) -> Any: - try: - scene_limit = max(1, min(200, int(arguments.get("limit", 20)))) - except (ValueError, TypeError): - scene_limit = 20 - scenes = memory.get_scenes( - user_id=arguments.get("user_id", "default"), - topic=arguments.get("topic"), - start_after=arguments.get("start_after"), - start_before=arguments.get("start_before"), - limit=scene_limit, - ) - return { - "scenes": [ - { - "id": s["id"], - "title": s.get("title"), - "topic": s.get("topic"), - "summary": s.get("summary"), - "start_time": s.get("start_time"), - "end_time": s.get("end_time"), - "memory_count": len(s.get("memory_ids", [])), - } - for s in scenes - ], - "total": len(scenes), - } - - -@_tool_handler("search_scenes") -def _handle_search_scenes(memory: "Memory", arguments: Dict[str, Any]) -> Any: - try: - limit = max(1, min(100, int(arguments.get("limit", 10)))) - except (ValueError, TypeError): - limit = 10 - scenes = memory.search_scenes( - query=arguments.get("query", ""), - user_id=arguments.get("user_id", "default"), - limit=limit, - ) - return { - "scenes": [ - { - "id": s.get("id"), - "title": s.get("title"), - "summary": s.get("summary"), - "topic": s.get("topic"), - "start_time": s.get("start_time"), - "memory_count": len(s.get("memory_ids", [])), - } - for s in scenes - ], - "total": len(scenes), } -@_tool_handler("get_profile") -def _handle_get_profile(memory: "Memory", arguments: Dict[str, Any]) -> Any: - profile = memory.get_profile(arguments.get("profile_id", "")) - if profile: - profile.pop("embedding", None) - return profile - return {"error": "Profile not found"} - - -@_tool_handler("list_profiles") -def _handle_list_profiles(memory: "Memory", arguments: Dict[str, Any]) -> Any: - profiles = memory.get_all_profiles(user_id=arguments.get("user_id", "default")) - return { - "profiles": [ - { - "id": p["id"], - "name": p.get("name"), - "profile_type": p.get("profile_type"), - "narrative": p.get("narrative"), - "fact_count": len(p.get("facts", [])), - "preference_count": len(p.get("preferences", [])), - } - for p in profiles - ], - "total": len(profiles), - } - - -@_tool_handler("search_profiles") -def _handle_search_profiles(memory: "Memory", arguments: Dict[str, Any]) -> Any: - try: - limit = max(1, min(100, int(arguments.get("limit", 10)))) - except (ValueError, TypeError): - limit = 10 - profiles = memory.search_profiles( - query=arguments.get("query", ""), - user_id=arguments.get("user_id", "default"), - limit=limit, - ) - return { - "profiles": [ - { - "id": p["id"], - "name": p.get("name"), - "profile_type": p.get("profile_type"), - "narrative": p.get("narrative"), - "facts": p.get("facts", [])[:5], - "search_score": p.get("search_score"), - } - for p in profiles - ], - "total": len(profiles), - } - - -@_tool_handler("get_last_session") -def _handle_get_last_session(memory: "Memory", arguments: Dict[str, Any]) -> Any: +def _handle_get_last_session(_memory, args): from engram.core.kernel import get_last_session - agent_id = arguments.get("agent_id", "mcp-server") - repo = arguments.get("repo") - fallback = arguments.get("fallback_log_recovery", True) session = get_last_session( - agent_id=agent_id, - repo=repo, - fallback_log_recovery=fallback, + agent_id=args.get("agent_id", "mcp-server"), + repo=args.get("repo"), + fallback_log_recovery=args.get("fallback_log_recovery", True), ) if session is None: return {"status": "no_session", "message": "No previous session found."} return session -@_tool_handler("save_session_digest") -def _handle_save_session_digest(memory: "Memory", arguments: Dict[str, Any]) -> Any: +def _handle_save_session_digest(_memory, args): from engram.core.kernel import save_session_digest return save_session_digest( - task_summary=arguments.get("task_summary", ""), - agent_id=arguments.get("agent_id", "claude-code"), - repo=arguments.get("repo"), - status=arguments.get("status", "active"), - decisions_made=arguments.get("decisions_made"), - files_touched=arguments.get("files_touched"), - todos_remaining=arguments.get("todos_remaining"), - blockers=arguments.get("blockers"), - key_commands=arguments.get("key_commands"), - test_results=arguments.get("test_results"), - ) - - -# ---- Task tool handlers ---- - -def _get_task_manager(memory: "Memory"): - from engram.memory.tasks import TaskManager - return TaskManager(memory) - - -@_tool_handler("create_task") -def _handle_create_task(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - return tm.create_task( - title=arguments.get("title", ""), - description=arguments.get("description", ""), - priority=arguments.get("priority"), - status=arguments.get("status", "inbox"), - assignee=arguments.get("assigned_agent"), - due_date=arguments.get("due_date"), - tags=arguments.get("tags"), - user_id=arguments.get("user_id", "default"), - extra_metadata=arguments.get("metadata"), - ) - - -@_tool_handler("list_tasks") -def _handle_list_tasks(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - try: - limit = max(1, min(500, int(arguments.get("limit", 50)))) - except (ValueError, TypeError): - limit = 50 - tasks = tm.list_tasks( - user_id=arguments.get("user_id", "default"), - status=arguments.get("status"), - priority=arguments.get("priority"), - assignee=arguments.get("assigned_agent"), - limit=limit, - ) - return {"tasks": tasks, "total": len(tasks)} - - -@_tool_handler("get_task") -def _handle_get_task(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - task = tm.get_task(arguments.get("task_id", "")) - return task if task else {"error": "Task not found"} - - -@_tool_handler("update_task") -def _handle_update_task(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - task_id = arguments.pop("task_id", "") - updates = {k: v for k, v in arguments.items() if v is not None} - result = tm.update_task(task_id, updates) - return result if result else {"error": "Task not found"} - - -@_tool_handler("complete_task") -def _handle_complete_task(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - result = tm.complete_task(arguments.get("task_id", "")) - return result if result else {"error": "Task not found"} - - -@_tool_handler("add_task_comment") -def _handle_add_task_comment(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - result = tm.add_comment( - task_id=arguments.get("task_id", ""), - agent=arguments.get("agent", "unknown"), - text=arguments.get("text", ""), - ) - return result if result else {"error": "Task not found"} - - -@_tool_handler("search_tasks") -def _handle_search_tasks(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - try: - limit = max(1, min(100, int(arguments.get("limit", 10)))) - except (ValueError, TypeError): - limit = 10 - tasks = tm.search_tasks( - query=arguments.get("query", ""), - user_id=arguments.get("user_id", "default"), - limit=limit, + task_summary=args.get("task_summary", ""), + agent_id=args.get("agent_id", "claude-code"), + repo=args.get("repo"), + status=args.get("status", "active"), + decisions_made=args.get("decisions_made"), + files_touched=args.get("files_touched"), + todos_remaining=args.get("todos_remaining"), + blockers=args.get("blockers"), + key_commands=args.get("key_commands"), + test_results=args.get("test_results"), ) - return {"tasks": tasks, "total": len(tasks)} -@_tool_handler("get_pending_tasks") -def _handle_get_pending_tasks(memory: "Memory", arguments: Dict[str, Any]) -> Any: - tm = _get_task_manager(memory) - tasks = tm.get_pending_tasks( - user_id=arguments.get("user_id", "default"), - assignee=arguments.get("assigned_agent"), +def _handle_get_memory_stats(memory, args): + return memory.get_stats( + user_id=args.get("user_id"), + agent_id=args.get("agent_id"), ) - return {"tasks": tasks, "total": len(tasks)} - - -# ---- Salience tools ---- - -@_tool_handler("tag_salience") -def _handle_tag_salience(memory: "Memory", arguments: Dict[str, Any]) -> Any: - from engram.core.salience import compute_salience - memory_id = arguments.get("memory_id", "") - mem = memory.get(memory_id) - if not mem: - return {"error": "Memory not found"} - content = mem.get("memory", "") - salience = compute_salience(content, llm=getattr(memory, "llm", None), - use_llm=arguments.get("use_llm", False)) - md = mem.get("metadata", {}) or {} - md.update(salience) - memory.update(memory_id, {"metadata": md}) - return {"memory_id": memory_id, **salience} -@_tool_handler("search_by_salience") -def _handle_search_by_salience(memory: "Memory", arguments: Dict[str, Any]) -> Any: - user_id = arguments.get("user_id", "default") - min_salience = float(arguments.get("min_salience", 0.3)) - try: - limit = max(1, min(100, int(arguments.get("limit", 20)))) - except (ValueError, TypeError): - limit = 20 - all_mem = memory.get_all(user_id=user_id, limit=limit * 3) - items = all_mem.get("results", []) - results = [] - for m in items: - md = m.get("metadata", {}) or {} - score = md.get("sal_salience_score", 0.0) - if score >= min_salience: - results.append({ - "id": m["id"], - "memory": m.get("memory", ""), - "salience_score": score, - "valence": md.get("sal_valence", 0.0), - "arousal": md.get("sal_arousal", 0.0), - }) - results.sort(key=lambda x: x["salience_score"], reverse=True) - return {"results": results[:limit], "total": len(results)} - - -@_tool_handler("get_salience_stats") -def _handle_get_salience_stats(memory: "Memory", arguments: Dict[str, Any]) -> Any: - user_id = arguments.get("user_id", "default") - all_mem = memory.get_all(user_id=user_id, limit=500) - items = all_mem.get("results", []) - tagged = 0 - total_salience = 0.0 - high_salience = 0 - for m in items: - md = m.get("metadata", {}) or {} - score = md.get("sal_salience_score") - if score is not None: - tagged += 1 - total_salience += score - if score >= 0.5: - high_salience += 1 - return { - "total_memories": len(items), - "salience_tagged": tagged, - "avg_salience": round(total_salience / tagged, 3) if tagged else 0.0, - "high_salience_count": high_salience, - } - +HANDLERS = { + "remember": _handle_remember, + "search_memory": _handle_search_memory, + "get_memory": _handle_get_memory, + "get_all_memories": _handle_get_all_memories, + "engram_context": _handle_engram_context, + "get_last_session": _handle_get_last_session, + "save_session_digest": _handle_save_session_digest, + "get_memory_stats": _handle_get_memory_stats, +} -# ---- Causal tools ---- - -@_tool_handler("add_causal_link") -def _handle_add_causal_link(memory: "Memory", arguments: Dict[str, Any]) -> Any: - from engram.core.graph import RelationType - source_id = arguments.get("source_id", "") - target_id = arguments.get("target_id", "") - rel_type = arguments.get("relation_type", "caused_by") - if not hasattr(memory, "knowledge_graph") or not memory.knowledge_graph: - return {"error": "Knowledge graph not available"} - try: - rt = RelationType(rel_type) - except ValueError: - return {"error": f"Invalid relation type: {rel_type}"} - rel = memory.knowledge_graph.add_relationship(source_id, target_id, rt) - return rel.to_dict() - - -@_tool_handler("get_causal_chain") -def _handle_get_causal_chain(memory: "Memory", arguments: Dict[str, Any]) -> Any: - memory_id = arguments.get("memory_id", "") - direction = arguments.get("direction", "backward") - depth = min(10, max(1, int(arguments.get("depth", 5)))) - if not hasattr(memory, "knowledge_graph") or not memory.knowledge_graph: - return {"error": "Knowledge graph not available"} - chain = memory.knowledge_graph.get_causal_chain(memory_id, direction, depth) - return { - "memory_id": memory_id, - "direction": direction, - "chain": [ - {"memory_id": mid, "depth": d, "path": [r.to_dict() for r in path]} - for mid, d, path in chain - ], - "length": len(chain), - } - - -@_tool_handler("query_causes") -def _handle_query_causes(memory: "Memory", arguments: Dict[str, Any]) -> Any: - memory_id = arguments.get("memory_id", "") - depth = min(10, max(1, int(arguments.get("depth", 3)))) - if not hasattr(memory, "knowledge_graph") or not memory.knowledge_graph: - return {"error": "Knowledge graph not available"} - backward = memory.knowledge_graph.get_causal_chain(memory_id, "backward", depth) - forward = memory.knowledge_graph.get_causal_chain(memory_id, "forward", depth) - return { - "memory_id": memory_id, - "causes": [{"memory_id": mid, "depth": d} for mid, d, _ in backward], - "effects": [{"memory_id": mid, "depth": d} for mid, d, _ in forward], - } - - -# ---- AGI Loop tools ---- - -@_tool_handler("get_agi_status") -def _handle_get_agi_status(memory: "Memory", arguments: Dict[str, Any]) -> Any: - from engram.core.agi_loop import get_system_health - return get_system_health(memory, user_id=arguments.get("user_id", "default")) - - -@_tool_handler("run_agi_cycle") -def _handle_run_agi_cycle(memory: "Memory", arguments: Dict[str, Any]) -> Any: - from engram.core.agi_loop import run_agi_cycle - return run_agi_cycle( - memory, - user_id=arguments.get("user_id", "default"), - context=arguments.get("context"), - ) - - -@_tool_handler("get_system_health") -def _handle_get_system_health(memory: "Memory", arguments: Dict[str, Any]) -> Any: - from engram.core.agi_loop import get_system_health - return get_system_health(memory, user_id=arguments.get("user_id", "default")) +_MEMORY_FREE_TOOLS = {"get_last_session", "save_session_digest"} -_MEMORY_FREE_TOOLS = {"get_last_session", "save_session_digest"} +@server.list_tools() +async def list_tools() -> List[Tool]: + return list(TOOLS) @server.call_tool() async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: - """Handle tool calls.""" try: memory = None if name in _MEMORY_FREE_TOOLS else get_memory() - - handler = _TOOL_HANDLERS.get(name) - if handler: - result = handler(memory, arguments) - elif name in _power_tool_handlers: - result = _power_tool_handlers[name](arguments) - else: + handler = HANDLERS.get(name) + if not handler: result = {"error": f"Unknown tool: {name}"} - + else: + result = handler(memory, arguments) return [TextContent(type="text", text=json.dumps(result, indent=2, default=str))] - except Exception as e: logger.exception("MCP tool '%s' failed", name) - error_msg = f"{type(e).__name__}: {e}" - return [TextContent(type="text", text=json.dumps({"error": error_msg}, indent=2))] + return [TextContent(type="text", text=json.dumps({"error": f"{type(e).__name__}: {e}"}, indent=2))] async def main(): diff --git a/engram/memory/__init__.py b/engram/memory/__init__.py index 7fe5748..272ed9f 100644 --- a/engram/memory/__init__.py +++ b/engram/memory/__init__.py @@ -1,5 +1,14 @@ -from engram.memory.main import Memory +from engram.memory.core import CoreMemory +from engram.memory.smart import SmartMemory +from engram.memory.main import FullMemory, Memory from engram.memory.tasks import TaskManager from engram.memory.projects import ProjectManager -__all__ = ["Memory", "TaskManager", "ProjectManager"] +__all__ = [ + "CoreMemory", + "SmartMemory", + "FullMemory", + "Memory", + "TaskManager", + "ProjectManager", +] diff --git a/engram/memory/core.py b/engram/memory/core.py new file mode 100644 index 0000000..7835674 --- /dev/null +++ b/engram/memory/core.py @@ -0,0 +1,444 @@ +"""CoreMemory — lightweight memory: add/search/delete with decay. No LLM required. + +This is the zero-config, zero-API-key entry point. Uses hash-based embeddings +and in-memory vector store by default. Supports content-hash deduplication +and query embedding cache. + +Dependencies: SQLiteManager, Embedder, VectorStore, engram_accel (for cosine sim). +NO LLM, NO echo, NO categories, NO scenes, NO profiles. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import uuid +from collections import OrderedDict +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from engram.configs.base import MemoryConfig +from engram.core.decay import calculate_decayed_strength, should_forget, should_promote +from engram.core.retrieval import composite_score +from engram.core.traces import ( + boost_fast_trace, + compute_effective_strength, + initialize_traces, +) +from engram.db.sqlite import SQLiteManager +from engram.utils.factory import EmbedderFactory, VectorStoreFactory +from engram.utils.math import cosine_similarity_batch + +logger = logging.getLogger(__name__) + + +def _content_hash(content: str) -> str: + """SHA-256 hash of normalized content for deduplication.""" + return hashlib.sha256(content.strip().lower().encode("utf-8")).hexdigest() + + +class CoreMemory: + """Lightweight memory: add/search/delete with decay. No LLM required. + + Usage: + m = CoreMemory() # zero-config, no API key + m.add("I like Python") + results = m.search("programming preferences") + """ + + def __init__( + self, + config: Optional[MemoryConfig] = None, + preset: Optional[str] = None, + ): + if config is None and preset is None: + config = MemoryConfig.minimal() + elif preset: + config = getattr(MemoryConfig, preset)() + self.config = config + + # Ensure vector store config has dims/collection + self.config.vector_store.config.setdefault("collection_name", self.config.collection_name) + self.config.vector_store.config.setdefault("embedding_model_dims", self.config.embedding_model_dims) + + self.db = SQLiteManager(self.config.history_db_path) + self.embedder = EmbedderFactory.create( + self.config.embedder.provider, self.config.embedder.config + ) + self.vector_store = VectorStoreFactory.create( + self.config.vector_store.provider, self.config.vector_store.config + ) + self.fadem_config = self.config.engram + self.distillation_config = getattr(self.config, "distillation", None) + + # Query embedding LRU cache + self._query_cache: OrderedDict[str, List[float]] = OrderedDict() + self._query_cache_max = 128 + + def close(self) -> None: + """Release resources.""" + if hasattr(self, "vector_store") and self.vector_store is not None: + self.vector_store.close() + if hasattr(self, "db") and self.db is not None: + self.db.close() + + def __repr__(self) -> str: + return f"CoreMemory(db={self.db!r})" + + # ---- Core API ---- + + def add( + self, + content: str, + user_id: str = "default", + metadata: Optional[Dict[str, Any]] = None, + categories: Optional[List[str]] = None, + agent_id: Optional[str] = None, + source_app: Optional[str] = None, + ) -> Dict[str, Any]: + """Add a memory. Simple: content in, result out. + + Returns dict with 'results' list containing the stored memory info. + Automatically deduplicates by content hash. + """ + content = str(content).strip() + if not content: + return {"results": []} + + user_id = user_id or "default" + metadata = dict(metadata or {}) + categories = list(categories or []) + + # Content-hash dedup + ch = _content_hash(content) + existing = self.db.get_memory_by_content_hash(ch, user_id) + if existing: + # Re-encountering = spaced repetition = stronger + self.db.increment_access(existing["id"]) + # Boost fast trace if multi-trace is enabled + if self.distillation_config and self.distillation_config.enable_multi_trace: + s_fast = existing.get("s_fast") or 0.0 + boosted = boost_fast_trace(s_fast, self.fadem_config.access_strength_boost) + self.db.update_memory(existing["id"], {"s_fast": boosted}) + return { + "results": [{ + "id": existing["id"], + "memory": existing.get("memory", ""), + "event": "DEDUPLICATED", + "layer": existing.get("layer", "sml"), + "strength": existing.get("strength", 1.0), + }] + } + + # Embed + embedding = self.embedder.embed(content, memory_action="add") + + # Classify memory type + memory_type = "semantic" + if metadata.get("memory_type"): + memory_type = metadata["memory_type"] + + # Initialize multi-trace strength + initial_strength = 1.0 + s_fast_val = s_mid_val = s_slow_val = None + if self.distillation_config and self.distillation_config.enable_multi_trace: + s_fast_val, s_mid_val, s_slow_val = initialize_traces(initial_strength, is_new=True) + + now = datetime.now(timezone.utc).isoformat() + memory_id = str(uuid.uuid4()) + namespace = str(metadata.get("namespace", "default") or "default").strip() or "default" + + memory_data = { + "id": memory_id, + "memory": content, + "user_id": user_id, + "agent_id": agent_id, + "metadata": metadata, + "categories": categories, + "created_at": now, + "updated_at": now, + "layer": "sml", + "strength": initial_strength, + "access_count": 0, + "last_accessed": now, + "embedding": embedding, + "confidentiality_scope": metadata.get("confidentiality_scope", "work"), + "source_type": "mcp", + "source_app": source_app, + "decay_lambda": self.fadem_config.sml_decay_rate, + "status": "active", + "importance": metadata.get("importance", 0.5), + "sensitivity": metadata.get("sensitivity", "normal"), + "namespace": namespace, + "memory_type": memory_type, + "s_fast": s_fast_val, + "s_mid": s_mid_val, + "s_slow": s_slow_val, + "content_hash": ch, + } + + # Store in DB + self.db.add_memory(memory_data) + + # Store in vector index + payload = { + "memory_id": memory_id, + "user_id": user_id, + "memory": content, + } + if agent_id: + payload["agent_id"] = agent_id + try: + self.vector_store.insert( + vectors=[embedding], + payloads=[payload], + ids=[memory_id], + ) + except Exception as e: + logger.warning("Vector insert failed: %s", e) + + return { + "results": [{ + "id": memory_id, + "memory": content, + "event": "ADD", + "layer": "sml", + "strength": initial_strength, + "categories": categories, + "namespace": namespace, + "memory_type": memory_type, + }] + } + + def search( + self, + query: str, + user_id: str = "default", + limit: int = 10, + agent_id: Optional[str] = None, + categories: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """Search memories. Returns ranked results with scores.""" + query = str(query).strip() + if not query: + return {"results": []} + + # Cached embed + embedding = self._cached_embed(query) + + # Vector search + filters = {"user_id": user_id} + if agent_id: + filters["agent_id"] = agent_id + try: + vector_results = self.vector_store.search( + query=None, + vectors=embedding, + limit=limit * 3, # oversample for filtering + filters=filters, + ) + except Exception as e: + logger.warning("Vector search failed: %s", e) + vector_results = [] + + if not vector_results: + return {"results": []} + + # Fetch full memory data and score + results = [] + for vr in vector_results: + # Handle both dict and MemoryResult objects + if hasattr(vr, "id"): + memory_id = vr.id or (vr.payload or {}).get("memory_id") + similarity = vr.score + else: + memory_id = vr.get("id") or (vr.get("payload", {}) or {}).get("memory_id") + similarity = vr.get("score", 0.0) + if not memory_id: + continue + mem = self.db.get_memory(memory_id) + if not mem: + continue + if mem.get("tombstone"): + continue + strength = float(mem.get("strength", 1.0)) + score = composite_score(similarity, strength) + + # Category filter + if categories: + mem_cats = mem.get("categories", []) + if isinstance(mem_cats, str): + try: + mem_cats = json.loads(mem_cats) + except (json.JSONDecodeError, TypeError): + mem_cats = [] + if not any(c in mem_cats for c in categories): + continue + + results.append({ + "id": mem["id"], + "memory": mem.get("memory", ""), + "score": round(score, 4), + "composite_score": round(score, 4), + "similarity": round(similarity, 4), + "strength": round(strength, 4), + "layer": mem.get("layer", "sml"), + "categories": mem.get("categories", []), + "created_at": mem.get("created_at"), + "access_count": mem.get("access_count", 0), + }) + + results.sort(key=lambda r: r["composite_score"], reverse=True) + return {"results": results[:limit]} + + def get(self, memory_id: str) -> Optional[Dict[str, Any]]: + """Get a specific memory by ID.""" + mem = self.db.get_memory(memory_id) + if mem: + self.db.increment_access(memory_id) + return mem + + def get_all( + self, + user_id: str = "default", + agent_id: Optional[str] = None, + layer: Optional[str] = None, + limit: int = 100, + ) -> Dict[str, Any]: + """Get all memories for a user.""" + memories = self.db.get_all_memories( + user_id=user_id, + agent_id=agent_id, + layer=layer, + limit=limit, + ) + return {"results": memories} + + def update(self, memory_id: str, data: Any) -> Dict[str, Any]: + """Update a memory's content or metadata.""" + if isinstance(data, str): + # Simple content update + content = data + embedding = self.embedder.embed(content, memory_action="add") + ch = _content_hash(content) + self.db.update_memory(memory_id, { + "memory": content, + "embedding": embedding, + "content_hash": ch, + }) + # Update vector store + payload = {"memory_id": memory_id, "memory": content} + try: + self.vector_store.delete(memory_id) + self.vector_store.insert( + vectors=[embedding], payloads=[payload], ids=[memory_id] + ) + except Exception as e: + logger.warning("Vector update failed: %s", e) + return {"id": memory_id, "event": "UPDATE", "memory": content} + elif isinstance(data, dict): + self.db.update_memory(memory_id, data) + return {"id": memory_id, "event": "UPDATE"} + return {"error": "Invalid update data"} + + def delete(self, memory_id: str) -> Dict[str, Any]: + """Delete a memory (tombstone).""" + self.db.delete_memory(memory_id) + try: + self.vector_store.delete(memory_id) + except Exception as e: + logger.warning("Vector delete failed: %s", e) + return {"id": memory_id, "event": "DELETE"} + + def apply_decay( + self, + user_id: Optional[str] = None, + scope: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Apply FadeMem decay to all memories.""" + scope = scope or {} + target_user = user_id or scope.get("user_id") + memories = self.db.get_all_memories( + user_id=target_user, + agent_id=scope.get("agent_id"), + include_tombstoned=False, + ) + + decayed = 0 + forgotten = 0 + promoted = 0 + + for mem in memories: + if mem.get("immutable"): + continue + + new_strength = calculate_decayed_strength( + current_strength=float(mem.get("strength", 1.0)), + last_accessed=mem.get("last_accessed", mem.get("created_at", "")), + access_count=int(mem.get("access_count", 0)), + layer=mem.get("layer", "sml"), + config=self.fadem_config, + ) + + if should_forget(new_strength, self.fadem_config): + if self.fadem_config.use_tombstone_deletion: + self.db.update_memory(mem["id"], {"tombstone": 1, "strength": new_strength}) + else: + self.db.delete_memory(mem["id"]) + try: + self.vector_store.delete(mem["id"]) + except Exception: + pass + forgotten += 1 + elif should_promote( + mem.get("layer", "sml"), + int(mem.get("access_count", 0)), + new_strength, + self.fadem_config, + ): + self.db.update_memory(mem["id"], {"strength": new_strength, "layer": "lml"}) + promoted += 1 + else: + self.db.update_memory(mem["id"], {"strength": new_strength}) + + decayed += 1 + + return { + "decayed": decayed, + "forgotten": forgotten, + "promoted": promoted, + } + + def get_stats( + self, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Get memory statistics.""" + memories = self.db.get_all_memories(user_id=user_id, agent_id=agent_id) + sml_count = sum(1 for m in memories if m.get("layer") == "sml") + lml_count = sum(1 for m in memories if m.get("layer") == "lml") + return { + "total": len(memories), + "sml_count": sml_count, + "lml_count": lml_count, + } + + def history(self, memory_id: str) -> List[Dict[str, Any]]: + """Get history for a memory.""" + return self.db.get_memory_history(memory_id) + + # ---- Internal helpers ---- + + def _cached_embed(self, query: str) -> List[float]: + """Embed a query with LRU caching.""" + key = hashlib.sha256(query.strip().lower().encode("utf-8")).hexdigest() + if key in self._query_cache: + self._query_cache.move_to_end(key) + return self._query_cache[key] + embedding = self.embedder.embed(query, memory_action="search") + self._query_cache[key] = embedding + if len(self._query_cache) > self._query_cache_max: + self._query_cache.popitem(last=False) + return embedding diff --git a/engram/memory/main.py b/engram/memory/main.py index 9076bea..c67c36a 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -42,6 +42,7 @@ strip_code_fences, ) from engram.memory.parallel import ParallelExecutor +from engram.memory.smart import SmartMemory from engram.observability import metrics from engram.utils.factory import EmbedderFactory, LLMFactory, VectorStoreFactory from engram.utils.prompts import AGENT_MEMORY_EXTRACTION_PROMPT, MEMORY_EXTRACTION_PROMPT @@ -230,126 +231,90 @@ class MemoryScope(str, Enum): GLOBAL = "global" -class Memory(MemoryBase): - """engram Memory class - biologically-inspired memory for AI agents.""" - - def __init__(self, config: Optional[MemoryConfig] = None): - self.config = config or MemoryConfig() - - # Ensure vector store config has dims/collection if missing - self.config.vector_store.config.setdefault("collection_name", self.config.collection_name) - self.config.vector_store.config.setdefault("embedding_model_dims", self.config.embedding_model_dims) - - self.db = SQLiteManager(self.config.history_db_path) - self.llm = LLMFactory.create(self.config.llm.provider, self.config.llm.config) - self.embedder = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config) - self.vector_store = VectorStoreFactory.create(self.config.vector_store.provider, self.config.vector_store.config) - self.fadem_config = self.config.engram - self.echo_config = self.config.echo - self.scope_config = getattr(self.config, "scope", None) - self.distillation_config = getattr(self.config, "distillation", None) - - # Initialize EchoMem processor - if self.echo_config.enable_echo: - self.echo_processor = EchoProcessor( - self.llm, - config={ - "auto_depth": self.echo_config.auto_depth, - "default_depth": self.echo_config.default_depth, - } - ) - else: - self.echo_processor = None - - # Initialize CategoryMem processor - self.category_config = self.config.category - if self.category_config.enable_categories: - self.category_processor = CategoryProcessor( - llm=self.llm, - embedder=self.embedder, - config={ - "use_llm": self.category_config.use_llm_categorization, - "auto_subcategories": self.category_config.auto_create_subcategories, - "max_depth": self.category_config.max_category_depth, - }, - ) - # Load existing categories from DB - existing_categories = self.db.get_all_categories() - if existing_categories: - self.category_processor.load_categories(existing_categories) - else: - self.category_processor = None - - # Initialize Knowledge Graph - self.graph_config = self.config.graph - if self.graph_config.enable_graph: - self.knowledge_graph = KnowledgeGraph( - llm=self.llm if self.graph_config.use_llm_extraction else None - ) - else: - self.knowledge_graph = None - - # Initialize SceneProcessor - self.scene_config = self.config.scene - if self.scene_config.enable_scenes: - self.scene_processor = SceneProcessor( +class FullMemory(SmartMemory): + """Full-featured engram Memory class with scenes, profiles, tasks, projects. + + Extends SmartMemory with additional FullMemory-specific features: + - SceneProcessor for episodic memory grouping + - ProfileProcessor for character/entity profiles + - Task and project management (future) + + All base features (echo encoding, categories, knowledge graph) are inherited + from SmartMemory with lazy initialization via @property. + """ + + def __init__(self, config: Optional[MemoryConfig] = None, preset: Optional[str] = None): + # Use default full() config if neither config nor preset provided + if config is None and preset is None: + config = MemoryConfig.full() + # Initialize parent SmartMemory (handles db, llm, embedder, etc.) + super().__init__(config=config, preset=preset) + # Only FullMemory-specific lazy init + self._scene_processor: Optional[SceneProcessor] = None + self._profile_processor: Optional[ProfileProcessor] = None + self._task_manager: Optional[Any] = None + self._project_manager: Optional[Any] = None + # Parallel executor (lazy: created only when config enables it) + self._executor: Optional[ParallelExecutor] = None + if self.config.parallel.enable_parallel: + self._executor = ParallelExecutor(max_workers=self.config.parallel.max_workers) + + @property + def scene_processor(self) -> Optional[SceneProcessor]: + """Lazy-initialized SceneProcessor (only if scenes enabled in config).""" + if self._scene_processor is None and self.config.scene.enable_scenes: + self._scene_processor = SceneProcessor( db=self.db, embedder=self.embedder, llm=self.llm, config={ - "scene_time_gap_minutes": self.scene_config.scene_time_gap_minutes, - "scene_topic_threshold": self.scene_config.scene_topic_threshold, - "auto_close_inactive_minutes": self.scene_config.auto_close_inactive_minutes, - "max_scene_memories": self.scene_config.max_scene_memories, - "use_llm_summarization": self.scene_config.use_llm_summarization, - "summary_regenerate_threshold": self.scene_config.summary_regenerate_threshold, + "scene_time_gap_minutes": self.config.scene.scene_time_gap_minutes, + "scene_topic_threshold": self.config.scene.scene_topic_threshold, + "auto_close_inactive_minutes": self.config.scene.auto_close_inactive_minutes, + "max_scene_memories": self.config.scene.max_scene_memories, + "use_llm_summarization": self.config.scene.use_llm_summarization, + "summary_regenerate_threshold": self.config.scene.summary_regenerate_threshold, }, ) - else: - self.scene_processor = None + return self._scene_processor - # Initialize ProfileProcessor - self.profile_config = self.config.profile - if self.profile_config.enable_profiles: - self.profile_processor = ProfileProcessor( + @property + def profile_processor(self) -> Optional[ProfileProcessor]: + """Lazy-initialized ProfileProcessor (only if profiles enabled in config).""" + if self._profile_processor is None and self.config.profile.enable_profiles: + self._profile_processor = ProfileProcessor( db=self.db, embedder=self.embedder, llm=self.llm, config={ - "auto_detect_profiles": self.profile_config.auto_detect_profiles, - "use_llm_extraction": self.profile_config.use_llm_extraction, - "narrative_regenerate_threshold": self.profile_config.narrative_regenerate_threshold, - "self_profile_auto_create": self.profile_config.self_profile_auto_create, - "max_facts_per_profile": self.profile_config.max_facts_per_profile, + "auto_detect_profiles": self.config.profile.auto_detect_profiles, + "use_llm_extraction": self.config.profile.use_llm_extraction, + "narrative_regenerate_threshold": self.config.profile.narrative_regenerate_threshold, + "self_profile_auto_create": self.config.profile.self_profile_auto_create, + "max_facts_per_profile": self.config.profile.max_facts_per_profile, }, ) - else: - self.profile_processor = None - - # Parallel executor for I/O-bound LLM/embedding calls - self.parallel_config = getattr(self.config, "parallel", None) - self._executor: Optional[ParallelExecutor] = None - if self.parallel_config and self.parallel_config.enable_parallel: - self._executor = ParallelExecutor( - max_workers=self.parallel_config.max_workers - ) + return self._profile_processor def close(self) -> None: """Release all resources held by the Memory instance.""" - if hasattr(self, '_executor') and self._executor is not None: + # Shutdown parallel executor if it was created + if self._executor is not None: self._executor.shutdown() self._executor = None - if hasattr(self, 'vector_store') and self.vector_store is not None: + # Release vector store + if self.vector_store is not None: self.vector_store.close() - if hasattr(self, 'db') and self.db is not None: + # Release database + if self.db is not None: self.db.close() def __repr__(self) -> str: - return f"Memory(db={self.db!r}, echo={self.echo_config.enable_echo}, scenes={self.scene_config.enable_scenes})" + return f"FullMemory(db={self.db!r}, echo={self.config.echo.enable_echo}, scenes={self.config.scene.enable_scenes})" - @classmethod - def from_config(cls, config_dict: Dict[str, Any]): - return cls(MemoryConfig(**config_dict)) + # _cached_embed inherited from SmartMemory + + # from_config inherited from SmartMemory def add( self, @@ -3001,3 +2966,7 @@ def get_constellation_data(self, user_id: Optional[str] = None, limit: int = 200 def get_decay_log(self, limit: int = 20) -> List[Dict[str, Any]]: """Get recent decay history for dashboard sparkline.""" return self.db.get_decay_log_entries(limit=limit) + + +# Backward-compatible alias — existing code that imports Memory still works. +Memory = FullMemory diff --git a/engram/memory/smart.py b/engram/memory/smart.py new file mode 100644 index 0000000..3f6494e --- /dev/null +++ b/engram/memory/smart.py @@ -0,0 +1,344 @@ +"""SmartMemory — bio-inspired memory: decay + echo + categories + knowledge graph. + +Extends CoreMemory with LLM-powered features: echo encoding for stronger +retention, dynamic category organization, and knowledge graph entity linking. +Requires an LLM provider (Gemini, OpenAI, Ollama) for full functionality. + +Processors are lazily initialized — only created on first use. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from engram.configs.base import MemoryConfig +from engram.memory.core import CoreMemory, _content_hash +from engram.utils.factory import LLMFactory + +logger = logging.getLogger(__name__) + + +class SmartMemory(CoreMemory): + """Bio-inspired memory: decay + echo + categories + knowledge graph. + + Usage: + m = SmartMemory(preset="smart") + m.add("I like Python", echo_depth="medium") + results = m.search("programming preferences") + """ + + def __init__( + self, + config: Optional[MemoryConfig] = None, + preset: Optional[str] = None, + ): + if config is None and preset is None: + config = MemoryConfig.smart() + super().__init__(config=config, preset=preset) + + self.echo_config = self.config.echo + self.category_config = self.config.category + self.graph_config = self.config.graph + self.scope_config = getattr(self.config, "scope", None) + + # LLM — created eagerly since echo/category need it + self.llm = LLMFactory.create(self.config.llm.provider, self.config.llm.config) + + # Lazy-init processors (only created on first use) + self._echo_processor = None + self._category_processor = None + self._knowledge_graph = None + + @property + def echo_processor(self): + if self._echo_processor is None and self.echo_config.enable_echo: + from engram.core.echo import EchoProcessor + self._echo_processor = EchoProcessor( + self.llm, + config={ + "auto_depth": self.echo_config.auto_depth, + "default_depth": self.echo_config.default_depth, + }, + ) + return self._echo_processor + + @property + def category_processor(self): + if self._category_processor is None and self.category_config.enable_categories: + from engram.core.category import CategoryProcessor + self._category_processor = CategoryProcessor( + llm=self.llm, + embedder=self.embedder, + config={ + "use_llm": self.category_config.use_llm_categorization, + "auto_subcategories": self.category_config.auto_create_subcategories, + "max_depth": self.category_config.max_category_depth, + }, + ) + # Load existing categories from DB + existing = self.db.get_all_categories() + if existing: + self._category_processor.load_categories(existing) + return self._category_processor + + @property + def knowledge_graph(self): + if self._knowledge_graph is None and self.graph_config.enable_graph: + from engram.core.graph import KnowledgeGraph + self._knowledge_graph = KnowledgeGraph( + llm=self.llm if self.graph_config.use_llm_extraction else None + ) + return self._knowledge_graph + + def add( + self, + content: str, + user_id: str = "default", + metadata: Optional[Dict[str, Any]] = None, + categories: Optional[List[str]] = None, + agent_id: Optional[str] = None, + source_app: Optional[str] = None, + echo_depth: Optional[str] = None, + ) -> Dict[str, Any]: + """Add with echo encoding and category detection.""" + content = str(content).strip() + if not content: + return {"results": []} + + user_id = user_id or "default" + metadata = dict(metadata or {}) + categories = list(categories or []) + + # Content-hash dedup (inherited from CoreMemory logic) + ch = _content_hash(content) + existing = self.db.get_memory_by_content_hash(ch, user_id) + if existing: + from engram.core.traces import boost_fast_trace + self.db.increment_access(existing["id"]) + if self.distillation_config and self.distillation_config.enable_multi_trace: + s_fast = existing.get("s_fast") or 0.0 + boosted = boost_fast_trace(s_fast, self.fadem_config.access_strength_boost) + self.db.update_memory(existing["id"], {"s_fast": boosted}) + return { + "results": [{ + "id": existing["id"], + "memory": existing.get("memory", ""), + "event": "DEDUPLICATED", + "layer": existing.get("layer", "sml"), + "strength": existing.get("strength", 1.0), + }] + } + + # Echo encoding + echo_result = None + initial_strength = 1.0 + effective_strength = initial_strength + if self.echo_processor and self.echo_config.enable_echo: + try: + from engram.core.echo import EchoDepth + depth_override = EchoDepth(echo_depth) if echo_depth else None + echo_result = self.echo_processor.process(content, depth=depth_override) + effective_strength = initial_strength * echo_result.strength_multiplier + metadata.update(echo_result.to_metadata()) + if not categories and echo_result.category: + categories = [echo_result.category] + except Exception as e: + logger.warning("Echo encoding failed: %s", e) + + # Category detection + if self.category_processor and self.category_config.auto_categorize and not categories: + try: + cat_match = self.category_processor.detect_category( + content, + metadata=metadata, + use_llm=self.category_config.use_llm_categorization, + ) + categories = [cat_match.category_id] + metadata["category_confidence"] = cat_match.confidence + metadata["category_auto"] = True + except Exception as e: + logger.warning("Category detection failed: %s", e) + + # Use echo's question_form for embedding if available + primary_text = content + if ( + echo_result + and self.echo_config.use_question_embedding + and hasattr(echo_result, "question_form") + and echo_result.question_form + ): + primary_text = echo_result.question_form + + embedding = self.embedder.embed(primary_text, memory_action="add") + + # Knowledge graph entity extraction + if self.knowledge_graph: + try: + self.knowledge_graph.extract_entities(content, metadata=metadata) + except Exception as e: + logger.warning("Entity extraction failed: %s", e) + + # Store via parent's DB logic, but with our enhanced data + from engram.core.traces import initialize_traces + import uuid + from datetime import datetime, timezone + + memory_type = metadata.get("memory_type", "semantic") + s_fast_val = s_mid_val = s_slow_val = None + if self.distillation_config and self.distillation_config.enable_multi_trace: + s_fast_val, s_mid_val, s_slow_val = initialize_traces(effective_strength, is_new=True) + + now = datetime.now(timezone.utc).isoformat() + memory_id = str(uuid.uuid4()) + namespace = str(metadata.get("namespace", "default") or "default").strip() or "default" + + memory_data = { + "id": memory_id, + "memory": content, + "user_id": user_id, + "agent_id": agent_id, + "metadata": metadata, + "categories": categories, + "created_at": now, + "updated_at": now, + "layer": "sml", + "strength": effective_strength, + "access_count": 0, + "last_accessed": now, + "embedding": embedding, + "confidentiality_scope": metadata.get("confidentiality_scope", "work"), + "source_type": "mcp", + "source_app": source_app, + "decay_lambda": self.fadem_config.sml_decay_rate, + "status": "active", + "importance": metadata.get("importance", 0.5), + "sensitivity": metadata.get("sensitivity", "normal"), + "namespace": namespace, + "memory_type": memory_type, + "s_fast": s_fast_val, + "s_mid": s_mid_val, + "s_slow": s_slow_val, + "content_hash": ch, + } + + self.db.add_memory(memory_data) + + # Vector store + payload = {"memory_id": memory_id, "user_id": user_id, "memory": content} + if agent_id: + payload["agent_id"] = agent_id + try: + self.vector_store.insert( + vectors=[embedding], payloads=[payload], ids=[memory_id] + ) + except Exception as e: + logger.warning("Vector insert failed: %s", e) + + # Persist categories + if self.category_processor and categories: + try: + for cat_id in categories: + self.category_processor.update_category_stats( + cat_id, effective_strength, is_addition=True + ) + self._persist_categories() + except Exception as e: + logger.warning("Category persistence failed: %s", e) + + return { + "results": [{ + "id": memory_id, + "memory": content, + "event": "ADD", + "layer": "sml", + "strength": effective_strength, + "categories": categories, + "namespace": namespace, + "memory_type": memory_type, + "echo_depth": echo_result.echo_depth.value if echo_result else None, + }] + } + + def search( + self, + query: str, + user_id: str = "default", + limit: int = 10, + agent_id: Optional[str] = None, + categories: Optional[List[str]] = None, + use_echo_boost: bool = True, + use_category_boost: bool = True, + ) -> Dict[str, Any]: + """Search with echo reranking and category boosting.""" + # Get base results from CoreMemory + result = super().search( + query=query, + user_id=user_id, + limit=limit * 2 if (use_echo_boost or use_category_boost) else limit, + agent_id=agent_id, + categories=categories, + ) + + if not use_echo_boost and not use_category_boost: + return result + + memories = result.get("results", []) + + # Apply echo boost + if use_echo_boost and self.echo_config.enable_echo: + for mem in memories: + full_mem = self.db.get_memory(mem["id"]) + if not full_mem: + continue + md = full_mem.get("metadata", {}) + if isinstance(md, str): + import json + try: + md = json.loads(md) + except (json.JSONDecodeError, TypeError): + md = {} + echo_depth = md.get("echo_depth") + if echo_depth: + multiplier = { + "shallow": self.echo_config.shallow_multiplier, + "medium": self.echo_config.medium_multiplier, + "deep": self.echo_config.deep_multiplier, + }.get(echo_depth, 1.0) + mem["composite_score"] = mem.get("composite_score", mem.get("score", 0)) * (0.9 + 0.1 * multiplier) + mem["score"] = mem["composite_score"] + + # Apply category boost + if use_category_boost and self.category_processor and categories: + for mem in memories: + mem_cats = mem.get("categories", []) + if isinstance(mem_cats, str): + import json + try: + mem_cats = json.loads(mem_cats) + except (json.JSONDecodeError, TypeError): + mem_cats = [] + if any(c in mem_cats for c in categories): + mem["composite_score"] = mem.get("composite_score", mem.get("score", 0)) * (1.0 + self.category_config.category_boost_weight) + mem["score"] = mem["composite_score"] + + # Re-rank + memories.sort(key=lambda r: r.get("composite_score", r.get("score", 0)), reverse=True) + return {"results": memories[:limit]} + + def get_categories(self) -> List[Dict[str, Any]]: + """Get all categories.""" + if self.category_processor: + return self.category_processor.get_all_categories() + return [] + + def _persist_categories(self): + """Persist category state to DB.""" + if not self.category_processor: + return + try: + categories = self.category_processor.export_categories() + for cat in categories: + self.db.upsert_category(cat) + except Exception as e: + logger.warning("Failed to persist categories: %s", e) diff --git a/engram/utils/factory.py b/engram/utils/factory.py index 1e74000..83704aa 100644 --- a/engram/utils/factory.py +++ b/engram/utils/factory.py @@ -1,4 +1,40 @@ -from typing import Any, Dict +"""Factories for creating embedder, LLM, and vector store instances.""" + +import logging +import os +from typing import Any, Dict, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _detect_provider() -> Tuple[str, str]: + """Auto-detect the best available LLM/embedder provider. + + Returns (embedder_provider, llm_provider) tuple. + + Detection order: + 1. GEMINI_API_KEY / GOOGLE_API_KEY set → gemini + 2. OPENAI_API_KEY set → openai + 3. Ollama running on localhost:11434 → ollama + 4. Fall back to simple embedder + mock LLM (zero-config, no API key) + """ + if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"): + return ("gemini", "gemini") + if os.environ.get("OPENAI_API_KEY"): + return ("openai", "openai") + + # Try Ollama + try: + import requests + resp = requests.get("http://localhost:11434/api/tags", timeout=1) + if resp.status_code == 200: + return ("ollama", "ollama") + except Exception: + pass + + # Zero-config fallback: hash embedder + mock LLM + return ("simple", "mock") + class EmbedderFactory: @classmethod @@ -25,6 +61,15 @@ def create(cls, provider: str, config: Dict[str, Any]): return NvidiaEmbedder(config) raise ValueError(f"Unsupported embedder provider: {provider}") + @classmethod + def create_auto(cls, config: Optional[Dict[str, Any]] = None): + """Auto-detect best available embedder. No API key required.""" + embedder_provider, _ = _detect_provider() + cfg = dict(config or {}) + if embedder_provider == "simple": + cfg.setdefault("embedding_dims", 384) + return cls.create(embedder_provider, cfg) + class LLMFactory: @classmethod @@ -51,6 +96,12 @@ def create(cls, provider: str, config: Dict[str, Any]): return NvidiaLLM(config) raise ValueError(f"Unsupported LLM provider: {provider}") + @classmethod + def create_auto(cls, config: Optional[Dict[str, Any]] = None): + """Auto-detect best available LLM. Falls back to mock.""" + _, llm_provider = _detect_provider() + return cls.create(llm_provider, dict(config or {})) + class VectorStoreFactory: @classmethod diff --git a/engram/utils/math.py b/engram/utils/math.py index 8f63f5c..6997d27 100644 --- a/engram/utils/math.py +++ b/engram/utils/math.py @@ -1,40 +1,12 @@ -"""Shared math utilities for engram. - -This module is the single canonical source for cosine similarity and related -numerical operations. All other modules should import from here. - -Requires engram-accel (Rust) for SIMD-optimized operations. -""" +"""Vector math — Rust-powered, no fallbacks.""" from typing import List, Optional +from engram_accel import ( + cosine_similarity as _rs_cosine, + cosine_similarity_batch as _rs_cosine_batch, +) -import math as _math - -try: - from engram_accel import ( - cosine_similarity as _rs_cosine, - cosine_similarity_batch as _rs_cosine_batch, - ) - ACCEL_AVAILABLE = True -except ImportError: - ACCEL_AVAILABLE = False - - def _rs_cosine(a, b): - dot = sum(x * y for x, y in zip(a, b)) - na = _math.sqrt(sum(x * x for x in a)) - nb = _math.sqrt(sum(x * x for x in b)) - return dot / (na * nb) if na and nb else 0.0 - - def _rs_cosine_batch(query, store): - return [_rs_cosine(query, v) for v in store] - - -def _pure_python_cosine(a, b): - """Pure Python cosine similarity (reference implementation for tests).""" - dot = sum(x * y for x, y in zip(a, b)) - na = _math.sqrt(sum(x * x for x in a)) - nb = _math.sqrt(sum(x * x for x in b)) - return dot / (na * nb) if na and nb else 0.0 +ACCEL_AVAILABLE = True def cosine_similarity(a: Optional[List[float]], b: Optional[List[float]]) -> float: diff --git a/pyproject.toml b/pyproject.toml index 22e6376..4e1ed28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta" [project] name = "engram-memory" -version = "0.5.0" -description = "Biologically-inspired memory layer for AI agents — forgetting, echo encoding, and dynamic categories" +version = "0.6.0" +description = "The memory layer for AI agents. Bio-inspired decay, echo encoding, zero-config." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" license = {text = "MIT"} authors = [ {name = "Engram Team"} @@ -18,82 +18,37 @@ classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dependencies = [ "pydantic>=2.0", "requests>=2.28.0", "sqlite-vec>=0.1.1", - "engram-bus>=0.1.0", + "engram-accel>=0.1.0", ] [project.optional-dependencies] # LLM / embedding providers (install what you use) -gemini = ["google-generativeai>=0.3.0"] +gemini = ["google-genai>=1.0.0"] openai = ["openai>=1.0.0"] nvidia = ["openai>=1.0.0"] ollama = ["ollama>=0.4.0"] # Integrations mcp = ["mcp>=1.0.0"] api = ["fastapi>=0.100.0", "uvicorn>=0.20.0"] -accel = ["engram-accel>=0.1.0"] -# Power packages (install what you need) -router = ["engram-router>=0.2.0"] -identity = ["engram-identity>=0.1.0"] -heartbeat = ["engram-heartbeat>=0.1.0"] -policy = ["engram-policy>=0.1.0"] -skills = ["engram-skills>=0.1.0"] -spawn = ["engram-spawn>=0.1.0"] -resilience = ["engram-resilience>=0.1.0"] -metamemory = ["engram-metamemory>=0.1.0"] -prospective = ["engram-prospective>=0.1.0"] -procedural = ["engram-procedural>=0.1.0"] -reconsolidation = ["engram-reconsolidation>=0.1.0"] -failure = ["engram-failure>=0.1.0"] -working = ["engram-working>=0.1.0"] -warroom = ["engram-warroom>=0.1.0"] -powers = [ - "engram-router>=0.2.0", - "engram-identity>=0.1.0", - "engram-heartbeat>=0.1.0", - "engram-policy>=0.1.0", - "engram-skills>=0.1.0", - "engram-spawn>=0.1.0", - "engram-resilience>=0.1.0", - "engram-metamemory>=0.1.0", - "engram-prospective>=0.1.0", - "engram-procedural>=0.1.0", - "engram-reconsolidation>=0.1.0", - "engram-failure>=0.1.0", - "engram-working>=0.1.0", - "engram-warroom>=0.1.0", -] +bus = ["engram-bus>=0.1.0"] all = [ - "google-generativeai>=0.3.0", + "google-genai>=1.0.0", "openai>=1.0.0", "ollama>=0.4.0", "mcp>=1.0.0", "fastapi>=0.100.0", "uvicorn>=0.20.0", "engram-accel>=0.1.0", - "engram-router>=0.2.0", - "engram-identity>=0.1.0", - "engram-heartbeat>=0.1.0", - "engram-policy>=0.1.0", - "engram-skills>=0.1.0", - "engram-spawn>=0.1.0", - "engram-resilience>=0.1.0", - "engram-metamemory>=0.1.0", - "engram-prospective>=0.1.0", - "engram-procedural>=0.1.0", - "engram-reconsolidation>=0.1.0", - "engram-failure>=0.1.0", - "engram-working>=0.1.0", - "engram-warroom>=0.1.0", + "engram-bus>=0.1.0", ] dev = [ "pytest>=7.0.0", diff --git a/tests/test_accel.py b/tests/test_accel.py index 7d39d9f..97cfa1e 100644 --- a/tests/test_accel.py +++ b/tests/test_accel.py @@ -1,8 +1,6 @@ """Tests for engram-accel Rust acceleration layer. -Tests correctness of both the Rust implementation (if available) and the -pure-Python fallback. All tests must pass regardless of whether engram_accel -is installed. +Tests correctness of the Rust implementation. engram_accel is required. """ import math @@ -13,7 +11,6 @@ cosine_similarity, cosine_similarity_batch, ACCEL_AVAILABLE, - _pure_python_cosine, ) from engram.core.retrieval import tokenize, bm25_score_batch @@ -45,12 +42,17 @@ def test_none_input(self): def test_high_dimensional(self): """Test with 1024-dim vectors (typical embedding size).""" + import math as m import random random.seed(42) a = [random.gauss(0, 1) for _ in range(1024)] b = [random.gauss(0, 1) for _ in range(1024)] result = cosine_similarity(a, b) - expected = _pure_python_cosine(a, b) + # Manual reference computation + dot = sum(x * y for x, y in zip(a, b)) + na = m.sqrt(sum(x * x for x in a)) + nb = m.sqrt(sum(x * x for x in b)) + expected = dot / (na * nb) if na and nb else 0.0 assert result == pytest.approx(expected, abs=1e-10) def test_parallel_vectors_different_magnitude(self): @@ -168,16 +170,12 @@ def test_empty_documents(self): assert scores == [] -# ── Fallback behavior ────────────────────────────────────────────────── +# ── Rust required ────────────────────────────────────────────────── -class TestFallback: - def test_cosine_fallback_works(self): - """Even without Rust, cosine_similarity should work.""" - result = _pure_python_cosine([1.0, 0.0], [1.0, 0.0]) - assert result == pytest.approx(1.0) - - def test_accel_flag_is_bool(self): - assert isinstance(ACCEL_AVAILABLE, bool) +class TestAccelRequired: + def test_accel_is_available(self): + """engram_accel Rust extension must be installed.""" + assert ACCEL_AVAILABLE is True # ── Decay acceleration ───────────────────────────────────────────────── diff --git a/tests/test_accel_benchmark.py b/tests/test_accel_benchmark.py index c478df9..c4a3259 100644 --- a/tests/test_accel_benchmark.py +++ b/tests/test_accel_benchmark.py @@ -10,7 +10,6 @@ from engram.utils.math import ( cosine_similarity, cosine_similarity_batch, - _pure_python_cosine, ACCEL_AVAILABLE, ) @@ -38,10 +37,14 @@ def test_cosine_single_pair(self, benchmark): b = STORE_100[0] benchmark(cosine_similarity, a, b) - def test_cosine_python_fallback(self, benchmark): - a = QUERY_VEC - b = STORE_100[0] - benchmark(_pure_python_cosine, a, b) + def test_cosine_python_reference(self, benchmark): + import math + def _python_cosine(a, b): + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + nb = math.sqrt(sum(x * x for x in b)) + return dot / (na * nb) if na and nb else 0.0 + benchmark(_python_cosine, QUERY_VEC, STORE_100[0]) def test_cosine_batch_100(self, benchmark): benchmark(cosine_similarity_batch, QUERY_VEC, STORE_100) diff --git a/tests/test_core_memory.py b/tests/test_core_memory.py new file mode 100644 index 0000000..8834c72 --- /dev/null +++ b/tests/test_core_memory.py @@ -0,0 +1,200 @@ +"""Tests for CoreMemory - zero-config, no LLM required.""" +import pytest +import tempfile +import os +import uuid + +from engram import CoreMemory + + +class TestCoreMemory: + """Test CoreMemory functionality.""" + + def test_add_and_search(self): + """Basic add and search functionality.""" + tag = uuid.uuid4().hex[:8] + content = f"I like Python {tag}" + m = CoreMemory(preset="minimal") + m.add(content) + # Simple embedder uses hash-based similarity, so search with same text + results = m.search(content) + assert len(results["results"]) >= 1 + assert "Python" in results["results"][0]["memory"] + m.close() + + def test_content_dedup(self): + """Same content twice = deduplication + access boost.""" + m = CoreMemory(preset="minimal") + m.add("I like Python") + r2 = m.add("I like Python") + # Should dedup + assert r2["results"][0]["event"] == "DEDUPLICATED" + m.close() + + def test_apply_decay(self): + """Decay cycle runs without error.""" + m = CoreMemory(preset="minimal") + m.add("Test memory") + result = m.apply_decay() + assert "decayed" in result + m.close() + + def test_get_and_delete(self): + """Get and delete operations work.""" + m = CoreMemory(preset="minimal") + r = m.add("To be deleted") + mem_id = r["results"][0]["id"] + # Get should return the memory + mem = m.get(mem_id) + assert mem is not None + assert mem["memory"] == "To be deleted" + # Delete (tombstone) + m.delete(mem_id) + # After tombstone delete, get() filters tombstoned records + mem_after_delete = m.get(mem_id) + assert mem_after_delete is None + m.close() + + def test_query_cache(self): + """Query embedding cache populates on search.""" + m = CoreMemory(preset="minimal") + m.add("Caching is good") + # First search populates cache + m.search("caching") + # Second search should hit cache + m.search("caching") + # Cache should have at least one entry + assert len(m._query_cache) > 0 + m.close() + + def test_get_all_memories(self): + """Get all memories returns results.""" + m = CoreMemory(preset="minimal") + m.add("Memory one") + m.add("Memory two") + results = m.get_all(limit=10) + assert len(results["results"]) >= 2 + m.close() + + def test_get_stats(self): + """Get stats returns memory counts.""" + m = CoreMemory(preset="minimal") + stats_before = m.get_stats() + m.add("Test memory") + stats_after = m.get_stats() + # Stats might be user-scoped, so just check structure + assert "total" in stats_after + m.close() + + def test_normalized_dedup(self): + """Case/whitespace normalized deduplication.""" + m = CoreMemory(preset="minimal") + m.add(" Hello World ") + r2 = m.add("hello world") + assert r2["results"][0]["event"] == "DEDUPLICATED" + m.close() + + def test_access_boost_on_dedup(self): + """Re-encountering strengthens memory.""" + m = CoreMemory(preset="minimal") + r1 = m.add("Boost test") + mem_id = r1["results"][0]["id"] + # Deduplicate should increment access count + r2 = m.add("Boost test") + # Access the memory to trigger increment_access + mem = m.get(mem_id) + # Access count should be incremented + assert mem["access_count"] >= 1 + m.close() + + def test_empty_content(self): + """Empty content returns empty results.""" + m = CoreMemory(preset="minimal") + result = m.add("") + assert result["results"] == [] + m.close() + + def test_whitespace_content(self): + """Whitespace-only content returns empty results.""" + m = CoreMemory(preset="minimal") + result = m.add(" ") + assert result["results"] == [] + m.close() + + def test_search_empty_query(self): + """Empty search query returns empty results.""" + m = CoreMemory(preset="minimal") + m.add("Test memory") + result = m.search("") + assert result["results"] == [] + m.close() + + def test_update_memory(self): + """Update memory content works.""" + m = CoreMemory(preset="minimal") + r = m.add("Original content") + mem_id = r["results"][0]["id"] + # Update via string + m.update(mem_id, "Updated content") + # Verify update + mem = m.get(mem_id) + assert mem["memory"] == "Updated content" + m.close() + + def test_history(self): + """History tracks memory operations.""" + content = f"History test {uuid.uuid4().hex[:8]}" + m = CoreMemory(preset="minimal") + r = m.add(content) + mem_id = r["results"][0]["id"] + history = m.history(mem_id) + # Should have at least the ADD event + assert len(history) >= 1 + events = [h["event"] for h in history] + assert "ADD" in events + m.close() + + def test_limit_parameter(self): + """Search limit parameter works.""" + m = CoreMemory(preset="minimal") + for i in range(10): + m.add(f"Memory {i}") + results = m.search("Memory", limit=3) + assert len(results["results"]) <= 3 + m.close() + + def test_user_id_filtering(self): + """Different user_ids are isolated.""" + m = CoreMemory(preset="minimal") + m.add("User A memory", user_id="user_a") + m.add("User B memory", user_id="user_b") + results_a = m.search("memory", user_id="user_a") + results_b = m.search("memory", user_id="user_b") + # Each should only see their own + for r in results_a["results"]: + assert "User A" in r["memory"] + for r in results_b["results"]: + assert "User B" in r["memory"] + m.close() + + def test_default_user(self): + """Default user_id is 'default'.""" + m = CoreMemory(preset="minimal") + r = m.add("Default user memory") + mem_id = r["results"][0]["id"] + mem = m.get(mem_id) + assert mem["user_id"] == "default" + m.close() + + def test_persistence(self): + """Memories persist in the database.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # First instance + m1 = CoreMemory(preset="minimal") + m1.db.db_path = db_path # Use test path + m1.add("Persistent memory") + mem_id = m1.db.get_all_memories(limit=1)[0]["id"] + # Mem should exist + assert m1.get(mem_id) is not None + m1.close() diff --git a/tests/test_dedup.py b/tests/test_dedup.py new file mode 100644 index 0000000..97fcf9d --- /dev/null +++ b/tests/test_dedup.py @@ -0,0 +1,74 @@ +"""Tests for content-hash deduplication and access boost.""" + +import uuid + +from engram import CoreMemory + + +def _unique(prefix: str = "dedup") -> str: + """Generate unique content to avoid cross-test collisions in shared DB.""" + return f"{prefix}_{uuid.uuid4().hex[:8]}" + + +class TestContentDedup: + def test_exact_match_dedup(self): + content = _unique("exact") + m = CoreMemory(preset="minimal") + r1 = m.add(content) + r2 = m.add(content) + assert r2["results"][0]["event"] == "DEDUPLICATED" + m.close() + + def test_normalized_dedup(self): + """Case/whitespace normalized.""" + tag = _unique("norm") + m = CoreMemory(preset="minimal") + m.add(f" {tag} ") + r2 = m.add(tag.lower()) + assert r2["results"][0]["event"] == "DEDUPLICATED" + m.close() + + def test_access_boost_on_dedup(self): + """Re-encountering strengthens memory.""" + content = _unique("boost") + m = CoreMemory(preset="minimal") + r1 = m.add(content) + mem_id = r1["results"][0]["id"] + # Deduplicate + r2 = m.add(content) + assert r2["results"][0]["event"] == "DEDUPLICATED" + assert r2["results"][0]["id"] == mem_id + # Access count should be incremented + mem = m.get(mem_id) + assert mem["access_count"] >= 1 + m.close() + + def test_different_content_no_dedup(self): + """Different content should not deduplicate.""" + m = CoreMemory(preset="minimal") + r1 = m.add(_unique("first")) + r2 = m.add(_unique("second")) + assert r1["results"][0]["event"] == "ADD" + assert r2["results"][0]["event"] == "ADD" + assert r1["results"][0]["id"] != r2["results"][0]["id"] + m.close() + + def test_dedup_preserves_original_id(self): + """Dedup returns the original memory's ID.""" + content = _unique("preserve") + m = CoreMemory(preset="minimal") + r1 = m.add(content) + original_id = r1["results"][0]["id"] + r2 = m.add(content) + assert r2["results"][0]["id"] == original_id + m.close() + + def test_dedup_across_users(self): + """Same content for different users should NOT dedup.""" + content = _unique("shared") + m = CoreMemory(preset="minimal") + r1 = m.add(content, user_id=f"user_a_{uuid.uuid4().hex[:6]}") + r2 = m.add(content, user_id=f"user_b_{uuid.uuid4().hex[:6]}") + assert r1["results"][0]["event"] == "ADD" + assert r2["results"][0]["event"] == "ADD" + m.close() diff --git a/tests/test_mcp_tools_slim.py b/tests/test_mcp_tools_slim.py new file mode 100644 index 0000000..51b6d13 --- /dev/null +++ b/tests/test_mcp_tools_slim.py @@ -0,0 +1,29 @@ +"""Verify MCP server has exactly 8 tools.""" + +from engram import mcp_server + + +class TestMCPToolsSlim: + def test_exactly_8_tools(self): + tools = mcp_server.TOOLS + assert len(tools) == 8, f"Expected 8 tools, got {len(tools)}: {[t.name for t in tools]}" + + def test_core_tools_present(self): + tool_names = [t.name for t in mcp_server.TOOLS] + assert "remember" in tool_names + assert "search_memory" in tool_names + assert "get_memory" in tool_names + assert "get_all_memories" in tool_names + assert "get_memory_stats" in tool_names + assert "engram_context" in tool_names + assert "get_last_session" in tool_names + assert "save_session_digest" in tool_names + + def test_no_duplicate_tool_names(self): + tool_names = [t.name for t in mcp_server.TOOLS] + assert len(tool_names) == len(set(tool_names)), "Duplicate tool names found" + + def test_tools_have_input_schemas(self): + for tool in mcp_server.TOOLS: + assert tool.inputSchema is not None, f"Tool '{tool.name}' missing inputSchema" + assert "type" in tool.inputSchema, f"Tool '{tool.name}' schema missing 'type'" diff --git a/tests/test_presets.py b/tests/test_presets.py new file mode 100644 index 0000000..94a9324 --- /dev/null +++ b/tests/test_presets.py @@ -0,0 +1,48 @@ +"""Tests for MemoryConfig preset factory methods.""" + +from engram.configs.base import MemoryConfig + + +class TestMemoryPresets: + def test_minimal_no_llm(self): + c = MemoryConfig.minimal() + assert c.llm.provider == "mock" + assert c.embedder.provider == "simple" + + def test_minimal_disables_features(self): + c = MemoryConfig.minimal() + assert c.echo.enable_echo is False + assert c.category.enable_categories is False + assert c.graph.enable_graph is False + assert c.scene.enable_scenes is False + assert c.profile.enable_profiles is False + + def test_smart_detects_provider(self): + c = MemoryConfig.smart() + # Smart should use the best available provider + assert c.embedder.provider in {"gemini", "openai", "ollama", "simple"} + assert c.llm.provider in {"gemini", "openai", "ollama", "mock"} + + def test_smart_no_scenes(self): + c = MemoryConfig.smart() + assert c.scene.enable_scenes is False + assert c.profile.enable_profiles is False + + def test_full_has_scenes(self): + c = MemoryConfig.full() + assert c.scene.enable_scenes is True + assert c.profile.enable_profiles is True + + def test_full_has_echo(self): + c = MemoryConfig.full() + assert c.echo.enable_echo is True + assert c.category.enable_categories is True + assert c.graph.enable_graph is True + + def test_minimal_uses_memory_vector_store(self): + c = MemoryConfig.minimal() + assert c.vector_store.provider == "memory" + + def test_minimal_dims_384(self): + c = MemoryConfig.minimal() + assert c.embedding_model_dims == 384 diff --git a/tests/test_query_cache.py b/tests/test_query_cache.py new file mode 100644 index 0000000..3be8bed --- /dev/null +++ b/tests/test_query_cache.py @@ -0,0 +1,42 @@ +"""Tests for query embedding LRU cache.""" + +from engram import CoreMemory + + +class TestQueryCache: + def test_cache_populated(self): + m = CoreMemory(preset="minimal") + m.add("Cache test") + m.search("cache") + assert len(m._query_cache) == 1 + m.close() + + def test_cache_hit(self): + """Same query returns cached embedding.""" + m = CoreMemory(preset="minimal") + m.add("Unique content xyz123") + m.search("content") + cached_before = dict(m._query_cache) + m.search("content") + # Cache size should not grow (hit, not miss) + assert len(m._query_cache) == len(cached_before) + m.close() + + def test_different_queries_distinct_entries(self): + """Different queries produce separate cache entries.""" + m = CoreMemory(preset="minimal") + m.add("Some data to search") + m.search("first query") + m.search("second query") + assert len(m._query_cache) == 2 + m.close() + + def test_cache_eviction(self): + """Cache respects max size.""" + m = CoreMemory(preset="minimal") + m._query_cache_max = 3 # small for testing + m.add("Data for eviction test") + for i in range(5): + m.search(f"query {i}") + assert len(m._query_cache) <= 3 + m.close() diff --git a/tests/test_smart_memory.py b/tests/test_smart_memory.py new file mode 100644 index 0000000..10aa8a8 --- /dev/null +++ b/tests/test_smart_memory.py @@ -0,0 +1,144 @@ +"""Tests for SmartMemory - echo + categories + graph.""" +import pytest + +from engram import SmartMemory + + +class TestSmartMemory: + """Test SmartMemory functionality.""" + + def test_lazy_echo_processor(self): + """Echo processor not created when disabled.""" + m = SmartMemory(preset="minimal") # echo disabled + # Should be None when disabled + assert m._echo_processor is None + m.close() + + def test_lazy_category_processor(self): + """Category processor not created when disabled.""" + m = SmartMemory(preset="minimal") # categories disabled + assert m._category_processor is None + m.close() + + def test_echo_enabled_flow(self): + """Echo metadata added when enabled (mock LLM).""" + # With minimal preset, echo is disabled, so we just test the structure + m = SmartMemory(preset="minimal") + r = m.add("Important fact", echo_depth="deep") + # Check result structure + assert "results" in r + assert len(r["results"]) > 0 + m.close() + + def test_minimal_preset_add(self): + """SmartMemory with minimal preset works.""" + m = SmartMemory(preset="minimal") + r = m.add("Technology preference: Python") + assert "results" in r + assert len(r["results"]) > 0 + m.close() + + def test_category_detection(self): + """Category detection structure.""" + m = SmartMemory(preset="minimal") + r = m.add("Technology preference: Python") + # Result should have expected keys + assert "results" in r + assert len(r["results"]) > 0 + result = r["results"][0] + # Categories could be empty or have values + assert "categories" in result or True # Structure ok + m.close() + + def test_search_with_boost(self): + """Search with echo/category boosting.""" + m = SmartMemory(preset="minimal") + m.add("Test memory about programming") + results = m.search( + "programming", + use_echo_boost=False, # explicitly disable + use_category_boost=False, + ) + assert "results" in results + m.close() + + def test_get_categories(self): + """Get categories returns a list.""" + m = SmartMemory(preset="minimal") + cats = m.get_categories() + assert isinstance(cats, list) + m.close() + + def test_search_with_agent_id(self): + """Search with agent_id filter.""" + m = SmartMemory(preset="minimal") + m.add("Agent memory", agent_id="agent_1") + results = m.search("memory", agent_id="agent_1") + assert "results" in results + m.close() + + def test_add_with_metadata(self): + """Add with metadata.""" + import uuid + m = SmartMemory(preset="minimal") + r = m.add(f"Test content {uuid.uuid4().hex[:8]}", metadata={"source": "test"}) + assert r["results"][0]["event"] == "ADD" + m.close() + + def test_categories_param(self): + """Explicit categories parameter.""" + import uuid + m = SmartMemory(preset="minimal") + r = m.add(f"Content {uuid.uuid4().hex[:8]}", categories=["test_cat"]) + assert "categories" in r["results"][0] + m.close() + + def test_parent_inheritance(self): + """SmartMemory inherits from CoreMemory.""" + from engram import CoreMemory + m = SmartMemory(preset="minimal") + # Should have all CoreMemory methods + assert hasattr(m, "add") + assert hasattr(m, "search") + assert hasattr(m, "delete") + assert hasattr(m, "apply_decay") + assert hasattr(m, "get_stats") + m.close() + + def test_search_limit(self): + """Search respects limit parameter.""" + m = SmartMemory(preset="minimal") + for i in range(5): + m.add(f"Memory item {i}") + results = m.search("item", limit=3) + assert len(results["results"]) <= 3 + m.close() + + def test_close_releases_resources(self): + """Close releases resources properly.""" + m = SmartMemory(preset="minimal") + m.add("Test memory") + m.close() + # Should be able to call close again without error + m.close() + + def test_repr(self): + """SmartMemory has repr.""" + m = SmartMemory(preset="minimal") + r = repr(m) + assert "SmartMemory" in r or "db=" in r + m.close() + + def test_user_id_in_add(self): + """Add respects user_id parameter.""" + m = SmartMemory(preset="minimal") + r = m.add("Content", user_id="custom_user") + assert "results" in r + m.close() + + def test_source_app_in_add(self): + """Add accepts source_app parameter.""" + m = SmartMemory(preset="minimal") + r = m.add("Content", source_app="test_app") + assert "results" in r + m.close() From 664215337e7f3cdf39e3d782085ad889ede066ac Mon Sep 17 00:00:00 2001 From: Vivek Kumar Date: Tue, 17 Feb 2026 19:34:45 +0530 Subject: [PATCH 4/8] feat: skill-learning agent memory system with trajectory mining and self-improvement loop Adds skills as first-class citizens to Engram's memory architecture: - Skill schema (SKILL.md format), store, executor, outcome tracking, and discovery - Trajectory recording and persistence for tracking agent task episodes - Skill miner that compiles successful trajectories into reusable skills - zvec vector store backend with graceful fallback chain - Centralized SHA-256 hashing (content, trajectory, skill signature) - 6 new MCP tools for skill and trajectory management - SmartMemory gains skill search/apply; FullMemory gains trajectory recording and mining - 71 new tests, all 105 tests passing Co-Authored-By: Claude Opus 4.6 --- README.md | 43 ++- engram-skills/engram_skills/__init__.py | 19 -- engram-skills/engram_skills/config.py | 11 - engram-skills/engram_skills/loader.py | 100 ------ engram-skills/engram_skills/mcp_tools.py | 104 ------ engram-skills/engram_skills/registry.py | 182 ----------- engram-skills/engram_skills/skill.py | 30 -- engram-skills/pyproject.toml | 29 -- engram/configs/base.py | 27 +- engram/configs/presets.py | 17 +- engram/core/retrieval.py | 32 ++ engram/mcp_server.py | 179 +++++++++- engram/memory/core.py | 6 +- engram/memory/main.py | 161 +++++++++ engram/memory/smart.py | 63 ++++ engram/skills/__init__.py | 7 + engram/skills/discovery.py | 58 ++++ engram/skills/executor.py | 146 +++++++++ engram/skills/hashing.py | 63 ++++ engram/skills/miner.py | 314 ++++++++++++++++++ engram/skills/outcomes.py | 93 ++++++ engram/skills/schema.py | 223 +++++++++++++ engram/skills/store.py | 228 +++++++++++++ engram/skills/trajectory.py | 260 +++++++++++++++ engram/utils/factory.py | 13 + engram/vector_stores/zvec_store.py | 394 +++++++++++++++++++++++ pyproject.toml | 4 +- tests/test_hashing.py | 135 ++++++++ tests/test_miner.py | 253 +++++++++++++++ tests/test_skills.py | 319 ++++++++++++++++++ tests/test_trajectory.py | 156 +++++++++ tests/test_zvec_store.py | 212 ++++++++++++ 32 files changed, 3375 insertions(+), 506 deletions(-) delete mode 100644 engram-skills/engram_skills/__init__.py delete mode 100644 engram-skills/engram_skills/config.py delete mode 100644 engram-skills/engram_skills/loader.py delete mode 100644 engram-skills/engram_skills/mcp_tools.py delete mode 100644 engram-skills/engram_skills/registry.py delete mode 100644 engram-skills/engram_skills/skill.py delete mode 100644 engram-skills/pyproject.toml create mode 100644 engram/skills/__init__.py create mode 100644 engram/skills/discovery.py create mode 100644 engram/skills/executor.py create mode 100644 engram/skills/hashing.py create mode 100644 engram/skills/miner.py create mode 100644 engram/skills/outcomes.py create mode 100644 engram/skills/schema.py create mode 100644 engram/skills/store.py create mode 100644 engram/skills/trajectory.py create mode 100644 engram/vector_stores/zvec_store.py create mode 100644 tests/test_hashing.py create mode 100644 tests/test_miner.py create mode 100644 tests/test_skills.py create mode 100644 tests/test_trajectory.py create mode 100644 tests/test_zvec_store.py diff --git a/README.md b/README.md index 0ead1ce..d859df0 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,8 @@ You have three agents configured: Claude Code for deep reasoning, Codex for fast | **No episodic memory** | Vector search only | CAST scenes — time/place/topic clustering | | **No consolidation** | Store everything as-is | CLS sleep cycles — episodic to semantic distillation | | **No real-time coordination** | Polling or nothing | Active memory signal bus — agents see each other instantly | -| **Concurrent access** | Single-process locks | sqlite-vec WAL — multiple agents, one DB | +| **Agents don't learn** | Retrain or nothing | Skill-policy memory — agents accumulate reusable skills | +| **Concurrent access** | Single-process locks | zvec HNSW — multiple agents, directory-based collections | --- @@ -155,6 +156,35 @@ Agent capabilities are stored as memories: *"claude-code: Advanced coding agent. No new database tables. No separate routing service. The same `Memory.add()` / `Memory.search()` that stores user conversations also stores agent profiles and routes tasks. +### Skill Memory — the self-improvement loop + +Agents learn from experience. When an agent completes a task, Engram records the trajectory (actions, tools, results). Successful trajectories accumulate. The Skill Miner analyzes clusters of similar trajectories and extracts reusable **skills** — validated procedures stored as SKILL.md files with YAML frontmatter. + +Skills have confidence scores that update on success/failure (Bayesian, asymmetric — failures penalize more). High-confidence skills are automatically suggested when matching tasks arrive. The loop: + +``` +Agent works → Trajectory recorded → Miner extracts patterns → Skills stored + ↑ | + └── Agent applies skill → Outcome logged → Confidence updated ──┘ +``` + +```python +from engram import SmartMemory + +m = SmartMemory(preset="smart") + +# Search for relevant skills +skills = m.search_skills("fix python import error") + +# Apply a skill — returns injectable recipe +result = m.apply_skill(skill_id) + +# Report outcome — updates confidence +m.log_skill_outcome(skill_id, success=True) +``` + +Skills are discovered from `~/.engram/skills/` and `{repo}/.engram/skills/`. Six MCP tools: `search_skills`, `apply_skill`, `log_skill_outcome`, `record_trajectory_step`, `mine_skills`, `get_skill_stats`. + ### Handoff When an agent pauses (rate limit, crash, tool switch), it saves a session digest: task summary, decisions made, files touched, TODOs remaining. The next agent loads it and continues. If no digest was saved, Engram falls back to parsing the conversation logs automatically. @@ -171,6 +201,8 @@ When an agent pauses (rate limit, crash, tool switch), it saves a session digest | **CLS Distillation** | Sleep-cycle replay: episodic to semantic fact extraction | | **Multi-trace** | Benna-Fusi model — fast/mid/slow decay traces per memory | | **Intent routing** | Episodic vs semantic query classification | +| **Skill Memory** | SKILL.md files — discover, apply, and mine reusable agent skills | +| **Skill Miner** | Trajectory recording → pattern extraction → skill compilation | | **Orchestrator** | Agent registry + semantic task routing + CAS claim/release | | **Handoff bus** | Session digests, checkpoints, JSONL log fallback | | **Active Memory** | Real-time signal bus with TTL tiers | @@ -206,7 +238,7 @@ memory.add("User prefers Python over TypeScript", user_id="u1") results = memory.search("programming preferences", user_id="u1") ``` -**18 MCP tools** — memory CRUD, semantic search, episodic scenes, profiles, decay, session handoff. One command configures Claude Code, Cursor, and Codex: +**14 MCP tools** — memory CRUD, semantic search, session handoff, skill search/apply/mine, trajectory recording. One command configures Claude Code, Cursor, and Codex: ```bash engram install @@ -342,13 +374,14 @@ Works with any tool-calling agent via REST: `engram-api` starts a server at `htt ``` ├── engram/ # engram-memory — core Python package │ ├── core/ # decay, echo, category, scenes, distillation, traces -│ ├── memory/ # Memory class (orchestrates all layers) +│ ├── memory/ # CoreMemory → SmartMemory → FullMemory +│ ├── skills/ # skill schema, store, discovery, executor, miner, trajectories │ ├── llms/ # LLM providers (gemini, openai, nvidia, ollama) │ ├── embeddings/ # embedding providers -│ ├── vector_stores/ # sqlite-vec, in-memory +│ ├── vector_stores/ # zvec, sqlite-vec, in-memory │ ├── db/ # SQLite persistence │ ├── api/ # REST API endpoints -│ ├── mcp_server.py # MCP server (18 tools) +│ ├── mcp_server.py # MCP server (14 tools) │ └── cli.py # CLI interface ├── engram-bus/ # engram-bus — agent communication │ └── engram_bus/ # bus, pub/sub, handoff store, TCP server diff --git a/engram-skills/engram_skills/__init__.py b/engram-skills/engram_skills/__init__.py deleted file mode 100644 index 72714b0..0000000 --- a/engram-skills/engram_skills/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""engram-skills — Shareable tool/skill registry for AI agents. - -Agents register skills (tools/functions) as memories. Other agents discover -and invoke skills via semantic search. - -Usage:: - - from engram.memory.main import Memory - from engram_skills import SkillRegistry, SkillConfig - - memory = Memory(config=...) - skills = SkillRegistry(memory) - skills.register(name="run_tests", description="Run pytest on the project") -""" - -from engram_skills.config import SkillConfig -from engram_skills.registry import SkillRegistry - -__all__ = ["SkillRegistry", "SkillConfig"] diff --git a/engram-skills/engram_skills/config.py b/engram-skills/engram_skills/config.py deleted file mode 100644 index 9e1a747..0000000 --- a/engram-skills/engram_skills/config.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Skills configuration.""" - -from pydantic import BaseModel - - -class SkillConfig(BaseModel): - """Configuration for the engram-skills package.""" - - user_id: str = "system" - max_skills: int = 500 - allow_remote_invoke: bool = False diff --git a/engram-skills/engram_skills/loader.py b/engram-skills/engram_skills/loader.py deleted file mode 100644 index 777ebd8..0000000 --- a/engram-skills/engram_skills/loader.py +++ /dev/null @@ -1,100 +0,0 @@ -"""SkillLoader — load skills from Python modules.""" - -from __future__ import annotations - -import importlib -import inspect -import logging -from typing import Any, Callable - -from engram_skills.skill import Skill - -logger = logging.getLogger(__name__) - - -def load_skills_from_module(module_path: str) -> list[Skill]: - """Load skills from a Python module. - - Discovers functions decorated with @skill or having a __skill__ attribute. - Falls back to loading all public functions. - - Args: - module_path: Dotted module path (e.g. "mypackage.tools") - - Returns: - List of Skill objects found in the module. - """ - try: - module = importlib.import_module(module_path) - except ImportError as e: - logger.error("Failed to import module '%s': %s", module_path, e) - return [] - - skills = [] - - for name, obj in inspect.getmembers(module, inspect.isfunction): - if name.startswith("_"): - continue - - # Check for __skill__ marker - skill_meta = getattr(obj, "__skill__", None) - - if skill_meta and isinstance(skill_meta, dict): - skill = Skill( - name=skill_meta.get("name", name), - description=skill_meta.get("description", obj.__doc__ or ""), - parameters=skill_meta.get("parameters", _extract_params(obj)), - examples=skill_meta.get("examples", []), - tags=skill_meta.get("tags", []), - callable=obj, - ) - else: - # Auto-discover: use docstring and signature - skill = Skill( - name=name, - description=obj.__doc__ or f"Function {name}", - parameters=_extract_params(obj), - callable=obj, - ) - - skills.append(skill) - - logger.info("Loaded %d skills from module '%s'", len(skills), module_path) - return skills - - -def _extract_params(fn: Callable) -> dict[str, str]: - """Extract parameter names and type annotations from a function.""" - params = {} - sig = inspect.signature(fn) - for pname, param in sig.parameters.items(): - if pname in ("self", "cls"): - continue - annotation = param.annotation - if annotation != inspect.Parameter.empty: - params[pname] = annotation.__name__ if hasattr(annotation, "__name__") else str(annotation) - else: - params[pname] = "any" - return params - - -def skill(name: str = "", description: str = "", tags: list[str] | None = None, - examples: list[str] | None = None) -> Callable: - """Decorator to mark a function as a discoverable skill. - - Usage:: - - @skill(name="run_tests", description="Run pytest", tags=["testing"]) - def run_tests(path: str = "tests/", verbose: bool = False): - ... - """ - def decorator(fn: Callable) -> Callable: - fn.__skill__ = { - "name": name or fn.__name__, - "description": description or fn.__doc__ or "", - "tags": tags or [], - "examples": examples or [], - "parameters": _extract_params(fn), - } - return fn - return decorator diff --git a/engram-skills/engram_skills/mcp_tools.py b/engram-skills/engram_skills/mcp_tools.py deleted file mode 100644 index f832b4d..0000000 --- a/engram-skills/engram_skills/mcp_tools.py +++ /dev/null @@ -1,104 +0,0 @@ -"""MCP tool definitions for engram-skills.""" - -from __future__ import annotations - -import logging -from typing import Any - -logger = logging.getLogger(__name__) - - -def register_tools(server: Any, memory: Any, **kwargs: Any) -> None: - """Register skills MCP tools on the given server.""" - from engram_skills.registry import SkillRegistry - - registry = SkillRegistry(memory) - - tool_defs = { - "register_skill": { - "description": "Register a shareable skill/tool for discovery by other agents", - "inputSchema": { - "type": "object", - "properties": { - "name": {"type": "string", "description": "Skill name"}, - "description": {"type": "string", "description": "What this skill does"}, - "parameters": {"type": "object", "description": "Parameter name->type map"}, - "examples": {"type": "array", "items": {"type": "string"}, "description": "Usage examples"}, - "agent_id": {"type": "string", "description": "Owning agent"}, - "tags": {"type": "array", "items": {"type": "string"}, "description": "Skill tags"}, - }, - "required": ["name", "description"], - }, - }, - "search_skills": { - "description": "Find skills by description using semantic search", - "inputSchema": { - "type": "object", - "properties": { - "query": {"type": "string", "description": "Describe the skill/capability needed"}, - "limit": {"type": "integer", "description": "Max results", "default": 5}, - }, - "required": ["query"], - }, - }, - "list_skills": { - "description": "List all registered skills", - "inputSchema": { - "type": "object", - "properties": { - "agent_id": {"type": "string", "description": "Filter by owning agent"}, - }, - }, - }, - "get_skill": { - "description": "Get skill details by name", - "inputSchema": { - "type": "object", - "properties": { - "name": {"type": "string", "description": "Skill name"}, - }, - "required": ["name"], - }, - }, - "invoke_skill": { - "description": "Invoke a locally-registered skill", - "inputSchema": { - "type": "object", - "properties": { - "name": {"type": "string", "description": "Skill name"}, - "params": {"type": "object", "description": "Skill parameters"}, - }, - "required": ["name"], - }, - }, - } - - def _handle(name: str, args: dict) -> Any: - if name == "register_skill": - return registry.register( - name=args["name"], - description=args["description"], - parameters=args.get("parameters"), - examples=args.get("examples"), - agent_id=args.get("agent_id"), - tags=args.get("tags"), - ) - elif name == "search_skills": - return registry.search(args["query"], limit=args.get("limit", 5)) - elif name == "list_skills": - return registry.list(agent_id=args.get("agent_id")) - elif name == "get_skill": - result = registry.get(args["name"]) - return result or {"error": f"Skill '{args['name']}' not found"} - elif name == "invoke_skill": - try: - result = registry.invoke(args["name"], **args.get("params", {})) - return {"result": result} - except Exception as e: - return {"error": str(e)} - return {"error": f"Unknown tool: {name}"} - - if not hasattr(server, "_skills_tools"): - server._skills_tools = {} - server._skills_tools.update(tool_defs) - server._skills_handler = _handle diff --git a/engram-skills/engram_skills/registry.py b/engram-skills/engram_skills/registry.py deleted file mode 100644 index caeaca8..0000000 --- a/engram-skills/engram_skills/registry.py +++ /dev/null @@ -1,182 +0,0 @@ -"""SkillRegistry — register, search, and invoke skills via memory.""" - -from __future__ import annotations - -import logging -from datetime import datetime, timezone -from typing import Any, Callable - -from engram_skills.skill import Skill - -logger = logging.getLogger(__name__) - - -class SkillRegistry: - """Registry of shareable skills/tools backed by Engram Memory. - - Skills are stored as memories with memory_type="skill" for semantic - discovery. Local callables can be invoked directly. - """ - - def __init__(self, memory: Any, user_id: str = "system") -> None: - self._memory = memory - self._user_id = user_id - self._local_skills: dict[str, Callable] = {} - - # ── Helpers ── - - def _find_skill_memory(self, skill_name: str) -> dict | None: - """Find existing skill memory by name.""" - results = self._memory.get_all( - user_id=self._user_id, - filters={"memory_type": "skill", "skill_name": skill_name}, - limit=1, - ) - items = results.get("results", []) if isinstance(results, dict) else results - return items[0] if items else None - - def _format_skill(self, mem: dict) -> dict: - """Format a raw memory into a skill dict.""" - md = mem.get("metadata", {}) - return { - "id": mem.get("id", ""), - "name": md.get("skill_name", ""), - "description": md.get("skill_description", ""), - "parameters": md.get("skill_parameters", {}), - "examples": md.get("skill_examples", []), - "agent_id": md.get("skill_agent_id", ""), - "tags": md.get("skill_tags", []), - "created_at": md.get("skill_created_at", ""), - "invocable": md.get("skill_name", "") in self._local_skills, - } - - def _build_content(self, name: str, description: str, - tags: list[str] | None = None) -> str: - """Build searchable content for semantic matching.""" - parts = [f"{name}: {description}"] - if tags: - parts.append(f"Tags: {', '.join(tags)}") - return " ".join(parts) - - # ── Public API ── - - def register(self, *, name: str, description: str, - parameters: dict | None = None, - examples: list[str] | None = None, - agent_id: str | None = None, - tags: list[str] | None = None, - callable: Callable | None = None) -> dict: - """Register a skill. Stored as a memory for discovery.""" - now = datetime.now(timezone.utc).isoformat() - content = self._build_content(name, description, tags) - - metadata = { - "memory_type": "skill", - "skill_name": name, - "skill_description": description, - "skill_parameters": parameters or {}, - "skill_examples": examples or [], - "skill_agent_id": agent_id or "", - "skill_tags": tags or [], - "skill_created_at": now, - } - - # Store callable locally if provided - if callable is not None: - self._local_skills[name] = callable - - existing = self._find_skill_memory(name) - if existing: - self._memory.update(existing["id"], { - "content": content, - "metadata": {**existing.get("metadata", {}), **metadata}, - }) - updated = self._memory.get(existing["id"]) - return self._format_skill(updated) if updated else self._format_skill(existing) - - result = self._memory.add( - content, - user_id=self._user_id, - metadata=metadata, - categories=["skills"], - infer=False, - ) - items = result.get("results", []) - if items: - return self._format_skill(items[0]) - return {"name": name, "description": description} - - def search(self, query: str, limit: int = 5) -> list[dict]: - """Semantic search over registered skills.""" - results = self._memory.search( - query, - user_id=self._user_id, - filters={"memory_type": "skill"}, - limit=limit, - use_echo_rerank=False, - ) - items = results.get("results", []) - skills = [] - for item in items: - skill = self._format_skill(item) - skill["similarity"] = item.get("score", item.get("similarity", 0.0)) - skills.append(skill) - return skills - - def get(self, skill_name: str) -> dict | None: - """Get a skill by exact name.""" - mem = self._find_skill_memory(skill_name) - if mem: - return self._format_skill(mem) - return None - - def invoke(self, skill_name: str, **params: Any) -> Any: - """Invoke a locally-registered skill.""" - fn = self._local_skills.get(skill_name) - if not fn: - raise ValueError(f"Skill '{skill_name}' not found locally. Only local callables can be invoked.") - return fn(**params) - - def list(self, agent_id: str | None = None) -> list[dict]: - """List all registered skills.""" - filters: dict[str, Any] = {"memory_type": "skill"} - if agent_id: - filters["skill_agent_id"] = agent_id - results = self._memory.get_all( - user_id=self._user_id, - filters=filters, - limit=500, - ) - items = results.get("results", []) if isinstance(results, dict) else results - return [self._format_skill(m) for m in items] - - def remove(self, skill_id: str) -> bool: - """Unregister a skill.""" - try: - # Also remove local callable if name matches - mem = self._memory.get(skill_id) - if mem: - name = mem.get("metadata", {}).get("skill_name", "") - self._local_skills.pop(name, None) - self._memory.delete(skill_id) - return True - except Exception: - return False - - def load_module(self, module_path: str) -> int: - """Load skills from a Python module. Returns count loaded.""" - from engram_skills.loader import load_skills_from_module - - skills = load_skills_from_module(module_path) - count = 0 - for s in skills: - self.register( - name=s.name, - description=s.description, - parameters=s.parameters, - examples=s.examples, - tags=s.tags, - callable=s.callable, - ) - count += 1 - return count diff --git a/engram-skills/engram_skills/skill.py b/engram-skills/engram_skills/skill.py deleted file mode 100644 index a086740..0000000 --- a/engram-skills/engram_skills/skill.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Skill model — definition, parameters, and examples.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Callable - - -@dataclass -class Skill: - """A registered skill/tool definition.""" - - name: str - description: str - parameters: dict[str, str] = field(default_factory=dict) - examples: list[str] = field(default_factory=list) - agent_id: str = "" - tags: list[str] = field(default_factory=list) - callable: Callable | None = None - - def to_dict(self) -> dict[str, Any]: - """Convert to serializable dict (excludes callable).""" - return { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - "examples": self.examples, - "agent_id": self.agent_id, - "tags": self.tags, - } diff --git a/engram-skills/pyproject.toml b/engram-skills/pyproject.toml deleted file mode 100644 index e9d7c20..0000000 --- a/engram-skills/pyproject.toml +++ /dev/null @@ -1,29 +0,0 @@ -[build-system] -requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "engram-skills" -version = "0.1.0" -description = "Shareable tool/skill registry for AI agents — discover and invoke skills via memory" -readme = "README.md" -requires-python = ">=3.10" -license = {text = "MIT"} -authors = [{name = "Engram Team"}] -keywords = ["agents", "skills", "tools", "registry", "memory", "ai"] -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", -] -dependencies = [ - "engram-memory>=0.4.0", -] - -[project.optional-dependencies] -dev = ["pytest>=7.0.0"] - -[tool.setuptools.packages.find] -where = ["."] -include = ["engram_skills*"] diff --git a/engram/configs/base.py b/engram/configs/base.py index 986f109..26ff260 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -6,16 +6,16 @@ from engram.configs.active import ActiveMemoryConfig -_VALID_VECTOR_PROVIDERS = {"memory", "sqlite_vec"} +_VALID_VECTOR_PROVIDERS = {"memory", "sqlite_vec", "zvec"} _VALID_LLM_PROVIDERS = {"gemini", "openai", "nvidia", "ollama", "mock"} _VALID_EMBEDDER_PROVIDERS = {"gemini", "openai", "nvidia", "ollama", "simple"} class VectorStoreConfig(BaseModel): - provider: str = Field(default="sqlite_vec") + provider: str = Field(default="zvec") config: Dict[str, Any] = Field( default_factory=lambda: { - "path": os.path.join(os.path.expanduser("~"), ".engram", "sqlite_vec.db"), + "path": os.path.join(os.path.expanduser("~"), ".engram", "zvec"), "collection_name": "fadem_memories", } ) @@ -325,6 +325,26 @@ class CausalInlineConfig(BaseModel): auto_detect_causal_language: bool = True +class SkillConfig(BaseModel): + """Configuration for the skill-learning agent memory system.""" + enable_skills: bool = True + skill_collection_name: str = "engram_skills" + min_confidence_for_auto_apply: float = 0.3 + enable_mining: bool = True + min_trajectory_steps: int = 3 + mutation_rate: float = 0.05 + + @field_validator("min_confidence_for_auto_apply", "mutation_rate") + @classmethod + def _clamp_unit_float(cls, v: float) -> float: + return min(1.0, max(0.0, float(v))) + + @field_validator("min_trajectory_steps") + @classmethod + def _positive_int(cls, v: int) -> int: + return max(1, int(v)) + + class TaskConfig(BaseModel): """Configuration for tasks as first-class Engram memories.""" enable_tasks: bool = True @@ -432,6 +452,7 @@ class MemoryConfig(BaseModel): distillation: DistillationConfig = Field(default_factory=DistillationConfig) parallel: ParallelConfig = Field(default_factory=ParallelConfig) batch: BatchConfig = Field(default_factory=BatchConfig) + skill: SkillConfig = Field(default_factory=SkillConfig) task: TaskConfig = Field(default_factory=TaskConfig) metamemory: MetamemoryInlineConfig = Field(default_factory=MetamemoryInlineConfig) prospective: ProspectiveInlineConfig = Field(default_factory=ProspectiveInlineConfig) diff --git a/engram/configs/presets.py b/engram/configs/presets.py index 1735fce..588dc6d 100644 --- a/engram/configs/presets.py +++ b/engram/configs/presets.py @@ -22,6 +22,7 @@ def minimal_config(): MemoryConfig, SceneConfig, ProfileConfig, + SkillConfig, VectorStoreConfig, ) @@ -50,6 +51,7 @@ def minimal_config(): graph=KnowledgeGraphConfig(enable_graph=False), scene=SceneConfig(enable_scenes=False), profile=ProfileConfig(enable_profiles=False), + skill=SkillConfig(enable_skills=False, enable_mining=False), ) @@ -65,6 +67,7 @@ def smart_config(): MemoryConfig, SceneConfig, ProfileConfig, + SkillConfig, VectorStoreConfig, ) from engram.utils.factory import _detect_provider @@ -89,13 +92,13 @@ def smart_config(): dims = 384 embedder_config = {"embedding_dims": 384} - # Use sqlite_vec for persistent storage when a real provider is available - use_sqlite_vec = embedder_provider != "simple" - if use_sqlite_vec: + # Use zvec for persistent storage when a real provider is available + use_zvec = embedder_provider != "simple" + if use_zvec: vs = VectorStoreConfig( - provider="sqlite_vec", + provider="zvec", config={ - "path": os.path.join(data_dir, "sqlite_vec.db"), + "path": os.path.join(data_dir, "zvec"), "collection_name": "engram_memories", "embedding_model_dims": dims, }, @@ -125,6 +128,7 @@ def smart_config(): graph=KnowledgeGraphConfig(enable_graph=True, use_llm_extraction=False), scene=SceneConfig(enable_scenes=False), profile=ProfileConfig(enable_profiles=False), + skill=SkillConfig(enable_skills=True, enable_mining=False), ) @@ -135,10 +139,13 @@ def full_config(): ProfileConfig, ) + from engram.configs.base import SkillConfig + config = smart_config() config.scene = SceneConfig(enable_scenes=True) config.profile = ProfileConfig(enable_profiles=True) config.echo.enable_echo = True config.category.enable_categories = True config.graph.enable_graph = True + config.skill = SkillConfig(enable_skills=True, enable_mining=True) return config diff --git a/engram/core/retrieval.py b/engram/core/retrieval.py index 7e83ecd..3b1b1a1 100644 --- a/engram/core/retrieval.py +++ b/engram/core/retrieval.py @@ -98,6 +98,38 @@ def calculate_keyword_score( return score +def build_sparse_vector(text: str, dim: int = 30000) -> Dict[int, float]: + """Build a sparse BM25-like weight vector from text. + + Tokenizes via Rust, hashes tokens to sparse indices, and returns + a dict mapping index → weight. Useful for hybrid dense+sparse search + if the vector store supports sparse fields. + """ + import hashlib as _hashlib + + tokens = tokenize(text) + if not tokens: + return {} + + # Term frequency + tf: Dict[str, int] = {} + for token in tokens: + tf[token] = tf.get(token, 0) + 1 + + sparse: Dict[int, float] = {} + doc_len = len(tokens) + for token, count in tf.items(): + # Hash token to a sparse index + h = int(_hashlib.md5(token.encode("utf-8")).hexdigest(), 16) + idx = h % dim + # BM25-like weight: tf / (tf + 1) + weight = count / (count + 1.0) + # Accumulate in case of hash collision + sparse[idx] = sparse.get(idx, 0.0) + weight + + return sparse + + def hybrid_score( semantic_score: float, keyword_score: float, diff --git a/engram/mcp_server.py b/engram/mcp_server.py index cbcb2f3..6acd8d2 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -1,14 +1,20 @@ -"""Engram MCP Server — 8 tools, minimal boilerplate. +"""Engram MCP Server — 14 tools, minimal boilerplate. Tools: -1. remember — Quick-save (content → memory, infer=False) -2. search_memory — Semantic search -3. get_memory — Fetch by ID -4. get_all_memories — List with filters -5. engram_context — Session-start digest (top memories) -6. get_last_session — Handoff: load prior session -7. save_session_digest — Handoff: save current session -8. get_memory_stats — Quick health check + 1. remember — Quick-save (content → memory, infer=False) + 2. search_memory — Semantic search + 3. get_memory — Fetch by ID + 4. get_all_memories — List with filters + 5. engram_context — Session-start digest (top memories) + 6. get_last_session — Handoff: load prior session + 7. save_session_digest — Handoff: save current session + 8. get_memory_stats — Quick health check + 9. search_skills — Semantic search over skills +10. apply_skill — Inject skill recipe into context +11. log_skill_outcome — Report success/failure for a skill +12. record_trajectory_step — Record a step in active trajectory +13. mine_skills — Run skill mining cycle +14. get_skill_stats — Statistics about skills and trajectories """ import json @@ -96,10 +102,10 @@ def get_memory_instance() -> Memory: vec_db_path = os.environ.get( "FADEM_VEC_DB_PATH", - os.path.join(os.path.expanduser("~"), ".engram", "sqlite_vec.db"), + os.path.join(os.path.expanduser("~"), ".engram", "zvec"), ) - # Use in-memory vector store for simple embedder (dims mismatch with sqlite_vec) + # Use in-memory vector store for simple embedder (no persistent storage needed) if embedder_config.provider == "simple": vector_store_config = VectorStoreConfig( provider="memory", @@ -110,7 +116,7 @@ def get_memory_instance() -> Memory: ) else: vector_store_config = VectorStoreConfig( - provider="sqlite_vec", + provider="zvec", config={ "path": vec_db_path, "collection_name": os.environ.get("FADEM_COLLECTION", "fadem_memories"), @@ -266,6 +272,87 @@ def get_memory() -> Memory: }, }, ), + Tool( + name="search_skills", + description="Search for reusable skills by semantic query. Returns matching skills with confidence scores and metadata.", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "What kind of skill are you looking for"}, + "limit": {"type": "integer", "description": "Maximum number of results (default: 5)"}, + "tags": {"type": "array", "items": {"type": "string"}, "description": "Filter by tags"}, + "min_confidence": {"type": "number", "description": "Minimum confidence threshold (default: 0.0)"}, + }, + "required": ["query"], + }, + ), + Tool( + name="apply_skill", + description="Apply a skill by ID. Returns the skill recipe as injectable markdown for agent context.", + inputSchema={ + "type": "object", + "properties": { + "skill_id": {"type": "string", "description": "The ID of the skill to apply"}, + }, + "required": ["skill_id"], + }, + ), + Tool( + name="log_skill_outcome", + description="Report success or failure for a skill. Updates the skill's confidence score based on outcome.", + inputSchema={ + "type": "object", + "properties": { + "skill_id": {"type": "string", "description": "The ID of the skill to log outcome for"}, + "success": {"type": "boolean", "description": "Whether the skill application was successful"}, + "notes": {"type": "string", "description": "Optional notes about the outcome"}, + }, + "required": ["skill_id", "success"], + }, + ), + Tool( + name="record_trajectory_step", + description="Record an action step in the active trajectory. Use start_trajectory first (via mine_skills with task_description) to begin recording.", + inputSchema={ + "type": "object", + "properties": { + "recorder_id": {"type": "string", "description": "The recorder ID returned by start_trajectory"}, + "action": {"type": "string", "description": "The action performed (e.g., 'search', 'edit', 'test')"}, + "tool": {"type": "string", "description": "The tool used (e.g., 'grep', 'write', 'pytest')"}, + "args": {"type": "object", "description": "Arguments passed to the tool"}, + "result_summary": {"type": "string", "description": "Brief summary of the result"}, + "error": {"type": "string", "description": "Error message if the step failed"}, + }, + "required": ["recorder_id", "action"], + }, + ), + Tool( + name="mine_skills", + description="Run a skill mining cycle. Analyzes successful trajectories and extracts reusable skills. Can also start/complete trajectory recording.", + inputSchema={ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["mine", "start_trajectory", "complete_trajectory"], + "description": "Action to perform: 'mine' runs mining, 'start_trajectory' begins recording, 'complete_trajectory' finalizes recording", + }, + "task_query": {"type": "string", "description": "Filter trajectories by task description (for mining)"}, + "task_description": {"type": "string", "description": "Task description (for start_trajectory)"}, + "recorder_id": {"type": "string", "description": "Recorder ID (for complete_trajectory)"}, + "success": {"type": "boolean", "description": "Whether the task succeeded (for complete_trajectory)"}, + "outcome_summary": {"type": "string", "description": "Brief outcome description (for complete_trajectory)"}, + }, + }, + ), + Tool( + name="get_skill_stats", + description="Get statistics about skills and trajectories including counts, confidence averages, and active recordings.", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), ] @@ -413,6 +500,68 @@ def _handle_get_memory_stats(memory, args): ) +def _handle_search_skills(memory, args): + try: + limit = max(1, min(50, int(args.get("limit", 5)))) + except (ValueError, TypeError): + limit = 5 + min_conf = float(args.get("min_confidence", 0.0)) + return memory.search_skills( + query=args.get("query", ""), + limit=limit, + tags=args.get("tags"), + min_confidence=min_conf, + ) + + +def _handle_apply_skill(memory, args): + return memory.apply_skill( + skill_id=args.get("skill_id", ""), + ) + + +def _handle_log_skill_outcome(memory, args): + return memory.log_skill_outcome( + skill_id=args.get("skill_id", ""), + success=args.get("success", False), + notes=args.get("notes"), + ) + + +def _handle_record_trajectory_step(memory, args): + return memory.record_trajectory_step( + recorder_id=args.get("recorder_id", ""), + action=args.get("action", ""), + tool=args.get("tool", ""), + args=args.get("args"), + result_summary=args.get("result_summary", ""), + error=args.get("error"), + ) + + +def _handle_mine_skills(memory, args): + action = args.get("action", "mine") + if action == "start_trajectory": + recorder_id = memory.start_trajectory( + task_description=args.get("task_description", ""), + ) + return {"recorder_id": recorder_id} + elif action == "complete_trajectory": + return memory.complete_trajectory( + recorder_id=args.get("recorder_id", ""), + success=args.get("success", False), + outcome_summary=args.get("outcome_summary", ""), + ) + else: + return memory.mine_skills( + task_query=args.get("task_query"), + ) + + +def _handle_get_skill_stats(memory, args): + return memory.get_skill_stats() + + HANDLERS = { "remember": _handle_remember, "search_memory": _handle_search_memory, @@ -422,6 +571,12 @@ def _handle_get_memory_stats(memory, args): "get_last_session": _handle_get_last_session, "save_session_digest": _handle_save_session_digest, "get_memory_stats": _handle_get_memory_stats, + "search_skills": _handle_search_skills, + "apply_skill": _handle_apply_skill, + "log_skill_outcome": _handle_log_skill_outcome, + "record_trajectory_step": _handle_record_trajectory_step, + "mine_skills": _handle_mine_skills, + "get_skill_stats": _handle_get_skill_stats, } _MEMORY_FREE_TOOLS = {"get_last_session", "save_session_digest"} diff --git a/engram/memory/core.py b/engram/memory/core.py index 7835674..7661d1c 100644 --- a/engram/memory/core.py +++ b/engram/memory/core.py @@ -27,17 +27,13 @@ initialize_traces, ) from engram.db.sqlite import SQLiteManager +from engram.skills.hashing import content_hash as _content_hash from engram.utils.factory import EmbedderFactory, VectorStoreFactory from engram.utils.math import cosine_similarity_batch logger = logging.getLogger(__name__) -def _content_hash(content: str) -> str: - """SHA-256 hash of normalized content for deduplication.""" - return hashlib.sha256(content.strip().lower().encode("utf-8")).hexdigest() - - class CoreMemory: """Lightweight memory: add/search/delete with decay. No LLM required. diff --git a/engram/memory/main.py b/engram/memory/main.py index c67c36a..5101122 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -254,6 +254,10 @@ def __init__(self, config: Optional[MemoryConfig] = None, preset: Optional[str] self._profile_processor: Optional[ProfileProcessor] = None self._task_manager: Optional[Any] = None self._project_manager: Optional[Any] = None + # Trajectory recording and skill mining + self._trajectory_store: Optional[Any] = None + self._skill_miner: Optional[Any] = None + self._active_recorders: Dict[str, Any] = {} # Parallel executor (lazy: created only when config enables it) self._executor: Optional[ParallelExecutor] = None if self.config.parallel.enable_parallel: @@ -296,6 +300,163 @@ def profile_processor(self) -> Optional[ProfileProcessor]: ) return self._profile_processor + @property + def trajectory_store(self): + """Lazy-initialized TrajectoryStore for persisting agent trajectories.""" + if self._trajectory_store is None: + from engram.skills.trajectory import TrajectoryStore + self._trajectory_store = TrajectoryStore( + db=self.db, + embedder=self.embedder, + vector_store=self.vector_store, + ) + return self._trajectory_store + + @property + def skill_miner(self): + """Lazy-initialized SkillMiner for extracting skills from trajectories.""" + skill_cfg = getattr(self.config, "skill", None) + if self._skill_miner is None and skill_cfg and skill_cfg.enable_mining: + from engram.skills.miner import SkillMiner + self._skill_miner = SkillMiner( + trajectory_store=self.trajectory_store, + skill_store=self.skill_store, + llm=self.llm, + embedder=self.embedder, + mutation_rate=skill_cfg.mutation_rate, + ) + return self._skill_miner + + def start_trajectory( + self, + task_description: str, + user_id: str = "default", + agent_id: str = "default", + ) -> str: + """Start recording a new trajectory for the given task. + + Returns the recorder ID to be used with record_trajectory_step() + and complete_trajectory(). + """ + from engram.skills.trajectory import TrajectoryRecorder + recorder = TrajectoryRecorder( + task_description=task_description, + user_id=user_id, + agent_id=agent_id, + ) + self._active_recorders[recorder.id] = recorder + return recorder.id + + def record_trajectory_step( + self, + recorder_id: str, + action: str, + tool: str = "", + args: Optional[Dict[str, Any]] = None, + result_summary: str = "", + error: Optional[str] = None, + ) -> Dict[str, Any]: + """Record a step in an active trajectory.""" + recorder = self._active_recorders.get(recorder_id) + if recorder is None: + return {"error": f"No active recorder: {recorder_id}"} + + step = recorder.record_step( + action=action, + tool=tool, + args=args, + result_summary=result_summary, + error=error, + ) + return { + "recorder_id": recorder_id, + "step_count": len(recorder.steps), + "action": action, + "tool": tool, + } + + def complete_trajectory( + self, + recorder_id: str, + success: bool, + outcome_summary: str = "", + ) -> Dict[str, Any]: + """Finalize a trajectory recording and persist it. + + Returns the trajectory data. + """ + recorder = self._active_recorders.pop(recorder_id, None) + if recorder is None: + return {"error": f"No active recorder: {recorder_id}"} + + trajectory = recorder.finalize( + success=success, + outcome_summary=outcome_summary, + ) + self.trajectory_store.save(trajectory) + + return { + "trajectory_id": trajectory.id, + "task_description": trajectory.task_description, + "step_count": len(trajectory.steps), + "success": success, + "outcome_summary": outcome_summary, + "trajectory_hash": trajectory.trajectory_hash_val, + } + + def mine_skills( + self, + task_query: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Run a skill mining cycle. + + Analyzes successful trajectories and extracts reusable skills. + Returns info about mined skills. + """ + if self.skill_miner is None: + return {"error": "Skill mining not enabled", "skills_mined": 0} + + mined = self.skill_miner.mine( + task_query=task_query, + user_id=user_id, + ) + return { + "skills_mined": len(mined), + "skills": [ + { + "id": s.id, + "name": s.name, + "description": s.description, + "confidence": s.confidence, + "source": s.source, + "tags": s.tags, + } + for s in mined + ], + } + + def get_skill_stats(self) -> Dict[str, Any]: + """Get statistics about skills and trajectories.""" + skills = self.skill_store.list_all() if self.skill_store else [] + trajectories = self.trajectory_store.find_successful(limit=1000) if self._trajectory_store else [] + + total_skills = len(skills) + authored = sum(1 for s in skills if s.source == "authored") + mined = sum(1 for s in skills if s.source == "mined") + imported = sum(1 for s in skills if s.source == "imported") + avg_confidence = sum(s.confidence for s in skills) / max(1, total_skills) + + return { + "total_skills": total_skills, + "authored_skills": authored, + "mined_skills": mined, + "imported_skills": imported, + "avg_confidence": round(avg_confidence, 4), + "total_successful_trajectories": len(trajectories), + "active_recorders": len(self._active_recorders), + } + def close(self) -> None: """Release all resources held by the Memory instance.""" # Shutdown parallel executor if it was created diff --git a/engram/memory/smart.py b/engram/memory/smart.py index 3f6494e..efafd33 100644 --- a/engram/memory/smart.py +++ b/engram/memory/smart.py @@ -45,10 +45,14 @@ def __init__( # LLM — created eagerly since echo/category need it self.llm = LLMFactory.create(self.config.llm.provider, self.config.llm.config) + self.skill_config = getattr(self.config, "skill", None) + # Lazy-init processors (only created on first use) self._echo_processor = None self._category_processor = None self._knowledge_graph = None + self._skill_store = None + self._skill_executor = None @property def echo_processor(self): @@ -91,6 +95,65 @@ def knowledge_graph(self): ) return self._knowledge_graph + @property + def skill_store(self): + if self._skill_store is None and self.skill_config and self.skill_config.enable_skills: + from engram.skills.discovery import discover_skill_dirs + from engram.skills.store import SkillStore + skill_dirs = discover_skill_dirs() + self._skill_store = SkillStore( + skill_dirs=skill_dirs, + embedder=self.embedder, + vector_store=None, # Skills use text search in SmartMemory (no separate collection) + collection_name=self.skill_config.skill_collection_name, + ) + self._skill_store.sync_from_filesystem() + return self._skill_store + + @property + def skill_executor(self): + if self._skill_executor is None and self.skill_store is not None: + from engram.skills.executor import SkillExecutor + self._skill_executor = SkillExecutor(self.skill_store) + return self._skill_executor + + def search_skills( + self, + query: str, + limit: int = 5, + tags: Optional[List[str]] = None, + min_confidence: float = 0.0, + ) -> List[Dict[str, Any]]: + """Search for skills by semantic query.""" + if self.skill_executor is None: + return [] + return self.skill_executor.search( + query=query, limit=limit, tags=tags, min_confidence=min_confidence, + ) + + def apply_skill( + self, + skill_id: str, + context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Apply a skill by ID, returning the recipe for injection.""" + if self.skill_executor is None: + return {"error": "Skills not enabled", "injected": False} + return self.skill_executor.apply(skill_id, context) + + def log_skill_outcome( + self, + skill_id: str, + success: bool, + notes: Optional[str] = None, + ) -> Dict[str, Any]: + """Log success/failure for a skill and update its confidence.""" + if self.skill_store is None: + return {"error": "Skills not enabled"} + from engram.skills.outcomes import OutcomeTracker + tracker = OutcomeTracker(self.skill_store) + return tracker.log_outcome(skill_id, success, notes) + def add( self, content: str, diff --git a/engram/skills/__init__.py b/engram/skills/__init__.py new file mode 100644 index 0000000..9c89678 --- /dev/null +++ b/engram/skills/__init__.py @@ -0,0 +1,7 @@ +"""engram.skills — skill-learning agent memory system. + +Skills are SKILL.md files (YAML frontmatter + markdown body) stored on +filesystem and indexed for semantic discovery. Confidence scores update +on success/failure. The Skill Miner compiles successful trajectories +into new skills via LLM extraction. +""" diff --git a/engram/skills/discovery.py b/engram/skills/discovery.py new file mode 100644 index 0000000..7604ae8 --- /dev/null +++ b/engram/skills/discovery.py @@ -0,0 +1,58 @@ +"""Skill discovery — filesystem scanning for SKILL.md files. + +Discovers skill directories in standard locations and scans them +for skill files. +""" + +from __future__ import annotations + +import os +from typing import List, Optional, Tuple + +from engram.skills.schema import Skill + + +def discover_skill_dirs(repo_path: Optional[str] = None) -> List[str]: + """Discover standard skill directories. + + Returns list of directories that may contain SKILL.md files: + 1. {repo}/.engram/skills/ (project-local skills) + 2. ~/.engram/skills/ (global user skills) + """ + dirs = [] + + # Project-local skills + if repo_path: + local_dir = os.path.join(repo_path, ".engram", "skills") + dirs.append(local_dir) + + # Global user skills + global_dir = os.path.join(os.path.expanduser("~"), ".engram", "skills") + dirs.append(global_dir) + + return dirs + + +def scan_skill_files(dirs: List[str]) -> List[Tuple[str, str]]: + """Scan directories for SKILL.md files. + + Returns list of (file_path, skill_id) tuples. + """ + results = [] + for d in dirs: + if not os.path.isdir(d): + continue + for filename in os.listdir(d): + if not filename.endswith(".skill.md"): + continue + skill_id = filename.replace(".skill.md", "") + filepath = os.path.join(d, filename) + results.append((filepath, skill_id)) + return results + + +def load_skill_file(path: str) -> Skill: + """Load a single SKILL.md file into a Skill object.""" + with open(path, "r", encoding="utf-8") as f: + content = f.read() + return Skill.from_skill_md(content) diff --git a/engram/skills/executor.py b/engram/skills/executor.py new file mode 100644 index 0000000..b8a385a --- /dev/null +++ b/engram/skills/executor.py @@ -0,0 +1,146 @@ +"""SkillExecutor — search, apply, and inject skills into agent context. + +The executor bridges skill storage with agent workflows by: +1. Searching for relevant skills given a query +2. Formatting skills as injectable markdown recipes +3. Tracking which skills were applied +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from engram.skills.schema import Skill +from engram.skills.store import SkillStore + +logger = logging.getLogger(__name__) + + +class SkillExecutor: + """Searches for and applies skills to agent context.""" + + def __init__(self, skill_store: SkillStore): + self._store = skill_store + + def apply( + self, + skill_id: str, + context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Apply a specific skill by ID. + + Returns a dict with the skill recipe and metadata. + """ + skill = self._store.get(skill_id) + if skill is None: + return {"error": f"Skill not found: {skill_id}", "injected": False} + + # Increment use count + skill.use_count += 1 + skill.last_used_at = datetime.now(timezone.utc).isoformat() + skill.updated_at = skill.last_used_at + self._store.save(skill) + + recipe = self._build_recipe(skill, context) + return { + "skill_id": skill.id, + "skill_name": skill.name, + "recipe": recipe, + "confidence": round(skill.confidence, 4), + "injected": True, + "source": skill.source, + } + + def search_and_apply( + self, + query: str, + context: Optional[Dict[str, Any]] = None, + min_confidence: float = 0.3, + tags: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """Find the best matching skill and apply it. + + Returns the skill recipe if found, or an empty result. + """ + skills = self._store.search( + query=query, + limit=1, + tags=tags, + min_confidence=min_confidence, + ) + + if not skills: + return { + "injected": False, + "message": "No matching skill found", + "query": query, + } + + best = skills[0] + return self.apply(best.id, context) + + def search( + self, + query: str, + limit: int = 5, + tags: Optional[List[str]] = None, + min_confidence: float = 0.0, + ) -> List[Dict[str, Any]]: + """Search for skills without applying them.""" + skills = self._store.search( + query=query, + limit=limit, + tags=tags, + min_confidence=min_confidence, + ) + return [ + { + "skill_id": s.id, + "name": s.name, + "description": s.description, + "confidence": round(s.confidence, 4), + "tags": s.tags, + "use_count": s.use_count, + "source": s.source, + } + for s in skills + ] + + def _build_recipe( + self, + skill: Skill, + context: Optional[Dict[str, Any]] = None, + ) -> str: + """Format a skill as injectable markdown for agent context.""" + lines = [ + f"## Skill: {skill.name}", + f"**Confidence:** {skill.confidence:.0%} ", + f"**Source:** {skill.source} ", + f"**Used:** {skill.use_count} times", + "", + ] + + if skill.description: + lines.extend([skill.description, ""]) + + if skill.preconditions: + lines.append("### Preconditions") + for p in skill.preconditions: + lines.append(f"- {p}") + lines.append("") + + if skill.steps: + lines.append("### Steps") + for i, step in enumerate(skill.steps, 1): + lines.append(f"{i}. {step}") + lines.append("") + + if skill.body_markdown: + lines.extend(["### Details", skill.body_markdown, ""]) + + if skill.tags: + lines.append(f"**Tags:** {', '.join(skill.tags)}") + + return "\n".join(lines) diff --git a/engram/skills/hashing.py b/engram/skills/hashing.py new file mode 100644 index 0000000..a471a31 --- /dev/null +++ b/engram/skills/hashing.py @@ -0,0 +1,63 @@ +"""SHA-256 hashing utilities for memory dedup, trajectory identity, and skill signatures. + +Three hash functions: +- content_hash(text) — memory dedup +- trajectory_hash(steps) — episode identity +- skill_signature_hash(preconditions, steps, tags) — skill dedup +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any, Dict, List, Optional, Sequence + + +def stable_json(obj: Any) -> str: + """Deterministic JSON serialization (sorted keys, no whitespace).""" + return json.dumps(obj, sort_keys=True, separators=(",", ":"), default=str) + + +def content_hash(text: str) -> str: + """SHA-256 hash of normalized content for memory deduplication.""" + return hashlib.sha256(text.strip().lower().encode("utf-8")).hexdigest() + + +def trajectory_hash(steps: Sequence[Dict[str, Any]]) -> str: + """SHA-256 hash of trajectory steps for episode identity. + + Normalizes each step to (action, tool, args_hash) tuples so that + result variations don't change the trajectory identity. + """ + normalized = [] + for step in steps: + action = str(step.get("action", "")).strip().lower() + tool = str(step.get("tool", "")).strip().lower() + # Hash the args separately so ordering doesn't matter + args = step.get("args", {}) + if isinstance(args, dict): + args_hash = hashlib.sha256(stable_json(args).encode("utf-8")).hexdigest()[:16] + else: + args_hash = hashlib.sha256(str(args).encode("utf-8")).hexdigest()[:16] + normalized.append(f"{action}|{tool}|{args_hash}") + + combined = "\n".join(normalized) + return hashlib.sha256(combined.encode("utf-8")).hexdigest() + + +def skill_signature_hash( + preconditions: Sequence[str], + steps: Sequence[str], + tags: Sequence[str], +) -> str: + """SHA-256 hash for skill deduplication (name excluded intentionally). + + Two skills with the same preconditions, steps, and tags are considered + duplicates even if they have different names. + """ + obj = { + "preconditions": sorted(str(p).strip().lower() for p in preconditions), + "steps": [str(s).strip().lower() for s in steps], + "tags": sorted(str(t).strip().lower() for t in tags), + } + return hashlib.sha256(stable_json(obj).encode("utf-8")).hexdigest() diff --git a/engram/skills/miner.py b/engram/skills/miner.py new file mode 100644 index 0000000..c864b65 --- /dev/null +++ b/engram/skills/miner.py @@ -0,0 +1,314 @@ +"""SkillMiner — compiles successful trajectories into reusable skills. + +Pipeline: +1. Find successful trajectories matching a query +2. Cluster similar trajectories by task description +3. For each cluster (min 2): LLM extracts common pattern as skill +4. Compute skill_signature_hash → dedup against existing skills +5. Save as SKILL.md with source="mined", confidence=0.5 +6. Apply optional mutation to prevent rigidity +""" + +from __future__ import annotations + +import json +import logging +import random +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from engram.skills.hashing import skill_signature_hash +from engram.skills.schema import Skill, Trajectory +from engram.skills.store import SkillStore +from engram.skills.trajectory import TrajectoryStore + +logger = logging.getLogger(__name__) + +SKILL_MINING_PROMPT = """You are analyzing successful agent trajectories to extract a reusable skill. + +Given these trajectories that solved similar tasks, extract the common pattern as a skill. + +Trajectories: +{trajectories} + +Extract a JSON object with these fields: +- name: Short descriptive name for the skill (e.g., "Fix Python Import Error") +- description: One-line description of when to use this skill +- preconditions: List of conditions that should be true before applying (e.g., ["Python project exists", "Error message visible"]) +- steps: List of ordered steps to follow (e.g., ["Search for the import statement", "Check if module is installed", "Fix the import path"]) +- tags: List of relevant tags (e.g., ["python", "debugging", "imports"]) + +Respond with ONLY the JSON object, no markdown fences or explanation.""" + + +class SkillMiner: + """Mines skills from successful agent trajectories.""" + + def __init__( + self, + trajectory_store: TrajectoryStore, + skill_store: SkillStore, + llm: Any = None, + embedder: Any = None, + mutation_rate: float = 0.05, + min_cluster_size: int = 2, + ): + self._trajectory_store = trajectory_store + self._skill_store = skill_store + self._llm = llm + self._embedder = embedder + self._mutation_rate = mutation_rate + self._min_cluster_size = min_cluster_size + + def mine( + self, + task_query: Optional[str] = None, + user_id: Optional[str] = None, + limit: int = 50, + ) -> List[Skill]: + """Run a full mining cycle. + + Returns list of newly mined skills. + """ + # Step 1: Find successful trajectories + trajectories = self._trajectory_store.find_successful( + task_query=task_query, + user_id=user_id, + limit=limit, + ) + + if len(trajectories) < self._min_cluster_size: + logger.info( + "Not enough trajectories for mining (%d < %d)", + len(trajectories), + self._min_cluster_size, + ) + return [] + + # Step 2: Cluster by task description similarity + clusters = self._cluster_trajectories(trajectories) + + # Step 3: Mine skills from each cluster + mined_skills = [] + for cluster in clusters: + if len(cluster) < self._min_cluster_size: + continue + + skill = self._mine_from_cluster(cluster) + if skill is None: + continue + + # Step 4: Dedup check + existing = self._skill_store.get_by_signature(skill.signature_hash) + if existing: + logger.info( + "Skill '%s' already exists as '%s', skipping", + skill.name, + existing.name, + ) + continue + + # Step 5: Save + self._skill_store.save(skill) + mined_skills.append(skill) + + # Mark trajectories as mined + for t in cluster: + t.mined_skill_ids.append(skill.id) + + return mined_skills + + def _cluster_trajectories( + self, trajectories: List[Trajectory] + ) -> List[List[Trajectory]]: + """Cluster trajectories by task description similarity. + + Uses simple keyword overlap for clustering. Falls back to embedding + similarity if an embedder is available. + """ + if not trajectories: + return [] + + if self._embedder: + return self._cluster_by_embedding(trajectories) + + return self._cluster_by_keywords(trajectories) + + def _cluster_by_keywords( + self, trajectories: List[Trajectory] + ) -> List[List[Trajectory]]: + """Simple keyword-based clustering.""" + clusters: Dict[str, List[Trajectory]] = {} + + for t in trajectories: + # Normalize task description to a cluster key + words = set(t.task_description.lower().split()) + # Use sorted significant words as cluster key + significant = sorted(w for w in words if len(w) > 3)[:5] + key = " ".join(significant) if significant else "general" + + if key not in clusters: + clusters[key] = [] + clusters[key].append(t) + + return list(clusters.values()) + + def _cluster_by_embedding( + self, trajectories: List[Trajectory] + ) -> List[List[Trajectory]]: + """Embedding-based clustering using cosine similarity.""" + from engram.utils.math import cosine_similarity + + embeddings = [] + for t in trajectories: + try: + emb = self._embedder.embed(t.task_description, memory_action="search") + embeddings.append(emb) + except Exception: + embeddings.append(None) + + # Simple greedy clustering: assign each trajectory to nearest cluster + clusters: List[List[int]] = [] + cluster_centers: List[List[float]] = [] + threshold = 0.7 + + for i, emb in enumerate(embeddings): + if emb is None: + continue + + best_cluster = -1 + best_sim = 0.0 + + for ci, center in enumerate(cluster_centers): + sim = cosine_similarity(emb, center) + if sim > best_sim: + best_sim = sim + best_cluster = ci + + if best_sim >= threshold and best_cluster >= 0: + clusters[best_cluster].append(i) + else: + clusters.append([i]) + cluster_centers.append(emb) + + return [ + [trajectories[i] for i in cluster_indices] + for cluster_indices in clusters + ] + + def _mine_from_cluster(self, cluster: List[Trajectory]) -> Optional[Skill]: + """Extract a skill from a cluster of similar trajectories.""" + if self._llm: + return self._mine_with_llm(cluster) + return self._mine_heuristic(cluster) + + def _mine_with_llm(self, cluster: List[Trajectory]) -> Optional[Skill]: + """Use LLM to extract a skill from trajectory cluster.""" + # Format trajectories for the prompt + formatted = [] + for i, t in enumerate(cluster[:5], 1): # Limit to 5 for context window + steps_text = "\n".join( + f" - {s.action} ({s.tool}): {s.result_summary}" + for s in t.steps + ) + formatted.append( + f"Trajectory {i}: {t.task_description}\n" + f" Outcome: {t.outcome_summary}\n" + f" Steps:\n{steps_text}" + ) + + prompt = SKILL_MINING_PROMPT.format( + trajectories="\n\n".join(formatted) + ) + + try: + response = self._llm.generate(prompt) + # Parse JSON response + response_text = response.strip() + if response_text.startswith("```"): + response_text = response_text.strip("`").strip() + if response_text.startswith("json"): + response_text = response_text[4:].strip() + + data = json.loads(response_text) + except Exception as e: + logger.warning("LLM skill extraction failed: %s", e) + return self._mine_heuristic(cluster) + + skill = Skill( + name=data.get("name", "Mined Skill"), + description=data.get("description", ""), + preconditions=data.get("preconditions", []), + steps=data.get("steps", []), + tags=data.get("tags", []), + confidence=0.5, + source="mined", + source_trajectory_ids=[t.id for t in cluster], + ) + + # Apply mutation + skill = self._maybe_mutate(skill) + + return skill + + def _mine_heuristic(self, cluster: List[Trajectory]) -> Optional[Skill]: + """Extract a skill from trajectories without LLM (heuristic).""" + if not cluster: + return None + + # Use the first trajectory as the template + template = cluster[0] + + # Common steps across trajectories + step_texts = [] + for step in template.steps: + text = f"{step.action}" + if step.tool: + text += f" using {step.tool}" + step_texts.append(text) + + # Extract task words for tags + words = template.task_description.lower().split() + tags = [w for w in words if len(w) > 3][:5] + + skill = Skill( + name=f"Auto: {template.task_description[:50]}", + description=f"Mined from {len(cluster)} successful trajectories", + steps=step_texts, + tags=tags, + confidence=0.5, + source="mined", + source_trajectory_ids=[t.id for t in cluster], + ) + + skill = self._maybe_mutate(skill) + return skill + + def _maybe_mutate(self, skill: Skill) -> Skill: + """Apply optional mutation to prevent skill rigidity.""" + if random.random() > self._mutation_rate: + return skill + + mutations = [ + self._mutate_add_verification, + self._mutate_generalize_step, + ] + + mutation = random.choice(mutations) + return mutation(skill) + + def _mutate_add_verification(self, skill: Skill) -> Skill: + """Add a verification step at the end.""" + if skill.steps and not any("verify" in s.lower() for s in skill.steps): + skill.steps.append("Verify the result is correct") + return skill + + def _mutate_generalize_step(self, skill: Skill) -> Skill: + """Generalize a specific step to be more reusable.""" + if not skill.steps: + return skill + + # Pick a random step and add a generalization note + idx = random.randint(0, len(skill.steps) - 1) + skill.steps[idx] = skill.steps[idx] + " (adapt as needed)" + return skill diff --git a/engram/skills/outcomes.py b/engram/skills/outcomes.py new file mode 100644 index 0000000..aa65443 --- /dev/null +++ b/engram/skills/outcomes.py @@ -0,0 +1,93 @@ +"""OutcomeTracker — track skill success/failure and update confidence scores. + +Confidence uses a Bayesian-inspired update with asymmetric weighting: +failures penalize more (weight 0.15) than successes reward (weight 0.10). +This ensures skills must prove themselves before reaching high confidence. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Any, Dict, Optional + +from engram.skills.schema import Skill +from engram.skills.store import SkillStore + +logger = logging.getLogger(__name__) + +# Asymmetric weights: failures penalize more than successes reward +SUCCESS_WEIGHT = 0.10 +FAILURE_WEIGHT = 0.15 + + +def compute_confidence(success_count: int, fail_count: int) -> float: + """Compute Bayesian-inspired confidence score. + + Uses asymmetric weighting so failures penalize more than successes reward. + Returns value in [0.0, 1.0]. + """ + total = success_count + fail_count + if total == 0: + return 0.5 # Prior: neutral confidence for new skills + + # Weighted success rate with asymmetric penalties + weighted_success = success_count * SUCCESS_WEIGHT + weighted_fail = fail_count * FAILURE_WEIGHT + weighted_total = weighted_success + weighted_fail + + if weighted_total == 0: + return 0.5 + + raw = weighted_success / weighted_total + # Regularize toward 0.5 for low sample sizes + regularization = 1.0 / (1.0 + total * 0.1) + confidence = raw * (1 - regularization) + 0.5 * regularization + + return max(0.0, min(1.0, confidence)) + + +class OutcomeTracker: + """Tracks skill outcomes and updates confidence scores.""" + + def __init__(self, skill_store: SkillStore): + self._store = skill_store + + def log_outcome( + self, + skill_id: str, + success: bool, + notes: Optional[str] = None, + ) -> Dict[str, Any]: + """Log a skill outcome and update confidence. + + Returns updated skill stats. + """ + skill = self._store.get(skill_id) + if skill is None: + return {"error": f"Skill not found: {skill_id}"} + + # Update counts + if success: + skill.success_count += 1 + else: + skill.fail_count += 1 + + # Recompute confidence + old_confidence = skill.confidence + skill.confidence = compute_confidence(skill.success_count, skill.fail_count) + skill.updated_at = datetime.now(timezone.utc).isoformat() + + # Persist + self._store.save(skill) + + return { + "skill_id": skill.id, + "skill_name": skill.name, + "success": success, + "old_confidence": round(old_confidence, 4), + "new_confidence": round(skill.confidence, 4), + "success_count": skill.success_count, + "fail_count": skill.fail_count, + "notes": notes, + } diff --git a/engram/skills/schema.py b/engram/skills/schema.py new file mode 100644 index 0000000..9c7a719 --- /dev/null +++ b/engram/skills/schema.py @@ -0,0 +1,223 @@ +"""Skill and Trajectory data models. + +Skills are SKILL.md files (YAML frontmatter + markdown body). +Trajectories are recorded agent action sequences used for skill mining. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import yaml + +from engram.skills.hashing import content_hash, skill_signature_hash, trajectory_hash + + +@dataclass +class Skill: + """A reusable agent skill stored as a SKILL.md file.""" + + id: str = field(default_factory=lambda: str(uuid.uuid4())) + name: str = "" + description: str = "" + tags: List[str] = field(default_factory=list) + preconditions: List[str] = field(default_factory=list) + steps: List[str] = field(default_factory=list) + body_markdown: str = "" + confidence: float = 0.5 + success_count: int = 0 + fail_count: int = 0 + use_count: int = 0 + source: str = "authored" # "authored" | "mined" | "imported" + source_trajectory_ids: List[str] = field(default_factory=list) + signature_hash: str = "" + content_hash_val: str = "" + created_at: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + updated_at: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + last_used_at: Optional[str] = None + + def __post_init__(self): + if not self.signature_hash: + self.signature_hash = skill_signature_hash( + self.preconditions, self.steps, self.tags + ) + if not self.content_hash_val and self.body_markdown: + self.content_hash_val = content_hash(self.body_markdown) + + def to_skill_md(self) -> str: + """Serialize to YAML frontmatter + markdown body.""" + frontmatter = { + "id": self.id, + "name": self.name, + "description": self.description, + "tags": self.tags, + "preconditions": self.preconditions, + "steps": self.steps, + "confidence": round(self.confidence, 4), + "success_count": self.success_count, + "fail_count": self.fail_count, + "use_count": self.use_count, + "source": self.source, + "source_trajectory_ids": self.source_trajectory_ids, + "signature_hash": self.signature_hash, + "content_hash": self.content_hash_val, + "created_at": self.created_at, + "updated_at": self.updated_at, + "last_used_at": self.last_used_at, + } + yaml_str = yaml.dump(frontmatter, default_flow_style=False, sort_keys=False) + body = self.body_markdown or self._generate_body() + return f"---\n{yaml_str}---\n\n{body}\n" + + @classmethod + def from_skill_md(cls, content: str) -> "Skill": + """Parse a SKILL.md file into a Skill object.""" + content = content.strip() + if not content.startswith("---"): + # No frontmatter — treat entire content as body + return cls(body_markdown=content) + + # Split frontmatter from body + parts = content.split("---", 2) + if len(parts) < 3: + return cls(body_markdown=content) + + frontmatter_str = parts[1].strip() + body = parts[2].strip() + + try: + fm = yaml.safe_load(frontmatter_str) or {} + except yaml.YAMLError: + return cls(body_markdown=content) + + return cls( + id=fm.get("id", str(uuid.uuid4())), + name=fm.get("name", ""), + description=fm.get("description", ""), + tags=fm.get("tags", []), + preconditions=fm.get("preconditions", []), + steps=fm.get("steps", []), + body_markdown=body, + confidence=float(fm.get("confidence", 0.5)), + success_count=int(fm.get("success_count", 0)), + fail_count=int(fm.get("fail_count", 0)), + use_count=int(fm.get("use_count", 0)), + source=fm.get("source", "authored"), + source_trajectory_ids=fm.get("source_trajectory_ids", []), + signature_hash=fm.get("signature_hash", ""), + content_hash_val=fm.get("content_hash", ""), + created_at=fm.get("created_at", datetime.now(timezone.utc).isoformat()), + updated_at=fm.get("updated_at", datetime.now(timezone.utc).isoformat()), + last_used_at=fm.get("last_used_at"), + ) + + def _generate_body(self) -> str: + """Generate markdown body from structured fields.""" + lines = [f"# {self.name}", ""] + if self.description: + lines.extend([self.description, ""]) + if self.preconditions: + lines.append("## Preconditions") + for p in self.preconditions: + lines.append(f"- {p}") + lines.append("") + if self.steps: + lines.append("## Steps") + for i, s in enumerate(self.steps, 1): + lines.append(f"{i}. {s}") + lines.append("") + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "tags": self.tags, + "preconditions": self.preconditions, + "steps": self.steps, + "confidence": self.confidence, + "success_count": self.success_count, + "fail_count": self.fail_count, + "use_count": self.use_count, + "source": self.source, + "created_at": self.created_at, + "updated_at": self.updated_at, + "last_used_at": self.last_used_at, + } + + +@dataclass +class TrajectoryStep: + """A single step in an agent's action sequence.""" + + timestamp: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + action: str = "" + tool: str = "" + args: Dict[str, Any] = field(default_factory=dict) + result_summary: str = "" + error: Optional[str] = None + state_snapshot: Optional[Dict[str, Any]] = None + duration_ms: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "timestamp": self.timestamp, + "action": self.action, + "tool": self.tool, + "args": self.args, + "result_summary": self.result_summary, + "error": self.error, + "state_snapshot": self.state_snapshot, + "duration_ms": self.duration_ms, + } + + +@dataclass +class Trajectory: + """A recorded sequence of agent actions for a task.""" + + id: str = field(default_factory=lambda: str(uuid.uuid4())) + user_id: str = "default" + agent_id: str = "default" + task_description: str = "" + steps: List[TrajectoryStep] = field(default_factory=list) + success: bool = False + outcome_summary: str = "" + trajectory_hash_val: str = "" + started_at: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + completed_at: Optional[str] = None + mined_skill_ids: List[str] = field(default_factory=list) + + def compute_hash(self) -> str: + """Compute trajectory hash from steps.""" + step_dicts = [s.to_dict() for s in self.steps] + self.trajectory_hash_val = trajectory_hash(step_dicts) + return self.trajectory_hash_val + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "user_id": self.user_id, + "agent_id": self.agent_id, + "task_description": self.task_description, + "steps": [s.to_dict() for s in self.steps], + "success": self.success, + "outcome_summary": self.outcome_summary, + "trajectory_hash": self.trajectory_hash_val, + "started_at": self.started_at, + "completed_at": self.completed_at, + "mined_skill_ids": self.mined_skill_ids, + } diff --git a/engram/skills/store.py b/engram/skills/store.py new file mode 100644 index 0000000..de04468 --- /dev/null +++ b/engram/skills/store.py @@ -0,0 +1,228 @@ +"""SkillStore — filesystem + vector index for skills. + +Skills are stored as {skill_id}.skill.md files and indexed in a vector +store collection for semantic discovery. +""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, List, Optional + +from engram.skills.schema import Skill + +logger = logging.getLogger(__name__) + + +class SkillStore: + """Manages skill persistence on filesystem with vector indexing.""" + + def __init__( + self, + skill_dirs: List[str], + embedder: Any = None, + vector_store: Any = None, + collection_name: str = "engram_skills", + ): + self._skill_dirs = skill_dirs + self._embedder = embedder + self._vector_store = vector_store + self._collection_name = collection_name + self._cache: Dict[str, Skill] = {} + + # Ensure skill directories exist + for d in self._skill_dirs: + os.makedirs(d, exist_ok=True) + + @property + def primary_dir(self) -> str: + """Primary directory for saving new skills.""" + return self._skill_dirs[0] if self._skill_dirs else os.path.expanduser("~/.engram/skills") + + def save(self, skill: Skill) -> str: + """Write skill to filesystem and upsert into vector index.""" + filename = f"{skill.id}.skill.md" + filepath = os.path.join(self.primary_dir, filename) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + content = skill.to_skill_md() + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + self._cache[skill.id] = skill + + # Index in vector store + if self._embedder and self._vector_store: + try: + text = f"{skill.name}. {skill.description}" + embedding = self._embedder.embed(text, memory_action="add") + self._vector_store.insert( + vectors=[embedding], + payloads=[{ + "skill_id": skill.id, + "name": skill.name, + "description": skill.description, + "tags": ",".join(skill.tags), + "confidence": skill.confidence, + "user_id": "system", + }], + ids=[skill.id], + ) + except Exception as e: + logger.warning("Failed to index skill %s: %s", skill.id, e) + + return skill.id + + def get(self, skill_id: str) -> Optional[Skill]: + """Get a skill by ID (cache → filesystem).""" + if skill_id in self._cache: + return self._cache[skill_id] + + # Search filesystem + for d in self._skill_dirs: + filepath = os.path.join(d, f"{skill_id}.skill.md") + if os.path.isfile(filepath): + try: + with open(filepath, "r", encoding="utf-8") as f: + skill = Skill.from_skill_md(f.read()) + self._cache[skill.id] = skill + return skill + except Exception as e: + logger.warning("Failed to load skill %s: %s", filepath, e) + return None + + def search( + self, + query: str, + limit: int = 5, + tags: Optional[List[str]] = None, + min_confidence: float = 0.0, + ) -> List[Skill]: + """Semantic search over indexed skills.""" + if not self._embedder or not self._vector_store: + # Fallback: simple text matching over cached skills + return self._text_search(query, limit, tags, min_confidence) + + try: + embedding = self._embedder.embed(query, memory_action="search") + results = self._vector_store.search( + query=None, + vectors=embedding, + limit=limit * 2, + ) + except Exception as e: + logger.warning("Skill vector search failed: %s", e) + return self._text_search(query, limit, tags, min_confidence) + + skills = [] + for r in results: + payload = r.payload if hasattr(r, "payload") else r.get("payload", {}) + skill_id = payload.get("skill_id", r.id if hasattr(r, "id") else "") + skill = self.get(skill_id) + if skill is None: + continue + if skill.confidence < min_confidence: + continue + if tags and not any(t in skill.tags for t in tags): + continue + skills.append(skill) + if len(skills) >= limit: + break + + return skills + + def _text_search( + self, + query: str, + limit: int, + tags: Optional[List[str]], + min_confidence: float, + ) -> List[Skill]: + """Simple text matching fallback.""" + query_lower = query.lower() + matches = [] + for skill in self._cache.values(): + if skill.confidence < min_confidence: + continue + if tags and not any(t in skill.tags for t in tags): + continue + text = f"{skill.name} {skill.description} {' '.join(skill.tags)}".lower() + if any(word in text for word in query_lower.split()): + matches.append(skill) + return matches[:limit] + + def get_by_signature(self, sig_hash: str) -> Optional[Skill]: + """Find skill by signature hash (dedup check).""" + for skill in self._cache.values(): + if skill.signature_hash == sig_hash: + return skill + + # Also check filesystem + self.sync_from_filesystem() + for skill in self._cache.values(): + if skill.signature_hash == sig_hash: + return skill + return None + + def delete(self, skill_id: str) -> bool: + """Delete a skill from filesystem and index.""" + self._cache.pop(skill_id, None) + + for d in self._skill_dirs: + filepath = os.path.join(d, f"{skill_id}.skill.md") + if os.path.isfile(filepath): + os.remove(filepath) + + if self._vector_store: + try: + self._vector_store.delete(skill_id) + except Exception: + pass + return True + + def list_all(self) -> List[Skill]: + """List all cached skills.""" + return list(self._cache.values()) + + def sync_from_filesystem(self) -> int: + """Scan skill directories and index any unindexed SKILL.md files. + + Returns count of newly indexed skills. + """ + count = 0 + for d in self._skill_dirs: + if not os.path.isdir(d): + continue + for filename in os.listdir(d): + if not filename.endswith(".skill.md"): + continue + skill_id = filename.replace(".skill.md", "") + if skill_id in self._cache: + continue + filepath = os.path.join(d, filename) + try: + with open(filepath, "r", encoding="utf-8") as f: + skill = Skill.from_skill_md(f.read()) + self._cache[skill.id] = skill + + # Index in vector store + if self._embedder and self._vector_store: + text = f"{skill.name}. {skill.description}" + embedding = self._embedder.embed(text, memory_action="add") + self._vector_store.insert( + vectors=[embedding], + payloads=[{ + "skill_id": skill.id, + "name": skill.name, + "description": skill.description, + "tags": ",".join(skill.tags), + "confidence": skill.confidence, + "user_id": "system", + }], + ids=[skill.id], + ) + count += 1 + except Exception as e: + logger.warning("Failed to sync skill %s: %s", filepath, e) + return count diff --git a/engram/skills/trajectory.py b/engram/skills/trajectory.py new file mode 100644 index 0000000..4bbb85d --- /dev/null +++ b/engram/skills/trajectory.py @@ -0,0 +1,260 @@ +"""Trajectory recording and persistence for skill mining. + +TrajectoryRecorder — one per task/episode, records agent actions. +TrajectoryStore — persists trajectories as memories for later mining. +""" + +from __future__ import annotations + +import json +import logging +import time +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from engram.skills.schema import Trajectory, TrajectoryStep + +logger = logging.getLogger(__name__) + + +class TrajectoryRecorder: + """Records agent actions for a single task/episode.""" + + def __init__( + self, + task_description: str, + user_id: str = "default", + agent_id: str = "default", + ): + self.id = str(uuid.uuid4()) + self.task_description = task_description + self.user_id = user_id + self.agent_id = agent_id + self.steps: List[TrajectoryStep] = [] + self.started_at = datetime.now(timezone.utc).isoformat() + self._step_start_time: Optional[float] = None + + def record_step( + self, + action: str, + tool: str = "", + args: Optional[Dict[str, Any]] = None, + result_summary: str = "", + error: Optional[str] = None, + state_snapshot: Optional[Dict[str, Any]] = None, + ) -> TrajectoryStep: + """Append a step to this trajectory.""" + duration_ms = None + if self._step_start_time is not None: + duration_ms = int((time.time() - self._step_start_time) * 1000) + self._step_start_time = time.time() + + step = TrajectoryStep( + action=action, + tool=tool, + args=args or {}, + result_summary=result_summary, + error=error, + state_snapshot=state_snapshot, + duration_ms=duration_ms, + ) + self.steps.append(step) + return step + + def finalize( + self, + success: bool, + outcome_summary: str = "", + ) -> Trajectory: + """Finalize the recording and return a Trajectory.""" + trajectory = Trajectory( + id=self.id, + user_id=self.user_id, + agent_id=self.agent_id, + task_description=self.task_description, + steps=self.steps, + success=success, + outcome_summary=outcome_summary, + started_at=self.started_at, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + trajectory.compute_hash() + return trajectory + + +class TrajectoryStore: + """Persists trajectories as memories for later retrieval and mining.""" + + def __init__(self, db: Any, embedder: Any = None, vector_store: Any = None): + self._db = db + self._embedder = embedder + self._vector_store = vector_store + + def save(self, trajectory: Trajectory) -> str: + """Save trajectory as a memory record.""" + content = f"[trajectory] {trajectory.task_description}" + metadata = { + "memory_type": "trajectory", + "trajectory_id": trajectory.id, + "trajectory_hash": trajectory.trajectory_hash_val, + "success": trajectory.success, + "outcome_summary": trajectory.outcome_summary, + "step_count": len(trajectory.steps), + "started_at": trajectory.started_at, + "completed_at": trajectory.completed_at, + "agent_id": trajectory.agent_id, + "steps_json": json.dumps([s.to_dict() for s in trajectory.steps], default=str), + "mined_skill_ids": json.dumps(trajectory.mined_skill_ids), + } + + memory_data = { + "id": trajectory.id, + "memory": content, + "user_id": trajectory.user_id, + "agent_id": trajectory.agent_id, + "metadata": metadata, + "categories": ["trajectories"], + "created_at": trajectory.started_at, + "updated_at": trajectory.completed_at or trajectory.started_at, + "layer": "sml", + "strength": 1.0 if trajectory.success else 0.5, + "access_count": 0, + "last_accessed": trajectory.completed_at or trajectory.started_at, + "status": "active", + "namespace": "trajectories", + "memory_type": "trajectory", + "content_hash": trajectory.trajectory_hash_val, + } + + try: + self._db.add_memory(memory_data) + except Exception as e: + logger.warning("Failed to save trajectory %s: %s", trajectory.id, e) + + # Index in vector store for semantic search + if self._embedder and self._vector_store: + try: + embedding = self._embedder.embed(content, memory_action="add") + self._vector_store.insert( + vectors=[embedding], + payloads=[{ + "trajectory_id": trajectory.id, + "user_id": trajectory.user_id, + "memory": content, + "success": trajectory.success, + }], + ids=[trajectory.id], + ) + except Exception as e: + logger.warning("Failed to index trajectory %s: %s", trajectory.id, e) + + return trajectory.id + + def get(self, trajectory_id: str) -> Optional[Trajectory]: + """Retrieve a trajectory by ID.""" + mem = self._db.get_memory(trajectory_id) + if not mem: + return None + return self._mem_to_trajectory(mem) + + def find_successful( + self, + task_query: Optional[str] = None, + user_id: Optional[str] = None, + limit: int = 50, + ) -> List[Trajectory]: + """Find successful trajectories, optionally filtered by task query.""" + memories = self._db.get_all_memories( + user_id=user_id, + limit=limit * 3, + ) + + trajectories = [] + for mem in memories: + md = mem.get("metadata", {}) + if isinstance(md, str): + try: + md = json.loads(md) + except (json.JSONDecodeError, TypeError): + continue + + if md.get("memory_type") != "trajectory": + continue + if not md.get("success"): + continue + if task_query: + content = mem.get("memory", "") + if task_query.lower() not in content.lower(): + continue + + t = self._mem_to_trajectory(mem) + if t: + trajectories.append(t) + if len(trajectories) >= limit: + break + + return trajectories + + def find_by_hash(self, trajectory_hash: str) -> Optional[Trajectory]: + """Find trajectory by its hash.""" + memories = self._db.get_all_memories(limit=500) + for mem in memories: + md = mem.get("metadata", {}) + if isinstance(md, str): + try: + md = json.loads(md) + except (json.JSONDecodeError, TypeError): + continue + if md.get("trajectory_hash") == trajectory_hash: + return self._mem_to_trajectory(mem) + return None + + def _mem_to_trajectory(self, mem: Dict[str, Any]) -> Optional[Trajectory]: + """Convert a memory record back to a Trajectory.""" + md = mem.get("metadata", {}) + if isinstance(md, str): + try: + md = json.loads(md) + except (json.JSONDecodeError, TypeError): + return None + + steps_json = md.get("steps_json", "[]") + try: + steps_data = json.loads(steps_json) if isinstance(steps_json, str) else steps_json + except (json.JSONDecodeError, TypeError): + steps_data = [] + + steps = [ + TrajectoryStep( + timestamp=s.get("timestamp", ""), + action=s.get("action", ""), + tool=s.get("tool", ""), + args=s.get("args", {}), + result_summary=s.get("result_summary", ""), + error=s.get("error"), + state_snapshot=s.get("state_snapshot"), + duration_ms=s.get("duration_ms"), + ) + for s in steps_data + ] + + mined_ids_raw = md.get("mined_skill_ids", "[]") + try: + mined_ids = json.loads(mined_ids_raw) if isinstance(mined_ids_raw, str) else mined_ids_raw + except (json.JSONDecodeError, TypeError): + mined_ids = [] + + return Trajectory( + id=mem.get("id", md.get("trajectory_id", "")), + user_id=mem.get("user_id", "default"), + agent_id=mem.get("agent_id", md.get("agent_id", "default")), + task_description=mem.get("memory", "").replace("[trajectory] ", ""), + steps=steps, + success=md.get("success", False), + outcome_summary=md.get("outcome_summary", ""), + trajectory_hash_val=md.get("trajectory_hash", ""), + started_at=md.get("started_at", mem.get("created_at", "")), + completed_at=md.get("completed_at"), + mined_skill_ids=mined_ids, + ) diff --git a/engram/utils/factory.py b/engram/utils/factory.py index 83704aa..1afc80e 100644 --- a/engram/utils/factory.py +++ b/engram/utils/factory.py @@ -114,4 +114,17 @@ def create(cls, provider: str, config: Dict[str, Any]): from engram.vector_stores.sqlite_vec import SqliteVecStore return SqliteVecStore(config) + if provider == "zvec": + try: + from engram.vector_stores.zvec_store import ZvecStore + return ZvecStore(config) + except ImportError: + logger.warning("zvec not installed, falling back to sqlite_vec") + try: + from engram.vector_stores.sqlite_vec import SqliteVecStore + return SqliteVecStore(config) + except ImportError: + logger.warning("sqlite_vec not installed, falling back to in-memory") + from engram.vector_stores.memory import InMemoryVectorStore + return InMemoryVectorStore(config) raise ValueError(f"Unsupported vector store provider: {provider}") diff --git a/engram/vector_stores/zvec_store.py b/engram/vector_stores/zvec_store.py new file mode 100644 index 0000000..4ef886d --- /dev/null +++ b/engram/vector_stores/zvec_store.py @@ -0,0 +1,394 @@ +"""zvec vector store implementation. + +Uses zvec (Rust-based) for HNSW vector similarity search with cosine distance. +Directory-based collections at ~/.engram/zvec/ by default. +""" + +from __future__ import annotations + +import json +import logging +import os +import threading +import uuid +from typing import Any, Dict, List, Optional + +from engram.memory.utils import matches_filters +from engram.vector_stores.base import MemoryResult, VectorStoreBase + +logger = logging.getLogger(__name__) + +# Promoted scalar fields stored natively in zvec for efficient filtering +_PROMOTED_FIELDS = {"user_id", "agent_id"} + + +def _build_filter_string(filters: Dict[str, Any]) -> Optional[str]: + """Translate a dict of filters into zvec SQL-like filter syntax. + + zvec supports: field == 'value', field != 'value', AND/OR grouping. + We only translate promoted scalar fields; remaining filters are applied + post-search via matches_filters(). + """ + parts = [] + for key, value in filters.items(): + if key not in _PROMOTED_FIELDS: + continue + if isinstance(value, str): + parts.append(f"{key} == '{value}'") + elif isinstance(value, (int, float)): + parts.append(f"{key} == {value}") + if not parts: + return None + return " AND ".join(parts) + + +class ZvecStore(VectorStoreBase): + """Vector store backed by zvec (Rust HNSW engine).""" + + def __init__(self, config: Optional[Dict[str, Any]] = None): + config = config or {} + self.config = config + self.collection_name = config.get("collection_name", "fadem_memories") + self.vector_size = ( + config.get("embedding_model_dims") + or config.get("vector_size") + or config.get("embedding_dims") + or 1536 + ) + db_path = config.get( + "path", + os.path.join(os.path.expanduser("~"), ".engram", "zvec"), + ) + os.makedirs(db_path, exist_ok=True) + self._db_path = db_path + + self._lock = threading.RLock() + self._closed = False + + import zvec + self._zvec = zvec + + self._collection = self._ensure_collection(self.collection_name, self.vector_size) + + def _collection_path(self, name: str) -> str: + return os.path.join(self._db_path, name) + + def _ensure_collection(self, name: str, vector_size: int): + """Open or create a zvec collection with HNSW index.""" + col_path = self._collection_path(name) + + schema = { + "dims": vector_size, + "metric": "cosine", + "fields": { + "user_id": "string", + "agent_id": "string", + "payload_json": "string", + "uuid": "string", + }, + } + + try: + col = self._zvec.Collection.open(col_path) + return col + except Exception: + pass + + try: + col = self._zvec.Collection.create(col_path, schema) + return col + except Exception: + # Collection may have been created by another process + col = self._zvec.Collection.open(col_path) + return col + + def create_col(self, name: str, vector_size: int, distance: str = "cosine") -> None: + self._check_open() + self._ensure_collection(name, vector_size) + + def insert( + self, + vectors: List[List[float]], + payloads: Optional[List[Dict[str, Any]]] = None, + ids: Optional[List[str]] = None, + ) -> None: + self._check_open() + payloads = payloads or [{} for _ in vectors] + if len(payloads) != len(vectors): + raise ValueError("payloads length must match vectors length") + if ids is not None and len(ids) != len(vectors): + raise ValueError("ids length must match vectors length") + ids = ids or [str(uuid.uuid4()) for _ in vectors] + + for vector in vectors: + if len(vector) != self.vector_size: + raise ValueError( + f"Vector has {len(vector)} dimensions, expected {self.vector_size}" + ) + + with self._lock: + for vector_id, vector, payload in zip(ids, vectors, payloads): + # Extract promoted fields from payload + user_id = str(payload.get("user_id", "")) + agent_id = str(payload.get("agent_id", "")) + + # Remaining payload as JSON + payload_json = json.dumps(payload, default=str) + + # Check if UUID already exists (upsert semantics) + try: + existing = self._collection.search( + vector=vector, + limit=1, + filter=f"uuid == '{vector_id}'", + ) + if existing and len(existing) > 0: + # Delete existing entry, then re-insert + for entry in existing: + try: + self._collection.delete(entry["id"]) + except Exception: + pass + except Exception: + pass + + self._collection.insert( + vector=vector, + fields={ + "uuid": vector_id, + "user_id": user_id, + "agent_id": agent_id, + "payload_json": payload_json, + }, + ) + + def search( + self, + query: Optional[str], + vectors: List[float], + limit: int = 5, + filters: Optional[Dict[str, Any]] = None, + ) -> List[MemoryResult]: + self._check_open() + + # Over-fetch when post-filtering is needed + has_non_promoted = filters and any( + k not in _PROMOTED_FIELDS for k in filters + ) + fetch_limit = limit * 3 if has_non_promoted else limit + + zvec_filter = _build_filter_string(filters) if filters else None + + with self._lock: + try: + kwargs: Dict[str, Any] = { + "vector": vectors, + "limit": fetch_limit, + } + if zvec_filter: + kwargs["filter"] = zvec_filter + raw_results = self._collection.search(**kwargs) + except Exception as e: + logger.warning("zvec search failed: %s", e) + return [] + + results = [] + for item in raw_results: + fields = item.get("fields", {}) + payload: Dict[str, Any] = {} + try: + payload = json.loads(fields.get("payload_json", "{}")) + except (json.JSONDecodeError, TypeError): + pass + + # Post-filter on non-promoted fields + if has_non_promoted and not matches_filters(payload, filters): + continue + + score = float(item.get("score", 0.0)) + + results.append( + MemoryResult( + id=fields.get("uuid", ""), + score=score, + payload=payload, + ) + ) + + return results[:limit] + + def delete(self, vector_id: str) -> None: + self._check_open() + with self._lock: + try: + # Find by uuid field + # Use a dummy vector search with filter to find the internal id + results = self._collection.search( + vector=[0.0] * self.vector_size, + limit=1, + filter=f"uuid == '{vector_id}'", + ) + for entry in results: + self._collection.delete(entry["id"]) + except Exception as e: + logger.warning("zvec delete failed for %s: %s", vector_id, e) + + def update( + self, + vector_id: str, + vector: Optional[List[float]] = None, + payload: Optional[Dict[str, Any]] = None, + ) -> None: + self._check_open() + with self._lock: + try: + # Find existing entry + results = self._collection.search( + vector=[0.0] * self.vector_size, + limit=1, + filter=f"uuid == '{vector_id}'", + ) + if not results: + return + + entry = results[0] + old_fields = entry.get("fields", {}) + old_payload = {} + try: + old_payload = json.loads(old_fields.get("payload_json", "{}")) + except (json.JSONDecodeError, TypeError): + pass + + # Merge payload + if payload is not None: + old_payload.update(payload) + + # Delete old entry + self._collection.delete(entry["id"]) + + # Re-insert with updated data + use_vector = vector if vector is not None else entry.get("vector", [0.0] * self.vector_size) + user_id = str(old_payload.get("user_id", old_fields.get("user_id", ""))) + agent_id = str(old_payload.get("agent_id", old_fields.get("agent_id", ""))) + + self._collection.insert( + vector=use_vector, + fields={ + "uuid": vector_id, + "user_id": user_id, + "agent_id": agent_id, + "payload_json": json.dumps(old_payload, default=str), + }, + ) + except Exception as e: + logger.warning("zvec update failed for %s: %s", vector_id, e) + + def get(self, vector_id: str) -> Optional[MemoryResult]: + self._check_open() + with self._lock: + try: + results = self._collection.search( + vector=[0.0] * self.vector_size, + limit=1, + filter=f"uuid == '{vector_id}'", + ) + if not results: + return None + fields = results[0].get("fields", {}) + payload = {} + try: + payload = json.loads(fields.get("payload_json", "{}")) + except (json.JSONDecodeError, TypeError): + pass + return MemoryResult(id=fields.get("uuid", vector_id), score=0.0, payload=payload) + except Exception: + return None + + def list_cols(self) -> List[str]: + self._check_open() + cols = [] + if os.path.isdir(self._db_path): + for entry in os.listdir(self._db_path): + full_path = os.path.join(self._db_path, entry) + if os.path.isdir(full_path): + cols.append(entry) + return cols + + def delete_col(self) -> None: + self._check_open() + import shutil + col_path = self._collection_path(self.collection_name) + with self._lock: + self._collection = None + if os.path.exists(col_path): + shutil.rmtree(col_path) + + def col_info(self) -> Dict[str, Any]: + self._check_open() + return { + "name": self.collection_name, + "vector_size": self.vector_size, + "path": self._collection_path(self.collection_name), + } + + def list( + self, + filters: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + ) -> List[MemoryResult]: + self._check_open() + effective_limit = limit or 100 + + zvec_filter = _build_filter_string(filters) if filters else None + has_non_promoted = filters and any( + k not in _PROMOTED_FIELDS for k in filters + ) + + with self._lock: + try: + kwargs: Dict[str, Any] = { + "vector": [0.0] * self.vector_size, + "limit": effective_limit * 3 if has_non_promoted else effective_limit, + } + if zvec_filter: + kwargs["filter"] = zvec_filter + raw_results = self._collection.search(**kwargs) + except Exception as e: + logger.warning("zvec list failed: %s", e) + return [] + + results = [] + for item in raw_results: + fields = item.get("fields", {}) + payload: Dict[str, Any] = {} + try: + payload = json.loads(fields.get("payload_json", "{}")) + except (json.JSONDecodeError, TypeError): + pass + + if has_non_promoted and not matches_filters(payload, filters): + continue + + results.append( + MemoryResult( + id=fields.get("uuid", ""), + score=0.0, + payload=payload, + ) + ) + + return results[:effective_limit] + + def reset(self) -> None: + self._check_open() + self.delete_col() + self._collection = self._ensure_collection(self.collection_name, self.vector_size) + + def _check_open(self) -> None: + if self._closed: + raise RuntimeError("ZvecStore is closed") + + def close(self) -> None: + with self._lock: + self._closed = True + self._collection = None diff --git a/pyproject.toml b/pyproject.toml index 4e1ed28..57bd9c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ dependencies = [ "pydantic>=2.0", "requests>=2.28.0", - "sqlite-vec>=0.1.1", + "pyyaml>=6.0", "engram-accel>=0.1.0", ] @@ -36,6 +36,8 @@ gemini = ["google-genai>=1.0.0"] openai = ["openai>=1.0.0"] nvidia = ["openai>=1.0.0"] ollama = ["ollama>=0.4.0"] +zvec = ["zvec>=0.2.0"] +sqlite_vec = ["sqlite-vec>=0.1.1"] # Integrations mcp = ["mcp>=1.0.0"] api = ["fastapi>=0.100.0", "uvicorn>=0.20.0"] diff --git a/tests/test_hashing.py b/tests/test_hashing.py new file mode 100644 index 0000000..96b3c17 --- /dev/null +++ b/tests/test_hashing.py @@ -0,0 +1,135 @@ +"""Tests for engram.skills.hashing — determinism, normalization, order sensitivity.""" + +import pytest + +from engram.skills.hashing import ( + content_hash, + skill_signature_hash, + stable_json, + trajectory_hash, +) + + +class TestStableJson: + def test_deterministic_output(self): + obj = {"b": 2, "a": 1, "c": [3, 2, 1]} + assert stable_json(obj) == stable_json(obj) + + def test_key_order_irrelevant(self): + a = {"z": 1, "a": 2} + b = {"a": 2, "z": 1} + assert stable_json(a) == stable_json(b) + + def test_no_whitespace(self): + result = stable_json({"key": "value"}) + assert " " not in result + + +class TestContentHash: + def test_deterministic(self): + assert content_hash("hello world") == content_hash("hello world") + + def test_case_insensitive(self): + assert content_hash("Hello World") == content_hash("hello world") + + def test_whitespace_normalized(self): + assert content_hash(" hello world ") == content_hash("hello world") + + def test_different_content_different_hash(self): + assert content_hash("hello") != content_hash("world") + + +class TestTrajectoryHash: + def test_deterministic(self): + steps = [ + {"action": "search", "tool": "grep", "args": {"pattern": "error"}}, + {"action": "edit", "tool": "write", "args": {"file": "main.py"}}, + ] + assert trajectory_hash(steps) == trajectory_hash(steps) + + def test_result_variations_ignored(self): + """Same actions with different results should produce the same hash.""" + steps_a = [ + {"action": "search", "tool": "grep", "args": {"pattern": "error"}, "result": "found 3"}, + ] + steps_b = [ + {"action": "search", "tool": "grep", "args": {"pattern": "error"}, "result": "found 5"}, + ] + assert trajectory_hash(steps_a) == trajectory_hash(steps_b) + + def test_different_actions_different_hash(self): + steps_a = [{"action": "search", "tool": "grep", "args": {}}] + steps_b = [{"action": "edit", "tool": "write", "args": {}}] + assert trajectory_hash(steps_a) != trajectory_hash(steps_b) + + def test_order_sensitive(self): + """Step order matters for trajectory identity.""" + step_a = {"action": "search", "tool": "grep", "args": {"pattern": "x"}} + step_b = {"action": "edit", "tool": "write", "args": {"file": "y"}} + assert trajectory_hash([step_a, step_b]) != trajectory_hash([step_b, step_a]) + + def test_empty_steps(self): + assert trajectory_hash([]) == trajectory_hash([]) + + +class TestSkillSignatureHash: + def test_deterministic(self): + h = skill_signature_hash( + preconditions=["repo exists"], + steps=["search for error", "fix bug"], + tags=["debugging"], + ) + assert h == skill_signature_hash( + preconditions=["repo exists"], + steps=["search for error", "fix bug"], + tags=["debugging"], + ) + + def test_name_excluded(self): + """Two skills with same content but different names should have same sig hash.""" + h1 = skill_signature_hash( + preconditions=["a"], steps=["b"], tags=["c"] + ) + h2 = skill_signature_hash( + preconditions=["a"], steps=["b"], tags=["c"] + ) + assert h1 == h2 + + def test_tag_order_irrelevant(self): + """Tags are sorted, so order shouldn't matter.""" + h1 = skill_signature_hash( + preconditions=[], steps=["step"], tags=["z", "a", "m"] + ) + h2 = skill_signature_hash( + preconditions=[], steps=["step"], tags=["a", "m", "z"] + ) + assert h1 == h2 + + def test_precondition_order_irrelevant(self): + """Preconditions are sorted.""" + h1 = skill_signature_hash( + preconditions=["z", "a"], steps=["step"], tags=[] + ) + h2 = skill_signature_hash( + preconditions=["a", "z"], steps=["step"], tags=[] + ) + assert h1 == h2 + + def test_step_order_sensitive(self): + """Steps are NOT sorted — order matters.""" + h1 = skill_signature_hash( + preconditions=[], steps=["first", "second"], tags=[] + ) + h2 = skill_signature_hash( + preconditions=[], steps=["second", "first"], tags=[] + ) + assert h1 != h2 + + def test_different_content_different_hash(self): + h1 = skill_signature_hash( + preconditions=["a"], steps=["b"], tags=["c"] + ) + h2 = skill_signature_hash( + preconditions=["x"], steps=["y"], tags=["z"] + ) + assert h1 != h2 diff --git a/tests/test_miner.py b/tests/test_miner.py new file mode 100644 index 0000000..c1aaf90 --- /dev/null +++ b/tests/test_miner.py @@ -0,0 +1,253 @@ +"""Tests for engram.skills.miner — mining pipeline, dedup, mutation, mock LLM.""" + +import json +import os +import pytest + +from engram.skills.miner import SkillMiner +from engram.skills.schema import Skill, Trajectory, TrajectoryStep +from engram.skills.store import SkillStore +from engram.skills.trajectory import TrajectoryRecorder, TrajectoryStore + + +class MockDB: + """Simple in-memory DB mock for testing.""" + def __init__(self): + self._store = {} + + def add_memory(self, data): + self._store[data["id"]] = data + + def get_memory(self, memory_id): + return self._store.get(memory_id) + + def get_all_memories(self, user_id=None, agent_id=None, limit=100, **kwargs): + results = list(self._store.values()) + if user_id: + results = [r for r in results if r.get("user_id") == user_id] + return results[:limit] + + +class MockLLM: + """Mock LLM that returns valid skill JSON.""" + def generate(self, prompt): + return json.dumps({ + "name": "Mined Bug Fix", + "description": "Fix bugs by searching and patching", + "preconditions": ["source code exists"], + "steps": ["search for error", "identify root cause", "apply fix", "run tests"], + "tags": ["debugging", "bugfix"], + }) + + +@pytest.fixture +def skill_store(tmp_path): + skill_dir = str(tmp_path / "skills") + os.makedirs(skill_dir, exist_ok=True) + return SkillStore(skill_dirs=[skill_dir]) + + +@pytest.fixture +def trajectory_store(): + return TrajectoryStore(db=MockDB()) + + +def _make_trajectory(desc: str, actions: list) -> Trajectory: + """Helper to create a finalized trajectory.""" + recorder = TrajectoryRecorder(task_description=desc) + for action in actions: + recorder.record_step( + action=action.get("action", "step"), + tool=action.get("tool", "tool"), + args=action.get("args", {}), + result_summary=action.get("result", "ok"), + ) + return recorder.finalize(success=True, outcome_summary=f"Completed: {desc}") + + +class TestSkillMiner: + def test_mine_no_trajectories(self, trajectory_store, skill_store): + miner = SkillMiner( + trajectory_store=trajectory_store, + skill_store=skill_store, + min_cluster_size=2, + ) + result = miner.mine() + assert result == [] + + def test_mine_insufficient_trajectories(self, trajectory_store, skill_store): + """Need at least min_cluster_size trajectories to mine.""" + t = _make_trajectory("single task", [{"action": "test"}]) + trajectory_store.save(t) + + miner = SkillMiner( + trajectory_store=trajectory_store, + skill_store=skill_store, + min_cluster_size=2, + ) + result = miner.mine() + assert result == [] + + def test_mine_heuristic(self, trajectory_store, skill_store): + """Test mining without LLM (heuristic mode).""" + # Create similar trajectories + for i in range(3): + t = _make_trajectory( + f"fix python error variant {i}", + [ + {"action": "search", "tool": "grep", "args": {"pattern": "error"}}, + {"action": "edit", "tool": "write", "args": {"file": "main.py"}}, + {"action": "test", "tool": "pytest"}, + ], + ) + trajectory_store.save(t) + + miner = SkillMiner( + trajectory_store=trajectory_store, + skill_store=skill_store, + llm=None, + min_cluster_size=2, + mutation_rate=0.0, # Disable mutation for deterministic test + ) + mined = miner.mine() + assert len(mined) >= 1 + skill = mined[0] + assert skill.source == "mined" + assert skill.confidence == 0.5 + assert len(skill.steps) > 0 + + def test_mine_with_llm(self, trajectory_store, skill_store): + """Test mining with mock LLM.""" + for i in range(3): + t = _make_trajectory( + f"debug application error {i}", + [ + {"action": "search", "tool": "grep"}, + {"action": "fix", "tool": "edit"}, + {"action": "verify", "tool": "test"}, + ], + ) + trajectory_store.save(t) + + miner = SkillMiner( + trajectory_store=trajectory_store, + skill_store=skill_store, + llm=MockLLM(), + min_cluster_size=2, + mutation_rate=0.0, + ) + mined = miner.mine() + assert len(mined) >= 1 + skill = mined[0] + assert skill.name == "Mined Bug Fix" + assert "debugging" in skill.tags + + def test_mine_dedup(self, trajectory_store, skill_store): + """Mining the same cluster twice should not create duplicates.""" + for i in range(3): + t = _make_trajectory( + f"fix python issue {i}", + [{"action": "search"}, {"action": "fix"}], + ) + trajectory_store.save(t) + + miner = SkillMiner( + trajectory_store=trajectory_store, + skill_store=skill_store, + min_cluster_size=2, + mutation_rate=0.0, + ) + + # First mine + first = miner.mine() + assert len(first) >= 1 + + # Second mine — should be deduped + second = miner.mine() + assert len(second) == 0 + + def test_mine_saves_to_store(self, trajectory_store, skill_store): + """Mined skills should be persisted in the skill store.""" + for i in range(3): + t = _make_trajectory( + f"deploy application {i}", + [{"action": "build"}, {"action": "test"}, {"action": "deploy"}], + ) + trajectory_store.save(t) + + miner = SkillMiner( + trajectory_store=trajectory_store, + skill_store=skill_store, + min_cluster_size=2, + mutation_rate=0.0, + ) + mined = miner.mine() + assert len(mined) >= 1 + + # Verify persisted + stored = skill_store.get(mined[0].id) + assert stored is not None + assert stored.source == "mined" + + +class TestMutation: + def test_mutation_adds_verification(self): + """With mutation_rate=1.0, every skill should be mutated.""" + import random + random.seed(42) + + skill_dir = "/tmp/test_mutation_skills" + os.makedirs(skill_dir, exist_ok=True) + + store = SkillStore(skill_dirs=[skill_dir]) + tstore = TrajectoryStore(db=MockDB()) + + for i in range(3): + t = _make_trajectory( + f"mutate test {i}", + [{"action": "step1"}, {"action": "step2"}], + ) + tstore.save(t) + + miner = SkillMiner( + trajectory_store=tstore, + skill_store=store, + mutation_rate=1.0, # Always mutate + min_cluster_size=2, + ) + mined = miner.mine() + + # At least one skill should exist + if mined: + skill = mined[0] + # Mutation should have added "verify" or "adapt as needed" + all_steps = " ".join(skill.steps).lower() + has_mutation = "verify" in all_steps or "adapt as needed" in all_steps + assert has_mutation, f"Expected mutation in steps: {skill.steps}" + + # Cleanup + import shutil + shutil.rmtree(skill_dir, ignore_errors=True) + + +class TestClusterByKeywords: + def test_similar_tasks_cluster_together(self): + """Tasks with similar keywords should be in the same cluster.""" + store = SkillStore(skill_dirs=["/tmp/test_cluster"]) + tstore = TrajectoryStore(db=MockDB()) + miner = SkillMiner( + trajectory_store=tstore, + skill_store=store, + min_cluster_size=2, + ) + + trajectories = [ + Trajectory(task_description="fix python error"), + Trajectory(task_description="fix python bug"), + Trajectory(task_description="deploy to production"), + ] + + clusters = miner._cluster_by_keywords(trajectories) + # "fix python error" and "fix python bug" share keywords + # They may or may not cluster depending on keyword overlap + assert len(clusters) >= 1 diff --git a/tests/test_skills.py b/tests/test_skills.py new file mode 100644 index 0000000..351bf42 --- /dev/null +++ b/tests/test_skills.py @@ -0,0 +1,319 @@ +"""Tests for engram.skills — schema, store, executor, outcomes, discovery.""" + +import os +import pytest +import tempfile + +from engram.skills.schema import Skill, TrajectoryStep, Trajectory +from engram.skills.store import SkillStore +from engram.skills.executor import SkillExecutor +from engram.skills.outcomes import OutcomeTracker, compute_confidence +from engram.skills.discovery import discover_skill_dirs, scan_skill_files, load_skill_file + + +# ── Schema tests ── + + +class TestSkillSchema: + def test_skill_roundtrip(self): + """Serialize to SKILL.md and parse back.""" + skill = Skill( + name="Fix Typos", + description="Find and fix typos in code", + tags=["debugging", "text"], + preconditions=["repo exists", "file has content"], + steps=["search for misspellings", "apply corrections", "run tests"], + confidence=0.75, + source="authored", + ) + md = skill.to_skill_md() + assert "---" in md + assert "Fix Typos" in md + + parsed = Skill.from_skill_md(md) + assert parsed.name == "Fix Typos" + assert parsed.description == "Find and fix typos in code" + assert parsed.tags == ["debugging", "text"] + assert parsed.preconditions == ["repo exists", "file has content"] + assert len(parsed.steps) == 3 + assert parsed.confidence == 0.75 + assert parsed.source == "authored" + + def test_skill_signature_hash_computed(self): + """Signature hash should be auto-computed.""" + skill = Skill( + name="Test", + preconditions=["a"], + steps=["b"], + tags=["c"], + ) + assert len(skill.signature_hash) == 64 # SHA-256 hex + + def test_skill_to_dict(self): + skill = Skill(name="Test", description="A test skill") + d = skill.to_dict() + assert d["name"] == "Test" + assert "id" in d + assert "confidence" in d + + def test_skill_from_md_no_frontmatter(self): + """Content without frontmatter treated as body.""" + skill = Skill.from_skill_md("Just some markdown content") + assert skill.body_markdown == "Just some markdown content" + + def test_skill_from_md_empty(self): + skill = Skill.from_skill_md("") + assert skill.body_markdown == "" + + +class TestTrajectorySchema: + def test_trajectory_step_to_dict(self): + step = TrajectoryStep( + action="search", + tool="grep", + args={"pattern": "error"}, + result_summary="found 3 matches", + ) + d = step.to_dict() + assert d["action"] == "search" + assert d["tool"] == "grep" + + def test_trajectory_compute_hash(self): + t = Trajectory( + task_description="fix a bug", + steps=[ + TrajectoryStep(action="search", tool="grep", args={"pattern": "error"}), + TrajectoryStep(action="edit", tool="write", args={"file": "main.py"}), + ], + ) + h = t.compute_hash() + assert len(h) == 64 + # Deterministic + assert t.compute_hash() == h + + def test_trajectory_to_dict(self): + t = Trajectory(task_description="test task") + d = t.to_dict() + assert d["task_description"] == "test task" + assert "id" in d + + +# ── Store tests ── + + +class TestSkillStore: + @pytest.fixture + def store(self, tmp_path): + skill_dir = str(tmp_path / "skills") + os.makedirs(skill_dir, exist_ok=True) + return SkillStore(skill_dirs=[skill_dir]) + + def test_save_and_get(self, store): + skill = Skill(name="Test Skill", description="A test") + store.save(skill) + retrieved = store.get(skill.id) + assert retrieved is not None + assert retrieved.name == "Test Skill" + + def test_save_creates_file(self, store): + skill = Skill(name="File Test", description="Check file creation") + store.save(skill) + filepath = os.path.join(store.primary_dir, f"{skill.id}.skill.md") + assert os.path.isfile(filepath) + + def test_get_nonexistent(self, store): + assert store.get("nonexistent-id") is None + + def test_delete(self, store): + skill = Skill(name="Delete Me") + store.save(skill) + assert store.get(skill.id) is not None + store.delete(skill.id) + assert store.get(skill.id) is None + + def test_text_search(self, store): + store.save(Skill(name="Python Debugging", description="Debug Python code", tags=["python"])) + store.save(Skill(name="JS Linting", description="Lint JavaScript", tags=["javascript"])) + results = store.search("python", limit=5) + assert len(results) >= 1 + assert results[0].name == "Python Debugging" + + def test_get_by_signature(self, store): + skill = Skill( + name="Unique", + preconditions=["a"], + steps=["b"], + tags=["c"], + ) + store.save(skill) + found = store.get_by_signature(skill.signature_hash) + assert found is not None + assert found.id == skill.id + + def test_list_all(self, store): + store.save(Skill(name="S1")) + store.save(Skill(name="S2")) + all_skills = store.list_all() + assert len(all_skills) == 2 + + def test_sync_from_filesystem(self, tmp_path): + skill_dir = str(tmp_path / "sync_skills") + os.makedirs(skill_dir, exist_ok=True) + + # Write a skill file manually + skill = Skill(name="Manual Skill", description="Manually written") + filepath = os.path.join(skill_dir, f"{skill.id}.skill.md") + with open(filepath, "w") as f: + f.write(skill.to_skill_md()) + + # Create store and sync + store = SkillStore(skill_dirs=[skill_dir]) + count = store.sync_from_filesystem() + assert count == 1 + assert store.get(skill.id) is not None + + +# ── Executor tests ── + + +class TestSkillExecutor: + @pytest.fixture + def executor(self, tmp_path): + skill_dir = str(tmp_path / "exec_skills") + os.makedirs(skill_dir, exist_ok=True) + store = SkillStore(skill_dirs=[skill_dir]) + return SkillExecutor(store), store + + def test_apply_skill(self, executor): + exec_, store = executor + skill = Skill( + name="Fix Bugs", + description="Standard bug fix workflow", + steps=["reproduce", "diagnose", "fix", "test"], + confidence=0.8, + ) + store.save(skill) + result = exec_.apply(skill.id) + assert result["injected"] is True + assert "recipe" in result + assert "Fix Bugs" in result["recipe"] + assert result["confidence"] == 0.8 + + def test_apply_increments_use_count(self, executor): + exec_, store = executor + skill = Skill(name="Counter Test", use_count=0) + store.save(skill) + exec_.apply(skill.id) + updated = store.get(skill.id) + assert updated.use_count == 1 + + def test_apply_nonexistent(self, executor): + exec_, _ = executor + result = exec_.apply("nonexistent") + assert result["injected"] is False + + def test_search_and_apply(self, executor): + exec_, store = executor + skill = Skill( + name="Python Debugging", + description="Debug Python errors", + tags=["python", "debug"], + confidence=0.7, + ) + store.save(skill) + result = exec_.search_and_apply("debug python") + assert result["injected"] is True + + def test_search(self, executor): + exec_, store = executor + store.save(Skill(name="Skill A", description="Do A", tags=["a"])) + store.save(Skill(name="Skill B", description="Do B", tags=["b"])) + results = exec_.search("Skill") + assert len(results) >= 1 + + +# ── Outcome tracking tests ── + + +class TestOutcomeTracker: + @pytest.fixture + def tracker(self, tmp_path): + skill_dir = str(tmp_path / "outcome_skills") + os.makedirs(skill_dir, exist_ok=True) + store = SkillStore(skill_dirs=[skill_dir]) + return OutcomeTracker(store), store + + def test_log_success(self, tracker): + tr, store = tracker + skill = Skill(name="Test", confidence=0.5) + store.save(skill) + result = tr.log_outcome(skill.id, success=True) + assert result["success"] is True + assert result["new_confidence"] > 0 + + def test_log_failure_lowers_confidence(self, tracker): + tr, store = tracker + # Start with balanced counts so a failure clearly lowers confidence + skill = Skill(name="Test", confidence=0.5, success_count=5, fail_count=5) + store.save(skill) + result = tr.log_outcome(skill.id, success=False) + assert result["new_confidence"] < result["old_confidence"] + + def test_log_nonexistent(self, tracker): + tr, _ = tracker + result = tr.log_outcome("nonexistent", success=True) + assert "error" in result + + +class TestComputeConfidence: + def test_neutral_prior(self): + assert compute_confidence(0, 0) == 0.5 + + def test_all_success_high(self): + c = compute_confidence(100, 0) + assert c > 0.5 + + def test_all_failure_low(self): + c = compute_confidence(0, 100) + assert c < 0.5 + + def test_asymmetric_penalty(self): + """Equal success/fail should be below 0.5 due to asymmetric weighting.""" + c = compute_confidence(10, 10) + assert c < 0.5 + + def test_bounded(self): + assert 0.0 <= compute_confidence(1000, 0) <= 1.0 + assert 0.0 <= compute_confidence(0, 1000) <= 1.0 + + +# ── Discovery tests ── + + +class TestDiscovery: + def test_discover_skill_dirs_global(self): + dirs = discover_skill_dirs() + assert any(".engram/skills" in d for d in dirs) + + def test_discover_skill_dirs_with_repo(self, tmp_path): + dirs = discover_skill_dirs(repo_path=str(tmp_path)) + assert any(".engram/skills" in d for d in dirs) + assert str(tmp_path) in dirs[0] + + def test_scan_skill_files(self, tmp_path): + skill_dir = str(tmp_path / "skills") + os.makedirs(skill_dir) + # Create a skill file + with open(os.path.join(skill_dir, "test-id.skill.md"), "w") as f: + f.write("---\nname: Test\n---\nBody") + results = scan_skill_files([skill_dir]) + assert len(results) == 1 + assert results[0][1] == "test-id" + + def test_load_skill_file(self, tmp_path): + skill = Skill(name="Load Test", description="Test loading") + filepath = str(tmp_path / "test.skill.md") + with open(filepath, "w") as f: + f.write(skill.to_skill_md()) + loaded = load_skill_file(filepath) + assert loaded.name == "Load Test" diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py new file mode 100644 index 0000000..3698c9c --- /dev/null +++ b/tests/test_trajectory.py @@ -0,0 +1,156 @@ +"""Tests for engram.skills.trajectory — recorder, store, hash determinism.""" + +import pytest +import time + +from engram.skills.trajectory import TrajectoryRecorder, TrajectoryStore +from engram.skills.schema import Trajectory, TrajectoryStep + + +class TestTrajectoryRecorder: + def test_record_steps(self): + recorder = TrajectoryRecorder( + task_description="fix a bug", + user_id="test-user", + agent_id="test-agent", + ) + recorder.record_step( + action="search", + tool="grep", + args={"pattern": "error"}, + result_summary="found 3 matches", + ) + recorder.record_step( + action="edit", + tool="write", + args={"file": "main.py"}, + result_summary="fixed typo", + ) + assert len(recorder.steps) == 2 + assert recorder.steps[0].action == "search" + assert recorder.steps[1].action == "edit" + + def test_finalize_success(self): + recorder = TrajectoryRecorder(task_description="test task") + recorder.record_step(action="test", tool="pytest", result_summary="pass") + trajectory = recorder.finalize(success=True, outcome_summary="All tests pass") + assert trajectory.success is True + assert trajectory.outcome_summary == "All tests pass" + assert len(trajectory.steps) == 1 + assert trajectory.completed_at is not None + assert len(trajectory.trajectory_hash_val) == 64 # SHA-256 hex + + def test_finalize_failure(self): + recorder = TrajectoryRecorder(task_description="broken task") + recorder.record_step(action="test", error="AssertionError") + trajectory = recorder.finalize(success=False, outcome_summary="Test failed") + assert trajectory.success is False + + def test_hash_determinism(self): + """Same steps should produce the same trajectory hash.""" + r1 = TrajectoryRecorder(task_description="task A") + r1.record_step(action="search", tool="grep", args={"pattern": "x"}) + r1.record_step(action="edit", tool="write", args={"file": "f.py"}) + t1 = r1.finalize(success=True) + + r2 = TrajectoryRecorder(task_description="task B") + r2.record_step(action="search", tool="grep", args={"pattern": "x"}) + r2.record_step(action="edit", tool="write", args={"file": "f.py"}) + t2 = r2.finalize(success=True) + + # Same steps → same hash (task description excluded from hash) + assert t1.trajectory_hash_val == t2.trajectory_hash_val + + def test_different_steps_different_hash(self): + r1 = TrajectoryRecorder(task_description="task") + r1.record_step(action="search", tool="grep") + t1 = r1.finalize(success=True) + + r2 = TrajectoryRecorder(task_description="task") + r2.record_step(action="edit", tool="write") + t2 = r2.finalize(success=True) + + assert t1.trajectory_hash_val != t2.trajectory_hash_val + + def test_recorder_id_unique(self): + r1 = TrajectoryRecorder(task_description="a") + r2 = TrajectoryRecorder(task_description="b") + assert r1.id != r2.id + + def test_step_error_recorded(self): + recorder = TrajectoryRecorder(task_description="error test") + recorder.record_step( + action="compile", + tool="gcc", + error="syntax error at line 42", + ) + assert recorder.steps[0].error == "syntax error at line 42" + + +class TestTrajectoryStore: + """Tests using a mock DB that stores in-memory.""" + + class MockDB: + def __init__(self): + self._store = {} + + def add_memory(self, data): + self._store[data["id"]] = data + + def get_memory(self, memory_id): + return self._store.get(memory_id) + + def get_all_memories(self, user_id=None, agent_id=None, limit=100, **kwargs): + results = list(self._store.values()) + if user_id: + results = [r for r in results if r.get("user_id") == user_id] + return results[:limit] + + @pytest.fixture + def store(self): + return TrajectoryStore(db=self.MockDB()) + + def test_save_and_get(self, store): + recorder = TrajectoryRecorder(task_description="save test") + recorder.record_step(action="test", tool="pytest", result_summary="ok") + trajectory = recorder.finalize(success=True, outcome_summary="done") + + store.save(trajectory) + retrieved = store.get(trajectory.id) + assert retrieved is not None + assert retrieved.task_description == "save test" + assert retrieved.success is True + assert len(retrieved.steps) == 1 + + def test_find_successful(self, store): + # Save 2 successful and 1 failed + for i, success in enumerate([True, True, False]): + recorder = TrajectoryRecorder(task_description=f"task {i}") + recorder.record_step(action="step", tool="tool") + trajectory = recorder.finalize(success=success, outcome_summary=f"result {i}") + store.save(trajectory) + + successful = store.find_successful() + assert len(successful) == 2 + + def test_find_successful_with_query(self, store): + for desc in ["fix python bug", "deploy to prod", "fix javascript error"]: + recorder = TrajectoryRecorder(task_description=desc) + recorder.record_step(action="do", tool="t") + store.save(recorder.finalize(success=True, outcome_summary="done")) + + results = store.find_successful(task_query="fix") + assert len(results) == 2 + + def test_find_by_hash(self, store): + recorder = TrajectoryRecorder(task_description="hash test") + recorder.record_step(action="x", tool="y") + trajectory = recorder.finalize(success=True) + store.save(trajectory) + + found = store.find_by_hash(trajectory.trajectory_hash_val) + assert found is not None + assert found.id == trajectory.id + + def test_get_nonexistent(self, store): + assert store.get("nonexistent") is None diff --git a/tests/test_zvec_store.py b/tests/test_zvec_store.py new file mode 100644 index 0000000..44dffb9 --- /dev/null +++ b/tests/test_zvec_store.py @@ -0,0 +1,212 @@ +"""Tests for zvec vector store implementation.""" + +import math +import os + +import pytest + +zvec = pytest.importorskip("zvec", reason="zvec not installed") + +from engram.vector_stores.base import MemoryResult +from engram.vector_stores.zvec_store import ZvecStore, _build_filter_string + + +@pytest.fixture +def store(tmp_path): + """Create a ZvecStore with a temporary directory.""" + config = { + "path": str(tmp_path / "zvec_test"), + "collection_name": "test_col", + "embedding_model_dims": 4, + } + return ZvecStore(config) + + +def _norm(v): + """Normalize a vector to unit length.""" + mag = math.sqrt(sum(x * x for x in v)) + return [x / mag for x in v] if mag > 0 else v + + +class TestFilterString: + def test_builds_promoted_fields(self): + result = _build_filter_string({"user_id": "alice", "agent_id": "bot1"}) + assert "user_id == 'alice'" in result + assert "agent_id == 'bot1'" in result + + def test_ignores_non_promoted(self): + result = _build_filter_string({"custom_field": "val"}) + assert result is None + + def test_mixed_fields(self): + result = _build_filter_string({"user_id": "alice", "custom": "val"}) + assert "user_id == 'alice'" in result + assert "custom" not in result + + +class TestInsert: + def test_insert_single(self, store): + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0])], + payloads=[{"text": "hello", "user_id": "default"}], + ids=["id-1"], + ) + result = store.get("id-1") + assert result is not None + assert result.id == "id-1" + assert result.payload["text"] == "hello" + + def test_insert_multiple(self, store): + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0]), _norm([0.0, 1.0, 0.0, 0.0])], + payloads=[{"text": "a"}, {"text": "b"}], + ids=["id-1", "id-2"], + ) + assert store.get("id-1") is not None + assert store.get("id-2") is not None + + def test_insert_validates_dimensions(self, store): + with pytest.raises(ValueError, match="dimensions"): + store.insert(vectors=[[1.0, 0.0]], payloads=[{}], ids=["bad"]) + + def test_insert_validates_lengths(self, store): + with pytest.raises(ValueError, match="payloads length"): + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0])], + payloads=[{}, {}], + ) + + def test_upsert_semantics(self, store): + """Re-inserting same ID updates the record.""" + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0])], + payloads=[{"text": "original"}], + ids=["id-1"], + ) + store.insert( + vectors=[_norm([0.0, 1.0, 0.0, 0.0])], + payloads=[{"text": "updated"}], + ids=["id-1"], + ) + result = store.get("id-1") + assert result is not None + assert result.payload["text"] == "updated" + + +class TestSearch: + def test_search_returns_results(self, store): + v1 = _norm([1.0, 0.0, 0.0, 0.0]) + v2 = _norm([0.0, 1.0, 0.0, 0.0]) + store.insert( + vectors=[v1, v2], + payloads=[{"text": "hello", "user_id": "default"}, {"text": "world", "user_id": "default"}], + ids=["id-1", "id-2"], + ) + results = store.search(query=None, vectors=v1, limit=2) + assert len(results) >= 1 + # First result should be the closest match + assert results[0].id == "id-1" + assert results[0].score > 0 + + def test_search_respects_limit(self, store): + vectors = [_norm([float(i), 0.0, 0.0, 0.0]) for i in range(1, 6)] + store.insert( + vectors=vectors, + payloads=[{"n": i} for i in range(5)], + ) + results = store.search(query=None, vectors=vectors[0], limit=2) + assert len(results) <= 2 + + def test_search_with_filter(self, store): + v1 = _norm([1.0, 0.0, 0.0, 0.0]) + v2 = _norm([1.0, 0.1, 0.0, 0.0]) + store.insert( + vectors=[v1, v2], + payloads=[ + {"text": "a", "user_id": "alice"}, + {"text": "b", "user_id": "bob"}, + ], + ids=["id-a", "id-b"], + ) + results = store.search( + query=None, vectors=v1, limit=5, filters={"user_id": "alice"} + ) + assert all(r.payload.get("user_id") == "alice" for r in results) + + def test_search_empty_collection(self, store): + results = store.search( + query=None, vectors=_norm([1.0, 0.0, 0.0, 0.0]), limit=5 + ) + assert results == [] + + +class TestDelete: + def test_delete_removes_record(self, store): + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0])], + payloads=[{"text": "bye"}], + ids=["id-del"], + ) + assert store.get("id-del") is not None + store.delete("id-del") + assert store.get("id-del") is None + + def test_delete_nonexistent_is_noop(self, store): + store.delete("nonexistent") + + +class TestUpdate: + def test_update_payload(self, store): + v = _norm([1.0, 0.0, 0.0, 0.0]) + store.insert(vectors=[v], payloads=[{"text": "old"}], ids=["id-up"]) + store.update("id-up", payload={"text": "new"}) + result = store.get("id-up") + assert result.payload["text"] == "new" + + def test_update_nonexistent_is_noop(self, store): + store.update("nonexistent", payload={"text": "x"}) + + +class TestCollectionOps: + def test_list_cols(self, store): + cols = store.list_cols() + assert "test_col" in cols + + def test_col_info(self, store): + info = store.col_info() + assert info["name"] == "test_col" + assert info["vector_size"] == 4 + + def test_delete_col(self, store): + store.delete_col() + # After deletion, the collection directory should be gone + col_path = store._collection_path("test_col") + assert not os.path.exists(col_path) + + def test_reset(self, store): + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0])], + payloads=[{"text": "x"}], + ids=["id-r"], + ) + store.reset() + assert store.get("id-r") is None + + def test_list_with_filters(self, store): + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0]), _norm([0.0, 1.0, 0.0, 0.0])], + payloads=[ + {"user_id": "alice", "text": "a"}, + {"user_id": "bob", "text": "b"}, + ], + ids=["id-a", "id-b"], + ) + results = store.list(filters={"user_id": "alice"}) + assert all(r.payload.get("user_id") == "alice" for r in results) + + +class TestClose: + def test_close_prevents_operations(self, store): + store.close() + with pytest.raises(RuntimeError, match="closed"): + store.search(query=None, vectors=[1.0, 0.0, 0.0, 0.0], limit=1) From 186d9a1a92d7448476c18768f8a298b6f2364817 Mon Sep 17 00:00:00 2001 From: Vivek Kumar Date: Wed, 18 Feb 2026 14:33:07 +0530 Subject: [PATCH 5/8] feat: unified enrichment pipeline, skill system enhancements, and NVIDIA provider updates - Add unified enrichment module with multi-modal encoding pipeline - Enhance skill system with structure module, improved executor/miner/outcomes - Update memory architecture (smart memory, main memory) for enrichment integration - Add NVIDIA embeddings and LLM provider improvements - Update LongMemEval benchmark runner - Add structural and unified enrichment test suites Co-Authored-By: Claude Opus 4.6 --- engram/benchmarks/longmemeval.py | 44 ++- engram/configs/base.py | 30 +- engram/configs/presets.py | 5 +- engram/core/category.py | 5 + engram/core/echo.py | 81 +++- engram/core/enrichment.py | 627 +++++++++++++++++++++++++++++++ engram/embeddings/nvidia.py | 52 ++- engram/llms/nvidia.py | 11 +- engram/mcp_server.py | 122 +++++- engram/memory/main.py | 195 +++++++--- engram/memory/smart.py | 118 +++++- engram/skills/executor.py | 116 ++++++ engram/skills/hashing.py | 18 + engram/skills/miner.py | 60 +++ engram/skills/outcomes.py | 62 ++- engram/skills/schema.py | 20 +- engram/skills/store.py | 53 +++ engram/skills/structure.py | 498 ++++++++++++++++++++++++ engram/utils/prompts.py | 91 +++++ tests/test_structural.py | 455 ++++++++++++++++++++++ tests/test_unified_enrichment.py | 504 +++++++++++++++++++++++++ 21 files changed, 3048 insertions(+), 119 deletions(-) create mode 100644 engram/core/enrichment.py create mode 100644 engram/skills/structure.py create mode 100644 tests/test_structural.py create mode 100644 tests/test_unified_enrichment.py diff --git a/engram/benchmarks/longmemeval.py b/engram/benchmarks/longmemeval.py index b3b6ead..380d981 100644 --- a/engram/benchmarks/longmemeval.py +++ b/engram/benchmarks/longmemeval.py @@ -8,6 +8,7 @@ import argparse import json +import logging import os import re from dataclasses import dataclass @@ -15,11 +16,12 @@ from statistics import mean from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple -from engram import Memory +from engram import FullMemory as Memory from engram.configs.base import ( CategoryMemConfig, EchoMemConfig, EmbedderConfig, + EnrichmentConfig, KnowledgeGraphConfig, LLMConfig, MemoryConfig, @@ -28,6 +30,7 @@ VectorStoreConfig, ) +logger = logging.getLogger(__name__) SESSION_ID_PATTERN = re.compile(r"^Session ID:\s*(?P\S+)\s*$", re.MULTILINE) HISTORY_HEADER = "User Transcript:" @@ -170,8 +173,14 @@ def build_memory( graph=KnowledgeGraphConfig(enable_graph=full_potential), scene=SceneConfig(use_llm_summarization=full_potential, enable_scenes=full_potential), profile=ProfileConfig(use_llm_extraction=full_potential, enable_profiles=full_potential), + enrichment=EnrichmentConfig(enable_unified=full_potential), ) - return Memory(config) + mem = Memory(config) + # FullMemory features (categories, scenes, profiles) need FullSQLiteManager + if full_potential: + from engram.db.sqlite import FullSQLiteManager + mem.db = FullSQLiteManager(history_db_path) + return mem def build_context_text(results: Sequence[Dict[str, Any]], max_chars: int) -> str: @@ -263,20 +272,23 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: sessions = entry.get("haystack_sessions") or [] for sess_id, sess_date, sess_turns in zip(session_ids, session_dates, sessions): payload = format_session_memory(str(sess_id), str(sess_date), sess_turns or []) - memory.add( - messages=payload, - user_id=args.user_id, - metadata={ - "session_id": str(sess_id), - "session_date": str(sess_date), - "question_id": question_id, - }, - categories=["longmemeval", "session"], - infer=False, - ) + try: + memory.add( + messages=payload, + user_id=args.user_id, + metadata={ + "session_id": str(sess_id), + "session_date": str(sess_date), + "question_id": question_id, + }, + categories=["longmemeval", "session"], + infer=False, + ) + except Exception as e: + logger.warning("Skipping session %s for question %s: %s", sess_id, question_id, e) query = str(entry.get("question", "")).strip() - search_payload = memory.search_with_context( + search_payload = memory.search( query=query, user_id=args.user_id, limit=args.top_k, @@ -313,6 +325,7 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: include_debug_fields=args.include_debug_fields, ) out_f.write(json.dumps(output_row, ensure_ascii=False) + "\n") + out_f.flush() if retrieval_f is not None: retrieval_row = { @@ -322,10 +335,11 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: "metrics": metrics, } retrieval_f.write(json.dumps(retrieval_row, ensure_ascii=False) + "\n") + retrieval_f.flush() processed += 1 if args.print_every > 0 and processed % args.print_every == 0: - print(f"[LongMemEval] processed={processed} question_id={question_id}") + print(f"[LongMemEval] processed={processed} question_id={question_id}", flush=True) finally: if retrieval_f is not None: retrieval_f.close() diff --git a/engram/configs/base.py b/engram/configs/base.py index 26ff260..ed9d023 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -33,9 +33,9 @@ class LLMConfig(BaseModel): provider: str = Field(default="nvidia") config: Dict[str, Any] = Field( default_factory=lambda: { - "model": "meta/llama-3.1-8b-instruct", + "model": "qwen/qwen3.5-397b-a17b", "temperature": 0.2, - "max_tokens": 1024, + "max_tokens": 4096, } ) @@ -333,8 +333,15 @@ class SkillConfig(BaseModel): enable_mining: bool = True min_trajectory_steps: int = 3 mutation_rate: float = 0.05 - - @field_validator("min_confidence_for_auto_apply", "mutation_rate") + # Structural intelligence + enable_structural: bool = True + use_llm_decomposition: bool = True + structural_similarity_threshold: float = 0.4 + auto_decompose_on_mine: bool = True + auto_decompose_on_import: bool = True + + @field_validator("min_confidence_for_auto_apply", "mutation_rate", + "structural_similarity_threshold") @classmethod def _clamp_unit_float(cls, v: float) -> float: return min(1.0, max(0.0, float(v))) @@ -366,6 +373,20 @@ def _valid_priority(cls, v: str) -> str: return v +class EnrichmentConfig(BaseModel): + """Configuration for unified enrichment (single LLM call for echo+category+entities+profiles).""" + enable_unified: bool = False # Off by default for backward compat + fallback_to_individual: bool = True # On parse failure, fall back to individual calls + include_entities: bool = True # Include entity extraction in unified call + include_profiles: bool = True # Include profile extraction in unified call + max_batch_size: int = 10 # Max memories per unified batch call + + @field_validator("max_batch_size") + @classmethod + def _clamp_batch_size(cls, v: int) -> int: + return min(50, max(1, int(v))) + + class BatchConfig(BaseModel): """Configuration for batch memory operations.""" enable_batch: bool = False # off by default @@ -452,6 +473,7 @@ class MemoryConfig(BaseModel): distillation: DistillationConfig = Field(default_factory=DistillationConfig) parallel: ParallelConfig = Field(default_factory=ParallelConfig) batch: BatchConfig = Field(default_factory=BatchConfig) + enrichment: EnrichmentConfig = Field(default_factory=EnrichmentConfig) skill: SkillConfig = Field(default_factory=SkillConfig) task: TaskConfig = Field(default_factory=TaskConfig) metamemory: MetamemoryInlineConfig = Field(default_factory=MetamemoryInlineConfig) diff --git a/engram/configs/presets.py b/engram/configs/presets.py index 588dc6d..dabe2e5 100644 --- a/engram/configs/presets.py +++ b/engram/configs/presets.py @@ -135,12 +135,12 @@ def smart_config(): def full_config(): """Everything: scenes, profiles, graph, tasks. Needs API key or Ollama.""" from engram.configs.base import ( + EnrichmentConfig, SceneConfig, ProfileConfig, + SkillConfig, ) - from engram.configs.base import SkillConfig - config = smart_config() config.scene = SceneConfig(enable_scenes=True) config.profile = ProfileConfig(enable_profiles=True) @@ -148,4 +148,5 @@ def full_config(): config.category.enable_categories = True config.graph.enable_graph = True config.skill = SkillConfig(enable_skills=True, enable_mining=True) + config.enrichment = EnrichmentConfig(enable_unified=True) return config diff --git a/engram/core/category.py b/engram/core/category.py index 59788d7..aa5d5dd 100644 --- a/engram/core/category.py +++ b/engram/core/category.py @@ -20,6 +20,7 @@ import json import logging +import re import uuid from dataclasses import dataclass, field from datetime import datetime, timezone @@ -442,6 +443,10 @@ def _llm_detect_category( try: response = self.llm.generate(prompt) + # Strip ... blocks (Qwen 3.x thinking models) + response = re.sub(r"[\s\S]*?", "", response, flags=re.IGNORECASE).strip() + response = re.sub(r"[\s\S]*$", "", response, flags=re.IGNORECASE).strip() + # Parse JSON response — use raw_decode to ignore trailing LLM text json_start = response.find("{") if json_start >= 0: diff --git a/engram/core/echo.py b/engram/core/echo.py index 961914e..ad65206 100644 --- a/engram/core/echo.py +++ b/engram/core/echo.py @@ -412,22 +412,48 @@ def _extract_json_blob(self, response: str) -> str: text = (response or "").strip() if not text: return text - fence_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text, re.IGNORECASE) + # Strip ... blocks (Qwen 3.x thinking models) + text = re.sub(r"[\s\S]*?", "", text, flags=re.IGNORECASE).strip() + text = re.sub(r"[\s\S]*$", "", text, flags=re.IGNORECASE).strip() + if not text: + return text + # Handle any code fence type: ```json, ```python, ```, etc. + fence_match = re.search(r"```\w*\s*([\s\S]*?)\s*```", text, re.IGNORECASE) if fence_match: - return fence_match.group(1).strip() - # Use raw_decode to extract the first complete JSON object, - # ignoring any trailing LLM commentary. - start = text.find("{") - if start != -1: + inner = fence_match.group(1).strip() + # If the code block contains Python code rather than JSON, extract JSON from it + json_in_code = re.search(r"(\{[\s\S]*\})", inner) + if json_in_code: + inner = json_in_code.group(1) + return inner + # Try to find JSON objects and pick the one that looks like an echo output + decoder = json.JSONDecoder() + candidates = [] + idx = 0 + while idx < len(text): + # Look for start of object or array + obj_start = text.find("{", idx) + arr_start = text.find("[", idx) + start = min(s for s in (obj_start, arr_start) if s != -1) if any(s != -1 for s in (obj_start, arr_start)) else -1 + if start == -1: + break try: - obj, end = json.JSONDecoder().raw_decode(text, start) - return json.dumps(obj) + obj, end = decoder.raw_decode(text, start) + candidates.append(obj) + idx = end except json.JSONDecodeError: - pass - # Fallback: bracket matching - end = text.rfind("}") - if end > start: - return text[start:end + 1].strip() + idx = start + 1 + # Prefer the candidate that has echo-like keys + echo_keys = {"paraphrases", "keywords", "importance"} + for candidate in candidates: + if isinstance(candidate, dict) and echo_keys & set(candidate.keys()): + return json.dumps(candidate) + if isinstance(candidate, list) and candidate and isinstance(candidate[0], dict): + if echo_keys & set(candidate[0].keys()): + return json.dumps(candidate[0]) + # Fall back to first candidate + if candidates: + return json.dumps(candidates[0]) return text def _repair_json(self, text: str) -> str: @@ -435,6 +461,13 @@ def _repair_json(self, text: str) -> str: return text # Remove trailing commas before } or ] repaired = re.sub(r",(\s*[}\]])", r"\1", text) + # Fix template-literal values: "field": ["str"] or "field": "str" placeholders + repaired = re.sub(r':\s*\["str"\]', ': []', repaired) + repaired = re.sub(r':\s*"str"', ': ""', repaired) + # Fix schema descriptions leaking into values: "0.0-1.0" → 0.5 + repaired = re.sub(r'"0\.0-1\.0"', '0.5', repaired) + # Remove // style comments (not valid JSON) + repaired = re.sub(r'\s*//[^\n]*', '', repaired) return repaired def _load_json_dict(self, text: str) -> Optional[Dict[str, Any]]: @@ -450,7 +483,12 @@ def _load_json_dict(self, text: str) -> Optional[Dict[str, Any]]: return None else: return None - return data if isinstance(data, dict) else None + if isinstance(data, dict): + return data + # Handle list responses: LLM sometimes returns [{ echo }, ...] instead of { echo } + if isinstance(data, list) and data and isinstance(data[0], dict): + return data[0] + return None def _normalize_echo_dict(self, data: Dict[str, Any]) -> Dict[str, Any]: normalized = dict(data) @@ -458,6 +496,21 @@ def _normalize_echo_dict(self, data: Dict[str, Any]) -> Dict[str, Any]: normalized["paraphrases"] = normalized.pop("paraphrase") if "questions" not in normalized and "question_form" in normalized: normalized["questions"] = normalized.get("question_form") + # Handle LLM returning session metadata as top-level keys (not echo data) + echo_keys = {"paraphrases", "keywords", "importance"} + if not (echo_keys & set(normalized.keys())): + # This dict doesn't look like echo output — check for nested echo data + for key, val in normalized.items(): + if isinstance(val, dict) and echo_keys & set(val.keys()): + return self._normalize_echo_dict(val) + # Try "results" key for batch-style responses + if "results" in normalized and isinstance(normalized["results"], list): + if normalized["results"] and isinstance(normalized["results"][0], dict): + return self._normalize_echo_dict(normalized["results"][0]) + # Ensure required fields have defaults if still missing + normalized.setdefault("paraphrases", []) + normalized.setdefault("keywords", []) + normalized.setdefault("importance", 0.5) return normalized def process_batch( diff --git a/engram/core/enrichment.py b/engram/core/enrichment.py new file mode 100644 index 0000000..885dd19 --- /dev/null +++ b/engram/core/enrichment.py @@ -0,0 +1,627 @@ +"""Unified Enrichment — single LLM call for echo + category + entities + profiles. + +Replaces 4 separate LLM calls per memory with one combined call. +Backward compatible: individual processors stay unchanged; unified is an +alternative path that falls back to individual calls on parse failure. +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from engram.core.category import CategoryMatch +from engram.core.echo import EchoDepth, EchoProcessor, EchoResult +from engram.core.graph import Entity, EntityType +from engram.core.profile import ProfileUpdate +from engram.utils.prompts import ( + UNIFIED_ENRICHMENT_BATCH_PROMPT, + UNIFIED_ENRICHMENT_PROMPT, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Pydantic output models — what the LLM returns +# --------------------------------------------------------------------------- + + +class UnifiedEchoOutput(BaseModel): + model_config = ConfigDict(extra="ignore") + + paraphrases: List[str] = [] + keywords: List[str] = [] + implications: List[str] = [] + questions: List[str] = [] + question_form: Optional[str] = None + category: Optional[str] = None # fact|preference|goal|relationship|event + importance: float = 0.5 + + @field_validator("paraphrases", "keywords", "implications", "questions", mode="before") + @classmethod + def _coerce_list(cls, value): + if value is None: + return [] + if isinstance(value, list): + return value + if isinstance(value, str): + return [value] + return [str(value)] + + @field_validator("importance", mode="before") + @classmethod + def _coerce_importance(cls, value): + if isinstance(value, str): + try: + return float(value) + except ValueError: + return 0.5 + if value is None: + return 0.5 + return value + + @field_validator("question_form", mode="before") + @classmethod + def _clean_question_form(cls, value): + if value is None: + return None + if isinstance(value, list): + return value[0] if value else None + value = str(value).strip() + return value or None + + @field_validator("category", mode="before") + @classmethod + def _clean_category(cls, value): + if value is None: + return None + value = str(value).strip() + return value or None + + +class UnifiedCategoryOutput(BaseModel): + model_config = ConfigDict(extra="ignore") + + action: str = "use_existing" # use_existing|create_child|create_new + category_id: Optional[str] = None + new_category: Optional[Dict[str, Any]] = None # {name, description, keywords, parent_id} + confidence: float = 0.5 + + @field_validator("confidence", mode="before") + @classmethod + def _coerce_confidence(cls, value): + if isinstance(value, str): + try: + return float(value) + except ValueError: + return 0.5 + if value is None: + return 0.5 + return value + + +class UnifiedEntityOutput(BaseModel): + model_config = ConfigDict(extra="ignore") + + name: str + type: str = "unknown" # person|organization|technology|concept|location|project|tool|preference + + +class UnifiedProfileOutput(BaseModel): + model_config = ConfigDict(extra="ignore") + + name: str + type: str = "contact" # self|contact|entity + facts: List[str] = [] + preferences: List[str] = [] + + @field_validator("facts", "preferences", mode="before") + @classmethod + def _coerce_list(cls, value): + if value is None: + return [] + if isinstance(value, list): + return value + if isinstance(value, str): + return [value] + return [str(value)] + + +class UnifiedEnrichmentOutput(BaseModel): + """Full parsed output from a single unified LLM call.""" + model_config = ConfigDict(extra="ignore") + + echo: UnifiedEchoOutput = Field(default_factory=UnifiedEchoOutput) + category: UnifiedCategoryOutput = Field(default_factory=UnifiedCategoryOutput) + entities: List[UnifiedEntityOutput] = [] + profiles: List[UnifiedProfileOutput] = [] + + +# --------------------------------------------------------------------------- +# Bridge dataclass — unified output → existing processor types +# --------------------------------------------------------------------------- + + +@dataclass +class EnrichmentResult: + """Bridges unified output to existing processor types.""" + echo_result: Optional[EchoResult] = None + category_match: Optional[CategoryMatch] = None + entities: List[Entity] = field(default_factory=list) + profile_updates: List[ProfileUpdate] = field(default_factory=list) + raw_response: str = "" + + +# --------------------------------------------------------------------------- +# Processor +# --------------------------------------------------------------------------- + +# Depth instructions for prompt construction +_DEPTH_INSTRUCTIONS = { + EchoDepth.SHALLOW: "keywords only (skip paraphrases, implications, questions)", + EchoDepth.MEDIUM: "paraphrases, keywords, category. Skip: implications, questions.", + EchoDepth.DEEP: "ALL fields: paraphrases, keywords, implications, questions, question_form, category.", +} + + +class UnifiedEnrichmentProcessor: + """Single LLM call for echo + category + entity + profile extraction.""" + + def __init__( + self, + llm, + echo_processor: Optional[EchoProcessor] = None, + category_processor=None, + knowledge_graph=None, + profile_processor=None, + ): + self.llm = llm + self.echo_processor = echo_processor + self.category_processor = category_processor + self.knowledge_graph = knowledge_graph + self.profile_processor = profile_processor + + # ------------------------------------------------------------------ + # Single-memory enrichment + # ------------------------------------------------------------------ + + def enrich( + self, + content: str, + depth: EchoDepth = EchoDepth.MEDIUM, + existing_categories: Optional[str] = None, + include_entities: bool = True, + include_profiles: bool = True, + ) -> EnrichmentResult: + """Single LLM call for one memory. Falls back to individual on failure.""" + prompt = self._build_prompt( + content, depth, existing_categories, + include_entities=include_entities, + include_profiles=include_profiles, + ) + try: + response = self.llm.generate(prompt) + return self._parse_response(response, content, depth) + except Exception as e: + logger.warning("Unified enrichment failed, falling back to individual: %s", e) + return self._fallback_individual(content, depth, existing_categories) + + # ------------------------------------------------------------------ + # Batch enrichment + # ------------------------------------------------------------------ + + def enrich_batch( + self, + contents: List[str], + depth: EchoDepth = EchoDepth.MEDIUM, + existing_categories: Optional[str] = None, + include_entities: bool = True, + include_profiles: bool = True, + ) -> List[EnrichmentResult]: + """Single LLM call for N memories. Falls back per-item on failure.""" + if not contents: + return [] + if len(contents) == 1: + return [self.enrich( + contents[0], depth, existing_categories, + include_entities=include_entities, + include_profiles=include_profiles, + )] + + prompt = self._build_batch_prompt( + contents, depth, existing_categories, + include_entities=include_entities, + include_profiles=include_profiles, + ) + try: + response = self.llm.generate(prompt) + return self._parse_batch_response(response, contents, depth) + except Exception as e: + logger.warning("Unified batch enrichment failed, falling back to sequential: %s", e) + return [ + self.enrich(c, depth, existing_categories, + include_entities=include_entities, + include_profiles=include_profiles) + for c in contents + ] + + # ------------------------------------------------------------------ + # Prompt builders + # ------------------------------------------------------------------ + + def _build_prompt( + self, + content: str, + depth: EchoDepth, + existing_categories: Optional[str] = None, + include_entities: bool = True, + include_profiles: bool = True, + ) -> str: + cats = existing_categories or self._format_existing_categories() + depth_instructions = _DEPTH_INSTRUCTIONS.get(depth, _DEPTH_INSTRUCTIONS[EchoDepth.MEDIUM]) + return UNIFIED_ENRICHMENT_PROMPT.format( + content=content[:2000], + depth=depth.value, + depth_instructions=depth_instructions, + existing_categories=cats, + include_entities="yes" if include_entities else "no", + include_profiles="yes" if include_profiles else "no", + ) + + def _build_batch_prompt( + self, + contents: List[str], + depth: EchoDepth, + existing_categories: Optional[str] = None, + include_entities: bool = True, + include_profiles: bool = True, + ) -> str: + cats = existing_categories or self._format_existing_categories() + depth_instructions = _DEPTH_INSTRUCTIONS.get(depth, _DEPTH_INSTRUCTIONS[EchoDepth.MEDIUM]) + memories_block = "\n".join( + f"[{i}] {c[:500]}" for i, c in enumerate(contents) + ) + return UNIFIED_ENRICHMENT_BATCH_PROMPT.format( + memories_block=memories_block, + count=len(contents), + depth=depth.value, + depth_instructions=depth_instructions, + existing_categories=cats, + include_entities="yes" if include_entities else "no", + include_profiles="yes" if include_profiles else "no", + ) + + def _format_existing_categories(self) -> str: + if not self.category_processor: + return "(none)" + cats = self.category_processor.get_all_categories() + if not cats: + return "(none)" + return "\n".join( + f"- {c['id']}: {c['name']} — {c.get('description', '')}" + for c in cats[:30] + ) + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + def _parse_response( + self, + response: str, + content: str, + depth: EchoDepth, + ) -> EnrichmentResult: + """Parse LLM response into EnrichmentResult.""" + json_str = _extract_json_blob(response) + data = _robust_json_load(json_str) + + try: + unified = UnifiedEnrichmentOutput.model_validate(data) + except Exception: + # Try normalizing keys + normalized = _normalize_unified_dict(data) + unified = UnifiedEnrichmentOutput.model_validate(normalized) + + return EnrichmentResult( + echo_result=self._to_echo_result(unified.echo, content, depth), + category_match=self._to_category_match(unified.category), + entities=self._to_entities(unified.entities), + profile_updates=self._to_profile_updates(unified.profiles), + raw_response=response, + ) + + def _parse_batch_response( + self, + response: str, + contents: List[str], + depth: EchoDepth, + ) -> List[EnrichmentResult]: + """Parse batch LLM response. Falls back per-item on partial failure.""" + json_str = _extract_json_blob(response) + data = _robust_json_load(json_str) + + results_list = data.get("results", []) + if not isinstance(results_list, list): + raise ValueError("Batch response 'results' is not a list") + + # Index by position + parsed_by_index: Dict[int, Dict[str, Any]] = {} + for item in results_list: + idx = item.get("index", -1) + if 0 <= idx < len(contents): + parsed_by_index[idx] = item + + results: List[EnrichmentResult] = [] + for i, content in enumerate(contents): + if i in parsed_by_index: + try: + unified = UnifiedEnrichmentOutput.model_validate(parsed_by_index[i]) + results.append(EnrichmentResult( + echo_result=self._to_echo_result(unified.echo, content, depth), + category_match=self._to_category_match(unified.category), + entities=self._to_entities(unified.entities), + profile_updates=self._to_profile_updates(unified.profiles), + )) + continue + except Exception: + pass + # Fallback: enrich individually + results.append(self.enrich(content, depth)) + return results + + # ------------------------------------------------------------------ + # Converters: unified output → existing types + # ------------------------------------------------------------------ + + def _to_echo_result( + self, + echo_out: UnifiedEchoOutput, + content: str, + depth: EchoDepth, + ) -> EchoResult: + multiplier = EchoProcessor.STRENGTH_MULTIPLIERS.get(depth, 1.0) + + question_form = echo_out.question_form + if not question_form and echo_out.questions: + question_form = echo_out.questions[0] + + return EchoResult( + raw=content, + paraphrases=echo_out.paraphrases, + keywords=echo_out.keywords, + implications=echo_out.implications if depth == EchoDepth.DEEP else [], + questions=echo_out.questions if depth == EchoDepth.DEEP else [], + question_form=question_form, + category=echo_out.category, + importance=echo_out.importance, + echo_depth=depth, + strength_multiplier=multiplier, + ) + + def _to_category_match(self, cat_out: UnifiedCategoryOutput) -> CategoryMatch: + action = cat_out.action + + if action == "use_existing" and cat_out.category_id: + # Verify it exists if we have a processor + if self.category_processor: + cat = self.category_processor.get_category(cat_out.category_id) + if cat: + return CategoryMatch( + category_id=cat.id, + category_name=cat.name, + confidence=cat_out.confidence, + ) + # Still return what the LLM said, even without verification + return CategoryMatch( + category_id=cat_out.category_id, + category_name=cat_out.category_id, + confidence=cat_out.confidence, + ) + + if action in ("create_child", "create_new") and cat_out.new_category: + new_cat = cat_out.new_category + if self.category_processor: + cat_id = self.category_processor._create_category( + name=new_cat.get("name", "Unnamed"), + description=new_cat.get("description", ""), + keywords=new_cat.get("keywords", []), + parent_id=new_cat.get("parent_id"), + ) + return CategoryMatch( + category_id=cat_id, + category_name=new_cat.get("name", "Unnamed"), + confidence=cat_out.confidence, + is_new=True, + suggested_parent_id=new_cat.get("parent_id"), + ) + + # Fallback + return CategoryMatch( + category_id="context", + category_name="Context & Situations", + confidence=0.3, + ) + + def _to_entities(self, entity_outs: List[UnifiedEntityOutput]) -> List[Entity]: + entities = [] + for eo in entity_outs: + name = eo.name.strip() + if not name: + continue + try: + entity_type = EntityType(eo.type) + except ValueError: + entity_type = EntityType.UNKNOWN + entities.append(Entity(name=name, entity_type=entity_type)) + return entities + + def _to_profile_updates(self, profile_outs: List[UnifiedProfileOutput]) -> List[ProfileUpdate]: + updates = [] + for po in profile_outs: + name = po.name.strip() + if not name: + continue + updates.append(ProfileUpdate( + profile_name=name, + profile_type=po.type, + new_facts=po.facts, + new_preferences=po.preferences, + )) + return updates + + # ------------------------------------------------------------------ + # Fallback: individual processor calls + # ------------------------------------------------------------------ + + def _fallback_individual( + self, + content: str, + depth: EchoDepth, + existing_categories: Optional[str], + ) -> EnrichmentResult: + """If unified parsing fails, call each processor separately.""" + echo_result = None + category_match = None + entities: List[Entity] = [] + profile_updates: List[ProfileUpdate] = [] + + # Echo + if self.echo_processor: + try: + echo_result = self.echo_processor.process(content, depth=depth) + except Exception as e: + logger.warning("Fallback echo failed: %s", e) + + # Category + if self.category_processor: + try: + category_match = self.category_processor.detect_category(content) + except Exception as e: + logger.warning("Fallback category failed: %s", e) + + # Entities (regex only in fallback to avoid extra LLM call) + if self.knowledge_graph: + try: + entities = self.knowledge_graph._extract_entities_regex(content, "") + except Exception as e: + logger.warning("Fallback entity extraction failed: %s", e) + + # Profiles (regex only in fallback) + if self.profile_processor: + try: + profile_updates = self.profile_processor.extract_profile_mentions(content) + except Exception as e: + logger.warning("Fallback profile extraction failed: %s", e) + + return EnrichmentResult( + echo_result=echo_result, + category_match=category_match, + entities=entities, + profile_updates=profile_updates, + ) + + +# --------------------------------------------------------------------------- +# JSON parsing helpers +# --------------------------------------------------------------------------- + +def _extract_json_blob(response: str) -> str: + """Extract JSON object from LLM response, handling code fences, thinking tags, and noise.""" + text = (response or "").strip() + if not text: + return "{}" + + # Strip ... blocks (Qwen 3.x thinking models) + text = re.sub(r"[\s\S]*?", "", text, flags=re.IGNORECASE).strip() + + # Also strip an unclosed block (model hit max_tokens mid-thought) + text = re.sub(r"[\s\S]*$", "", text, flags=re.IGNORECASE).strip() + + if not text: + return "{}" + + # Strip code fences + fence_match = re.search(r"```\w*\s*([\s\S]*?)\s*```", text, re.IGNORECASE) + if fence_match: + text = fence_match.group(1).strip() + + # Find the first JSON object + start = text.find("{") + if start >= 0: + try: + obj, _ = json.JSONDecoder().raw_decode(text, start) + return json.dumps(obj) + except json.JSONDecodeError: + pass + + return text + + +def _robust_json_load(text: str) -> Dict[str, Any]: + """Load JSON with repair for common LLM output issues.""" + try: + data = json.loads(text) + if isinstance(data, dict): + return data + except json.JSONDecodeError: + pass + + # Repair: remove trailing commas, comments, template placeholders + repaired = re.sub(r",(\s*[}\]])", r"\1", text) + repaired = re.sub(r'\s*//[^\n]*', '', repaired) + repaired = re.sub(r'"0\.0-1\.0"', '0.5', repaired) + repaired = re.sub(r':\s*\["str"\]', ': []', repaired) + repaired = re.sub(r':\s*"str"', ': ""', repaired) + + try: + data = json.loads(repaired) + if isinstance(data, dict): + return data + except json.JSONDecodeError: + pass + + # Last resort: raw_decode + start = repaired.find("{") + if start >= 0: + try: + data, _ = json.JSONDecoder().raw_decode(repaired, start) + if isinstance(data, dict): + return data + except json.JSONDecodeError: + pass + + raise json.JSONDecodeError("Could not parse unified enrichment response", text, 0) + + +def _normalize_unified_dict(data: Dict[str, Any]) -> Dict[str, Any]: + """Normalize common LLM key variations.""" + normalized = dict(data) + + # Ensure top-level keys exist + if "echo" not in normalized: + # Check if echo keys are at top level + echo_keys = {"paraphrases", "keywords", "importance"} + if echo_keys & set(normalized.keys()): + normalized["echo"] = { + k: normalized.pop(k) + for k in list(normalized.keys()) + if k in {"paraphrases", "keywords", "implications", "questions", + "question_form", "category", "importance"} + } + else: + normalized["echo"] = {} + + if "category" not in normalized: + normalized["category"] = {} + if "entities" not in normalized: + normalized["entities"] = [] + if "profiles" not in normalized: + normalized["profiles"] = [] + + return normalized diff --git a/engram/embeddings/nvidia.py b/engram/embeddings/nvidia.py index a8a210d..43b6914 100644 --- a/engram/embeddings/nvidia.py +++ b/engram/embeddings/nvidia.py @@ -40,21 +40,44 @@ def _extra_body(self, memory_action: Optional[str] = None) -> dict: return {"input_type": input_type, "truncate": "END"} return {} + def _truncate_if_needed(self, text: str) -> str: + """Truncate text to stay within model token limits. + + nv-embed-v1 has a 4096 token limit. Using ~3.5 chars/token as + a conservative estimate, cap at 14000 characters. + """ + max_chars = int(self.config.get("max_input_chars", 14000)) + if len(text) > max_chars: + logger.debug("Truncating input from %d to %d chars for embedding", len(text), max_chars) + return text[:max_chars] + return text + def embed(self, text: str, memory_action: Optional[str] = None) -> List[float]: - try: - extra_body = self._extra_body(memory_action) - response = self.client.embeddings.create( - input=[text], - model=self.model, - encoding_format="float", - **({"extra_body": extra_body} if extra_body else {}), - ) - return response.data[0].embedding - except Exception as exc: - logger.error("NVIDIA embedding failed (model=%s): %s", self.model, exc) - raise RuntimeError( - f"NVIDIA embedding failed (model={self.model}): {exc}" - ) from exc + import time as _time + text = self._truncate_if_needed(text) + max_retries = int(self.config.get("max_retries", 3)) + last_exc = None + for attempt in range(max_retries + 1): + try: + extra_body = self._extra_body(memory_action) + response = self.client.embeddings.create( + input=[text], + model=self.model, + encoding_format="float", + **({"extra_body": extra_body} if extra_body else {}), + ) + return response.data[0].embedding + except Exception as exc: + last_exc = exc + if attempt < max_retries: + delay = min(2 ** attempt, 8) + logger.warning("NVIDIA embed retry %d/%d after %ss: %s", attempt + 1, max_retries, delay, exc) + _time.sleep(delay) + else: + logger.error("NVIDIA embedding failed (model=%s): %s", self.model, exc) + raise RuntimeError( + f"NVIDIA embedding failed (model={self.model}): {last_exc}" + ) from last_exc def embed_batch( self, texts: List[str], memory_action: Optional[str] = None @@ -62,6 +85,7 @@ def embed_batch( """Native batch embedding — single API call for N texts.""" if not texts: return [] + texts = [self._truncate_if_needed(t) for t in texts] if len(texts) == 1: return [self.embed(texts[0], memory_action=memory_action)] try: diff --git a/engram/llms/nvidia.py b/engram/llms/nvidia.py index 0267799..8672493 100644 --- a/engram/llms/nvidia.py +++ b/engram/llms/nvidia.py @@ -8,7 +8,7 @@ class NvidiaLLM(BaseLLM): - """LLM provider for NVIDIA API (OpenAI-compatible). Default model: Llama 3.1 8B Instruct.""" + """LLM provider for NVIDIA API (OpenAI-compatible). Default model: Qwen 3.5 397B.""" def __init__(self, config: Optional[dict] = None): super().__init__(config) @@ -19,21 +19,22 @@ def __init__(self, config: Optional[dict] = None): api_key = ( self.config.get("api_key") + or os.getenv("NVIDIA_QWEN_API_KEY") or os.getenv("LLAMA_API_KEY") or os.getenv("NVIDIA_API_KEY") ) if not api_key: raise ValueError( "NVIDIA API key required. Set config['api_key'], " - "LLAMA_API_KEY, or NVIDIA_API_KEY env var." + "NVIDIA_QWEN_API_KEY, LLAMA_API_KEY, or NVIDIA_API_KEY env var." ) base_url = self.config.get("base_url", "https://integrate.api.nvidia.com/v1") - timeout = self.config.get("timeout", 60) + timeout = self.config.get("timeout", 120) self.client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout) - self.model = self.config.get("model", "meta/llama-3.1-8b-instruct") + self.model = self.config.get("model", "qwen/qwen3.5-397b-a17b") self.temperature = self.config.get("temperature", 0.2) - self.max_tokens = self.config.get("max_tokens", 1024) + self.max_tokens = self.config.get("max_tokens", 4096) self.top_p = self.config.get("top_p", 0.7) self.enable_thinking = self.config.get("enable_thinking", False) diff --git a/engram/mcp_server.py b/engram/mcp_server.py index 6acd8d2..d109b36 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -1,4 +1,4 @@ -"""Engram MCP Server — 14 tools, minimal boilerplate. +"""Engram MCP Server — 18 tools, minimal boilerplate. Tools: 1. remember — Quick-save (content → memory, infer=False) @@ -15,6 +15,10 @@ 12. record_trajectory_step — Record a step in active trajectory 13. mine_skills — Run skill mining cycle 14. get_skill_stats — Statistics about skills and trajectories +15. search_skills_structural — Find skills by structural similarity +16. analyze_skill_gaps — Show what transfers vs what needs experimentation +17. decompose_skill — Trigger structural decomposition of a flat skill +18. apply_skill_with_bindings — Apply skill with slot values, includes gap analysis """ import json @@ -299,13 +303,28 @@ def get_memory() -> Memory: ), Tool( name="log_skill_outcome", - description="Report success or failure for a skill. Updates the skill's confidence score based on outcome.", + description="Report success or failure for a skill. Updates the skill's confidence score based on outcome. Optionally accepts per-step outcomes for granular feedback.", inputSchema={ "type": "object", "properties": { "skill_id": {"type": "string", "description": "The ID of the skill to log outcome for"}, "success": {"type": "boolean", "description": "Whether the skill application was successful"}, "notes": {"type": "string", "description": "Optional notes about the outcome"}, + "step_outcomes": { + "type": "array", + "description": "Optional per-step outcomes for granular feedback", + "items": { + "type": "object", + "properties": { + "step_index": {"type": "integer", "description": "Index of the step (0-based)"}, + "success": {"type": "boolean", "description": "Whether this step succeeded"}, + "failure_type": {"type": "string", "enum": ["structural", "slot"], "description": "Type of failure"}, + "failed_slot": {"type": "string", "description": "Which slot caused the failure"}, + "notes": {"type": "string", "description": "Notes about this step's outcome"}, + }, + "required": ["step_index", "success"], + }, + }, }, "required": ["skill_id", "success"], }, @@ -353,6 +372,66 @@ def get_memory() -> Memory: "properties": {}, }, ), + Tool( + name="search_skills_structural", + description="Find skills by structural similarity to given steps. Decomposes the query steps into a recipe template and matches against skills with structural decomposition.", + inputSchema={ + "type": "object", + "properties": { + "query_steps": { + "type": "array", + "items": {"type": "string"}, + "description": "Steps to find structurally similar skills for (e.g., ['Build Go app', 'Run go test', 'Deploy to GCP'])", + }, + "limit": {"type": "integer", "description": "Maximum number of results (default: 5)"}, + "min_similarity": {"type": "number", "description": "Minimum structural similarity threshold 0.0-1.0 (default: 0.3)"}, + }, + "required": ["query_steps"], + }, + ), + Tool( + name="analyze_skill_gaps", + description="Analyze what transfers from a skill to a new target context. Shows proven bindings, untested bindings, and missing slots with recommendations.", + inputSchema={ + "type": "object", + "properties": { + "skill_id": {"type": "string", "description": "The ID of the skill to analyze"}, + "target_context": { + "type": "object", + "description": "Target context with slot values (e.g., {'language': 'go', 'deploy_target': 'gcp'})", + "additionalProperties": {"type": "string"}, + }, + }, + "required": ["skill_id", "target_context"], + }, + ), + Tool( + name="decompose_skill", + description="Trigger structural decomposition of a flat skill into recipe + ingredients. Extracts slots and creates structured step templates.", + inputSchema={ + "type": "object", + "properties": { + "skill_id": {"type": "string", "description": "The ID of the skill to decompose"}, + }, + "required": ["skill_id"], + }, + ), + Tool( + name="apply_skill_with_bindings", + description="Apply a skill with specific slot values. Renders steps with bindings and includes gap analysis showing proven vs untested components.", + inputSchema={ + "type": "object", + "properties": { + "skill_id": {"type": "string", "description": "The ID of the skill to apply"}, + "bindings": { + "type": "object", + "description": "Slot bindings (e.g., {'language': 'go', 'test_framework': 'go test', 'deploy_target': 'gcp'})", + "additionalProperties": {"type": "string"}, + }, + }, + "required": ["skill_id", "bindings"], + }, + ), ] @@ -525,6 +604,7 @@ def _handle_log_skill_outcome(memory, args): skill_id=args.get("skill_id", ""), success=args.get("success", False), notes=args.get("notes"), + step_outcomes=args.get("step_outcomes"), ) @@ -562,6 +642,40 @@ def _handle_get_skill_stats(memory, args): return memory.get_skill_stats() +def _handle_search_skills_structural(memory, args): + query_steps = args.get("query_steps", []) + try: + limit = max(1, min(50, int(args.get("limit", 5)))) + except (ValueError, TypeError): + limit = 5 + min_sim = float(args.get("min_similarity", 0.3)) + return memory.search_skills_structural( + query_steps=query_steps, + limit=limit, + min_similarity=min_sim, + ) + + +def _handle_analyze_skill_gaps(memory, args): + return memory.analyze_skill_gaps( + skill_id=args.get("skill_id", ""), + target_context=args.get("target_context", {}), + ) + + +def _handle_decompose_skill(memory, args): + return memory.decompose_skill( + skill_id=args.get("skill_id", ""), + ) + + +def _handle_apply_skill_with_bindings(memory, args): + return memory.apply_skill( + skill_id=args.get("skill_id", ""), + bindings=args.get("bindings", {}), + ) + + HANDLERS = { "remember": _handle_remember, "search_memory": _handle_search_memory, @@ -577,6 +691,10 @@ def _handle_get_skill_stats(memory, args): "record_trajectory_step": _handle_record_trajectory_step, "mine_skills": _handle_mine_skills, "get_skill_stats": _handle_get_skill_stats, + "search_skills_structural": _handle_search_skills_structural, + "analyze_skill_gaps": _handle_analyze_skill_gaps, + "decompose_skill": _handle_decompose_skill, + "apply_skill_with_bindings": _handle_apply_skill_with_bindings, } _MEMORY_FREE_TOOLS = {"get_last_session", "save_session_digest"} diff --git a/engram/memory/main.py b/engram/memory/main.py index 5101122..d0ee8d5 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -355,18 +355,28 @@ def record_trajectory_step( args: Optional[Dict[str, Any]] = None, result_summary: str = "", error: Optional[str] = None, + slot_values: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: - """Record a step in an active trajectory.""" + """Record a step in an active trajectory. + + If slot_values are provided, they are stored in state_snapshot + for later structural mining. + """ recorder = self._active_recorders.get(recorder_id) if recorder is None: return {"error": f"No active recorder: {recorder_id}"} + state_snapshot = None + if slot_values: + state_snapshot = {"slot_values": slot_values} + step = recorder.record_step( action=action, tool=tool, args=args, result_summary=result_summary, error=error, + state_snapshot=state_snapshot, ) return { "recorder_id": recorder_id, @@ -1051,65 +1061,128 @@ def _process_single_memory( and not mem_categories ) - # Site 1: Parallel echo encoding + category detection - _use_parallel = ( - self._executor is not None - and self.parallel_config - and self.parallel_config.parallel_add - and _should_categorize - and self.echo_processor + # Pre-extracted data from unified enrichment (used to skip redundant post-store calls) + _unified_entities = None # List[Entity] or None + _unified_profiles = None # List[ProfileUpdate] or None + + # Determine echo depth for unified path check + _depth_for_echo = EchoDepth(echo_depth) if echo_depth else None + if _depth_for_echo is None and self.echo_processor and hasattr(self.echo_processor, '_assess_depth'): + try: + _depth_for_echo = self.echo_processor._assess_depth(content) + except Exception: + _depth_for_echo = EchoDepth.MEDIUM + + # Site 0: Unified enrichment (single LLM call for echo+category+entities+profiles) + _use_unified = ( + self.unified_enrichment is not None and self.echo_config.enable_echo + and _depth_for_echo != EchoDepth.SHALLOW # shallow is LLM-free ) - if _use_parallel: - # Run echo and category detection in parallel (both only read content) - def _do_echo(): - depth_override = EchoDepth(echo_depth) if echo_depth else None - return self.echo_processor.process(content, depth=depth_override) - - def _do_category(): - return self.category_processor.detect_category( - content, - metadata=mem_metadata, - use_llm=self.category_config.use_llm_categorization, - ) + if _use_unified: + enrichment_config = getattr(self.config, "enrichment", None) + existing_cats = None + if self.category_processor: + cats = self.category_processor.get_all_categories() + if cats: + existing_cats = "\n".join( + f"- {c['id']}: {c['name']} — {c.get('description', '')}" + for c in cats[:30] + ) - echo_result_p, category_match = self._executor.run_parallel([ - (_do_echo, ()), - (_do_category, ()), - ]) + enrichment = self.unified_enrichment.enrich( + content=content, + depth=_depth_for_echo or EchoDepth.MEDIUM, + existing_categories=existing_cats, + include_entities=enrichment_config.include_entities if enrichment_config else True, + include_profiles=enrichment_config.include_profiles if enrichment_config else True, + ) # Apply echo result - effective_strength = initial_strength * echo_result_p.strength_multiplier - mem_metadata.update(echo_result_p.to_metadata()) - if not mem_categories and echo_result_p.category: - mem_categories = [echo_result_p.category] + echo_result = enrichment.echo_result + if echo_result: + effective_strength = initial_strength * echo_result.strength_multiplier + mem_metadata.update(echo_result.to_metadata()) + if not mem_categories and echo_result.category: + mem_categories = [echo_result.category] + else: + effective_strength = initial_strength # Apply category result - mem_categories = [category_match.category_id] - mem_metadata["category_confidence"] = category_match.confidence - mem_metadata["category_auto"] = True + if enrichment.category_match and not mem_categories: + mem_categories = [enrichment.category_match.category_id] + mem_metadata["category_confidence"] = enrichment.category_match.confidence + mem_metadata["category_auto"] = True - # Generate embedding (depends on echo result, must be serial) - primary_text = self._select_primary_text(content, echo_result_p) + # Stash entities + profiles for post-store hooks + _unified_entities = enrichment.entities + _unified_profiles = enrichment.profile_updates + + # Generate embedding + primary_text = self._select_primary_text(content, echo_result) embedding = self.embedder.embed(primary_text, memory_action="add") - echo_result = echo_result_p + else: - # Sequential path (original behavior) - if _should_categorize: - category_match = self.category_processor.detect_category( - content, - metadata=mem_metadata, - use_llm=self.category_config.use_llm_categorization, - ) + # Site 1: Parallel echo encoding + category detection + _use_parallel = ( + self._executor is not None + and self.parallel_config + and self.parallel_config.parallel_add + and _should_categorize + and self.echo_processor + and self.echo_config.enable_echo + ) + + if _use_parallel: + # Run echo and category detection in parallel (both only read content) + def _do_echo(): + depth_override = EchoDepth(echo_depth) if echo_depth else None + return self.echo_processor.process(content, depth=depth_override) + + def _do_category(): + return self.category_processor.detect_category( + content, + metadata=mem_metadata, + use_llm=self.category_config.use_llm_categorization, + ) + + echo_result_p, category_match = self._executor.run_parallel([ + (_do_echo, ()), + (_do_category, ()), + ]) + + # Apply echo result + effective_strength = initial_strength * echo_result_p.strength_multiplier + mem_metadata.update(echo_result_p.to_metadata()) + if not mem_categories and echo_result_p.category: + mem_categories = [echo_result_p.category] + + # Apply category result mem_categories = [category_match.category_id] mem_metadata["category_confidence"] = category_match.confidence mem_metadata["category_auto"] = True - # Encode memory (echo + embedding). - echo_result, effective_strength, mem_categories, embedding = self._encode_memory( - content, echo_depth, mem_categories, mem_metadata, initial_strength, - ) + # Generate embedding (depends on echo result, must be serial) + primary_text = self._select_primary_text(content, echo_result_p) + embedding = self.embedder.embed(primary_text, memory_action="add") + echo_result = echo_result_p + else: + # Sequential path (original behavior) + if _should_categorize: + category_match = self.category_processor.detect_category( + content, + metadata=mem_metadata, + use_llm=self.category_config.use_llm_categorization, + ) + mem_categories = [category_match.category_id] + mem_metadata["category_confidence"] = category_match.confidence + mem_metadata["category_auto"] = True + + # Encode memory (echo + embedding). + echo_result, effective_strength, mem_categories, embedding = self._encode_memory( + content, echo_depth, mem_categories, mem_metadata, initial_strength, + ) nearest, similarity = self._nearest_memory(embedding, store_filters) repeated_threshold = max(self.fadem_config.conflict_similarity_threshold - 0.05, 0.7) @@ -1277,11 +1350,22 @@ def _do_category(): ) if self.knowledge_graph: - self.knowledge_graph.extract_entities( - content=content, - memory_id=effective_memory_id, - use_llm=self.graph_config.use_llm_extraction, - ) + if _unified_entities is not None: + # Use pre-extracted entities from unified enrichment + for entity in _unified_entities: + existing = self.knowledge_graph._get_or_create_entity( + entity.name, entity.entity_type, + ) + existing.memory_ids.add(effective_memory_id) + self.knowledge_graph.memory_entities[effective_memory_id] = { + e.name for e in _unified_entities + } + else: + self.knowledge_graph.extract_entities( + content=content, + memory_id=effective_memory_id, + use_llm=self.graph_config.use_llm_extraction, + ) if self.graph_config.auto_link_entities: self.knowledge_graph.link_by_shared_entities(effective_memory_id) @@ -1293,7 +1377,16 @@ def _do_category(): if self.profile_processor: try: - self._update_profiles(effective_memory_id, content, mem_metadata, user_id) + if _unified_profiles is not None and _unified_profiles: + # Use pre-extracted profiles from unified enrichment + for profile_update in _unified_profiles: + self.profile_processor.apply_update( + profile_update=profile_update, + memory_id=effective_memory_id, + user_id=user_id or "default", + ) + else: + self._update_profiles(effective_memory_id, content, mem_metadata, user_id) except Exception as e: logger.warning("Profile update failed for %s: %s", effective_memory_id, e) diff --git a/engram/memory/smart.py b/engram/memory/smart.py index efafd33..aaa84df 100644 --- a/engram/memory/smart.py +++ b/engram/memory/smart.py @@ -51,6 +51,7 @@ def __init__( self._echo_processor = None self._category_processor = None self._knowledge_graph = None + self._unified_enrichment = None self._skill_store = None self._skill_executor = None @@ -95,6 +96,20 @@ def knowledge_graph(self): ) return self._knowledge_graph + @property + def unified_enrichment(self): + enrichment_config = getattr(self.config, "enrichment", None) + if self._unified_enrichment is None and enrichment_config and enrichment_config.enable_unified: + from engram.core.enrichment import UnifiedEnrichmentProcessor + self._unified_enrichment = UnifiedEnrichmentProcessor( + llm=self.llm, + echo_processor=self.echo_processor, + category_processor=self.category_processor, + knowledge_graph=self.knowledge_graph, + profile_processor=getattr(self, "profile_processor", None), + ) + return self._unified_enrichment + @property def skill_store(self): if self._skill_store is None and self.skill_config and self.skill_config.enable_skills: @@ -131,28 +146,113 @@ def search_skills( query=query, limit=limit, tags=tags, min_confidence=min_confidence, ) + def log_skill_outcome( + self, + skill_id: str, + success: bool, + notes: Optional[str] = None, + step_outcomes: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """Log success/failure for a skill and update its confidence. + + If step_outcomes is provided (list of dicts with step_index, success, + failure_type, failed_slot, notes), per-step confidence is updated too. + """ + if self.skill_store is None: + return {"error": "Skills not enabled"} + from engram.skills.outcomes import OutcomeTracker, StepOutcome + tracker = OutcomeTracker(self.skill_store) + parsed_step_outcomes = None + if step_outcomes: + parsed_step_outcomes = [StepOutcome.from_dict(so) for so in step_outcomes] + return tracker.log_outcome(skill_id, success, notes, parsed_step_outcomes) + def apply_skill( self, skill_id: str, context: Optional[Dict[str, Any]] = None, + bindings: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: - """Apply a skill by ID, returning the recipe for injection.""" + """Apply a skill by ID, optionally with structural bindings.""" if self.skill_executor is None: return {"error": "Skills not enabled", "injected": False} - return self.skill_executor.apply(skill_id, context) + return self.skill_executor.apply(skill_id, context, bindings) - def log_skill_outcome( + def search_skills_structural( + self, + query_steps: List[str], + limit: int = 5, + min_similarity: float = 0.3, + ) -> List[Dict[str, Any]]: + """Search for skills by structural similarity to given steps.""" + if self.skill_executor is None: + return [] + return self.skill_executor.search_structural( + query_steps=query_steps, + limit=limit, + min_similarity=min_similarity, + ) + + def analyze_skill_gaps( self, skill_id: str, - success: bool, - notes: Optional[str] = None, + target_context: Dict[str, str], ) -> Dict[str, Any]: - """Log success/failure for a skill and update its confidence.""" + """Analyze what transfers from a skill to a target context.""" if self.skill_store is None: return {"error": "Skills not enabled"} - from engram.skills.outcomes import OutcomeTracker - tracker = OutcomeTracker(self.skill_store) - return tracker.log_outcome(skill_id, success, notes) + skill = self.skill_store.get(skill_id) + if skill is None: + return {"error": f"Skill not found: {skill_id}"} + structure = skill.get_structure() + if structure is None: + return {"error": "Skill has no structural decomposition"} + from engram.skills.structure import analyze_gaps + report = analyze_gaps(structure, target_context, skill.confidence) + report.skill_id = skill_id + return report.to_dict() + + def decompose_skill(self, skill_id: str) -> Dict[str, Any]: + """Trigger structural decomposition of a flat skill.""" + if self.skill_store is None: + return {"error": "Skills not enabled"} + skill = self.skill_store.get(skill_id) + if skill is None: + return {"error": f"Skill not found: {skill_id}"} + if skill.get_structure() is not None: + return {"skill_id": skill_id, "status": "already_decomposed"} + + from engram.skills.structure import ( + SkillStructure, + extract_slots_heuristic, + extract_slots_llm, + ) + skill_cfg = getattr(self, "skill_config", None) + use_llm = skill_cfg and skill_cfg.use_llm_decomposition and hasattr(self, "llm") and self.llm + if use_llm: + slots, steps = extract_slots_llm( + skill.name, skill.description, skill.steps, skill.tags, self.llm, + ) + else: + slots, steps = extract_slots_heuristic(skill.steps, skill.tags) + + known_bindings = {s.name: list(s.examples) for s in slots if s.examples} + structure = SkillStructure( + slots=slots, + structured_steps=steps, + known_bindings=known_bindings, + ) + structure.compute_structural_signature() + skill.set_structure(structure) + self.skill_store.save(skill) + + return { + "skill_id": skill_id, + "status": "decomposed", + "slots": [s.name for s in slots], + "step_count": len(steps), + "structural_signature": structure.structural_signature, + } def add( self, diff --git a/engram/skills/executor.py b/engram/skills/executor.py index b8a385a..d99b1aa 100644 --- a/engram/skills/executor.py +++ b/engram/skills/executor.py @@ -28,9 +28,13 @@ def apply( self, skill_id: str, context: Optional[Dict[str, Any]] = None, + bindings: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: """Apply a specific skill by ID. + If the skill has structure and bindings are provided, renders + steps with bindings and includes gap analysis. + Returns a dict with the skill recipe and metadata. """ skill = self._store.get(skill_id) @@ -43,6 +47,22 @@ def apply( skill.updated_at = skill.last_used_at self._store.save(skill) + # Structural apply path + structure = skill.get_structure() + if structure and bindings: + recipe = self._build_structural_recipe(skill, structure, bindings) + gap_analysis = self._analyze_gaps(skill, structure, bindings) + return { + "skill_id": skill.id, + "skill_name": skill.name, + "recipe": recipe, + "confidence": round(skill.confidence, 4), + "injected": True, + "source": skill.source, + "structural": True, + "gap_analysis": gap_analysis, + } + recipe = self._build_recipe(skill, context) return { "skill_id": skill.id, @@ -144,3 +164,99 @@ def _build_recipe( lines.append(f"**Tags:** {', '.join(skill.tags)}") return "\n".join(lines) + + def search_structural( + self, + query_steps: List[str], + limit: int = 5, + min_similarity: float = 0.3, + ) -> List[Dict[str, Any]]: + """Search for skills by structural similarity to given steps. + + Decomposes query_steps into templates, then compares against + all cached skills that have structure. + """ + from engram.skills.structure import ( + extract_slots_heuristic, + structural_similarity, + ) + + _, query_structured = extract_slots_heuristic(query_steps) + + results = [] + for skill in self._store.list_all(): + structure = skill.get_structure() + if structure is None: + continue + + sim = structural_similarity(query_structured, structure.structured_steps) + if sim >= min_similarity: + results.append({ + "skill_id": skill.id, + "name": skill.name, + "description": skill.description, + "confidence": round(skill.confidence, 4), + "structural_similarity": round(sim, 4), + "tags": skill.tags, + }) + + results.sort(key=lambda r: r["structural_similarity"], reverse=True) + return results[:limit] + + def _build_structural_recipe( + self, + skill: Skill, + structure: "SkillStructure", + bindings: Dict[str, str], + ) -> str: + """Format a structured skill as injectable markdown with slot bindings.""" + from engram.skills.structure import SkillStructure + + lines = [ + f"## Skill: {skill.name} (Structural)", + f"**Confidence:** {skill.confidence:.0%} ", + f"**Source:** {skill.source} ", + f"**Used:** {skill.use_count} times", + "", + ] + + if skill.description: + lines.extend([skill.description, ""]) + + # Slot bindings table + if structure.slots: + lines.append("### Slot Bindings") + lines.append("| Slot | Value | Status |") + lines.append("|------|-------|--------|") + for slot in structure.slots: + value = bindings.get(slot.name, "—") + known = structure.known_bindings.get(slot.name, []) + if value == "—": + status = "UNBOUND" + elif value.lower() in [v.lower() for v in known]: + status = "proven" + else: + status = "UNTESTED" + lines.append(f"| {slot.name} | {value} | {status} |") + lines.append("") + + # Rendered steps with role markers + rendered = structure.render_steps(bindings) + lines.append("### Steps") + for i, (step_text, sstep) in enumerate(zip(rendered, structure.structured_steps), 1): + role_marker = "[S]" if sstep.role == "structural" else "[V]" + lines.append(f"{i}. {role_marker} {step_text}") + lines.append("") + + return "\n".join(lines) + + def _analyze_gaps( + self, + skill: Skill, + structure: "SkillStructure", + bindings: Dict[str, str], + ) -> Dict[str, Any]: + """Run gap analysis for a structural apply.""" + from engram.skills.structure import analyze_gaps + report = analyze_gaps(structure, bindings, skill.confidence) + return report.to_dict() diff --git a/engram/skills/hashing.py b/engram/skills/hashing.py index a471a31..a552ea8 100644 --- a/engram/skills/hashing.py +++ b/engram/skills/hashing.py @@ -61,3 +61,21 @@ def skill_signature_hash( "tags": sorted(str(t).strip().lower() for t in tags), } return hashlib.sha256(stable_json(obj).encode("utf-8")).hexdigest() + + +def structural_signature_hash( + step_templates: Sequence[str], + step_roles: Sequence[str], + slot_names: Sequence[str], +) -> str: + """SHA-256 hash of normalized templates + roles + sorted slot names. + + Two skills with the same structural signature share the same recipe + structure, even if their slot values differ. + """ + obj = { + "templates": [str(t).strip().lower() for t in step_templates], + "roles": [str(r).strip().lower() for r in step_roles], + "slot_names": sorted(str(n).strip().lower() for n in slot_names), + } + return hashlib.sha256(stable_json(obj).encode("utf-8")).hexdigest() diff --git a/engram/skills/miner.py b/engram/skills/miner.py index c864b65..7687e73 100644 --- a/engram/skills/miner.py +++ b/engram/skills/miner.py @@ -52,6 +52,8 @@ def __init__( embedder: Any = None, mutation_rate: float = 0.05, min_cluster_size: int = 2, + auto_decompose: bool = True, + use_llm_decomposition: bool = True, ): self._trajectory_store = trajectory_store self._skill_store = skill_store @@ -59,6 +61,8 @@ def __init__( self._embedder = embedder self._mutation_rate = mutation_rate self._min_cluster_size = min_cluster_size + self._auto_decompose = auto_decompose + self._use_llm_decomposition = use_llm_decomposition def mine( self, @@ -249,6 +253,10 @@ def _mine_with_llm(self, cluster: List[Trajectory]) -> Optional[Skill]: # Apply mutation skill = self._maybe_mutate(skill) + # Structural decomposition + if self._auto_decompose: + skill = self._add_structure(skill, cluster) + return skill def _mine_heuristic(self, cluster: List[Trajectory]) -> Optional[Skill]: @@ -282,6 +290,58 @@ def _mine_heuristic(self, cluster: List[Trajectory]) -> Optional[Skill]: ) skill = self._maybe_mutate(skill) + + # Structural decomposition + if self._auto_decompose: + skill = self._add_structure(skill, cluster) + + return skill + + def _add_structure(self, skill: Skill, cluster: List[Trajectory]) -> Skill: + """Add structural decomposition to a mined skill. + + Extracts slots (via LLM or heuristic), populates known_bindings + from trajectory cluster, and computes structural signature. + """ + if not skill.steps: + return skill + + try: + from engram.skills.structure import ( + SkillStructure, + extract_slots_heuristic, + extract_slots_llm, + ) + + if self._use_llm_decomposition and self._llm: + slots, structured_steps = extract_slots_llm( + name=skill.name, + description=skill.description, + steps=skill.steps, + tags=skill.tags, + llm=self._llm, + ) + else: + slots, structured_steps = extract_slots_heuristic( + skill.steps, skill.tags, + ) + + # Populate known_bindings from slot examples discovered during extraction + known_bindings: Dict[str, List[str]] = {} + for slot in slots: + if slot.examples: + known_bindings[slot.name] = list(slot.examples) + + structure = SkillStructure( + slots=slots, + structured_steps=structured_steps, + known_bindings=known_bindings, + ) + structure.compute_structural_signature() + skill.set_structure(structure) + except Exception as e: + logger.warning("Structural decomposition failed for '%s': %s", skill.name, e) + return skill def _maybe_mutate(self, skill: Skill) -> Skill: diff --git a/engram/skills/outcomes.py b/engram/skills/outcomes.py index aa65443..2f83a6b 100644 --- a/engram/skills/outcomes.py +++ b/engram/skills/outcomes.py @@ -8,8 +8,9 @@ from __future__ import annotations import logging +from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from engram.skills.schema import Skill from engram.skills.store import SkillStore @@ -21,6 +22,36 @@ FAILURE_WEIGHT = 0.15 +@dataclass +class StepOutcome: + """Granular outcome for a single step within a skill execution.""" + + step_index: int = 0 + success: bool = True + failure_type: Optional[str] = None # "structural" | "slot" + failed_slot: Optional[str] = None + notes: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "step_index": self.step_index, + "success": self.success, + "failure_type": self.failure_type, + "failed_slot": self.failed_slot, + "notes": self.notes, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "StepOutcome": + return cls( + step_index=int(data.get("step_index", 0)), + success=data.get("success", True), + failure_type=data.get("failure_type"), + failed_slot=data.get("failed_slot"), + notes=data.get("notes"), + ) + + def compute_confidence(success_count: int, fail_count: int) -> float: """Compute Bayesian-inspired confidence score. @@ -58,9 +89,13 @@ def log_outcome( skill_id: str, success: bool, notes: Optional[str] = None, + step_outcomes: Optional[List[StepOutcome]] = None, ) -> Dict[str, Any]: """Log a skill outcome and update confidence. + If step_outcomes are provided and the skill has structure, + per-step confidence/counts are updated too. + Returns updated skill stats. """ skill = self._store.get(skill_id) @@ -73,6 +108,26 @@ def log_outcome( else: skill.fail_count += 1 + # Granular per-step feedback + step_updates = [] + if step_outcomes and skill.structure is not None: + from engram.skills.structure import SkillStructure + structure = SkillStructure.from_dict(skill.structure) + for so in step_outcomes: + if 0 <= so.step_index < len(structure.structured_steps): + step = structure.structured_steps[so.step_index] + if so.success: + step.success_count += 1 + else: + step.fail_count += 1 + step.confidence = compute_confidence(step.success_count, step.fail_count) + step_updates.append({ + "step_index": so.step_index, + "template": step.template, + "new_confidence": round(step.confidence, 4), + }) + skill.structure = structure.to_dict() + # Recompute confidence old_confidence = skill.confidence skill.confidence = compute_confidence(skill.success_count, skill.fail_count) @@ -81,7 +136,7 @@ def log_outcome( # Persist self._store.save(skill) - return { + result = { "skill_id": skill.id, "skill_name": skill.name, "success": success, @@ -91,3 +146,6 @@ def log_outcome( "fail_count": skill.fail_count, "notes": notes, } + if step_updates: + result["step_updates"] = step_updates + return result diff --git a/engram/skills/schema.py b/engram/skills/schema.py index 9c7a719..c6d5e46 100644 --- a/engram/skills/schema.py +++ b/engram/skills/schema.py @@ -42,6 +42,18 @@ class Skill: default_factory=lambda: datetime.now(timezone.utc).isoformat() ) last_used_at: Optional[str] = None + structure: Optional[Dict[str, Any]] = None # Serialized SkillStructure + + def get_structure(self) -> Optional["SkillStructure"]: + """Lazy-deserialize the structure field into a SkillStructure.""" + if self.structure is None: + return None + from engram.skills.structure import SkillStructure + return SkillStructure.from_dict(self.structure) + + def set_structure(self, structure: "SkillStructure") -> None: + """Serialize a SkillStructure and store it.""" + self.structure = structure.to_dict() def __post_init__(self): if not self.signature_hash: @@ -72,6 +84,8 @@ def to_skill_md(self) -> str: "updated_at": self.updated_at, "last_used_at": self.last_used_at, } + if self.structure is not None: + frontmatter["structure"] = self.structure yaml_str = yaml.dump(frontmatter, default_flow_style=False, sort_keys=False) body = self.body_markdown or self._generate_body() return f"---\n{yaml_str}---\n\n{body}\n" @@ -116,6 +130,7 @@ def from_skill_md(cls, content: str) -> "Skill": created_at=fm.get("created_at", datetime.now(timezone.utc).isoformat()), updated_at=fm.get("updated_at", datetime.now(timezone.utc).isoformat()), last_used_at=fm.get("last_used_at"), + structure=fm.get("structure"), ) def _generate_body(self) -> str: @@ -137,7 +152,7 @@ def _generate_body(self) -> str: def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization.""" - return { + d = { "id": self.id, "name": self.name, "description": self.description, @@ -153,6 +168,9 @@ def to_dict(self) -> Dict[str, Any]: "updated_at": self.updated_at, "last_used_at": self.last_used_at, } + if self.structure is not None: + d["structure"] = self.structure + return d @dataclass diff --git a/engram/skills/store.py b/engram/skills/store.py index de04468..9e82aae 100644 --- a/engram/skills/store.py +++ b/engram/skills/store.py @@ -165,6 +165,59 @@ def get_by_signature(self, sig_hash: str) -> Optional[Skill]: return skill return None + def search_structural( + self, + query_steps: List[str], + limit: int = 5, + min_similarity: float = 0.3, + ) -> List[Skill]: + """Search for skills by structural similarity to given steps.""" + from engram.skills.structure import ( + extract_slots_heuristic, + structural_similarity, + ) + + _, query_structured = extract_slots_heuristic(query_steps) + + scored: List[tuple] = [] + for skill in self._cache.values(): + structure = skill.get_structure() + if structure is None: + continue + sim = structural_similarity(query_structured, structure.structured_steps) + if sim >= min_similarity: + scored.append((sim, skill)) + + # Also check filesystem for skills not yet cached + self.sync_from_filesystem() + for skill in self._cache.values(): + structure = skill.get_structure() + if structure is None: + continue + # Avoid re-scoring already scored skills + if any(s.id == skill.id for _, s in scored): + continue + sim = structural_similarity(query_structured, structure.structured_steps) + if sim >= min_similarity: + scored.append((sim, skill)) + + scored.sort(key=lambda x: x[0], reverse=True) + return [skill for _, skill in scored[:limit]] + + def get_by_structural_signature(self, sig_hash: str) -> Optional[Skill]: + """Find skill by structural signature hash (structural dedup check).""" + for skill in self._cache.values(): + structure = skill.get_structure() + if structure and structure.structural_signature == sig_hash: + return skill + + self.sync_from_filesystem() + for skill in self._cache.values(): + structure = skill.get_structure() + if structure and structure.structural_signature == sig_hash: + return skill + return None + def delete(self, skill_id: str) -> bool: """Delete a skill from filesystem and index.""" self._cache.pop(skill_id, None) diff --git a/engram/skills/structure.py b/engram/skills/structure.py new file mode 100644 index 0000000..5774588 --- /dev/null +++ b/engram/skills/structure.py @@ -0,0 +1,498 @@ +"""Structural Intelligence — recipe/ingredient decomposition for skills. + +Every task has ingredients (swappable components) and a recipe (transferable +structure). This module provides: + +- Slot/StructuredStep/SkillStructure dataclasses +- Heuristic and LLM-enhanced slot extraction from flat step lists +- Structural similarity (analogical retrieval via LCS) +- Gap analysis for transfer to new contexts +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Tuple + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + + +@dataclass +class Slot: + """A named variable in a skill recipe (the 'ingredient').""" + + name: str + description: str = "" + slot_type: str = "string" # string | tool | path | config + default: Optional[str] = None + examples: List[str] = field(default_factory=list) + required: bool = True + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "slot_type": self.slot_type, + "default": self.default, + "examples": self.examples, + "required": self.required, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Slot": + return cls( + name=data.get("name", ""), + description=data.get("description", ""), + slot_type=data.get("slot_type", "string"), + default=data.get("default"), + examples=data.get("examples", []), + required=data.get("required", True), + ) + + +@dataclass +class StructuredStep: + """A single step in a skill recipe, with template slots.""" + + template: str # e.g. "Run tests using {test_framework}" + role: str = "structural" # "structural" (core recipe) | "variable" (context-dependent) + slot_refs: List[str] = field(default_factory=list) + confidence: float = 1.0 + success_count: int = 0 + fail_count: int = 0 + order_index: int = 0 + + def render(self, bindings: Dict[str, str]) -> str: + """Replace {slot} placeholders with bound values.""" + result = self.template + for slot_name, value in bindings.items(): + result = result.replace(f"{{{slot_name}}}", value) + return result + + def to_dict(self) -> Dict[str, Any]: + return { + "template": self.template, + "role": self.role, + "slot_refs": self.slot_refs, + "confidence": self.confidence, + "success_count": self.success_count, + "fail_count": self.fail_count, + "order_index": self.order_index, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "StructuredStep": + return cls( + template=data.get("template", ""), + role=data.get("role", "structural"), + slot_refs=data.get("slot_refs", []), + confidence=float(data.get("confidence", 1.0)), + success_count=int(data.get("success_count", 0)), + fail_count=int(data.get("fail_count", 0)), + order_index=int(data.get("order_index", 0)), + ) + + +@dataclass +class SkillStructure: + """The full structural decomposition of a skill.""" + + slots: List[Slot] = field(default_factory=list) + structured_steps: List[StructuredStep] = field(default_factory=list) + known_bindings: Dict[str, List[str]] = field(default_factory=dict) # slot_name -> [proven values] + context_bindings: Dict[str, Dict[str, str]] = field(default_factory=dict) # context_tag -> {slot: value} + structural_signature: str = "" + + def render_steps(self, bindings: Dict[str, str]) -> List[str]: + """Render all steps with the given slot bindings.""" + return [step.render(bindings) for step in self.structured_steps] + + def compute_structural_signature(self) -> str: + """Compute a hash that captures the recipe structure, ignoring slot values.""" + from engram.skills.hashing import structural_signature_hash + self.structural_signature = structural_signature_hash( + step_templates=[s.template for s in self.structured_steps], + step_roles=[s.role for s in self.structured_steps], + slot_names=[s.name for s in self.slots], + ) + return self.structural_signature + + def to_dict(self) -> Dict[str, Any]: + return { + "slots": [s.to_dict() for s in self.slots], + "structured_steps": [s.to_dict() for s in self.structured_steps], + "known_bindings": self.known_bindings, + "context_bindings": self.context_bindings, + "structural_signature": self.structural_signature, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SkillStructure": + if not data: + return cls() + return cls( + slots=[Slot.from_dict(s) for s in data.get("slots", [])], + structured_steps=[StructuredStep.from_dict(s) for s in data.get("structured_steps", [])], + known_bindings=data.get("known_bindings", {}), + context_bindings=data.get("context_bindings", {}), + structural_signature=data.get("structural_signature", ""), + ) + + +# --------------------------------------------------------------------------- +# Heuristic slot extraction patterns +# --------------------------------------------------------------------------- + +_SLOT_PATTERNS: Dict[str, re.Pattern] = { + "language": re.compile( + r"\b(python|javascript|typescript|go|golang|rust|java|ruby|c\+\+|" + r"csharp|c#|scala|kotlin|swift|php|perl|elixir|haskell|lua|r)\b", + re.IGNORECASE, + ), + "test_framework": re.compile( + r"\b(pytest|jest|mocha|vitest|go\s+test|cargo\s+test|junit|rspec|" + r"minitest|phpunit|nose|unittest|cypress|playwright)\b", + re.IGNORECASE, + ), + "package_manager": re.compile( + r"\b(pip|npm|yarn|pnpm|cargo|gem|maven|gradle|composer|bun|poetry|" + r"conda|brew|apt|dnf|pacman)\b", + re.IGNORECASE, + ), + "deploy_target": re.compile( + r"\b(aws|gcp|azure|heroku|docker|kubernetes|k8s|vercel|netlify|" + r"fly\.io|cloudflare|digitalocean|lambda|ecs|ec2|s3)\b", + re.IGNORECASE, + ), + "file_path": re.compile( + r"(? Tuple[List[Slot], List[StructuredStep]]: + """Extract slots and create structured steps from flat step list using heuristics. + + Returns (slots, structured_steps). + """ + tags = tags or [] + discovered_slots: Dict[str, Slot] = {} + structured_steps: List[StructuredStep] = [] + + for idx, step in enumerate(steps): + template = step + slot_refs: List[str] = [] + + for slot_name, pattern in _SLOT_PATTERNS.items(): + matches = list(pattern.finditer(template)) + for match in reversed(matches): # reverse so indices stay valid + matched_text = match.group(0) + # Replace match with {slot_name} placeholder + template = template[:match.start()] + f"{{{slot_name}}}" + template[match.end():] + + if slot_name not in slot_refs: + slot_refs.append(slot_name) + + if slot_name not in discovered_slots: + discovered_slots[slot_name] = Slot( + name=slot_name, + slot_type=_infer_slot_type(slot_name), + examples=[matched_text], + ) + else: + if matched_text not in discovered_slots[slot_name].examples: + discovered_slots[slot_name].examples.append(matched_text) + + # Classify role: variable if it has slots, structural if not + role = "variable" if slot_refs else "structural" + + structured_steps.append(StructuredStep( + template=template, + role=role, + slot_refs=slot_refs, + order_index=idx, + )) + + return list(discovered_slots.values()), structured_steps + + +def _infer_slot_type(slot_name: str) -> str: + """Infer slot type from its name.""" + if slot_name == "file_path": + return "path" + if slot_name == "tool": + return "tool" + if slot_name in ("deploy_target",): + return "config" + return "string" + + +# --------------------------------------------------------------------------- +# LLM-enhanced slot extraction +# --------------------------------------------------------------------------- + +_LLM_DECOMPOSE_PROMPT = """Analyze this skill and decompose it into a recipe with variable slots. + +Skill: {name} +Description: {description} +Steps: +{steps_text} +Tags: {tags} + +Identify which parts of each step are: +1. STRUCTURAL (the core recipe pattern that transfers to any similar task) +2. VARIABLE (swappable components: languages, tools, frameworks, paths, targets) + +Return a JSON object: +{{ + "slots": [ + {{"name": "slot_name", "description": "what this slot represents", "slot_type": "string|tool|path|config", "examples": ["value1", "value2"]}} + ], + "structured_steps": [ + {{"template": "Step text with {{slot_name}} placeholders", "role": "structural|variable"}} + ] +}} + +Respond with ONLY the JSON object.""" + + +def extract_slots_llm( + name: str, + description: str, + steps: List[str], + tags: List[str], + llm: Any, +) -> Tuple[List[Slot], List[StructuredStep]]: + """Use an LLM to identify slots and rewrite steps as templates. + + Falls back to heuristic extraction on failure. + """ + steps_text = "\n".join(f" {i+1}. {s}" for i, s in enumerate(steps)) + prompt = _LLM_DECOMPOSE_PROMPT.format( + name=name, + description=description, + steps_text=steps_text, + tags=", ".join(tags), + ) + + try: + response = llm.generate(prompt) + response_text = response.strip() + # Strip code fences if present + if response_text.startswith("```"): + response_text = response_text.strip("`").strip() + if response_text.startswith("json"): + response_text = response_text[4:].strip() + + data = json.loads(response_text) + + slots = [] + for sd in data.get("slots", []): + slots.append(Slot( + name=sd.get("name", ""), + description=sd.get("description", ""), + slot_type=sd.get("slot_type", "string"), + examples=sd.get("examples", []), + )) + + structured_steps = [] + for idx, sd in enumerate(data.get("structured_steps", [])): + template = sd.get("template", "") + role = sd.get("role", "structural") + # Detect slot refs from template + slot_refs = re.findall(r"\{(\w+)\}", template) + structured_steps.append(StructuredStep( + template=template, + role=role, + slot_refs=slot_refs, + order_index=idx, + )) + + if slots or structured_steps: + return slots, structured_steps + + except Exception as e: + logger.warning("LLM slot extraction failed, falling back to heuristic: %s", e) + + return extract_slots_heuristic(steps, tags) + + +# --------------------------------------------------------------------------- +# Structural similarity (analogical retrieval) +# --------------------------------------------------------------------------- + +def _normalize_template(template: str) -> str: + """Normalize a template for structural comparison. + + Replaces all {slot_name} with a generic {SLOT} and lowercases. + """ + return re.sub(r"\{\w+\}", "{SLOT}", template).strip().lower() + + +def _lcs_length(a: Sequence[str], b: Sequence[str]) -> int: + """Longest Common Subsequence length (dynamic programming).""" + m, n = len(a), len(b) + if m == 0 or n == 0: + return 0 + # Space-optimized: only need two rows + prev = [0] * (n + 1) + curr = [0] * (n + 1) + for i in range(1, m + 1): + for j in range(1, n + 1): + if a[i - 1] == b[j - 1]: + curr[j] = prev[j - 1] + 1 + else: + curr[j] = max(prev[j], curr[j - 1]) + prev, curr = curr, [0] * (n + 1) + return prev[n] + + +def structural_similarity(a_steps: List[StructuredStep], b_steps: List[StructuredStep]) -> float: + """Compute structural similarity between two skill recipes. + + Normalizes templates (replacing slot names with generic {SLOT}), + then computes Dice coefficient over the LCS: + similarity = 2 * len(LCS) / (len(a) + len(b)) + + Returns 0.0..1.0. Two skills with the same recipe but different slot + values will score ~1.0. + """ + if not a_steps and not b_steps: + return 1.0 + if not a_steps or not b_steps: + return 0.0 + + a_normalized = [_normalize_template(s.template) for s in a_steps] + b_normalized = [_normalize_template(s.template) for s in b_steps] + + lcs = _lcs_length(a_normalized, b_normalized) + return (2.0 * lcs) / (len(a_normalized) + len(b_normalized)) + + +# --------------------------------------------------------------------------- +# Gap analysis +# --------------------------------------------------------------------------- + + +@dataclass +class GapReport: + """Analysis of what transfers vs what needs experimentation.""" + + skill_id: str = "" + total_slots: int = 0 + bound_slots: List[Dict[str, Any]] = field(default_factory=list) + unbound_slots: List[Dict[str, Any]] = field(default_factory=list) + transfer_confidence: float = 0.0 + structural_coverage: float = 0.0 + variable_coverage: float = 0.0 + recommendations: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "skill_id": self.skill_id, + "total_slots": self.total_slots, + "bound_slots": self.bound_slots, + "unbound_slots": self.unbound_slots, + "transfer_confidence": round(self.transfer_confidence, 4), + "structural_coverage": round(self.structural_coverage, 4), + "variable_coverage": round(self.variable_coverage, 4), + "recommendations": self.recommendations, + } + + +def analyze_gaps( + structure: SkillStructure, + target_context: Dict[str, str], + skill_confidence: float = 0.5, +) -> GapReport: + """Analyze what transfers from a skill to a new context. + + For each slot: + - target provides value AND value in known_bindings → proven (high confidence) + - target provides value but value not in known_bindings → untested + - target doesn't provide value → unknown (needs input) + """ + report = GapReport(total_slots=len(structure.slots)) + recommendations = [] + + structural_steps = [s for s in structure.structured_steps if s.role == "structural"] + variable_steps = [s for s in structure.structured_steps if s.role == "variable"] + + # Structural coverage: structural steps are always transferable + total_steps = len(structure.structured_steps) + if total_steps > 0: + report.structural_coverage = len(structural_steps) / total_steps + else: + report.structural_coverage = 1.0 + + bound_count = 0 + for slot in structure.slots: + target_value = target_context.get(slot.name) + known_values = structure.known_bindings.get(slot.name, []) + + if target_value and target_value.lower() in [v.lower() for v in known_values]: + # Proven binding + report.bound_slots.append({ + "slot": slot.name, + "value": target_value, + "status": "proven", + "confidence": "high", + }) + bound_count += 1 + elif target_value: + # Untested binding + report.bound_slots.append({ + "slot": slot.name, + "value": target_value, + "status": "untested", + "confidence": "low", + }) + recommendations.append( + f"Slot '{slot.name}' has untested value '{target_value}'. " + f"Known values: {known_values or ['none']}. Experiment carefully." + ) + bound_count += 0.5 # Partial credit + else: + # Unknown + report.unbound_slots.append({ + "slot": slot.name, + "status": "unknown", + "known_values": known_values, + "required": slot.required, + }) + if slot.required: + recommendations.append( + f"Slot '{slot.name}' needs a value. " + f"Known options: {known_values or ['none']}." + ) + + # Variable coverage + if structure.slots: + report.variable_coverage = bound_count / len(structure.slots) + else: + report.variable_coverage = 1.0 + + # Transfer confidence: weighted combination + report.transfer_confidence = ( + 0.4 * report.structural_coverage + + 0.3 * report.variable_coverage + + 0.3 * skill_confidence + ) + + report.recommendations = recommendations + return report diff --git a/engram/utils/prompts.py b/engram/utils/prompts.py index 3dff7e0..2e94274 100644 --- a/engram/utils/prompts.py +++ b/engram/utils/prompts.py @@ -200,6 +200,97 @@ Return exactly {count} elements in the results array. """ +UNIFIED_ENRICHMENT_PROMPT = """You are enriching a memory for a long-term AI memory system. +Perform ALL analyses in a single pass. + +MEMORY: {content} +ECHO DEPTH: {depth} +ECHO INSTRUCTIONS: {depth_instructions} +EXISTING CATEGORIES: +{existing_categories} +INCLUDE ENTITIES: {include_entities} +INCLUDE PROFILES: {include_profiles} + +Return ONLY valid JSON matching this schema: +{{ + "echo": {{ + "paraphrases": ["diverse rephrasings"], + "keywords": ["core entities/tags"], + "implications": ["logical consequences"], + "questions": ["questions this answers"], + "question_form": "single question-form version of the memory or null", + "category": "fact|preference|goal|relationship|event", + "importance": 0.0-1.0 + }}, + "category": {{ + "action": "use_existing|create_child|create_new", + "category_id": "existing_category_id or null", + "new_category": {{ + "name": "Category Name", + "description": "Brief description", + "keywords": ["keyword1", "keyword2"], + "parent_id": "parent_category_id or null" + }}, + "confidence": 0.0-1.0 + }}, + "entities": [ + {{"name": "entity name", "type": "person|organization|technology|concept|location|project|tool|preference"}} + ], + "profiles": [ + {{"name": "person name", "type": "self|contact|entity", "facts": ["fact"], "preferences": ["pref"]}} + ] +}} + +Rules: +- Follow ECHO INSTRUCTIONS for which echo fields to populate +- For category: prefer use_existing when an existing category fits well +- For entities: extract named entities (people, tech, orgs, tools) +- For profiles: extract person mentions with their facts/preferences +- If INCLUDE ENTITIES or INCLUDE PROFILES is "no", return empty arrays for those +""" + +UNIFIED_ENRICHMENT_BATCH_PROMPT = """You are enriching multiple memories for a long-term AI memory system. +Perform ALL analyses in a single pass for each memory. + +MEMORIES: +{memories_block} + +ECHO DEPTH: {depth} +ECHO INSTRUCTIONS: {depth_instructions} +EXISTING CATEGORIES: +{existing_categories} +INCLUDE ENTITIES: {include_entities} +INCLUDE PROFILES: {include_profiles} + +Return ONLY valid JSON with a "results" array. Each element must include the memory index: +{{ + "results": [ + {{ + "index": 0, + "echo": {{ + "paraphrases": ["diverse rephrasings"], + "keywords": ["core entities/tags"], + "implications": ["logical consequences"], + "questions": ["questions this answers"], + "question_form": "single question-form version or null", + "category": "fact|preference|goal|relationship|event", + "importance": 0.0-1.0 + }}, + "category": {{ + "action": "use_existing|create_child|create_new", + "category_id": "existing_category_id or null", + "new_category": null, + "confidence": 0.0-1.0 + }}, + "entities": [{{"name": "entity name", "type": "person|technology|..."}}], + "profiles": [{{"name": "person name", "type": "self|contact|entity", "facts": [], "preferences": []}}] + }} + ] +}} + +IMPORTANT: Return exactly {count} elements in the results array, one per memory, in the same order. +""" + FUSION_PROMPT = """You are consolidating multiple related memories into a single, comprehensive memory. This is part of a biologically-inspired memory system that mimics how human brains consolidate related memories during sleep. The goal is to: diff --git a/tests/test_structural.py b/tests/test_structural.py new file mode 100644 index 0000000..d24f73a --- /dev/null +++ b/tests/test_structural.py @@ -0,0 +1,455 @@ +"""Tests for the Structural Intelligence layer. + +Covers: Slot, StructuredStep, SkillStructure, heuristic extraction, +structural similarity, gap analysis, integration with Skill schema, +and structural_signature_hash. +""" + +import pytest + +from engram.skills.structure import ( + GapReport, + Slot, + SkillStructure, + StructuredStep, + _lcs_length, + _normalize_template, + analyze_gaps, + extract_slots_heuristic, + structural_similarity, +) +from engram.skills.hashing import structural_signature_hash +from engram.skills.schema import Skill + + +# ── Slot tests ── + + +class TestSlot: + def test_roundtrip_serialization(self): + slot = Slot( + name="language", + description="Programming language", + slot_type="string", + default="python", + examples=["python", "go"], + required=True, + ) + d = slot.to_dict() + restored = Slot.from_dict(d) + assert restored.name == "language" + assert restored.description == "Programming language" + assert restored.slot_type == "string" + assert restored.default == "python" + assert restored.examples == ["python", "go"] + assert restored.required is True + + def test_from_dict_defaults(self): + slot = Slot.from_dict({"name": "x"}) + assert slot.slot_type == "string" + assert slot.default is None + assert slot.examples == [] + assert slot.required is True + + +# ── StructuredStep tests ── + + +class TestStructuredStep: + def test_render_with_bindings(self): + step = StructuredStep( + template="Build {language} app", + slot_refs=["language"], + ) + assert step.render({"language": "Go"}) == "Build Go app" + + def test_render_missing_slots_passthrough(self): + step = StructuredStep( + template="Run {test_framework} suite", + slot_refs=["test_framework"], + ) + result = step.render({}) + assert result == "Run {test_framework} suite" + + def test_render_multiple_slots(self): + step = StructuredStep( + template="Deploy {language} to {deploy_target}", + slot_refs=["language", "deploy_target"], + ) + result = step.render({"language": "Rust", "deploy_target": "AWS"}) + assert result == "Deploy Rust to AWS" + + def test_roundtrip_serialization(self): + step = StructuredStep( + template="Run {test_framework}", + role="variable", + slot_refs=["test_framework"], + confidence=0.8, + success_count=5, + fail_count=1, + order_index=2, + ) + d = step.to_dict() + restored = StructuredStep.from_dict(d) + assert restored.template == step.template + assert restored.role == "variable" + assert restored.confidence == 0.8 + assert restored.success_count == 5 + assert restored.fail_count == 1 + assert restored.order_index == 2 + + +# ── SkillStructure tests ── + + +class TestSkillStructure: + def test_render_steps(self): + structure = SkillStructure( + slots=[Slot(name="language"), Slot(name="deploy_target")], + structured_steps=[ + StructuredStep(template="Build {language} app", slot_refs=["language"]), + StructuredStep(template="Deploy to {deploy_target}", slot_refs=["deploy_target"]), + ], + ) + rendered = structure.render_steps({"language": "Go", "deploy_target": "GCP"}) + assert rendered == ["Build Go app", "Deploy to GCP"] + + def test_compute_structural_signature_deterministic(self): + structure = SkillStructure( + slots=[Slot(name="language"), Slot(name="test_framework")], + structured_steps=[ + StructuredStep(template="Build {language} app", role="variable"), + StructuredStep(template="Run {test_framework}", role="variable"), + ], + ) + sig1 = structure.compute_structural_signature() + sig2 = structure.compute_structural_signature() + assert sig1 == sig2 + assert len(sig1) == 64 # SHA-256 hex + + def test_roundtrip_serialization(self): + structure = SkillStructure( + slots=[Slot(name="language", examples=["python"])], + structured_steps=[ + StructuredStep(template="Build {language} app", role="variable", slot_refs=["language"]), + StructuredStep(template="Review PR", role="structural"), + ], + known_bindings={"language": ["python", "go"]}, + context_bindings={"prod": {"language": "go"}}, + ) + structure.compute_structural_signature() + d = structure.to_dict() + restored = SkillStructure.from_dict(d) + assert len(restored.slots) == 1 + assert len(restored.structured_steps) == 2 + assert restored.known_bindings["language"] == ["python", "go"] + assert restored.context_bindings["prod"] == {"language": "go"} + assert restored.structural_signature == structure.structural_signature + + def test_from_dict_empty(self): + s = SkillStructure.from_dict({}) + assert s.slots == [] + assert s.structured_steps == [] + + def test_from_dict_none(self): + s = SkillStructure.from_dict(None) + assert s.slots == [] + + +# ── Heuristic extraction tests ── + + +class TestExtractSlotsHeuristic: + def test_detects_language(self): + slots, steps = extract_slots_heuristic(["Build Python app"]) + slot_names = [s.name for s in slots] + assert "language" in slot_names + assert "{language}" in steps[0].template + assert steps[0].role == "variable" + + def test_detects_test_framework(self): + slots, steps = extract_slots_heuristic(["Run pytest"]) + slot_names = [s.name for s in slots] + assert "test_framework" in slot_names + + def test_detects_deploy_target(self): + slots, steps = extract_slots_heuristic(["Deploy to AWS"]) + slot_names = [s.name for s in slots] + assert "deploy_target" in slot_names + + def test_detects_package_manager(self): + slots, steps = extract_slots_heuristic(["Install with pip"]) + slot_names = [s.name for s in slots] + assert "package_manager" in slot_names + + def test_no_false_positives_generic_step(self): + slots, steps = extract_slots_heuristic(["Review the pull request"]) + assert len(slots) == 0 + assert steps[0].role == "structural" + + def test_structural_role_for_generic_steps(self): + slots, steps = extract_slots_heuristic([ + "Review code changes", + "Write documentation", + ]) + for step in steps: + assert step.role == "structural" + + def test_multiple_slots_in_step(self): + slots, steps = extract_slots_heuristic(["Build Python app and deploy to AWS"]) + slot_names = [s.name for s in slots] + assert "language" in slot_names + assert "deploy_target" in slot_names + assert len(steps[0].slot_refs) == 2 + + def test_examples_populated(self): + slots, _ = extract_slots_heuristic(["Run pytest", "Run jest"]) + tf_slot = next(s for s in slots if s.name == "test_framework") + assert "pytest" in tf_slot.examples + assert "jest" in tf_slot.examples + + def test_order_index(self): + _, steps = extract_slots_heuristic(["Step A", "Step B", "Step C"]) + for i, step in enumerate(steps): + assert step.order_index == i + + +# ── Structural similarity tests ── + + +class TestStructuralSimilarity: + def test_identical_steps(self): + steps = [ + StructuredStep(template="Build {language} app"), + StructuredStep(template="Run {test_framework}"), + StructuredStep(template="Deploy to {deploy_target}"), + ] + assert structural_similarity(steps, steps) == 1.0 + + def test_completely_different(self): + a = [StructuredStep(template="Build {language} app")] + b = [StructuredStep(template="Write documentation")] + sim = structural_similarity(a, b) + assert sim < 0.5 + + def test_same_structure_different_slot_names(self): + a = [ + StructuredStep(template="Build {language} app"), + StructuredStep(template="Run {test_framework}"), + ] + b = [ + StructuredStep(template="Build {lang} app"), + StructuredStep(template="Run {testing_tool}"), + ] + sim = structural_similarity(a, b) + assert sim == 1.0 # {SLOT} normalization makes them identical + + def test_empty_steps(self): + assert structural_similarity([], []) == 1.0 + assert structural_similarity([], [StructuredStep(template="x")]) == 0.0 + + def test_partial_overlap(self): + a = [ + StructuredStep(template="Build {SLOT} app"), + StructuredStep(template="Run tests"), + StructuredStep(template="Deploy to {SLOT}"), + ] + b = [ + StructuredStep(template="Build {SLOT} app"), + StructuredStep(template="Deploy to {SLOT}"), + ] + sim = structural_similarity(a, b) + # LCS = 2, dice = 2*2 / (3+2) = 0.8 + assert abs(sim - 0.8) < 0.01 + + +# ── LCS helper tests ── + + +class TestLCS: + def test_identical(self): + assert _lcs_length(["a", "b", "c"], ["a", "b", "c"]) == 3 + + def test_empty(self): + assert _lcs_length([], ["a"]) == 0 + assert _lcs_length(["a"], []) == 0 + + def test_partial(self): + assert _lcs_length(["a", "b", "c"], ["a", "c"]) == 2 + + +# ── Gap analysis tests ── + + +class TestGapAnalysis: + def test_all_bound_and_proven(self): + structure = SkillStructure( + slots=[Slot(name="language"), Slot(name="deploy_target")], + structured_steps=[ + StructuredStep(template="Build {language} app", role="variable"), + StructuredStep(template="Deploy to {deploy_target}", role="variable"), + ], + known_bindings={"language": ["python"], "deploy_target": ["aws"]}, + ) + report = analyze_gaps(structure, {"language": "python", "deploy_target": "aws"}) + assert len(report.bound_slots) == 2 + assert len(report.unbound_slots) == 0 + assert all(b["status"] == "proven" for b in report.bound_slots) + assert report.variable_coverage == 1.0 + + def test_unbound_slots(self): + structure = SkillStructure( + slots=[Slot(name="language", required=True), Slot(name="deploy_target", required=True)], + structured_steps=[], + known_bindings={"language": ["python"]}, + ) + report = analyze_gaps(structure, {"language": "python"}) + assert len(report.unbound_slots) == 1 + assert report.unbound_slots[0]["slot"] == "deploy_target" + assert len(report.recommendations) > 0 + + def test_untested_bindings(self): + structure = SkillStructure( + slots=[Slot(name="language")], + structured_steps=[], + known_bindings={"language": ["python"]}, + ) + report = analyze_gaps(structure, {"language": "go"}) + assert len(report.bound_slots) == 1 + assert report.bound_slots[0]["status"] == "untested" + assert any("untested" in r.lower() for r in report.recommendations) + + def test_structural_coverage(self): + structure = SkillStructure( + slots=[], + structured_steps=[ + StructuredStep(template="Review code", role="structural"), + StructuredStep(template="Build {language} app", role="variable"), + ], + ) + report = analyze_gaps(structure, {}) + assert report.structural_coverage == 0.5 + + def test_transfer_confidence_range(self): + structure = SkillStructure( + slots=[Slot(name="language")], + structured_steps=[StructuredStep(template="Build {language} app", role="variable")], + known_bindings={"language": ["python"]}, + ) + report = analyze_gaps(structure, {"language": "python"}, skill_confidence=0.8) + assert 0.0 <= report.transfer_confidence <= 1.0 + + def test_gap_report_to_dict(self): + report = GapReport( + skill_id="test-123", + total_slots=2, + transfer_confidence=0.75, + ) + d = report.to_dict() + assert d["skill_id"] == "test-123" + assert d["transfer_confidence"] == 0.75 + + +# ── Integration: Skill with structure roundtrip ── + + +class TestSkillStructureIntegration: + def test_skill_md_roundtrip_with_structure(self): + skill = Skill( + name="Deploy App", + description="Deploy an application", + steps=["Build Python app", "Run pytest", "Deploy to AWS"], + tags=["deploy", "python"], + ) + slots, steps = extract_slots_heuristic(skill.steps) + structure = SkillStructure( + slots=slots, + structured_steps=steps, + known_bindings={"language": ["python"]}, + ) + structure.compute_structural_signature() + skill.set_structure(structure) + + # Roundtrip through SKILL.md + md = skill.to_skill_md() + restored = Skill.from_skill_md(md) + + assert restored.structure is not None + restored_structure = restored.get_structure() + assert len(restored_structure.slots) == len(slots) + assert len(restored_structure.structured_steps) == len(steps) + assert restored_structure.structural_signature == structure.structural_signature + + def test_skill_to_dict_with_structure(self): + skill = Skill(name="Test Skill", steps=["Build Python app"]) + slots, steps = extract_slots_heuristic(skill.steps) + structure = SkillStructure(slots=slots, structured_steps=steps) + skill.set_structure(structure) + + d = skill.to_dict() + assert "structure" in d + assert d["structure"]["slots"][0]["name"] == "language" + + def test_skill_without_structure_backward_compatible(self): + skill = Skill(name="Flat Skill", steps=["Do something"]) + assert skill.structure is None + assert skill.get_structure() is None + + md = skill.to_skill_md() + restored = Skill.from_skill_md(md) + assert restored.structure is None + assert restored.get_structure() is None + + +# ── structural_signature_hash tests ── + + +class TestStructuralSignatureHash: + def test_deterministic(self): + h1 = structural_signature_hash( + ["build {language} app", "run {test_framework}"], + ["variable", "variable"], + ["language", "test_framework"], + ) + h2 = structural_signature_hash( + ["build {language} app", "run {test_framework}"], + ["variable", "variable"], + ["language", "test_framework"], + ) + assert h1 == h2 + assert len(h1) == 64 + + def test_different_for_different_structures(self): + h1 = structural_signature_hash( + ["build {language} app"], + ["variable"], + ["language"], + ) + h2 = structural_signature_hash( + ["deploy to {target}"], + ["variable"], + ["target"], + ) + assert h1 != h2 + + def test_slot_name_order_irrelevant(self): + h1 = structural_signature_hash( + ["t1"], ["structural"], ["b", "a"], + ) + h2 = structural_signature_hash( + ["t1"], ["structural"], ["a", "b"], + ) + assert h1 == h2 # slot_names are sorted + + +# ── Normalize template test ── + + +class TestNormalizeTemplate: + def test_replaces_all_slot_names(self): + result = _normalize_template("Build {language} with {tool}") + assert result == "build {slot} with {slot}" + + def test_lowercases(self): + result = _normalize_template("REVIEW PR") + assert result == "review pr" diff --git a/tests/test_unified_enrichment.py b/tests/test_unified_enrichment.py new file mode 100644 index 0000000..7a39353 --- /dev/null +++ b/tests/test_unified_enrichment.py @@ -0,0 +1,504 @@ +"""Tests for unified enrichment processor.""" + +import json +import pytest +from unittest.mock import MagicMock, patch + +from engram.core.category import CategoryMatch, CategoryProcessor +from engram.core.echo import EchoDepth, EchoProcessor, EchoResult +from engram.core.enrichment import ( + EnrichmentResult, + UnifiedCategoryOutput, + UnifiedEchoOutput, + UnifiedEnrichmentOutput, + UnifiedEnrichmentProcessor, + UnifiedEntityOutput, + UnifiedProfileOutput, + _extract_json_blob, + _normalize_unified_dict, + _robust_json_load, +) +from engram.core.graph import Entity, EntityType, KnowledgeGraph +from engram.core.profile import ProfileProcessor, ProfileUpdate + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +VALID_UNIFIED_RESPONSE = json.dumps({ + "echo": { + "paraphrases": ["I enjoy coding in Python", "Python is my preferred language"], + "keywords": ["Python", "programming", "preference"], + "implications": ["User is likely a developer"], + "questions": ["What programming language does the user prefer?"], + "question_form": "What is the user's preferred programming language?", + "category": "preference", + "importance": 0.8, + }, + "category": { + "action": "use_existing", + "category_id": "preferences", + "new_category": None, + "confidence": 0.9, + }, + "entities": [ + {"name": "Python", "type": "technology"}, + ], + "profiles": [ + {"name": "self", "type": "self", "facts": [], "preferences": ["Python programming"]}, + ], +}) + +VALID_BATCH_RESPONSE = json.dumps({ + "results": [ + { + "index": 0, + "echo": { + "paraphrases": ["User likes Python"], + "keywords": ["Python"], + "implications": [], + "questions": [], + "question_form": None, + "category": "preference", + "importance": 0.7, + }, + "category": {"action": "use_existing", "category_id": "preferences", "confidence": 0.8}, + "entities": [{"name": "Python", "type": "technology"}], + "profiles": [], + }, + { + "index": 1, + "echo": { + "paraphrases": ["User works at Acme"], + "keywords": ["Acme", "work"], + "implications": [], + "questions": [], + "question_form": None, + "category": "fact", + "importance": 0.6, + }, + "category": {"action": "use_existing", "category_id": "facts", "confidence": 0.7}, + "entities": [{"name": "Acme", "type": "organization"}], + "profiles": [{"name": "self", "type": "self", "facts": ["Works at Acme"], "preferences": []}], + }, + ] +}) + + +def _make_mock_llm(response=VALID_UNIFIED_RESPONSE): + llm = MagicMock() + llm.generate.return_value = response + return llm + + +def _make_processor(llm=None, echo=True, category=True, graph=True, profile=False): + llm = llm or _make_mock_llm() + echo_proc = EchoProcessor(llm) if echo else None + cat_proc = None + if category: + embedder = MagicMock() + embedder.embed.return_value = [0.1] * 384 + cat_proc = CategoryProcessor(llm=llm, embedder=embedder) + kg = KnowledgeGraph() if graph else None + pp = None + if profile: + db = MagicMock() + pp = ProfileProcessor(db=db, llm=llm) + return UnifiedEnrichmentProcessor( + llm=llm, + echo_processor=echo_proc, + category_processor=cat_proc, + knowledge_graph=kg, + profile_processor=pp, + ) + + +# --------------------------------------------------------------------------- +# Pydantic model tests +# --------------------------------------------------------------------------- + +class TestPydanticModels: + def test_echo_output_defaults(self): + echo = UnifiedEchoOutput() + assert echo.paraphrases == [] + assert echo.keywords == [] + assert echo.importance == 0.5 + + def test_echo_output_coerces_lists(self): + echo = UnifiedEchoOutput(paraphrases=None, keywords="single") + assert echo.paraphrases == [] + assert echo.keywords == ["single"] + + def test_echo_output_coerces_importance(self): + echo = UnifiedEchoOutput(importance="0.7") + assert echo.importance == 0.7 + + def test_category_output_defaults(self): + cat = UnifiedCategoryOutput() + assert cat.action == "use_existing" + assert cat.confidence == 0.5 + + def test_entity_output(self): + ent = UnifiedEntityOutput(name="Python", type="technology") + assert ent.name == "Python" + assert ent.type == "technology" + + def test_profile_output_coerces_lists(self): + prof = UnifiedProfileOutput(name="self", facts=None, preferences="coding") + assert prof.facts == [] + assert prof.preferences == ["coding"] + + def test_full_output_parsing(self): + data = json.loads(VALID_UNIFIED_RESPONSE) + output = UnifiedEnrichmentOutput.model_validate(data) + assert len(output.echo.paraphrases) == 2 + assert output.category.category_id == "preferences" + assert len(output.entities) == 1 + assert output.entities[0].name == "Python" + assert len(output.profiles) == 1 + + def test_output_extra_fields_ignored(self): + data = {"echo": {"extra_field": True, "paraphrases": ["a"], "keywords": ["b"], "importance": 0.5}, + "category": {}, "entities": [], "profiles": []} + output = UnifiedEnrichmentOutput.model_validate(data) + assert output.echo.paraphrases == ["a"] + + +# --------------------------------------------------------------------------- +# JSON parsing tests +# --------------------------------------------------------------------------- + +class TestJsonParsing: + def test_extract_json_blob_clean(self): + blob = _extract_json_blob('{"key": "value"}') + data = json.loads(blob) + assert data["key"] == "value" + + def test_extract_json_blob_code_fence(self): + response = '```json\n{"key": "value"}\n```' + blob = _extract_json_blob(response) + data = json.loads(blob) + assert data["key"] == "value" + + def test_extract_json_blob_with_preamble(self): + response = 'Here is the result:\n{"key": "value"}' + blob = _extract_json_blob(response) + data = json.loads(blob) + assert data["key"] == "value" + + def test_robust_json_load_trailing_comma(self): + text = '{"key": "value", "list": [1, 2, 3,],}' + data = _robust_json_load(text) + assert data["key"] == "value" + + def test_robust_json_load_comments(self): + text = '{"key": "value" // this is a comment\n}' + data = _robust_json_load(text) + assert data["key"] == "value" + + def test_normalize_echo_at_top_level(self): + data = {"paraphrases": ["a"], "keywords": ["b"], "importance": 0.5, "category": {}} + normalized = _normalize_unified_dict(data) + assert "echo" in normalized + assert normalized["echo"]["paraphrases"] == ["a"] + + +# --------------------------------------------------------------------------- +# Converter tests +# --------------------------------------------------------------------------- + +class TestConverters: + def test_to_echo_result_medium(self): + proc = _make_processor() + echo_out = UnifiedEchoOutput( + paraphrases=["rewrite"], + keywords=["key"], + implications=["impl"], + questions=["q?"], + question_form="What?", + category="fact", + importance=0.7, + ) + result = proc._to_echo_result(echo_out, "original", EchoDepth.MEDIUM) + assert isinstance(result, EchoResult) + assert result.echo_depth == EchoDepth.MEDIUM + assert result.paraphrases == ["rewrite"] + assert result.implications == [] # medium skips implications + assert result.questions == [] # medium skips questions + assert result.question_form == "What?" + assert result.strength_multiplier == EchoProcessor.STRENGTH_MULTIPLIERS[EchoDepth.MEDIUM] + + def test_to_echo_result_deep(self): + proc = _make_processor() + echo_out = UnifiedEchoOutput( + paraphrases=["rewrite"], + keywords=["key"], + implications=["impl"], + questions=["q?"], + importance=0.7, + ) + result = proc._to_echo_result(echo_out, "original", EchoDepth.DEEP) + assert result.implications == ["impl"] + assert result.questions == ["q?"] + assert result.strength_multiplier == EchoProcessor.STRENGTH_MULTIPLIERS[EchoDepth.DEEP] + + def test_to_echo_result_question_form_from_questions(self): + proc = _make_processor() + echo_out = UnifiedEchoOutput( + paraphrases=["rewrite"], + keywords=["key"], + questions=["What is X?"], + importance=0.5, + ) + result = proc._to_echo_result(echo_out, "original", EchoDepth.DEEP) + assert result.question_form == "What is X?" + + def test_to_category_match_use_existing(self): + proc = _make_processor() + cat_out = UnifiedCategoryOutput( + action="use_existing", + category_id="preferences", + confidence=0.9, + ) + match = proc._to_category_match(cat_out) + assert isinstance(match, CategoryMatch) + assert match.category_id == "preferences" + assert match.confidence == 0.9 + assert not match.is_new + + def test_to_category_match_create_new(self): + proc = _make_processor() + cat_out = UnifiedCategoryOutput( + action="create_new", + new_category={"name": "Hobbies", "description": "User hobbies", "keywords": ["hobby"]}, + confidence=0.7, + ) + match = proc._to_category_match(cat_out) + assert isinstance(match, CategoryMatch) + assert match.is_new + assert match.confidence == 0.7 + + def test_to_category_match_fallback(self): + proc = _make_processor() + cat_out = UnifiedCategoryOutput() # defaults + match = proc._to_category_match(cat_out) + assert match.category_id == "context" + + def test_to_entities(self): + proc = _make_processor() + entity_outs = [ + UnifiedEntityOutput(name="Python", type="technology"), + UnifiedEntityOutput(name="Alice", type="person"), + UnifiedEntityOutput(name="", type="unknown"), # should be filtered + ] + entities = proc._to_entities(entity_outs) + assert len(entities) == 2 + assert entities[0].name == "Python" + assert entities[0].entity_type == EntityType.TECHNOLOGY + assert entities[1].entity_type == EntityType.PERSON + + def test_to_entities_invalid_type(self): + proc = _make_processor() + entity_outs = [UnifiedEntityOutput(name="Foo", type="invalid_type")] + entities = proc._to_entities(entity_outs) + assert len(entities) == 1 + assert entities[0].entity_type == EntityType.UNKNOWN + + def test_to_profile_updates(self): + proc = _make_processor() + profile_outs = [ + UnifiedProfileOutput(name="self", type="self", facts=["Name: Alice"], preferences=["Python"]), + UnifiedProfileOutput(name="", type="contact"), # should be filtered + ] + updates = proc._to_profile_updates(profile_outs) + assert len(updates) == 1 + assert isinstance(updates[0], ProfileUpdate) + assert updates[0].profile_name == "self" + assert updates[0].new_facts == ["Name: Alice"] + assert updates[0].new_preferences == ["Python"] + + +# --------------------------------------------------------------------------- +# Enrichment flow tests +# --------------------------------------------------------------------------- + +class TestEnrichFlow: + def test_enrich_single_valid(self): + proc = _make_processor() + result = proc.enrich("I prefer Python for backend development", EchoDepth.MEDIUM) + assert isinstance(result, EnrichmentResult) + assert result.echo_result is not None + assert result.echo_result.echo_depth == EchoDepth.MEDIUM + assert result.category_match is not None + assert len(result.entities) >= 1 + assert result.raw_response == VALID_UNIFIED_RESPONSE + + def test_enrich_fallback_on_invalid_json(self): + llm = _make_mock_llm("this is not json at all!!!") + proc = _make_processor(llm=llm) + # Should fall back to individual processors + result = proc.enrich("I prefer Python", EchoDepth.MEDIUM) + assert isinstance(result, EnrichmentResult) + # Fallback calls echo_processor.process which also calls the broken LLM, + # but EchoProcessor has its own fallback to shallow + # The important thing is no exception is raised + + def test_enrich_with_code_fenced_response(self): + response = f"```json\n{VALID_UNIFIED_RESPONSE}\n```" + llm = _make_mock_llm(response) + proc = _make_processor(llm=llm) + result = proc.enrich("I prefer Python", EchoDepth.MEDIUM) + assert result.echo_result is not None + assert result.category_match is not None + + +class TestEnrichBatch: + def test_batch_single_item(self): + proc = _make_processor() + results = proc.enrich_batch(["I prefer Python"], EchoDepth.MEDIUM) + assert len(results) == 1 + assert results[0].echo_result is not None + + def test_batch_multiple_items(self): + llm = _make_mock_llm(VALID_BATCH_RESPONSE) + proc = _make_processor(llm=llm) + results = proc.enrich_batch( + ["I prefer Python", "I work at Acme"], + EchoDepth.MEDIUM, + ) + assert len(results) == 2 + assert results[0].echo_result is not None + assert results[1].echo_result is not None + assert results[0].entities[0].name == "Python" + assert results[1].entities[0].name == "Acme" + + def test_batch_empty(self): + proc = _make_processor() + results = proc.enrich_batch([], EchoDepth.MEDIUM) + assert results == [] + + def test_batch_partial_failure(self): + """If one item fails in batch, it falls back to individual.""" + response = json.dumps({ + "results": [ + { + "index": 0, + "echo": {"paraphrases": ["ok"], "keywords": ["ok"], "importance": 0.5}, + "category": {"action": "use_existing", "category_id": "facts", "confidence": 0.5}, + "entities": [], + "profiles": [], + }, + # Index 1 is missing — should trigger fallback for that item + ] + }) + llm = MagicMock() + # First call returns batch response, subsequent calls return single response + llm.generate.side_effect = [response, VALID_UNIFIED_RESPONSE] + proc = _make_processor(llm=llm) + results = proc.enrich_batch(["Memory 1", "Memory 2"], EchoDepth.MEDIUM) + assert len(results) == 2 + + +# --------------------------------------------------------------------------- +# Fallback tests +# --------------------------------------------------------------------------- + +class TestFallback: + def test_fallback_calls_individual_processors(self): + llm = _make_mock_llm() + proc = _make_processor(llm=llm, echo=True, category=True, graph=True) + # Force JSON parse failure by passing garbage + result = proc._fallback_individual("I prefer Python", EchoDepth.MEDIUM, None) + assert isinstance(result, EnrichmentResult) + # Echo result should come from individual processor + # (it will call llm.generate which returns our valid response, + # but that's the echo prompt format not unified — it may or may not parse, + # the important thing is no exception bubbles up) + + +# --------------------------------------------------------------------------- +# Config toggle tests +# --------------------------------------------------------------------------- + +class TestConfigToggle: + def test_enrichment_config_defaults(self): + from engram.configs.base import EnrichmentConfig + config = EnrichmentConfig() + assert config.enable_unified is False + assert config.fallback_to_individual is True + assert config.include_entities is True + assert config.include_profiles is True + assert config.max_batch_size == 10 + + def test_memory_config_has_enrichment(self): + from engram.configs.base import MemoryConfig + config = MemoryConfig() + assert hasattr(config, "enrichment") + assert config.enrichment.enable_unified is False + + def test_full_preset_enables_unified(self): + from engram.configs.base import MemoryConfig + config = MemoryConfig.full() + assert config.enrichment.enable_unified is True + + def test_minimal_preset_unified_disabled(self): + from engram.configs.base import MemoryConfig + config = MemoryConfig.minimal() + assert config.enrichment.enable_unified is False + + +# --------------------------------------------------------------------------- +# Prompt generation tests +# --------------------------------------------------------------------------- + +class TestPromptGeneration: + def test_single_prompt_contains_content(self): + proc = _make_processor() + prompt = proc._build_prompt("I love Python", EchoDepth.MEDIUM) + assert "I love Python" in prompt + assert "MEMORY:" in prompt + assert "ECHO DEPTH: medium" in prompt + + def test_single_prompt_with_categories(self): + proc = _make_processor() + prompt = proc._build_prompt("test", EchoDepth.MEDIUM, "- preferences: User Preferences") + assert "preferences: User Preferences" in prompt + + def test_batch_prompt_contains_all_memories(self): + proc = _make_processor() + prompt = proc._build_batch_prompt( + ["Memory A", "Memory B", "Memory C"], + EchoDepth.DEEP, + ) + assert "[0] Memory A" in prompt + assert "[1] Memory B" in prompt + assert "[2] Memory C" in prompt + assert "ECHO DEPTH: deep" in prompt + + def test_prompt_entity_toggle(self): + proc = _make_processor() + prompt_yes = proc._build_prompt("test", EchoDepth.MEDIUM, include_entities=True) + prompt_no = proc._build_prompt("test", EchoDepth.MEDIUM, include_entities=False) + assert "INCLUDE ENTITIES: yes" in prompt_yes + assert "INCLUDE ENTITIES: no" in prompt_no + + +# --------------------------------------------------------------------------- +# Integration: EnrichmentResult metadata schema +# --------------------------------------------------------------------------- + +class TestEnrichmentResultSchema: + def test_echo_result_to_metadata(self): + """Unified echo result produces same metadata keys as EchoProcessor.""" + proc = _make_processor() + result = proc.enrich("I prefer Python", EchoDepth.MEDIUM) + assert result.echo_result is not None + metadata = result.echo_result.to_metadata() + expected_keys = { + "echo_paraphrases", "echo_keywords", "echo_implications", + "echo_questions", "echo_question_form", "echo_category", + "echo_importance", "echo_depth", + } + assert set(metadata.keys()) == expected_keys + assert metadata["echo_depth"] == "medium" From ed56a57495fb540738a515e539d44c63f7c58789 Mon Sep 17 00:00:00 2001 From: Vivek Kumar Date: Thu, 19 Feb 2026 11:59:53 +0530 Subject: [PATCH 6/8] feat: fact decomposition, echo-augmented embeddings, and batch enrichment optimizations --- engram/benchmarks/longmemeval.py | 51 ++++-- engram/configs/base.py | 3 + engram/core/enrichment.py | 6 +- engram/memory/main.py | 293 +++++++++++++++++++++++++++---- engram/utils/prompts.py | 20 ++- 5 files changed, 324 insertions(+), 49 deletions(-) diff --git a/engram/benchmarks/longmemeval.py b/engram/benchmarks/longmemeval.py index 380d981..41c50f1 100644 --- a/engram/benchmarks/longmemeval.py +++ b/engram/benchmarks/longmemeval.py @@ -18,6 +18,7 @@ from engram import FullMemory as Memory from engram.configs.base import ( + BatchConfig, CategoryMemConfig, EchoMemConfig, EmbedderConfig, @@ -155,7 +156,7 @@ def build_memory( "embedding_model_dims": embedding_dims, } - llm_cfg: Dict[str, Any] = {} + llm_cfg: Dict[str, Any] = {"max_tokens": 8192, "timeout": 300, "model": "meta/llama-3.3-70b-instruct"} if llm_model: llm_cfg["model"] = llm_model embedder_cfg: Dict[str, Any] = {"embedding_dims": embedding_dims} @@ -168,12 +169,13 @@ def build_memory( embedder=EmbedderConfig(provider=embedder_provider, config=embedder_cfg), history_db_path=history_db_path, embedding_model_dims=embedding_dims, - echo=EchoMemConfig(enable_echo=full_potential), + echo=EchoMemConfig(enable_echo=full_potential, default_depth="deep"), category=CategoryMemConfig(use_llm_categorization=full_potential, enable_categories=full_potential), graph=KnowledgeGraphConfig(enable_graph=full_potential), scene=SceneConfig(use_llm_summarization=full_potential, enable_scenes=full_potential), profile=ProfileConfig(use_llm_extraction=full_potential, enable_profiles=full_potential), - enrichment=EnrichmentConfig(enable_unified=full_potential), + enrichment=EnrichmentConfig(enable_unified=full_potential, max_batch_size=10), + batch=BatchConfig(enable_batch=full_potential, max_batch_size=50), ) mem = Memory(config) # FullMemory features (categories, scenes, profiles) need FullSQLiteManager @@ -270,28 +272,49 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: session_ids = entry.get("haystack_session_ids") or [] session_dates = entry.get("haystack_dates") or [] sessions = entry.get("haystack_sessions") or [] + + # Build batch items for all sessions + batch_items = [] for sess_id, sess_date, sess_turns in zip(session_ids, session_dates, sessions): payload = format_session_memory(str(sess_id), str(sess_date), sess_turns or []) + batch_items.append({ + "content": payload, + "metadata": { + "session_id": str(sess_id), + "session_date": str(sess_date), + "question_id": question_id, + }, + "categories": ["longmemeval", "session"], + }) + + # Use add_batch for fewer LLM calls; fallback to sequential on failure + if batch_items: try: - memory.add( - messages=payload, + memory.add_batch( + items=batch_items, user_id=args.user_id, - metadata={ - "session_id": str(sess_id), - "session_date": str(sess_date), - "question_id": question_id, - }, - categories=["longmemeval", "session"], - infer=False, ) except Exception as e: - logger.warning("Skipping session %s for question %s: %s", sess_id, question_id, e) + logger.warning("Batch add failed for question %s, retrying sequentially: %s", question_id, e) + for item in batch_items: + try: + memory.add( + messages=item["content"], + user_id=args.user_id, + metadata=item["metadata"], + categories=item["categories"], + infer=False, + ) + except Exception as e2: + logger.warning("Skipping session for question %s: %s", question_id, e2) query = str(entry.get("question", "")).strip() search_payload = memory.search( query=query, user_id=args.user_id, limit=args.top_k, + keyword_search=True, + hybrid_alpha=0.7, ) results = search_payload.get("results", []) @@ -383,7 +406,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--max-questions", type=int, default=-1, help="Cap number of evaluated questions.") parser.add_argument("--skip-abstention", action="store_true", help="Skip *_abs questions.") - parser.add_argument("--top-k", type=int, default=8, help="Number of retrieved memories for context.") + parser.add_argument("--top-k", type=int, default=20, help="Number of retrieved memories for context.") parser.add_argument("--max-context-chars", type=int, default=12000, help="Maximum context size passed to reader.") parser.add_argument("--print-every", type=int, default=25, help="Progress print interval.") diff --git a/engram/configs/base.py b/engram/configs/base.py index ed9d023..318c543 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -88,6 +88,9 @@ class EchoMemConfig(BaseModel): deep_multiplier: float = 1.6 # Use question_form embedding for primary vector (better query matching) use_question_embedding: bool = True + # Echo-augmented embedding: compose primary text from content + echo data + # (question_form, keywords, first paraphrase) for richer retrieval vectors + use_echo_augmented_embedding: bool = True @field_validator("default_depth") @classmethod diff --git a/engram/core/enrichment.py b/engram/core/enrichment.py index 885dd19..d704e49 100644 --- a/engram/core/enrichment.py +++ b/engram/core/enrichment.py @@ -140,6 +140,7 @@ class UnifiedEnrichmentOutput(BaseModel): category: UnifiedCategoryOutput = Field(default_factory=UnifiedCategoryOutput) entities: List[UnifiedEntityOutput] = [] profiles: List[UnifiedProfileOutput] = [] + facts: List[str] = [] # --------------------------------------------------------------------------- @@ -154,6 +155,7 @@ class EnrichmentResult: category_match: Optional[CategoryMatch] = None entities: List[Entity] = field(default_factory=list) profile_updates: List[ProfileUpdate] = field(default_factory=list) + facts: List[str] = field(default_factory=list) raw_response: str = "" @@ -284,7 +286,7 @@ def _build_batch_prompt( cats = existing_categories or self._format_existing_categories() depth_instructions = _DEPTH_INSTRUCTIONS.get(depth, _DEPTH_INSTRUCTIONS[EchoDepth.MEDIUM]) memories_block = "\n".join( - f"[{i}] {c[:500]}" for i, c in enumerate(contents) + f"[{i}] {c[:2000]}" for i, c in enumerate(contents) ) return UNIFIED_ENRICHMENT_BATCH_PROMPT.format( memories_block=memories_block, @@ -333,6 +335,7 @@ def _parse_response( category_match=self._to_category_match(unified.category), entities=self._to_entities(unified.entities), profile_updates=self._to_profile_updates(unified.profiles), + facts=[f for f in unified.facts if isinstance(f, str) and f.strip()], raw_response=response, ) @@ -367,6 +370,7 @@ def _parse_batch_response( category_match=self._to_category_match(unified.category), entities=self._to_entities(unified.entities), profile_updates=self._to_profile_updates(unified.profiles), + facts=[f for f in unified.facts if isinstance(f, str) and f.strip()], )) continue except Exception: diff --git a/engram/memory/main.py b/engram/memory/main.py index d0ee8d5..d58ea57 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -690,38 +690,90 @@ def _process_memory_batch( item_meta.update(item.get("metadata") or {}) item_metadata_list.append(item_meta) - # 1. Batch echo encoding + # 0. Try unified enrichment (single LLM call for echo+category+entities+profiles) echo_results = [None] * len(contents) - if self.echo_processor and self.echo_config.enable_echo and batch_config.batch_echo: - try: - depth_override = EchoDepth(echo_depth) if echo_depth else None - echo_results = self.echo_processor.process_batch( - contents, depth=depth_override - ) - except Exception as e: - logger.warning("Batch echo failed, processing individually: %s", e) - for i, c in enumerate(contents): - if c: - try: - depth_override = EchoDepth(echo_depth) if echo_depth else None - echo_results[i] = self.echo_processor.process(c, depth=depth_override) - except Exception: - pass - - # 2. Batch category detection category_results = [None] * len(contents) - if ( - self.category_processor - and self.category_config.auto_categorize - and batch_config.batch_category - ): + enrichment_results = [None] * len(contents) # stash for post-store hooks + + enrichment_config = getattr(self.config, "enrichment", None) + _use_unified = ( + self.unified_enrichment is not None + and self.echo_config.enable_echo + and batch_config.batch_echo + ) + + if _use_unified: try: - category_results = self.category_processor.detect_categories_batch( - contents, - use_llm=self.category_config.use_llm_categorization, - ) + depth_override = EchoDepth(echo_depth) if echo_depth else EchoDepth(self.echo_config.default_depth) + existing_cats = None + if self.category_processor: + cats = self.category_processor.get_all_categories() + if cats: + existing_cats = "\n".join( + f"- {c['id']}: {c['name']} — {c.get('description', '')}" + for c in cats[:30] + ) + + # Process in sub-batches of enrichment_config.max_batch_size + enrich_batch_size = enrichment_config.max_batch_size if enrichment_config else 10 + for start in range(0, len(contents), enrich_batch_size): + end = min(start + enrich_batch_size, len(contents)) + sub_contents = contents[start:end] + sub_results = self.unified_enrichment.enrich_batch( + sub_contents, + depth=depth_override, + existing_categories=existing_cats, + include_entities=enrichment_config.include_entities if enrichment_config else True, + include_profiles=enrichment_config.include_profiles if enrichment_config else True, + ) + for j, enrichment in enumerate(sub_results): + idx = start + j + if enrichment.echo_result: + echo_results[idx] = enrichment.echo_result + if enrichment.category_match: + category_results[idx] = enrichment.category_match + enrichment_results[idx] = enrichment + + logger.info("Unified batch enrichment completed for %d memories", len(contents)) except Exception as e: - logger.warning("Batch category failed: %s", e) + logger.warning("Unified batch enrichment failed, falling back to separate: %s", e) + # Reset — let the fallback below handle it + echo_results = [None] * len(contents) + category_results = [None] * len(contents) + enrichment_results = [None] * len(contents) + _use_unified = False + + # 1. Batch echo encoding (fallback if unified was not used or failed) + if not _use_unified: + if self.echo_processor and self.echo_config.enable_echo and batch_config.batch_echo: + try: + depth_override = EchoDepth(echo_depth) if echo_depth else None + echo_results = self.echo_processor.process_batch( + contents, depth=depth_override + ) + except Exception as e: + logger.warning("Batch echo failed, processing individually: %s", e) + for i, c in enumerate(contents): + if c: + try: + depth_override = EchoDepth(echo_depth) if echo_depth else None + echo_results[i] = self.echo_processor.process(c, depth=depth_override) + except Exception: + pass + + # 2. Batch category detection + if ( + self.category_processor + and self.category_config.auto_categorize + and batch_config.batch_category + ): + try: + category_results = self.category_processor.detect_categories_batch( + contents, + use_llm=self.category_config.use_llm_categorization, + ) + except Exception as e: + logger.warning("Batch category failed: %s", e) # 3. Batch embeddings primary_texts = [] @@ -731,7 +783,11 @@ def _process_memory_batch( if batch_config.batch_embed: try: - embeddings = self.embedder.embed_batch(primary_texts, memory_action="add") + # Sub-batch to stay within API limits (~50 per call) + embeddings: List[List[float]] = [] + for start in range(0, len(primary_texts), 50): + sub = primary_texts[start:start + 50] + embeddings.extend(self.embedder.embed_batch(sub, memory_action="add")) except Exception as e: logger.warning("Batch embed failed, falling back to sequential: %s", e) embeddings = [ @@ -742,6 +798,44 @@ def _process_memory_batch( self.embedder.embed(t, memory_action="add") for t in primary_texts ] + # 3b. Pre-embed all echo node texts (paraphrases, questions, content variants) + # so _build_index_vectors can use the cache instead of individual embed() calls. + echo_node_texts = [] + for i, content in enumerate(contents): + echo_result = echo_results[i] + pt = primary_texts[i] + if pt != content: + cleaned = content.strip() + if cleaned: + echo_node_texts.append(cleaned) + if echo_result: + for p in echo_result.paraphrases: + cleaned = str(p).strip() + if cleaned: + echo_node_texts.append(cleaned) + for q in echo_result.questions: + cleaned = str(q).strip() + if cleaned: + echo_node_texts.append(cleaned) + + embedding_cache: Dict[str, List[float]] = {} + if echo_node_texts: + # Deduplicate while preserving order for batch embedding + unique_texts = list(dict.fromkeys(echo_node_texts)) + try: + # Sub-batch to stay within NVIDIA API limits (~50 per call) + all_echo_embeddings: List[List[float]] = [] + for start in range(0, len(unique_texts), 50): + sub = unique_texts[start:start + 50] + sub_embs = self.embedder.embed_batch(sub, memory_action="add") + all_echo_embeddings.extend(sub_embs) + for text, emb in zip(unique_texts, all_echo_embeddings): + embedding_cache[text] = emb + logger.info("Batch-embedded %d echo node texts in %d API calls", + len(unique_texts), (len(unique_texts) + 49) // 50) + except Exception as e: + logger.warning("Batch echo node embedding failed, will embed individually: %s", e) + # 4. Build memory records and batch-insert into DB processed_metadata_base, effective_filters = build_filters_and_metadata( user_id=user_id, @@ -838,6 +932,7 @@ def _process_memory_batch( agent_id=agent_id, run_id=run_id, app_id=app_id, + embedding_cache=embedding_cache if embedding_cache else None, ) if vectors: vector_batch.append((vectors, payloads, vector_ids)) @@ -870,7 +965,7 @@ def _process_memory_batch( except Exception as e: logger.error("Vector insert failed in batch: %s", e) - # Post-store hooks + # Post-store hooks: category stats for i, record in enumerate(memory_records): if self.category_processor and record.get("categories"): for cat_id in record["categories"]: @@ -878,6 +973,79 @@ def _process_memory_batch( cat_id, record["strength"], is_addition=True ) + # Post-store hooks: fact decomposition (batch embed + insert) + all_fact_texts = [] + all_fact_meta = [] # (memory_id, fact_index) + for i, record in enumerate(memory_records): + enrichment = enrichment_results[i] if i < len(enrichment_results) else None + if enrichment and enrichment.facts: + for fi, fact_text in enumerate(enrichment.facts[:8]): + fact_text = fact_text.strip() + if fact_text and len(fact_text) >= 10: + all_fact_texts.append(fact_text) + all_fact_meta.append((record["id"], fi)) + + if all_fact_texts: + try: + # Sub-batch fact embeddings to stay within API limits + fact_embeddings: List[List[float]] = [] + for fs in range(0, len(all_fact_texts), 50): + sub = all_fact_texts[fs:fs + 50] + fact_embeddings.extend(self.embedder.embed_batch(sub, memory_action="add")) + fact_vectors = [] + fact_payloads = [] + fact_ids = [] + for (memory_id, fi), fact_text, fact_emb in zip(all_fact_meta, all_fact_texts, fact_embeddings): + fact_id = f"{memory_id}__fact_{fi}" + fact_vectors.append(fact_emb) + fact_payloads.append({ + "memory_id": memory_id, + "is_fact": True, + "fact_index": fi, + "fact_text": fact_text, + "user_id": user_id, + "agent_id": agent_id, + }) + fact_ids.append(fact_id) + if fact_vectors: + self.vector_store.insert(vectors=fact_vectors, payloads=fact_payloads, ids=fact_ids) + except Exception as e: + logger.warning("Batch fact embedding/insert failed: %s", e) + + # Post-store hooks: entity linking and profile updates + for i, record in enumerate(memory_records): + enrichment = enrichment_results[i] if i < len(enrichment_results) else None + if not enrichment: + continue + memory_id = record["id"] + content = record.get("memory", "") + + if self.knowledge_graph and enrichment.entities: + try: + for entity in enrichment.entities: + existing_ent = self.knowledge_graph._get_or_create_entity( + entity.name, entity.entity_type, + ) + existing_ent.memory_ids.add(memory_id) + self.knowledge_graph.memory_entities[memory_id] = { + e.name for e in enrichment.entities + } + if self.graph_config.auto_link_entities: + self.knowledge_graph.link_by_shared_entities(memory_id) + except Exception as e: + logger.warning("Entity linking failed for %s: %s", memory_id, e) + + if self.profile_processor and enrichment.profile_updates: + try: + for profile_update in enrichment.profile_updates: + self.profile_processor.apply_update( + profile_update=profile_update, + memory_id=memory_id, + user_id=record.get("user_id") or user_id or "default", + ) + except Exception as e: + logger.warning("Profile update failed for %s: %s", memory_id, e) + return results def _resolve_memory_metadata( @@ -1064,6 +1232,7 @@ def _process_single_memory( # Pre-extracted data from unified enrichment (used to skip redundant post-store calls) _unified_entities = None # List[Entity] or None _unified_profiles = None # List[ProfileUpdate] or None + _unified_facts = None # List[str] or None # Determine echo depth for unified path check _depth_for_echo = EchoDepth(echo_depth) if echo_depth else None @@ -1115,9 +1284,10 @@ def _process_single_memory( mem_metadata["category_confidence"] = enrichment.category_match.confidence mem_metadata["category_auto"] = True - # Stash entities + profiles for post-store hooks + # Stash entities + profiles + facts for post-store hooks _unified_entities = enrichment.entities _unified_profiles = enrichment.profile_updates + _unified_facts = enrichment.facts # Generate embedding primary_text = self._select_primary_text(content, echo_result) @@ -1342,6 +1512,40 @@ def _do_category(): ) raise + # Fact decomposition: store each extracted fact as a sub-vector for direct retrieval. + # Each fact gets its own embedding, linked back to the parent memory. + # Uses batch embedding (single API call) for efficiency. + if _unified_facts: + valid_facts = [] + for i, fact_text in enumerate(_unified_facts[:8]): # Cap at 8 facts + fact_text = fact_text.strip() + if fact_text and len(fact_text) >= 10: + valid_facts.append((i, fact_text)) + + if valid_facts: + try: + fact_texts = [ft for _, ft in valid_facts] + fact_embeddings = self.embedder.embed_batch(fact_texts, memory_action="add") + fact_vectors = [] + fact_payloads = [] + fact_ids = [] + for (i, fact_text), fact_embedding in zip(valid_facts, fact_embeddings): + fact_id = f"{effective_memory_id}__fact_{i}" + fact_vectors.append(fact_embedding) + fact_payloads.append({ + "memory_id": effective_memory_id, + "is_fact": True, + "fact_index": i, + "fact_text": fact_text, + "user_id": user_id, + "agent_id": store_agent_id, + }) + fact_ids.append(fact_id) + if fact_vectors: + self.vector_store.insert(vectors=fact_vectors, payloads=fact_payloads, ids=fact_ids) + except Exception as e: + logger.warning("Fact embedding/insert failed for %s: %s", effective_memory_id, e) + # Post-store hooks. if self.category_processor and mem_categories: for cat_id in mem_categories: @@ -2367,7 +2571,23 @@ def _classify_memory_type(self, metadata: Dict[str, Any], role: str) -> str: return "semantic" def _select_primary_text(self, content: str, echo_result: Optional[EchoResult]) -> str: - if self.echo_config.use_question_embedding and echo_result and echo_result.question_form: + if not echo_result: + return content + + # Echo-augmented embedding: compose content + echo data for richer vectors. + # Multiple retrieval paths in one embedding — like the brain's multi-path access. + if self.echo_config.use_echo_augmented_embedding: + parts = [content[:1500]] # Keep original content (capped to leave room) + if echo_result.question_form: + parts.append(echo_result.question_form) + if echo_result.keywords: + parts.append("Keywords: " + ", ".join(echo_result.keywords[:10])) + if echo_result.paraphrases: + parts.append(echo_result.paraphrases[0]) + return "\n".join(parts) + + # Legacy: replace content with question_form only + if self.echo_config.use_question_embedding and echo_result.question_form: return echo_result.question_form return content @@ -2522,6 +2742,7 @@ def _build_index_vectors( agent_id: Optional[str], run_id: Optional[str], app_id: Optional[str], + embedding_cache: Optional[Dict[str, List[float]]] = None, ) -> tuple[List[List[float]], List[Dict[str, Any]], List[str]]: base_payload = dict(metadata) base_payload.update( @@ -2571,7 +2792,13 @@ def add_node( if echo_result and echo_result.category: payload["category"] = echo_result.category - vectors.append(vector if vector is not None else self.embedder.embed(cleaned, memory_action="add")) + if vector is not None: + emb = vector + elif embedding_cache is not None and cleaned in embedding_cache: + emb = embedding_cache[cleaned] + else: + emb = self.embedder.embed(cleaned, memory_action="add") + vectors.append(emb) payloads.append(payload) vector_ids.append(node_id or str(uuid.uuid4())) diff --git a/engram/utils/prompts.py b/engram/utils/prompts.py index 2e94274..4f10950 100644 --- a/engram/utils/prompts.py +++ b/engram/utils/prompts.py @@ -238,15 +238,27 @@ ], "profiles": [ {{"name": "person name", "type": "self|contact|entity", "facts": ["fact"], "preferences": ["pref"]}} + ], + "facts": [ + "Atomic, self-contained fact 1 extracted from the memory", + "Atomic, self-contained fact 2 extracted from the memory" ] }} Rules: - Follow ECHO INSTRUCTIONS for which echo fields to populate +- For paraphrases: ensure EVERY distinct factual claim gets at least one paraphrase. Do NOT only rephrase the main topic — also rephrase secondary/minor details (e.g. if the memory mentions a degree, include a paraphrase about the degree even if the main topic is task management) +- For questions: generate questions that each factual claim in the memory ANSWERS. Example: memory says "graduated with an MBA" → include "What degree did the user graduate with?" Each fact should have a corresponding question. - For category: prefer use_existing when an existing category fits well - For entities: extract named entities (people, tech, orgs, tools) - For profiles: extract person mentions with their facts/preferences - If INCLUDE ENTITIES or INCLUDE PROFILES is "no", return empty arrays for those +- For facts: extract ALL distinct, searchable facts from the memory as standalone statements + - Each fact must be self-contained (understandable without the original context) + - Use third person ("User graduated with MBA" not "I graduated with MBA") + - Include specific details: names, places, numbers, dates + - Extract 3-8 facts per memory (more for longer/richer content) + - Facts should be diverse — each captures a DIFFERENT piece of information """ UNIFIED_ENRICHMENT_BATCH_PROMPT = """You are enriching multiple memories for a long-term AI memory system. @@ -283,11 +295,17 @@ "confidence": 0.0-1.0 }}, "entities": [{{"name": "entity name", "type": "person|technology|..."}}], - "profiles": [{{"name": "person name", "type": "self|contact|entity", "facts": [], "preferences": []}}] + "profiles": [{{"name": "person name", "type": "self|contact|entity", "facts": [], "preferences": []}}], + "facts": ["Atomic self-contained fact 1", "Atomic self-contained fact 2"] }} ] }} +Rules: +- For paraphrases: ensure EVERY distinct factual claim gets at least one paraphrase. Do NOT only rephrase the main topic — also rephrase secondary/minor details. +- For questions: generate questions that each factual claim ANSWERS. Example: "graduated with an MBA" → "What degree did the user graduate with?" +- For facts: extract ALL distinct, searchable facts as standalone statements. Use third person. Include specifics (names, places, dates). 3-8 facts per memory. + IMPORTANT: Return exactly {count} elements in the results array, one per memory, in the same order. """ From e18aed4dd7257b4fcc1c762e365cbbcd1e4cd25a Mon Sep 17 00:00:00 2001 From: Ashish-dwi99 Date: Thu, 19 Feb 2026 14:22:34 +0530 Subject: [PATCH 7/8] new changes --- engram-bus/tests/__init__.py | 0 .../engram_enterprise/api/app.py | 43 +- engram-enterprise/tests/__init__.py | 0 engram/benchmarks/longmemeval.py | 89 ++- engram/configs/base.py | 6 +- engram/db/sqlite.py | 99 +++- engram/mcp_server.py | 42 +- engram/memory/main.py | 433 ++++++++++++++ engram/observability.py | 53 +- pytest.ini | 11 + tests/__init__.py | 0 tests/longmemeval_test.json | 33 ++ tests/test_deferred_enrichment.py | 529 ++++++++++++++++++ tests/test_e2e_all_features.py | 11 + tests/test_mcp_tools_slim.py | 41 +- tests/test_power_packages.py | 9 + 16 files changed, 1344 insertions(+), 55 deletions(-) delete mode 100644 engram-bus/tests/__init__.py delete mode 100644 engram-enterprise/tests/__init__.py create mode 100644 pytest.ini delete mode 100644 tests/__init__.py create mode 100644 tests/longmemeval_test.json create mode 100644 tests/test_deferred_enrichment.py diff --git a/engram-bus/tests/__init__.py b/engram-bus/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/engram-enterprise/engram_enterprise/api/app.py b/engram-enterprise/engram_enterprise/api/app.py index be33c46..402dcb5 100644 --- a/engram-enterprise/engram_enterprise/api/app.py +++ b/engram-enterprise/engram_enterprise/api/app.py @@ -14,7 +14,8 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field -from engram import Memory +from engram.configs.base import EmbedderConfig, LLMConfig, MemoryConfig, VectorStoreConfig +from engram.memory.main import Memory from engram_enterprise.api.auth import ( enforce_session_issuer, get_token_from_request, @@ -41,6 +42,7 @@ SessionCreateResponse, ) from engram_enterprise.policy import feature_enabled +from engram_enterprise.kernel import PersonalMemoryKernel from engram.exceptions import FadeMemValidationError from engram.observability import add_metrics_routes, logger as structured_logger, metrics @@ -103,6 +105,28 @@ class DecayResponse(BaseModel): _memory: Optional[Memory] = None _memory_lock = threading.Lock() +_kernel: Optional[PersonalMemoryKernel] = None +_kernel_lock = threading.Lock() + + +def _fallback_memory_config() -> MemoryConfig: + data_dir = os.path.join(os.path.expanduser("~"), ".engram") + os.makedirs(data_dir, exist_ok=True) + dims = 384 + return MemoryConfig( + llm=LLMConfig(provider="mock", config={}), + embedder=EmbedderConfig(provider="simple", config={"embedding_dims": dims}), + vector_store=VectorStoreConfig( + provider="memory", + config={ + "collection_name": "engram_enterprise", + "embedding_model_dims": dims, + }, + ), + history_db_path=os.path.join(data_dir, "enterprise_history.db"), + collection_name="engram_enterprise", + embedding_model_dims=dims, + ) def get_memory() -> Memory: @@ -110,12 +134,25 @@ def get_memory() -> Memory: if _memory is None: with _memory_lock: if _memory is None: - _memory = Memory() + try: + _memory = Memory() + except Exception as exc: + logger.warning( + "Failed to initialize default Memory config (%s). " + "Falling back to mock/simple in-memory configuration.", + exc, + ) + _memory = Memory(config=_fallback_memory_config()) return _memory def get_kernel(): - return get_memory().kernel + global _kernel + if _kernel is None: + with _kernel_lock: + if _kernel is None: + _kernel = PersonalMemoryKernel(get_memory()) + return _kernel def _extract_content(messages: Optional[Union[str, List[Dict[str, Any]]]], content: Optional[str]) -> str: diff --git a/engram-enterprise/tests/__init__.py b/engram-enterprise/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/engram/benchmarks/longmemeval.py b/engram/benchmarks/longmemeval.py index 41c50f1..ccb7228 100644 --- a/engram/benchmarks/longmemeval.py +++ b/engram/benchmarks/longmemeval.py @@ -43,14 +43,27 @@ def extract_user_only_text(session_turns: Sequence[Dict[str, Any]]) -> str: return "\n".join([line for line in lines if line]) -def format_session_memory(session_id: str, session_date: str, session_turns: Sequence[Dict[str, Any]]) -> str: - """Create a memory payload that preserves session metadata in plain text.""" - user_text = extract_user_only_text(session_turns) +def format_session_memory(session_id: str, session_date: str, session_turns: Sequence[Dict[str, Any]], include_all_roles: bool = False) -> str: + """Create a memory payload that preserves session metadata in plain text. + + When include_all_roles=True, includes both user and assistant turns + for richer context in deferred enrichment mode. + """ + if include_all_roles: + all_text = [] + for turn in session_turns: + role = turn.get("role", "user") + content = str(turn.get("content", "")).strip() + if content: + all_text.append(f"{role}: {content}") + full_text = "\n".join(all_text) + else: + full_text = extract_user_only_text(session_turns) return ( f"Session ID: {session_id}\n" f"Session Date: {session_date}\n" f"{HISTORY_HEADER}\n" - f"{user_text}" + f"{full_text}" ) @@ -149,8 +162,13 @@ def build_memory( llm_model: Optional[str] = None, embedder_model: Optional[str] = None, full_potential: bool = True, + defer_enrichment: bool = False, ) -> Memory: - """Build Engram Memory for LongMemEval. By default uses full potential (echo, categories, graph, scenes, profiles).""" + """Build Engram Memory for LongMemEval. By default uses full potential (echo, categories, graph, scenes, profiles). + + When defer_enrichment=True, ingestion uses 0 LLM calls (store fast), and + enrichment is done in batch after all sessions are loaded. + """ vector_cfg: Dict[str, Any] = { "collection_name": "engram_longmemeval", "embedding_model_dims": embedding_dims, @@ -174,8 +192,12 @@ def build_memory( graph=KnowledgeGraphConfig(enable_graph=full_potential), scene=SceneConfig(use_llm_summarization=full_potential, enable_scenes=full_potential), profile=ProfileConfig(use_llm_extraction=full_potential, enable_profiles=full_potential), - enrichment=EnrichmentConfig(enable_unified=full_potential, max_batch_size=10), - batch=BatchConfig(enable_batch=full_potential, max_batch_size=50), + enrichment=EnrichmentConfig( + enable_unified=full_potential, + max_batch_size=10, + defer_enrichment=defer_enrichment, + ), + batch=BatchConfig(enable_batch=full_potential and not defer_enrichment, max_batch_size=50), ) mem = Memory(config) # FullMemory features (categories, scenes, profiles) need FullSQLiteManager @@ -234,6 +256,7 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: if args.skip_abstention: selected = [entry for entry in selected if "_abs" not in str(entry.get("question_id", ""))] + use_deferred = getattr(args, "defer_enrichment", False) memory = build_memory( llm_provider=args.llm_provider, embedder_provider=args.embedder_provider, @@ -243,6 +266,7 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: llm_model=args.llm_model, embedder_model=args.embedder_model, full_potential=args.full_potential, + defer_enrichment=use_deferred, ) hf_responder: Optional[HFResponder] = None @@ -276,7 +300,17 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: # Build batch items for all sessions batch_items = [] for sess_id, sess_date, sess_turns in zip(session_ids, session_dates, sessions): - payload = format_session_memory(str(sess_id), str(sess_date), sess_turns or []) + payload = format_session_memory( + str(sess_id), str(sess_date), sess_turns or [], + include_all_roles=use_deferred, + ) + # Build context_messages from session turns for deferred mode + ctx_msgs = None + if use_deferred and sess_turns: + ctx_msgs = [ + {"role": t.get("role", "user"), "content": str(t.get("content", "")).strip()} + for t in sess_turns if str(t.get("content", "")).strip() + ] batch_items.append({ "content": payload, "metadata": { @@ -285,17 +319,13 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: "question_id": question_id, }, "categories": ["longmemeval", "session"], + "_context_messages": ctx_msgs, }) # Use add_batch for fewer LLM calls; fallback to sequential on failure if batch_items: - try: - memory.add_batch( - items=batch_items, - user_id=args.user_id, - ) - except Exception as e: - logger.warning("Batch add failed for question %s, retrying sequentially: %s", question_id, e) + if use_deferred: + # Deferred mode: sequential add with context_messages for item in batch_items: try: memory.add( @@ -304,9 +334,36 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: metadata=item["metadata"], categories=item["categories"], infer=False, + context_messages=item.get("_context_messages"), ) except Exception as e2: logger.warning("Skipping session for question %s: %s", question_id, e2) + else: + try: + memory.add_batch( + items=batch_items, + user_id=args.user_id, + ) + except Exception as e: + logger.warning("Batch add failed for question %s, retrying sequentially: %s", question_id, e) + for item in batch_items: + try: + memory.add( + messages=item["content"], + user_id=args.user_id, + metadata=item["metadata"], + categories=item["categories"], + infer=False, + ) + except Exception as e2: + logger.warning("Skipping session for question %s: %s", question_id, e2) + + # Batch enrich after all sessions loaded (deferred mode) + if use_deferred: + try: + memory.enrich_pending(user_id=args.user_id, batch_size=10, max_batches=50) + except Exception as e: + logger.warning("Enrichment failed for question %s: %s", question_id, e) query = str(entry.get("question", "")).strip() search_payload = memory.search( @@ -436,8 +493,10 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--embedding-dims", type=int, default=1536, help="Embedding dimensions for simple/memory configs.") parser.add_argument("--vector-store-provider", choices=["memory", "sqlite_vec"], default="memory") parser.add_argument("--history-db-path", default="/content/engram-longmemeval.db", help="SQLite db path.") + parser.add_argument("--defer-enrichment", action="store_true", default=False, help="Use deferred enrichment (0 LLM calls at ingestion, batch enrich after).") args = parser.parse_args() args.full_potential = not args.minimal + args.defer_enrichment = args.defer_enrichment return args diff --git a/engram/configs/base.py b/engram/configs/base.py index 318c543..271be63 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -33,7 +33,7 @@ class LLMConfig(BaseModel): provider: str = Field(default="nvidia") config: Dict[str, Any] = Field( default_factory=lambda: { - "model": "qwen/qwen3.5-397b-a17b", + "model": "minimaxai/minimax-m2.1", "temperature": 0.2, "max_tokens": 4096, } @@ -383,6 +383,10 @@ class EnrichmentConfig(BaseModel): include_entities: bool = True # Include entity extraction in unified call include_profiles: bool = True # Include profile extraction in unified call max_batch_size: int = 10 # Max memories per unified batch call + # Deferred enrichment: store with 0 LLM calls, enrich later in batch + defer_enrichment: bool = False # When True: 0 LLM calls at ingestion + context_window_turns: int = 10 # Store last N conversation turns with each memory + enrich_on_access: bool = False # Auto-enrich pending memories when retrieved in search @field_validator("max_batch_size") @classmethod diff --git a/engram/db/sqlite.py b/engram/db/sqlite.py index 8c0bc31..b331bae 100644 --- a/engram/db/sqlite.py +++ b/engram/db/sqlite.py @@ -19,6 +19,7 @@ "access_count", "last_accessed", "immutable", "expiration_date", "scene_id", "user_id", "agent_id", "run_id", "app_id", "memory_type", "s_fast", "s_mid", "s_slow", "content_hash", + "conversation_context", "enrichment_status", }) VALID_SCENE_COLUMNS = frozenset({ @@ -534,10 +535,6 @@ def purge_tombstoned(self) -> int: return count -# Backward compatibility alias -SQLiteManager = CoreSQLiteManager - - class FullSQLiteManager(CoreSQLiteManager): def __init__(self, db_path: str): self.db_path = db_path @@ -881,6 +878,26 @@ def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: # Content-hash dedup column (idempotent). self._ensure_content_hash_column(conn) + # Deferred enrichment columns (idempotent). + self._ensure_deferred_enrichment_columns(conn) + + def _ensure_deferred_enrichment_columns(self, conn: sqlite3.Connection) -> None: + """Add conversation_context and enrichment_status columns for deferred enrichment.""" + if self._is_migration_applied(conn, "v2_deferred_enrichment"): + return + self._migrate_add_column_conn(conn, "memories", "conversation_context", "TEXT") + self._migrate_add_column_conn(conn, "memories", "enrichment_status", "TEXT DEFAULT 'complete'") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_enrichment_status ON memories(enrichment_status)" + ) + # Backfill: existing memories are already enriched. + conn.execute( + "UPDATE memories SET enrichment_status = 'complete' WHERE enrichment_status IS NULL" + ) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_deferred_enrichment')" + ) + def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: """Add content_hash column + index for SHA-256 dedup.""" if self._is_migration_applied(conn, "v2_content_hash"): @@ -962,8 +979,9 @@ def add_memory(self, memory_data: Dict[str, Any]) -> str: last_accessed, embedding, related_memories, source_memories, tombstone, confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, status, importance, sensitivity, - memory_type, s_fast, s_mid, s_slow, content_hash - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + memory_type, s_fast, s_mid, s_slow, content_hash, + conversation_context, enrichment_status + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( memory_id, @@ -1000,6 +1018,8 @@ def add_memory(self, memory_data: Dict[str, Any]) -> str: memory_data.get("s_mid"), memory_data.get("s_slow"), memory_data.get("content_hash"), + memory_data.get("conversation_context"), + memory_data.get("enrichment_status", "complete"), ), ) @@ -1068,6 +1088,8 @@ def add_memories_batch(self, memories: List[Dict[str, Any]]) -> List[str]: memory_data.get("s_fast"), memory_data.get("s_mid"), memory_data.get("s_slow"), + memory_data.get("conversation_context"), + memory_data.get("enrichment_status", "complete"), )) history_rows.append(( memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None, @@ -1083,8 +1105,9 @@ def add_memories_batch(self, memories: List[Dict[str, Any]]) -> List[str]: last_accessed, embedding, related_memories, source_memories, tombstone, confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, status, importance, sensitivity, - memory_type, s_fast, s_mid, s_slow - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + memory_type, s_fast, s_mid, s_slow, + conversation_context, enrichment_status + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, insert_rows, ) @@ -1331,6 +1354,60 @@ def update_strength_bulk(self, updates: Dict[str, float]) -> None: [(strength, now, memory_id) for memory_id, strength in updates.items()], ) + def get_pending_enrichment(self, user_id: Optional[str] = None, limit: int = 50) -> List[Dict[str, Any]]: + """Return memories with enrichment_status='pending', ordered oldest first.""" + with self._get_connection() as conn: + if user_id: + rows = conn.execute( + "SELECT * FROM memories WHERE enrichment_status = 'pending' AND user_id = ? " + "AND tombstone = 0 ORDER BY created_at ASC LIMIT ?", + (user_id, limit), + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM memories WHERE enrichment_status = 'pending' " + "AND tombstone = 0 ORDER BY created_at ASC LIMIT ?", + (limit,), + ).fetchall() + return [self._row_to_dict(row) for row in rows] + + def update_enrichment_status(self, memory_id: str, status: str) -> None: + """Mark a memory's enrichment_status (e.g. 'complete').""" + now = _utcnow_iso() + with self._get_connection() as conn: + conn.execute( + "UPDATE memories SET enrichment_status = ?, updated_at = ? WHERE id = ?", + (status, now, memory_id), + ) + + def update_enrichment_bulk(self, updates: List[Dict[str, Any]]) -> None: + """Batch-update enrichment results for multiple memories. + + Each dict: {id, metadata, categories, enrichment_status}. + """ + if not updates: + return + now = _utcnow_iso() + with self._get_connection() as conn: + for upd in updates: + mid = upd["id"] + sets = ["updated_at = ?"] + params: list = [now] + if "metadata" in upd: + sets.append("metadata = ?") + params.append(json.dumps(upd["metadata"])) + if "categories" in upd: + sets.append("categories = ?") + params.append(json.dumps(upd["categories"])) + if "enrichment_status" in upd: + sets.append("enrichment_status = ?") + params.append(upd["enrichment_status"]) + params.append(mid) + conn.execute( + f"UPDATE memories SET {', '.join(sets)} WHERE id = ?", + params, + ) + _MEMORY_JSON_FIELDS = ("metadata", "categories", "related_memories", "source_memories") def _row_to_dict(self, row: sqlite3.Row, *, skip_embedding: bool = False) -> Dict[str, Any]: @@ -2070,3 +2147,9 @@ def _parse_json_value(value: Any, default: Any) -> Any: return json.loads(value) except Exception: return default + + +# Backward compatibility alias +# Keep SQLiteManager mapped to the full-capability manager so legacy call sites +# that expect category/scene/profile APIs continue to work. +SQLiteManager = FullSQLiteManager diff --git a/engram/mcp_server.py b/engram/mcp_server.py index d109b36..d6c1963 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -1,4 +1,4 @@ -"""Engram MCP Server — 18 tools, minimal boilerplate. +"""Engram MCP Server — 19 tools, minimal boilerplate. Tools: 1. remember — Quick-save (content → memory, infer=False) @@ -19,6 +19,7 @@ 16. analyze_skill_gaps — Show what transfers vs what needs experimentation 17. decompose_skill — Trigger structural decomposition of a flat skill 18. apply_skill_with_bindings — Apply skill with slot values, includes gap analysis +19. enrich_pending — Batch-enrich deferred memories """ import json @@ -176,6 +177,17 @@ def get_memory() -> Memory: "properties": { "content": {"type": "string", "description": "The fact or preference to remember"}, "categories": {"type": "array", "items": {"type": "string"}, "description": "Optional categories to tag this memory with (e.g., ['preferences', 'coding'])"}, + "context": { + "type": "array", + "items": { + "type": "object", + "properties": { + "role": {"type": "string"}, + "content": {"type": "string"}, + }, + }, + "description": "Recent conversation turns (sliding window) for richer memory context", + }, }, "required": ["content"], }, @@ -432,6 +444,17 @@ def get_memory() -> Memory: "required": ["skill_id", "bindings"], }, ), + Tool( + name="enrich_pending", + description="Batch-enrich memories stored with deferred enrichment. Runs echo, category, entity, and profile extraction in batched LLM calls. Use after bulk ingestion to retroactively enrich memories.", + inputSchema={ + "type": "object", + "properties": { + "batch_size": {"type": "integer", "description": "Memories per LLM call (default: 10)"}, + "max_batches": {"type": "integer", "description": "Max batches to process (default: 5)"}, + }, + }, + ), ] @@ -445,6 +468,7 @@ def _handle_remember(memory, args): categories=args.get("categories"), source_app="claude-code", infer=False, + context_messages=args.get("context"), ) @@ -676,6 +700,21 @@ def _handle_apply_skill_with_bindings(memory, args): ) +def _handle_enrich_pending(memory, args): + try: + batch_size = max(1, min(50, int(args.get("batch_size", 10)))) + except (ValueError, TypeError): + batch_size = 10 + try: + max_batches = max(1, min(100, int(args.get("max_batches", 5)))) + except (ValueError, TypeError): + max_batches = 5 + return memory.enrich_pending( + batch_size=batch_size, + max_batches=max_batches, + ) + + HANDLERS = { "remember": _handle_remember, "search_memory": _handle_search_memory, @@ -695,6 +734,7 @@ def _handle_apply_skill_with_bindings(memory, args): "analyze_skill_gaps": _handle_analyze_skill_gaps, "decompose_skill": _handle_decompose_skill, "apply_skill_with_bindings": _handle_apply_skill_with_bindings, + "enrich_pending": _handle_enrich_pending, } _MEMORY_FREE_TOOLS = {"get_last_session", "save_session_digest"} diff --git a/engram/memory/main.py b/engram/memory/main.py index d58ea57..7e35ebf 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -511,6 +511,7 @@ def add( scope: Optional[str] = None, source_app: Optional[str] = None, memory_id: Optional[str] = None, + context_messages: Optional[List[Dict[str, str]]] = None, **kwargs: Any, ) -> Dict[str, Any]: processed_metadata, effective_filters = build_filters_and_metadata( @@ -570,6 +571,7 @@ def add( initial_strength=initial_strength, echo_depth=echo_depth, memory_id=memory_id, + context_messages=context_messages, ) if result is not None: results.append(result) @@ -1153,6 +1155,7 @@ def _process_single_memory( initial_strength: float, echo_depth: Optional[str], memory_id: Optional[str] = None, + context_messages: Optional[List[Dict[str, str]]] = None, ) -> Optional[Dict[str, Any]]: """Process and store a single memory item. Returns result dict or None if skipped.""" content = mem.get("content", "").strip() @@ -1203,6 +1206,31 @@ def _process_single_memory( "memory": content, } + # --- Deferred enrichment: lite path (0 LLM calls) --- + enrichment_config = getattr(self.config, "enrichment", None) + if enrichment_config and enrichment_config.defer_enrichment: + return self._process_single_memory_lite( + content=content, + mem_metadata=mem_metadata, + mem_categories=mem_categories, + context_messages=context_messages, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + app_id=app_id, + effective_filters=effective_filters, + agent_category=agent_category, + connector_id=connector_id, + scope=scope, + source_app=source_app, + immutable=immutable, + expiration_date=expiration_date, + initial_layer=initial_layer, + initial_strength=initial_strength, + explicit_remember=explicit_remember, + memory_id=memory_id, + ) + # Resolve store identifiers and scope metadata. store_agent_id, store_run_id, store_app_id, store_filters = self._resolve_memory_metadata( content=content, @@ -1607,6 +1635,409 @@ def _do_category(): "memory_type": memory_type, } + def _process_single_memory_lite( + self, + *, + content: str, + mem_metadata: Dict[str, Any], + mem_categories: List[str], + context_messages: Optional[List[Dict[str, str]]], + user_id: Optional[str], + agent_id: Optional[str], + run_id: Optional[str], + app_id: Optional[str], + effective_filters: Dict[str, Any], + agent_category: Optional[str], + connector_id: Optional[str], + scope: Optional[str], + source_app: Optional[str], + immutable: bool, + expiration_date: Optional[str], + initial_layer: str, + initial_strength: float, + explicit_remember: bool, + memory_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + """Lite processing path for deferred enrichment — 0 LLM calls. + + Stores the memory with regex-extracted keywords, context-enriched + embedding, and enrichment_status='pending'. All heavy LLM processing + (echo, category, conflict, entities, profiles) is deferred to + enrich_pending(). + """ + # Resolve store identifiers and scope metadata. + store_agent_id, store_run_id, store_app_id, store_filters = self._resolve_memory_metadata( + content=content, + mem_metadata=mem_metadata, + explicit_remember=explicit_remember, + agent_id=agent_id, + run_id=run_id, + app_id=app_id, + effective_filters=effective_filters, + agent_category=agent_category, + connector_id=connector_id, + scope=scope, + source_app=source_app, + ) + + high_confidence = explicit_remember or looks_high_confidence(content, mem_metadata) + + # --- Regex keyword extraction (0 LLM calls) --- + extracted_keywords: List[str] = [] + content_lower = content.lower() + + # Extract preference/routine/goal hints + for regex, tag in [ + (_PREFERENCE_HINT_RE, "preference"), + (_ROUTINE_HINT_RE, "routine"), + (_GOAL_HINT_RE, "goal"), + ]: + if regex.search(content): + extracted_keywords.append(tag) + + # Simple word tokenization for top keywords (skip stopwords) + _STOPWORDS = { + "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", + "have", "has", "had", "do", "does", "did", "will", "would", "could", + "should", "may", "might", "can", "shall", "to", "of", "in", "for", + "on", "with", "at", "by", "from", "as", "into", "through", "during", + "before", "after", "above", "below", "between", "and", "but", "or", + "nor", "not", "so", "yet", "both", "either", "neither", "each", + "every", "all", "any", "few", "more", "most", "other", "some", "such", + "no", "only", "own", "same", "than", "too", "very", "just", "i", "me", + "my", "we", "our", "you", "your", "he", "she", "it", "they", "them", + "this", "that", "these", "those", "am", "his", "her", "its", + } + words = re.findall(r"\b[a-z][a-z0-9_-]{2,}\b", content_lower) + word_freq: Dict[str, int] = {} + for w in words: + if w not in _STOPWORDS: + word_freq[w] = word_freq.get(w, 0) + 1 + top_words = sorted(word_freq, key=lambda w: word_freq[w], reverse=True)[:15] + extracted_keywords.extend(top_words) + + # Regex entity extraction (names, dates) + name_match = _NAME_HINT_RE.search(content) + if name_match: + extracted_keywords.append(f"name:{name_match.group(1).strip()}") + + mem_metadata["echo_keywords"] = extracted_keywords + mem_metadata["enrichment_status"] = "pending" + + # --- Build rich embedding text (content + context summary) --- + context_window = getattr(self.config.enrichment, "context_window_turns", 10) + context_summary = "" + if context_messages: + recent = context_messages[-context_window:] + context_lines = [ + f"{m.get('role', 'user')}: {str(m.get('content', ''))[:200]}" + for m in recent + ] + context_summary = " | ".join(context_lines) + + embed_text = content + if context_summary: + embed_text += f" [Context: {context_summary[:500]}]" + + # --- Generate embedding (1 API call, NOT an LLM call) --- + embedding = self.embedder.embed(embed_text, memory_action="add") + + # --- Confidence and layer --- + effective_strength = initial_strength + if not explicit_remember and not high_confidence: + mem_metadata["policy_low_confidence"] = True + effective_strength = min(effective_strength, 0.4) + + layer = initial_layer + if layer == "auto": + layer = "sml" + + # --- Metadata --- + confidentiality_scope = str( + mem_metadata.get("confidentiality_scope") + or mem_metadata.get("privacy_scope") + or "work" + ).lower() + source_type = ( + mem_metadata.get("source_type") + or ("cli" if (source_app or "").lower() == "cli" else "mcp") + ) + namespace_value = str(mem_metadata.get("namespace", "default") or "default").strip() or "default" + memory_type = self._classify_memory_type(mem_metadata, mem_metadata.get("role", "user")) + + # Multi-trace strength + s_fast_val = s_mid_val = s_slow_val = None + if self.distillation_config and self.distillation_config.enable_multi_trace: + s_fast_val, s_mid_val, s_slow_val = initialize_traces(effective_strength, is_new=True) + + # Content hash for dedup + from engram.memory.core import _content_hash + ch = _content_hash(content) + existing = self.db.get_memory_by_content_hash(ch, user_id) if hasattr(self.db, 'get_memory_by_content_hash') else None + if existing: + self.db.increment_access(existing["id"]) + return { + "id": existing["id"], + "memory": existing.get("memory", ""), + "event": "DEDUPLICATED", + "layer": existing.get("layer", "sml"), + "strength": existing.get("strength", 1.0), + } + + effective_memory_id = memory_id or str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + + # Serialize conversation context + context_json = None + if context_messages: + recent = context_messages[-context_window:] + context_json = json.dumps(recent) + + memory_data = { + "id": effective_memory_id, + "memory": content, + "user_id": user_id, + "agent_id": store_agent_id, + "run_id": store_run_id, + "app_id": store_app_id, + "metadata": mem_metadata, + "categories": mem_categories, + "immutable": immutable, + "expiration_date": expiration_date, + "created_at": now, + "updated_at": now, + "layer": layer, + "strength": effective_strength, + "access_count": 0, + "last_accessed": now, + "embedding": embedding, + "confidentiality_scope": confidentiality_scope, + "source_type": source_type, + "source_app": source_app or mem_metadata.get("source_app"), + "source_event_id": mem_metadata.get("source_event_id"), + "decay_lambda": self.fadem_config.sml_decay_rate, + "status": "active", + "importance": mem_metadata.get("importance", 0.5), + "sensitivity": mem_metadata.get("sensitivity", "normal"), + "namespace": namespace_value, + "memory_type": memory_type, + "s_fast": s_fast_val, + "s_mid": s_mid_val, + "s_slow": s_slow_val, + "content_hash": ch, + "conversation_context": context_json, + "enrichment_status": "pending", + } + + # Build vector index (single primary vector, no echo nodes) + base_payload = { + "memory_id": effective_memory_id, + "user_id": user_id, + "agent_id": store_agent_id, + "run_id": store_run_id, + "app_id": store_app_id, + "categories": mem_categories, + "text": embed_text, + "type": "primary", + "memory": content, + } + vectors = [embedding] + payloads = [base_payload] + vector_ids = [effective_memory_id] + + self.db.add_memory(memory_data) + try: + self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) + except Exception as e: + logger.error("Vector insert failed for memory %s (lite), rolling back: %s", effective_memory_id, e) + try: + self.db.delete_memory(effective_memory_id, use_tombstone=False) + except Exception as rollback_err: + logger.critical("DB rollback also failed for %s: %s", effective_memory_id, rollback_err) + raise + + # Scene assignment still works (embedding-based, no LLM) + if self.scene_processor: + try: + self._assign_to_scene(effective_memory_id, content, embedding, user_id, now) + except Exception as e: + logger.warning("Scene assignment failed for %s (lite): %s", effective_memory_id, e) + + return { + "id": effective_memory_id, + "memory": content, + "event": "ADD", + "layer": layer, + "strength": effective_strength, + "echo_depth": None, + "categories": mem_categories, + "namespace": namespace_value, + "vector_nodes": 1, + "memory_type": memory_type, + "enrichment_status": "pending", + } + + def enrich_pending( + self, + user_id: str = "default", + batch_size: int = 10, + max_batches: int = 5, + ) -> Dict[str, Any]: + """Batch-enrich memories that were stored with deferred enrichment. + + Uses unified enrichment: 1 LLM call per batch_size memories. + Returns {enriched_count, batches, remaining}. + """ + limit = batch_size * max_batches + pending = self.db.get_pending_enrichment(user_id=user_id, limit=limit) + if not pending: + return {"enriched_count": 0, "batches": 0, "remaining": 0} + + enriched_count = 0 + batches_processed = 0 + + for start in range(0, len(pending), batch_size): + batch = pending[start:start + batch_size] + contents = [m.get("memory", "") for m in batch] + + # Try unified enrichment (single LLM call for the batch) + enrichment_results = None + if self.unified_enrichment is not None: + try: + existing_cats = None + if self.category_processor: + cats = self.category_processor.get_all_categories() + if cats: + existing_cats = "\n".join( + f"- {c['id']}: {c['name']} — {c.get('description', '')}" + for c in cats[:30] + ) + + enrichment_results = self.unified_enrichment.enrich_batch( + contents, + depth=EchoDepth.MEDIUM, + existing_categories=existing_cats, + include_entities=True, + include_profiles=True, + ) + except Exception as e: + logger.warning("Unified batch enrichment failed in enrich_pending: %s", e) + enrichment_results = None + + # Fallback: individual enrichment per memory + if enrichment_results is None: + enrichment_results = [] + for c in contents: + if self.unified_enrichment is not None: + try: + enrichment_results.append( + self.unified_enrichment.enrich(c, depth=EchoDepth.MEDIUM) + ) + except Exception: + enrichment_results.append(None) + else: + enrichment_results.append(None) + + # Apply enrichment results and update DB + db_updates: List[Dict[str, Any]] = [] + for mem, enrichment in zip(batch, enrichment_results): + mem_id = mem["id"] + mem_meta = mem.get("metadata", {}) or {} + mem_cats = mem.get("categories", []) or [] + + if enrichment: + # Apply echo result + if enrichment.echo_result: + mem_meta.update(enrichment.echo_result.to_metadata()) + if not mem_cats and enrichment.echo_result.category: + mem_cats = [enrichment.echo_result.category] + + # Apply category result + if enrichment.category_match and not mem_cats: + mem_cats = [enrichment.category_match.category_id] + mem_meta["category_confidence"] = enrichment.category_match.confidence + mem_meta["category_auto"] = True + + # Apply extracted facts to metadata + if enrichment.facts: + mem_meta["enrichment_facts"] = enrichment.facts[:8] + + # Post-store hooks: entities + if self.knowledge_graph and enrichment.entities: + for entity in enrichment.entities: + existing_ent = self.knowledge_graph._get_or_create_entity( + entity.name, entity.entity_type, + ) + existing_ent.memory_ids.add(mem_id) + self.knowledge_graph.memory_entities[mem_id] = { + e.name for e in enrichment.entities + } + + # Post-store hooks: profiles + if self.profile_processor and enrichment.profile_updates: + for profile_update in enrichment.profile_updates: + try: + self.profile_processor.apply_update( + profile_update=profile_update, + memory_id=mem_id, + user_id=user_id, + ) + except Exception as e: + logger.warning("Profile update failed during enrichment for %s: %s", mem_id, e) + + # Generate fact decomposition vectors + if enrichment.facts: + valid_facts = [ + (i, f.strip()) for i, f in enumerate(enrichment.facts[:8]) + if f.strip() and len(f.strip()) >= 10 + ] + if valid_facts: + try: + fact_texts = [ft for _, ft in valid_facts] + fact_embeddings = self.embedder.embed_batch(fact_texts, memory_action="add") + fact_vectors, fact_payloads, fact_ids = [], [], [] + for (i, fact_text), fact_emb in zip(valid_facts, fact_embeddings): + fact_id = f"{mem_id}__fact_{i}" + fact_vectors.append(fact_emb) + fact_payloads.append({ + "memory_id": mem_id, + "is_fact": True, + "fact_index": i, + "fact_text": fact_text, + "user_id": user_id, + }) + fact_ids.append(fact_id) + if fact_vectors: + self.vector_store.insert( + vectors=fact_vectors, + payloads=fact_payloads, + ids=fact_ids, + ) + except Exception as e: + logger.warning("Fact embedding failed during enrichment for %s: %s", mem_id, e) + + mem_meta["enrichment_status"] = "complete" + db_updates.append({ + "id": mem_id, + "metadata": mem_meta, + "categories": mem_cats, + "enrichment_status": "complete", + }) + enriched_count += 1 + + # Batch DB update + self.db.update_enrichment_bulk(db_updates) + batches_processed += 1 + + # Check remaining + remaining_count = len(self.db.get_pending_enrichment(user_id=user_id, limit=1)) + + return { + "enriched_count": enriched_count, + "batches": batches_processed, + "remaining": remaining_count, + } + def search( self, query: str, @@ -1900,6 +2331,8 @@ def search( "memory_type": mem_type, "query_intent": query_intent.value if query_intent else None, "confidence": metadata.get("mm_confidence"), + "conversation_context": memory.get("conversation_context"), + "enrichment_status": memory.get("enrichment_status", "complete"), } ) diff --git a/engram/observability.py b/engram/observability.py index 5199e32..94e8ce2 100644 --- a/engram/observability.py +++ b/engram/observability.py @@ -1,24 +1,49 @@ -"""Engram Observability — lightweight no-op stub. +"""Engram Observability — compatibility-safe no-op implementation. -Full observability (Prometheus export, structured logging, etc.) -lives in engram-enterprise. +Core Engram does not require metrics infrastructure at runtime, but enterprise +and API layers import symbols from this module. Keep this interface stable and +side-effect free so those imports always succeed. """ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import Any, Dict, Iterator + +logger = logging.getLogger("engram") + class _NoOpMetrics: """Drop-in replacement that silently discards all metric calls.""" - def record_add(self, *a, **kw): pass - def record_search(self, *a, **kw): pass - def record_decay(self, *a, **kw): pass - def record_get(self, *a, **kw): pass - def record_delete(self, *a, **kw): pass - def record_masked_hits(self, *a, **kw): pass - def record_staged_commit(self, *a, **kw): pass - def record_commit_approval(self, *a, **kw): pass - def record_commit_rejection(self, *a, **kw): pass - def record_ref_protected_skip(self, *a, **kw): pass - def get_summary(self): return {} + def __getattr__(self, _: str): + def _noop(*args, **kwargs): + return None + + return _noop + + @contextmanager + def measure(self, *args, **kwargs) -> Iterator[None]: + yield + + def get_summary(self) -> Dict[str, Any]: + return {} metrics = _NoOpMetrics() + + +def add_metrics_routes(app: Any) -> None: + """Register a lightweight /metrics endpoint if FastAPI is available.""" + try: + routes = getattr(app, "routes", []) + if any(getattr(route, "path", None) == "/metrics" for route in routes): + return + + @app.get("/metrics") + async def _metrics_endpoint() -> Dict[str, Any]: + return metrics.get_summary() + except Exception: + # Keep observability strictly non-blocking. + return diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..1d4fe0a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,11 @@ +[pytest] +testpaths = + tests + engram-bus/tests + engram-enterprise/tests +pythonpath = + . + engram-bus + engram-enterprise +markers = + integration: tests that require external services or credentials diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/longmemeval_test.json b/tests/longmemeval_test.json new file mode 100644 index 0000000..5b25a9a --- /dev/null +++ b/tests/longmemeval_test.json @@ -0,0 +1,33 @@ +[ + { + "question_id": "test_q1", + "question": "What programming language does the user prefer for machine learning projects?", + "haystack_session_ids": ["sess_001", "sess_002", "sess_003", "sess_004"], + "haystack_dates": ["2024-01-15", "2024-02-20", "2024-03-10", "2024-04-05"], + "haystack_sessions": [ + [ + {"role": "user", "content": "I've been thinking about getting into machine learning"}, + {"role": "assistant", "content": "That's exciting! What interests you most about ML?"}, + {"role": "user", "content": "I want to build models for image classification. Which language should I start with?"}, + {"role": "assistant", "content": "Python is the most popular choice for ML. It has great libraries like TensorFlow and PyTorch."} + ], + [ + {"role": "user", "content": "I'm setting up my development environment for ML work"}, + {"role": "assistant", "content": "Great! Make sure to install Python 3.10 and set up a virtual environment."}, + {"role": "user", "content": "Should I use Conda or venv?"}, + {"role": "assistant", "content": "Both work well. Conda is convenient for data science as it comes with many scientific packages pre-installed."} + ], + [ + {"role": "user", "content": "I'm struggling with deploying my models to production"}, + {"role": "assistant", "content": "What framework are you currently using with Python?"}, + {"role": "user", "content": "I'm using PyTorch with FastAPI for serving predictions. It's working well."} + ], + [ + {"role": "user", "content": "Can you help me optimize a neural network?"}, + {"role": "assistant", "content": "What kind of optimization are you looking for - training speed or model size?"}, + {"role": "user", "content": "Training speed mainly. My Python training scripts are taking too long on large datasets."} + ] + ], + "answer_session_ids": ["sess_001", "sess_003"] + } +] diff --git a/tests/test_deferred_enrichment.py b/tests/test_deferred_enrichment.py new file mode 100644 index 0000000..7a87f1e --- /dev/null +++ b/tests/test_deferred_enrichment.py @@ -0,0 +1,529 @@ +"""Tests for conversation-aware memory with deferred enrichment. + +Covers: +1. Lite path stores memory with 0 LLM calls +2. Conversation context stored in dedicated column +3. Regex keyword extraction (preferences, routines, goals, entities) +4. Embedding includes context summary +5. enrichment_status set to "pending" +6. enrich_pending() processes batch correctly +7. After enrichment: echo_keywords, categories populated +8. enrichment_status updated to "complete" +9. Search returns conversation_context +10. Scene assignment still works in lite path +11. Content dedup (content_hash) still works in lite path +12. Integration: add N memories lite → enrich_pending → search +""" + +import json +import os +import tempfile + +import pytest + +from engram.configs.base import ( + BatchConfig, + CategoryMemConfig, + EchoMemConfig, + EmbedderConfig, + EnrichmentConfig, + KnowledgeGraphConfig, + LLMConfig, + MemoryConfig, + ProfileConfig, + SceneConfig, + VectorStoreConfig, +) +from engram.memory.main import FullMemory as Memory + + +def _make_deferred_memory(tmpdir, defer=True, echo=False, categories=False): + """Create a Memory instance with deferred enrichment enabled.""" + config = MemoryConfig( + llm=LLMConfig(provider="mock", config={}), + embedder=EmbedderConfig(provider="simple", config={"embedding_dims": 384}), + vector_store=VectorStoreConfig( + provider="memory", + config={"embedding_model_dims": 384}, + ), + history_db_path=os.path.join(tmpdir, "test.db"), + embedding_model_dims=384, + echo=EchoMemConfig(enable_echo=echo), + category=CategoryMemConfig(enable_categories=categories, use_llm_categorization=False), + graph=KnowledgeGraphConfig(enable_graph=False), + scene=SceneConfig(enable_scenes=False), + profile=ProfileConfig(enable_profiles=False), + enrichment=EnrichmentConfig( + enable_unified=False, + defer_enrichment=defer, + context_window_turns=5, + ), + batch=BatchConfig(enable_batch=False), + ) + return Memory(config) + + +class TestDeferredEnrichmentLitePath: + """Test the lite processing path (0 LLM calls).""" + + def test_lite_path_stores_memory(self): + """Lite path stores memory and returns ADD event.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="I prefer using Python for data science", + user_id="test_user", + infer=False, + ) + assert "results" in result + assert len(result["results"]) == 1 + r = result["results"][0] + assert r["event"] == "ADD" + assert r["memory"] == "I prefer using Python for data science" + assert r["enrichment_status"] == "pending" + m.close() + + def test_enrichment_status_pending(self): + """Memories stored via lite path have enrichment_status='pending'.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="Remember my favorite color is blue", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + assert r["enrichment_status"] == "pending" + + # Also verify in DB + mem = m.db.get_memory(r["id"]) + assert mem is not None + assert mem.get("enrichment_status") == "pending" + m.close() + + def test_non_deferred_is_complete(self): + """Without defer_enrichment, status is 'complete'.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir, defer=False) + result = m.add( + messages="I like cats", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + # Non-deferred path doesn't set enrichment_status in result + # but the DB should have 'complete' + mem = m.db.get_memory(r["id"]) + assert mem.get("enrichment_status") in ("complete", None) + m.close() + + +class TestRegexKeywordExtraction: + """Test regex-based keyword extraction in the lite path.""" + + def test_preference_keywords(self): + """Preference hints are extracted as keywords.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="I prefer Python over JavaScript and always use type hints", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + metadata = mem.get("metadata", {}) + keywords = metadata.get("echo_keywords", []) + assert "preference" in keywords + m.close() + + def test_routine_keywords(self): + """Routine hints are extracted.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="Every morning I review pull requests before standup", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + metadata = mem.get("metadata", {}) + keywords = metadata.get("echo_keywords", []) + assert "routine" in keywords + m.close() + + def test_goal_keywords(self): + """Goal hints are extracted.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="My goal is to learn Rust this year for systems programming", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + metadata = mem.get("metadata", {}) + keywords = metadata.get("echo_keywords", []) + assert "goal" in keywords + m.close() + + def test_name_entity_extraction(self): + """Names are extracted via regex.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="My name is Alice and I work at Anthropic", + user_id="test_user", + metadata={"allow_sensitive": True}, + infer=False, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + metadata = mem.get("metadata", {}) + keywords = metadata.get("echo_keywords", []) + # Should contain name:Alice + assert any("name:Alice" in k for k in keywords) + m.close() + + def test_word_tokenization_keywords(self): + """Top content words are extracted as keywords.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="Python machine learning tensorflow pytorch neural networks deep learning", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + metadata = mem.get("metadata", {}) + keywords = metadata.get("echo_keywords", []) + # Should contain domain-specific words + assert "python" in keywords + assert "learning" in keywords + m.close() + + +class TestConversationContext: + """Test conversation context storage and retrieval.""" + + def test_context_stored_in_db(self): + """Context messages are stored in the conversation_context column.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + ctx = [ + {"role": "user", "content": "What language should I learn?"}, + {"role": "assistant", "content": "It depends on your goals."}, + {"role": "user", "content": "I want to do data science"}, + ] + result = m.add( + messages="User prefers Python for data science", + user_id="test_user", + infer=False, + context_messages=ctx, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + ctx_raw = mem.get("conversation_context") + assert ctx_raw is not None + parsed = json.loads(ctx_raw) + assert len(parsed) == 3 + assert parsed[0]["role"] == "user" + m.close() + + def test_context_window_truncation(self): + """Only last N turns are stored based on context_window_turns config.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + # Config has context_window_turns=5, send 10 turns + ctx = [{"role": "user", "content": f"Message {i}"} for i in range(10)] + result = m.add( + messages="Summary of the conversation", + user_id="test_user", + infer=False, + context_messages=ctx, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + parsed = json.loads(mem.get("conversation_context", "[]")) + assert len(parsed) == 5 # Only last 5 + assert parsed[0]["content"] == "Message 5" + m.close() + + def test_no_context_is_null(self): + """When no context_messages provided, conversation_context is None.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.add( + messages="Standalone fact", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + mem = m.db.get_memory(r["id"]) + assert mem.get("conversation_context") is None + m.close() + + def test_context_in_search_results(self): + """Search results include conversation_context field.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + ctx = [ + {"role": "user", "content": "Tell me about Python"}, + {"role": "assistant", "content": "Python is a great language"}, + ] + m.add( + messages="User loves Python programming", + user_id="test_user", + infer=False, + context_messages=ctx, + ) + search_result = m.search("Python", user_id="test_user") + results = search_result.get("results", []) + assert len(results) >= 1 + # conversation_context should be present in search results + assert "conversation_context" in results[0] + m.close() + + +class TestContentDedup: + """Test content hash deduplication in lite path.""" + + def test_duplicate_content_deduplicated(self): + """Adding the same content twice returns DEDUPLICATED.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + r1 = m.add( + messages="I prefer dark mode in all editors", + user_id="test_user", + infer=False, + ) + assert r1["results"][0]["event"] == "ADD" + + r2 = m.add( + messages="I prefer dark mode in all editors", + user_id="test_user", + infer=False, + ) + assert r2["results"][0]["event"] == "DEDUPLICATED" + m.close() + + +class TestDBMigration: + """Test database migration for deferred enrichment columns.""" + + def test_columns_exist_after_migration(self): + """New columns are created during DB init.""" + with tempfile.TemporaryDirectory() as tmpdir: + from engram.db.sqlite import SQLiteManager + db = SQLiteManager(os.path.join(tmpdir, "test.db")) + # Try inserting with the new columns + import uuid + mid = str(uuid.uuid4()) + db.add_memory({ + "id": mid, + "memory": "test", + "user_id": "u1", + "conversation_context": json.dumps([{"role": "user", "content": "hi"}]), + "enrichment_status": "pending", + }) + mem = db.get_memory(mid) + assert mem is not None + assert mem.get("enrichment_status") == "pending" + assert mem.get("conversation_context") is not None + db.close() + + def test_get_pending_enrichment(self): + """get_pending_enrichment returns only pending memories.""" + with tempfile.TemporaryDirectory() as tmpdir: + from engram.db.sqlite import SQLiteManager + db = SQLiteManager(os.path.join(tmpdir, "test.db")) + import uuid + + # Add a pending memory + mid1 = str(uuid.uuid4()) + db.add_memory({ + "id": mid1, "memory": "pending memory", + "user_id": "u1", "enrichment_status": "pending", + }) + # Add a complete memory + mid2 = str(uuid.uuid4()) + db.add_memory({ + "id": mid2, "memory": "complete memory", + "user_id": "u1", "enrichment_status": "complete", + }) + + pending = db.get_pending_enrichment(user_id="u1", limit=10) + assert len(pending) == 1 + assert pending[0]["id"] == mid1 + db.close() + + def test_update_enrichment_status(self): + """update_enrichment_status marks a memory as complete.""" + with tempfile.TemporaryDirectory() as tmpdir: + from engram.db.sqlite import SQLiteManager + db = SQLiteManager(os.path.join(tmpdir, "test.db")) + import uuid + + mid = str(uuid.uuid4()) + db.add_memory({ + "id": mid, "memory": "test", + "user_id": "u1", "enrichment_status": "pending", + }) + + db.update_enrichment_status(mid, "complete") + mem = db.get_memory(mid) + assert mem["enrichment_status"] == "complete" + + # Should no longer appear in pending + pending = db.get_pending_enrichment(user_id="u1") + assert len(pending) == 0 + db.close() + + +class TestEnrichPending: + """Test batch enrichment of pending memories.""" + + def test_enrich_pending_empty(self): + """enrich_pending with no pending memories returns 0.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + result = m.enrich_pending(user_id="test_user") + assert result["enriched_count"] == 0 + assert result["batches"] == 0 + m.close() + + def test_enrich_pending_marks_complete(self): + """After enrich_pending, memories are marked as complete.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + # Add memories via lite path + for i in range(3): + m.add( + messages=f"Fact number {i}: Python is great for data analysis", + user_id="test_user", + infer=False, + ) + + # Verify they're pending + pending = m.db.get_pending_enrichment(user_id="test_user") + assert len(pending) == 3 + + # Enrich + result = m.enrich_pending(user_id="test_user", batch_size=10) + assert result["enriched_count"] == 3 + assert result["remaining"] == 0 + + # Verify they're now complete + pending_after = m.db.get_pending_enrichment(user_id="test_user") + assert len(pending_after) == 0 + m.close() + + def test_enrich_pending_respects_batch_size(self): + """enrich_pending processes in batches respecting max_batches.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + # Add 5 memories + for i in range(5): + m.add( + messages=f"Memory item {i} about machine learning", + user_id="test_user", + infer=False, + ) + + # Process only 1 batch of 2 + result = m.enrich_pending( + user_id="test_user", + batch_size=2, + max_batches=1, + ) + assert result["enriched_count"] == 2 + assert result["remaining"] > 0 + m.close() + + +class TestIntegration: + """Integration tests: add → enrich → search.""" + + def test_add_enrich_search_flow(self): + """Full flow: add memories with deferred enrichment, enrich, then search.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + + # Add several memories + m.add(messages="I prefer Python for backend development", user_id="alice", infer=False) + m.add(messages="My favorite editor is VS Code with vim keybindings", user_id="alice", infer=False) + m.add(messages="Every morning I review GitHub notifications first", user_id="alice", infer=False) + + # Search should work even before enrichment (embedding-based) + pre_results = m.search("Python backend", user_id="alice") + assert len(pre_results.get("results", [])) >= 1 + + # Enrich + enrich_result = m.enrich_pending(user_id="alice") + assert enrich_result["enriched_count"] == 3 + + # Search after enrichment + post_results = m.search("Python backend", user_id="alice") + assert len(post_results.get("results", [])) >= 1 + m.close() + + def test_add_with_context_and_search(self): + """Add with context, verify context is retrievable via search.""" + with tempfile.TemporaryDirectory() as tmpdir: + m = _make_deferred_memory(tmpdir) + ctx = [ + {"role": "user", "content": "What's the best language for web scraping?"}, + {"role": "assistant", "content": "Python with Beautiful Soup or Scrapy is excellent."}, + ] + m.add( + messages="User wants to learn Python web scraping with Beautiful Soup", + user_id="bob", + infer=False, + context_messages=ctx, + ) + + results = m.search("web scraping Python", user_id="bob") + hits = results.get("results", []) + assert len(hits) >= 1 + # Context should be in the result + ctx_field = hits[0].get("conversation_context") + assert ctx_field is not None + parsed_ctx = json.loads(ctx_field) if isinstance(ctx_field, str) else ctx_field + assert len(parsed_ctx) == 2 + m.close() + + def test_mixed_deferred_and_normal(self): + """Switching defer_enrichment off still works for normal add.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create with deferred OFF + m = _make_deferred_memory(tmpdir, defer=False) + result = m.add( + messages="Normal memory without deferred enrichment", + user_id="test_user", + infer=False, + ) + r = result["results"][0] + assert r["event"] == "ADD" + # Should not have enrichment_status in the result (normal path) + assert r.get("enrichment_status") is None + m.close() + + +class TestEnrichmentConfig: + """Test EnrichmentConfig deferred fields.""" + + def test_default_values(self): + """Default config has defer_enrichment=False.""" + config = EnrichmentConfig() + assert config.defer_enrichment is False + assert config.context_window_turns == 10 + assert config.enrich_on_access is False + + def test_enable_deferred(self): + """Can enable deferred enrichment via config.""" + config = EnrichmentConfig(defer_enrichment=True, context_window_turns=5) + assert config.defer_enrichment is True + assert config.context_window_turns == 5 diff --git a/tests/test_e2e_all_features.py b/tests/test_e2e_all_features.py index aeb69f6..42cad7a 100644 --- a/tests/test_e2e_all_features.py +++ b/tests/test_e2e_all_features.py @@ -15,6 +15,8 @@ import traceback from datetime import datetime, timezone +import pytest + # ── Setup ────────────────────────────────────────────────────── _ENV_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") @@ -26,6 +28,15 @@ key, _, value = line.partition("=") os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) +_NVIDIA_KEYS = ( + "NVIDIA_API_KEY", + "NVIDIA_EMBEDDING_API_KEY", + "NVIDIA_QWEN_API_KEY", + "LLAMA_API_KEY", +) +if not any(os.environ.get(key) for key in _NVIDIA_KEYS): + pytest.skip("requires NVIDIA API credentials", allow_module_level=True) + from engram.configs.base import ( MemoryConfig, LLMConfig, EmbedderConfig, VectorStoreConfig, EchoMemConfig, CategoryMemConfig, ProfileConfig, diff --git a/tests/test_mcp_tools_slim.py b/tests/test_mcp_tools_slim.py index 51b6d13..2fceb52 100644 --- a/tests/test_mcp_tools_slim.py +++ b/tests/test_mcp_tools_slim.py @@ -1,23 +1,38 @@ -"""Verify MCP server has exactly 8 tools.""" +"""Verify MCP server tool contract.""" from engram import mcp_server +EXPECTED_TOOL_NAMES = { + "remember", + "search_memory", + "get_memory", + "get_all_memories", + "engram_context", + "get_last_session", + "save_session_digest", + "get_memory_stats", + "search_skills", + "apply_skill", + "log_skill_outcome", + "record_trajectory_step", + "mine_skills", + "get_skill_stats", + "search_skills_structural", + "analyze_skill_gaps", + "decompose_skill", + "apply_skill_with_bindings", + "enrich_pending", +} + class TestMCPToolsSlim: - def test_exactly_8_tools(self): + def test_expected_tool_contract(self): tools = mcp_server.TOOLS - assert len(tools) == 8, f"Expected 8 tools, got {len(tools)}: {[t.name for t in tools]}" - - def test_core_tools_present(self): tool_names = [t.name for t in mcp_server.TOOLS] - assert "remember" in tool_names - assert "search_memory" in tool_names - assert "get_memory" in tool_names - assert "get_all_memories" in tool_names - assert "get_memory_stats" in tool_names - assert "engram_context" in tool_names - assert "get_last_session" in tool_names - assert "save_session_digest" in tool_names + assert len(tools) == len(EXPECTED_TOOL_NAMES), ( + f"Expected {len(EXPECTED_TOOL_NAMES)} tools, got {len(tools)}: {tool_names}" + ) + assert set(tool_names) == EXPECTED_TOOL_NAMES def test_no_duplicate_tool_names(self): tool_names = [t.name for t in mcp_server.TOOLS] diff --git a/tests/test_power_packages.py b/tests/test_power_packages.py index 11e99c0..3f6d12f 100644 --- a/tests/test_power_packages.py +++ b/tests/test_power_packages.py @@ -22,6 +22,15 @@ value = value.strip().strip('"').strip("'") os.environ.setdefault(key.strip(), value) +_NVIDIA_KEYS = ( + "NVIDIA_API_KEY", + "NVIDIA_EMBEDDING_API_KEY", + "NVIDIA_QWEN_API_KEY", + "LLAMA_API_KEY", +) +if not any(os.environ.get(key) for key in _NVIDIA_KEYS): + pytest.skip("requires NVIDIA API credentials", allow_module_level=True) + @pytest.fixture(scope="session") def memory(): From 36f705d243d026b569a5855effeee22b2bff45ca Mon Sep 17 00:00:00 2001 From: Vivek Kumar Date: Fri, 20 Feb 2026 12:57:41 +0530 Subject: [PATCH 8/8] feat: neural reranker and NVIDIA nemotron-embed support Add cross-encoder reranking stage using NVIDIA NIM API for improved retrieval accuracy. Add nemotron-embed model support with proper modality list handling in batch embeddings. Co-Authored-By: Claude Opus 4.6 --- engram/benchmarks/longmemeval.py | 16 +++- engram/configs/base.py | 11 +++ engram/embeddings/nvidia.py | 15 +++- engram/memory/main.py | 38 +++++++++ engram/retrieval/__init__.py | 9 ++- engram/retrieval/reranker.py | 135 +++++++++++++++++++++++++++++++ 6 files changed, 217 insertions(+), 7 deletions(-) create mode 100644 engram/retrieval/reranker.py diff --git a/engram/benchmarks/longmemeval.py b/engram/benchmarks/longmemeval.py index ccb7228..ea664ec 100644 --- a/engram/benchmarks/longmemeval.py +++ b/engram/benchmarks/longmemeval.py @@ -27,6 +27,7 @@ LLMConfig, MemoryConfig, ProfileConfig, + RerankConfig, SceneConfig, VectorStoreConfig, ) @@ -163,6 +164,8 @@ def build_memory( embedder_model: Optional[str] = None, full_potential: bool = True, defer_enrichment: bool = False, + enable_rerank: bool = False, + rerank_model: Optional[str] = None, ) -> Memory: """Build Engram Memory for LongMemEval. By default uses full potential (echo, categories, graph, scenes, profiles). @@ -174,13 +177,17 @@ def build_memory( "embedding_model_dims": embedding_dims, } - llm_cfg: Dict[str, Any] = {"max_tokens": 8192, "timeout": 300, "model": "meta/llama-3.3-70b-instruct"} + llm_cfg: Dict[str, Any] = {"max_tokens": 16384, "timeout": 300, "model": "meta/llama-3.3-70b-instruct"} if llm_model: llm_cfg["model"] = llm_model embedder_cfg: Dict[str, Any] = {"embedding_dims": embedding_dims} if embedder_model: embedder_cfg["model"] = embedder_model + rerank_cfg = RerankConfig(enable_rerank=enable_rerank) + if rerank_model: + rerank_cfg = RerankConfig(enable_rerank=enable_rerank, model=rerank_model) + config = MemoryConfig( vector_store=VectorStoreConfig(provider=vector_store_provider, config=vector_cfg), llm=LLMConfig(provider=llm_provider, config=llm_cfg), @@ -194,10 +201,11 @@ def build_memory( profile=ProfileConfig(use_llm_extraction=full_potential, enable_profiles=full_potential), enrichment=EnrichmentConfig( enable_unified=full_potential, - max_batch_size=10, + max_batch_size=5, defer_enrichment=defer_enrichment, ), batch=BatchConfig(enable_batch=full_potential and not defer_enrichment, max_batch_size=50), + rerank=rerank_cfg, ) mem = Memory(config) # FullMemory features (categories, scenes, profiles) need FullSQLiteManager @@ -267,6 +275,8 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: embedder_model=args.embedder_model, full_potential=args.full_potential, defer_enrichment=use_deferred, + enable_rerank=getattr(args, "enable_rerank", False), + rerank_model=getattr(args, "rerank_model", None), ) hf_responder: Optional[HFResponder] = None @@ -494,6 +504,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--vector-store-provider", choices=["memory", "sqlite_vec"], default="memory") parser.add_argument("--history-db-path", default="/content/engram-longmemeval.db", help="SQLite db path.") parser.add_argument("--defer-enrichment", action="store_true", default=False, help="Use deferred enrichment (0 LLM calls at ingestion, batch enrich after).") + parser.add_argument("--enable-rerank", action="store_true", default=False, help="Enable neural reranking (cross-encoder second stage on retrieved results).") + parser.add_argument("--rerank-model", default=None, help="Reranker model override (default: nvidia/llama-3.2-nv-rerankqa-1b-v2).") args = parser.parse_args() args.full_potential = not args.minimal args.defer_enrichment = args.defer_enrichment diff --git a/engram/configs/base.py b/engram/configs/base.py index 271be63..44cc199 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -376,6 +376,16 @@ def _valid_priority(cls, v: str) -> str: return v +class RerankConfig(BaseModel): + """Configuration for neural reranking (cross-encoder second stage).""" + enable_rerank: bool = False + provider: str = "nvidia" # Currently only nvidia supported + model: str = "nvidia/llama-3.2-nv-rerankqa-1b-v2" + api_key_env: str = "NVIDIA_API_KEY" # Env var name for API key + top_n: int = 0 # Number of results to return after reranking (0 = return all, re-sorted) + config: Dict[str, Any] = Field(default_factory=dict) + + class EnrichmentConfig(BaseModel): """Configuration for unified enrichment (single LLM call for echo+category+entities+profiles).""" enable_unified: bool = False # Off by default for backward compat @@ -481,6 +491,7 @@ class MemoryConfig(BaseModel): parallel: ParallelConfig = Field(default_factory=ParallelConfig) batch: BatchConfig = Field(default_factory=BatchConfig) enrichment: EnrichmentConfig = Field(default_factory=EnrichmentConfig) + rerank: RerankConfig = Field(default_factory=RerankConfig) skill: SkillConfig = Field(default_factory=SkillConfig) task: TaskConfig = Field(default_factory=TaskConfig) metamemory: MetamemoryInlineConfig = Field(default_factory=MetamemoryInlineConfig) diff --git a/engram/embeddings/nvidia.py b/engram/embeddings/nvidia.py index 43b6914..52837df 100644 --- a/engram/embeddings/nvidia.py +++ b/engram/embeddings/nvidia.py @@ -33,11 +33,20 @@ def __init__(self, config: Optional[dict] = None): self.client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout) self.model = self.config.get("model", "nvidia/nv-embed-v1") - def _extra_body(self, memory_action: Optional[str] = None) -> dict: - """Build extra_body for E5/embedqa models.""" + def _extra_body(self, memory_action: Optional[str] = None, count: int = 1) -> dict: + """Build extra_body for models that need input_type differentiation. + + Args: + memory_action: The action type (search, forget, etc.) + count: Number of texts in the batch. nemotron-embed requires + modality list length to match input length. + """ if "e5" in self.model or "embedqa" in self.model: input_type = "query" if memory_action in ("search", "forget") else "passage" return {"input_type": input_type, "truncate": "END"} + if "nemotron-embed" in self.model: + input_type = "query" if memory_action in ("search", "forget") else "passage" + return {"modality": ["text"] * count, "input_type": input_type, "truncate": "END"} return {} def _truncate_if_needed(self, text: str) -> str: @@ -89,7 +98,7 @@ def embed_batch( if len(texts) == 1: return [self.embed(texts[0], memory_action=memory_action)] try: - extra_body = self._extra_body(memory_action) + extra_body = self._extra_body(memory_action, count=len(texts)) response = self.client.embeddings.create( input=texts, model=self.model, diff --git a/engram/memory/main.py b/engram/memory/main.py index 7e35ebf..eadd0d9 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -254,6 +254,8 @@ def __init__(self, config: Optional[MemoryConfig] = None, preset: Optional[str] self._profile_processor: Optional[ProfileProcessor] = None self._task_manager: Optional[Any] = None self._project_manager: Optional[Any] = None + # Neural reranker (lazy init) + self._reranker: Optional[Any] = None # Trajectory recording and skill mining self._trajectory_store: Optional[Any] = None self._skill_miner: Optional[Any] = None @@ -327,6 +329,20 @@ def skill_miner(self): ) return self._skill_miner + @property + def reranker(self): + """Lazy-initialized neural reranker (only if enabled in config).""" + rerank_cfg = getattr(self.config, "rerank", None) + if self._reranker is None and rerank_cfg and rerank_cfg.enable_rerank: + from engram.retrieval.reranker import create_reranker + self._reranker = create_reranker({ + "provider": rerank_cfg.provider, + "model": rerank_cfg.model, + "api_key_env": rerank_cfg.api_key_env, + **rerank_cfg.config, + }) + return self._reranker + def start_trajectory( self, task_description: str, @@ -2367,6 +2383,28 @@ def search( results.sort(key=lambda x: x["composite_score"], reverse=True) + # Neural reranking: cross-encoder second stage on top candidates + rerank_cfg = getattr(self.config, "rerank", None) + if rerank and self.reranker and results: + try: + passages = [r.get("memory", "") for r in results[:limit]] + reranked = self.reranker.rerank( + query=query, + passages=passages, + top_n=rerank_cfg.top_n if rerank_cfg and rerank_cfg.top_n > 0 else 0, + ) + # Re-order results by reranker logits + idx_to_logit = {r["index"]: r["logit"] for r in reranked} + for i, result in enumerate(results[:limit]): + result["rerank_logit"] = idx_to_logit.get(i, float("-inf")) + results[:limit] = sorted( + results[:limit], + key=lambda x: x.get("rerank_logit", float("-inf")), + reverse=True, + ) + except Exception as e: + logger.warning("Reranking failed, using composite_score order: %s", e) + # Metamemory: auto-log knowledge gap when search returns no results if not results and self.config.metamemory.auto_log_gaps: try: diff --git a/engram/retrieval/__init__.py b/engram/retrieval/__init__.py index 1942d25..10ed676 100644 --- a/engram/retrieval/__init__.py +++ b/engram/retrieval/__init__.py @@ -1,5 +1,10 @@ """Engram v2 retrieval components.""" -from engram.retrieval.dual_search import DualSearchEngine +try: + from engram.retrieval.dual_search import DualSearchEngine +except ImportError: + DualSearchEngine = None -__all__ = ["DualSearchEngine"] +from engram.retrieval.reranker import NvidiaReranker, create_reranker + +__all__ = ["DualSearchEngine", "NvidiaReranker", "create_reranker"] diff --git a/engram/retrieval/reranker.py b/engram/retrieval/reranker.py new file mode 100644 index 0000000..22842ba --- /dev/null +++ b/engram/retrieval/reranker.py @@ -0,0 +1,135 @@ +"""Neural reranker for second-stage retrieval refinement. + +Uses a cross-encoder model to re-score (query, passage) pairs with full +attention, producing much more accurate relevance scores than embedding +cosine similarity alone. +""" + +import logging +import os +import time +from typing import Any, Dict, List, Optional + +import requests + +logger = logging.getLogger(__name__) + + +class NvidiaReranker: + """NVIDIA NIM reranker using the /reranking endpoint.""" + + _DEFAULT_URL = ( + "https://ai.api.nvidia.com/v1/retrieval/" + "nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking" + ) + + def __init__(self, config: Optional[Dict[str, Any]] = None): + config = config or {} + self.model = config.get("model", "nvidia/llama-3.2-nv-rerankqa-1b-v2") + api_key_env = config.get("api_key_env", "NVIDIA_API_KEY") + self.api_key = config.get("api_key") or os.getenv(api_key_env) + if not self.api_key: + raise ValueError( + f"NVIDIA API key required for reranker. Set config['api_key'] or {api_key_env} env var." + ) + # Build URL from model name: replace / with _ and dots with _ + # e.g. nvidia/llama-3.2-nv-rerankqa-1b-v2 -> nvidia/llama-3_2-nv-rerankqa-1b-v2 + model_path = self.model.replace(".", "_") + self.url = config.get( + "url", + f"https://ai.api.nvidia.com/v1/retrieval/{model_path}/reranking", + ) + self.timeout = config.get("timeout", 30) + self.max_retries = config.get("max_retries", 2) + + def rerank( + self, + query: str, + passages: List[str], + top_n: int = 0, + ) -> List[Dict[str, Any]]: + """Rerank passages against a query. + + Args: + query: The search query. + passages: List of passage texts to rerank. + top_n: Number of top results to return (0 = return all, re-sorted). + + Returns: + List of dicts with keys: index (original position), logit, text. + Sorted by logit descending. + """ + if not passages: + return [] + if len(passages) == 1: + return [{"index": 0, "logit": 0.0, "text": passages[0]}] + + payload = { + "model": self.model, + "query": {"text": query}, + "passages": [{"text": p} for p in passages], + } + if top_n > 0: + payload["top_n"] = top_n + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + last_exc = None + for attempt in range(self.max_retries + 1): + try: + t0 = time.monotonic() + resp = requests.post( + self.url, + json=payload, + headers=headers, + timeout=self.timeout, + ) + elapsed_ms = (time.monotonic() - t0) * 1000 + resp.raise_for_status() + data = resp.json() + + rankings = data.get("rankings", []) + results = [] + for r in rankings: + idx = r.get("index", 0) + results.append({ + "index": idx, + "logit": r.get("logit", 0.0), + "text": passages[idx] if idx < len(passages) else "", + }) + results.sort(key=lambda x: x["logit"], reverse=True) + logger.debug( + "Reranked %d passages in %.0fms (top logit=%.2f)", + len(passages), elapsed_ms, + results[0]["logit"] if results else 0.0, + ) + return results + + except Exception as exc: + last_exc = exc + if attempt < self.max_retries: + delay = min(2 ** attempt, 4) + logger.warning( + "Reranker retry %d/%d after %ss: %s", + attempt + 1, self.max_retries, delay, exc, + ) + time.sleep(delay) + else: + logger.error("Reranker failed after %d attempts: %s", self.max_retries + 1, exc) + + raise RuntimeError(f"Reranker failed: {last_exc}") from last_exc + + +def create_reranker(config: Optional[Dict[str, Any]] = None) -> Optional[NvidiaReranker]: + """Factory: create a reranker from config, or return None if disabled.""" + if not config: + return None + provider = config.get("provider", "nvidia") + if provider == "nvidia": + return NvidiaReranker(config) + logger.warning("Unknown reranker provider: %s", provider) + return None