diff --git a/.gitignore b/.gitignore index cb074d4..75b8d79 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,10 @@ pyrightconfig.json # Local roadmap files ROADMAP.md agentmemory-roadmap.md + +# Hookify rules (personal) +.claude/*.local.md + +# Benchmark data and results +benchmarks/data/*.json +benchmarks/results/*/ diff --git a/Makefile b/Makefile index 72d7e55..87d73bc 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,18 @@ docs: ## Build the documentation docs-serve: ## Build and serve the documentation uv run mkdocs serve +.PHONY: benchmark +benchmark: ## Run LongMemEval benchmark (all 3 stages) + uv run python -m benchmarks.longmemeval.run + +.PHONY: benchmark-smoke +benchmark-smoke: ## Quick 3-question benchmark sanity check + uv run python -m benchmarks.longmemeval.run --num-questions 3 --run-name smoke --config fast + +.PHONY: benchmark-baseline +benchmark-baseline: ## Full baseline benchmark run (concurrent) + uv run python -m benchmarks.longmemeval.run --run-name baseline --max-concurrent 20 + .PHONY: all all: format lint typecheck test ## Run formatting, linting, type checks, and tests diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/data/.gitkeep b/benchmarks/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/longmemeval/__init__.py b/benchmarks/longmemeval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/longmemeval/_checkpoint.py b/benchmarks/longmemeval/_checkpoint.py new file mode 100644 index 0000000..bf5471c --- /dev/null +++ b/benchmarks/longmemeval/_checkpoint.py @@ -0,0 +1,47 @@ +"""JSONL checkpoint helpers for crash-safe benchmark runs.""" + +from __future__ import annotations + +import json +from pathlib import Path + +RESULTS_DIR = Path(__file__).parent.parent / "results" + + +def load_completed(jsonl_path: Path) -> set[str]: + """Load completed question IDs from checkpoint JSONL file.""" + completed: set[str] = set() + if not jsonl_path.exists(): + return completed + for line in jsonl_path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + completed.add(obj["question_id"]) + except (json.JSONDecodeError, KeyError): + continue + return completed + + +def append_jsonl(jsonl_path: Path, result: dict) -> None: + """Append a single result as a JSONL line (atomic append).""" + with jsonl_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + + +def load_all_results(jsonl_path: Path) -> list[dict]: + """Load all results from checkpoint JSONL file.""" + results: list[dict] = [] + if not jsonl_path.exists(): + return results + for line in jsonl_path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + try: + results.append(json.loads(line)) + except json.JSONDecodeError: + continue + return results diff --git a/benchmarks/longmemeval/add.py b/benchmarks/longmemeval/add.py new file mode 100644 index 0000000..8f74134 --- /dev/null +++ b/benchmarks/longmemeval/add.py @@ -0,0 +1,232 @@ +"""Stage 1: Ingest LongMemEval conversation histories into memv.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import time +from datetime import datetime, timezone +from pathlib import Path + +from memv.memory.memory import Memory +from memv.models import Message, MessageRole + +from ._checkpoint import RESULTS_DIR, append_jsonl, load_all_results, load_completed +from .config import get_config +from .dataset import LongMemEvalQuestion, load_dataset + +logger = logging.getLogger(__name__) + + +def parse_longmemeval_date(date_str: str) -> datetime: + """Parse LongMemEval date format: '2023/05/20 (Sat) 02:21' → datetime (UTC).""" + try: + dt = datetime.strptime(date_str, "%Y/%m/%d (%a) %H:%M") + return dt.replace(tzinfo=timezone.utc) + except ValueError: + logger.warning("Failed to parse date '%s', using epoch", date_str) + return datetime(2023, 1, 1, tzinfo=timezone.utc) + + +async def process_question( + question_idx: int, + question_data: LongMemEvalQuestion, + db_dir: Path, + config_name: str, + embedding_client, + llm_client, +) -> dict: + """Process a single LongMemEval question: ingest all sessions, extract knowledge.""" + question_id = question_data.question_id + user_id = f"question_{question_id}" + db_path = str(db_dir / f"{question_id}.db") + + config = get_config(config_name) + + memory = Memory( + db_path=db_path, + config=config, + embedding_client=embedding_client, + llm_client=llm_client, + enable_embedding_cache=True, + ) + + start_time = time.monotonic() + total_messages = 0 + + async with memory: + # Ingest each session + for session, date_str in zip(question_data.haystack_sessions, question_data.haystack_dates, strict=True): + timestamp = parse_longmemeval_date(date_str) + for turn in session: + role = MessageRole.USER if turn["role"] == "user" else MessageRole.ASSISTANT + msg = Message( + user_id=user_id, + role=role, + content=turn["content"], + sent_at=timestamp, + ) + await memory.add_message(msg) + total_messages += 1 + + # Extract knowledge + knowledge_count = await memory.process(user_id) + + elapsed = time.monotonic() - start_time + + return { + "question_id": question_id, + "question_type": question_data.question_type, + "messages_count": total_messages, + "knowledge_count": knowledge_count, + "sessions_count": len(question_data.haystack_sessions), + "construction_time_s": round(elapsed, 2), + } + + +async def run( + run_name: str = "baseline", + config_name: str = "default", + data_path: str | None = None, + num_questions: int | None = None, + max_concurrent: int = 5, + timeout: int = 1200, + resume: bool = True, + embedding_client=None, + llm_client=None, +): + """Run ingestion stage for all questions. + + Args: + run_name: Name for this benchmark run. + config_name: Config preset name from config.py. + data_path: Path to dataset JSON (None = default location). + num_questions: Limit number of questions (None = all). + max_concurrent: Max concurrent question processing tasks. + timeout: Per-question timeout in seconds. + resume: Resume from checkpoint if prior results exist. + embedding_client: EmbeddingClient instance. + llm_client: LLMClient instance. + """ + if embedding_client is None or llm_client is None: + raise RuntimeError("embedding_client and llm_client are required. Pass them directly or set up default clients.") + + dataset = load_dataset(data_path) + if num_questions is not None: + dataset = dataset[:num_questions] + + run_dir = RESULTS_DIR / run_name + db_dir = run_dir / "dbs" + db_dir.mkdir(parents=True, exist_ok=True) + + jsonl_path = run_dir / "add.jsonl" + + # Load checkpoint + completed_ids = load_completed(jsonl_path) if resume else set() + if not resume and jsonl_path.exists(): + jsonl_path.unlink() + + remaining = [q for q in dataset if q.question_id not in completed_ids] + + print( + f"LongMemEval Add | run={run_name} config={config_name} " + f"questions={len(dataset)} remaining={len(remaining)} concurrent={max_concurrent}" + ) + if completed_ids: + print(f" Resuming: {len(completed_ids)} already completed") + + semaphore = asyncio.Semaphore(max_concurrent) + completed_count = len(completed_ids) + total_count = len(dataset) + + async def process_with_guard(idx: int, question: LongMemEvalQuestion) -> dict | None: + nonlocal completed_count + async with semaphore: + try: + result = await asyncio.wait_for( + process_question(idx, question, db_dir, config_name, embedding_client, llm_client), + timeout=timeout, + ) + except asyncio.TimeoutError: + result = { + "question_id": question.question_id, + "question_type": question.question_type, + "error": "timeout", + "construction_time_s": timeout, + } + except Exception as e: + logger.exception("Failed to process question %s", question.question_id) + result = { + "question_id": question.question_id, + "question_type": question.question_type, + "error": str(e), + "construction_time_s": 0, + } + + append_jsonl(jsonl_path, result) + completed_count += 1 + error = result.get("error") + if error: + print(f" [{completed_count}/{total_count}] {question.question_id} ERROR: {error}") + else: + print( + f" [{completed_count}/{total_count}] {question.question_id} " + f"→ {result['knowledge_count']} facts in {result['construction_time_s']}s" + ) + return result + + tasks = [process_with_guard(idx, q) for idx, q in enumerate(remaining)] + await asyncio.gather(*tasks) + + # Write compatibility JSON from all JSONL results + all_results = load_all_results(jsonl_path) + output_path = run_dir / "add.json" + output_path.write_text(json.dumps(all_results, indent=2), encoding="utf-8") + print(f"\nResults saved to {output_path}") + + total_knowledge = sum(r.get("knowledge_count", 0) for r in all_results) + total_time = sum(r.get("construction_time_s", 0) for r in all_results) + print(f"Total: {total_knowledge} facts extracted in {total_time:.1f}s") + + return all_results + + +def _make_clients(): + """Create default OpenAI-based clients for CLI usage.""" + from memv.embeddings.openai import OpenAIEmbedAdapter + from memv.llm.pydantic_ai import PydanticAIAdapter + + return OpenAIEmbedAdapter(), PydanticAIAdapter() + + +def main(): + parser = argparse.ArgumentParser(description="LongMemEval Stage 1: Ingestion") + parser.add_argument("--run-name", default="baseline", help="Name for this run") + parser.add_argument("--config", default="default", help="Config preset name") + parser.add_argument("--data-path", default=None, help="Path to dataset JSON") + parser.add_argument("--num-questions", type=int, default=None, help="Limit number of questions") + parser.add_argument("--max-concurrent", type=int, default=5, help="Max concurrent question processing") + parser.add_argument("--timeout", type=int, default=1200, help="Per-question timeout in seconds") + parser.add_argument("--no-resume", action="store_true", help="Start fresh, ignore prior checkpoint") + args = parser.parse_args() + + embedding_client, llm_client = _make_clients() + asyncio.run( + run( + run_name=args.run_name, + config_name=args.config, + data_path=args.data_path, + num_questions=args.num_questions, + max_concurrent=args.max_concurrent, + timeout=args.timeout, + resume=not args.no_resume, + embedding_client=embedding_client, + llm_client=llm_client, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/longmemeval/config.py b/benchmarks/longmemeval/config.py new file mode 100644 index 0000000..0814b73 --- /dev/null +++ b/benchmarks/longmemeval/config.py @@ -0,0 +1,33 @@ +"""Named MemoryConfig presets for LongMemEval benchmark ablations.""" + +from __future__ import annotations + +from memv.config import MemoryConfig + +CONFIGS: dict[str, MemoryConfig] = { + "default": MemoryConfig(), + # Fast: skips predict-calibrate, dedup, and merging. For iteration speed only — + # results are NOT comparable to 'default' config. + "fast": MemoryConfig( + max_statements_for_prediction=0, + enable_knowledge_dedup=False, + enable_episode_merging=False, + ), + "no_predict_calibrate": MemoryConfig(max_statements_for_prediction=0), + "no_segmentation": MemoryConfig(use_legacy_segmentation=True, segmentation_threshold=9999), + "no_dedup": MemoryConfig(enable_knowledge_dedup=False, enable_episode_merging=False), +} + + +def get_config(name: str) -> MemoryConfig: + """Get a named config preset. + + Args: + name: One of: default, fast, no_predict_calibrate, no_segmentation, no_dedup. + + Returns: + MemoryConfig for the named preset. + """ + if name not in CONFIGS: + raise ValueError(f"Unknown config '{name}'. Available: {', '.join(CONFIGS)}") + return CONFIGS[name] diff --git a/benchmarks/longmemeval/dataset.py b/benchmarks/longmemeval/dataset.py new file mode 100644 index 0000000..78ce1b7 --- /dev/null +++ b/benchmarks/longmemeval/dataset.py @@ -0,0 +1,49 @@ +"""LongMemEval dataset loader and Pydantic models.""" + +from __future__ import annotations + +import json +from pathlib import Path + +from pydantic import BaseModel, field_validator + + +class LongMemEvalQuestion(BaseModel): + question_id: str + question_type: str + question: str + answer: str + + @field_validator("answer", mode="before") + @classmethod + def _coerce_answer(cls, v: object) -> str: + return str(v) + + question_date: str # "2023/05/20 (Sat) 02:21" + haystack_session_ids: list[str] + haystack_dates: list[str] + haystack_sessions: list[list[dict]] # list of sessions, each is list of {role, content} + answer_session_ids: list[str] + + +DEFAULT_DATA_PATH = Path(__file__).parent.parent / "data" / "longmemeval_s_cleaned.json" + + +def load_dataset(path: Path | str | None = None) -> list[LongMemEvalQuestion]: + """Load LongMemEval dataset from JSON file. + + Args: + path: Path to longmemeval_s_cleaned.json. Defaults to benchmarks/data/. + + Returns: + List of parsed questions. + """ + path = Path(path) if path else DEFAULT_DATA_PATH + if not path.exists(): + raise FileNotFoundError( + f"Dataset not found at {path}. Download it with:\n" + f" wget -P benchmarks/data/ " + f"https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned/resolve/main/longmemeval_s_cleaned.json" + ) + raw = json.loads(path.read_text(encoding="utf-8")) + return [LongMemEvalQuestion.model_validate(item) for item in raw] diff --git a/benchmarks/longmemeval/evaluate.py b/benchmarks/longmemeval/evaluate.py new file mode 100644 index 0000000..5215ff0 --- /dev/null +++ b/benchmarks/longmemeval/evaluate.py @@ -0,0 +1,276 @@ +"""Stage 3: LLM-judge evaluation of LongMemEval search results.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +from datetime import datetime, timezone + +from ._checkpoint import RESULTS_DIR, append_jsonl, load_all_results, load_completed + +logger = logging.getLogger(__name__) + +# --- Type-specific judge prompts (adapted from Nemori/Zep LongMemEval evals) --- + +TEMPORAL_REASONING_PROMPT = """I will give you a question, a correct answer, and a response from a model. \ +Please answer yes if the response contains the correct answer. Otherwise, answer no. \ +If the response is equivalent to the correct answer or contains all the intermediate steps to get the correct answer, \ +you should also answer yes. If the response only contains a subset of the information required by the answer, answer no. \ +In addition, do not penalize off-by-one errors for the number of days. \ +If the question asks for the number of days/weeks/months, etc., and the model makes off-by-one errors \ +(e.g., predicting 19 days when the answer is 18), the model's response is still correct. + + +{question} + + +{gold_answer} + + +{response} +""" + +KNOWLEDGE_UPDATE_PROMPT = """I will give you a question, a correct answer, and a response from a model. \ +Please answer yes if the response contains the correct answer. Otherwise, answer no. \ +If the response contains some previous information along with an updated answer, \ +the response should be considered as correct as long as the updated answer is the required answer. + + +{question} + + +{gold_answer} + + +{response} +""" + +SINGLE_SESSION_PREFERENCE_PROMPT = """I will give you a question, a rubric for desired personalized response, \ +and a response from a model. Please answer yes if the response satisfies the desired response. Otherwise, answer no. \ +The model does not need to reflect all the points in the rubric. \ +The response is correct as long as it recalls and utilizes the user's personal information correctly. + + +{question} + + +{gold_answer} + + +{response} +""" + +DEFAULT_PROMPT = """I will give you a question, a correct answer, and a response from a model. \ +Please answer yes if the response contains the correct answer. Otherwise, answer no. \ +If the response is equivalent to the correct answer or contains all the intermediate steps \ +to get the correct answer, you should also answer yes. \ +If the response only contains a subset of the information required by the answer, answer no. + + +{question} + + +{gold_answer} + + +{response} +""" + +SYSTEM_PROMPT = "You are an expert grader. Respond with ONLY 'yes' or 'no'." + +PROMPTS_BY_TYPE = { + "temporal-reasoning": TEMPORAL_REASONING_PROMPT, + "knowledge-update": KNOWLEDGE_UPDATE_PROMPT, + "single-session-preference": SINGLE_SESSION_PREFERENCE_PROMPT, +} + + +async def evaluate_single( + llm_client, + question: str, + gold_answer: str, + response: str, + question_type: str, +) -> bool: + """Evaluate a single question-response pair using LLM judge.""" + template = PROMPTS_BY_TYPE.get(question_type, DEFAULT_PROMPT) + prompt = template.format(question=question, gold_answer=gold_answer, response=response) + full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" + + result = await llm_client.generate(full_prompt) + return result.strip().lower().startswith("yes") + + +async def run( + run_name: str = "baseline", + llm_client=None, + max_concurrent: int = 10, + resume: bool = True, +): + """Run evaluation on search results. + + Args: + run_name: Name for this benchmark run (must match search stage). + llm_client: LLMClient instance for LLM-judge. + max_concurrent: Max concurrent LLM calls. + resume: Resume from checkpoint if prior results exist. + """ + if llm_client is None: + raise RuntimeError("llm_client is required.") + + run_dir = RESULTS_DIR / run_name + search_path = run_dir / "search.json" + if not search_path.exists(): + raise FileNotFoundError(f"No search results at {search_path}. Run search stage first.") + + data = json.loads(search_path.read_text(encoding="utf-8")) + + jsonl_path = run_dir / "evaluate.jsonl" + + # Load checkpoint + completed_ids = load_completed(jsonl_path) if resume else set() + if not resume and jsonl_path.exists(): + jsonl_path.unlink() + + remaining = [item for item in data if item["question_id"] not in completed_ids] + + print(f"LongMemEval Evaluate | run={run_name} questions={len(data)} remaining={len(remaining)}") + if completed_ids: + print(f" Resuming: {len(completed_ids)} already completed") + + # Evaluate with concurrency control + semaphore = asyncio.Semaphore(max_concurrent) + + async def eval_with_semaphore(item: dict) -> None: + async with semaphore: + # Skip items that errored in search stage + if item.get("error"): + scored = { + "question_id": item["question_id"], + "question_type": item.get("question_type"), + "is_correct": None, + "error": item["error"], + "question": item.get("question", ""), + "gold_answer": item.get("answer", ""), + "response": item.get("response", ""), + } + else: + try: + is_correct = await evaluate_single( + llm_client, + item["question"], + item["answer"], + item["response"], + item.get("question_type", "default"), + ) + scored = { + "question_id": item["question_id"], + "question_type": item.get("question_type"), + "is_correct": is_correct, + "question": item["question"], + "gold_answer": item["answer"], + "response": item["response"], + } + except Exception as e: + logger.error("Evaluation failed for %s: %s", item["question_id"], e) + scored = { + "question_id": item["question_id"], + "question_type": item.get("question_type"), + "is_correct": None, + "error": f"eval_failed: {e}", + "question": item.get("question", ""), + "gold_answer": item.get("answer", ""), + "response": item.get("response", ""), + } + append_jsonl(jsonl_path, scored) + + tasks = [eval_with_semaphore(item) for item in remaining] + await asyncio.gather(*tasks) + + # Load all results (checkpoint + new) + all_scored = load_all_results(jsonl_path) + + # Aggregate scores — exclude errored items + type_stats: dict[str, dict[str, int]] = {} + total_correct = 0 + total_scored = 0 + total_errors = 0 + + for scored in all_scored: + qtype = scored.get("question_type", "unknown") + if qtype not in type_stats: + type_stats[qtype] = {"correct": 0, "total": 0} + + if scored.get("is_correct") is None: + total_errors += 1 + continue + + type_stats[qtype]["total"] += 1 + total_scored += 1 + + if scored["is_correct"]: + type_stats[qtype]["correct"] += 1 + total_correct += 1 + + # Calculate accuracies + overall_accuracy = total_correct / total_scored if total_scored > 0 else 0 + accuracy_by_type = {} + for qtype, stats in sorted(type_stats.items()): + acc = stats["correct"] / stats["total"] if stats["total"] > 0 else 0 + accuracy_by_type[qtype] = { + "correct": stats["correct"], + "total": stats["total"], + "accuracy": round(acc, 4), + } + + scores = { + "run_name": run_name, + "total_questions": len(all_scored), + "scored_questions": total_scored, + "errors": total_errors, + "correct_answers": total_correct, + "overall_accuracy": round(overall_accuracy, 4), + "accuracy_by_type": accuracy_by_type, + "evaluation_timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"), + "scored_items": all_scored, + } + + # Print summary + print(f"\n{'=' * 50}") + print(f"Overall: {total_correct}/{total_scored} = {overall_accuracy:.1%}") + if total_errors: + print(f"Errors (excluded from scoring): {total_errors}") + print(f"{'=' * 50}") + for qtype, stats in sorted(accuracy_by_type.items()): + print(f" {qtype}: {stats['correct']}/{stats['total']} = {stats['accuracy']:.1%}") + print(f"{'=' * 50}") + + # Save + output_path = run_dir / "scores.json" + output_path.write_text(json.dumps(scores, indent=2, ensure_ascii=False), encoding="utf-8") + print(f"\nScores saved to {output_path}") + + return scores + + +def _make_llm_client(): + from memv.llm.pydantic_ai import PydanticAIAdapter + + return PydanticAIAdapter() + + +def main(): + parser = argparse.ArgumentParser(description="LongMemEval Stage 3: Evaluation") + parser.add_argument("--run-name", default="baseline", help="Name for this run") + parser.add_argument("--max-concurrent", type=int, default=10, help="Max concurrent LLM calls") + parser.add_argument("--no-resume", action="store_true", help="Start fresh, ignore prior checkpoint") + args = parser.parse_args() + + llm_client = _make_llm_client() + asyncio.run(run(run_name=args.run_name, llm_client=llm_client, max_concurrent=args.max_concurrent, resume=not args.no_resume)) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/longmemeval/run.py b/benchmarks/longmemeval/run.py new file mode 100644 index 0000000..19b650e --- /dev/null +++ b/benchmarks/longmemeval/run.py @@ -0,0 +1,121 @@ +"""End-to-end runner for LongMemEval benchmark.""" + +from __future__ import annotations + +import argparse +import asyncio +import time + +from . import add, evaluate, search + + +def _make_clients(model: str = "openai:gpt-4.1-mini"): + from memv.embeddings.openai import OpenAIEmbedAdapter + from memv.llm.pydantic_ai import PydanticAIAdapter + + return OpenAIEmbedAdapter(), PydanticAIAdapter(model=model) + + +async def run( + run_name: str = "baseline", + config_name: str = "default", + data_path: str | None = None, + num_questions: int | None = None, + max_concurrent: int = 5, + timeout: int = 1200, + top_k: int = 10, + model: str = "openai:gpt-4.1-mini", + stages: list[str] | None = None, + resume: bool = True, +): + stages = stages or ["add", "search", "evaluate"] + embedding_client, llm_client = _make_clients(model=model) + print(f"Model: {model}") + + total_start = time.monotonic() + + if "add" in stages: + print(f"\n{'=' * 60}") + print("STAGE 1: ADD") + print(f"{'=' * 60}\n") + await add.run( + run_name=run_name, + config_name=config_name, + data_path=data_path, + num_questions=num_questions, + max_concurrent=max_concurrent, + timeout=timeout, + resume=resume, + embedding_client=embedding_client, + llm_client=llm_client, + ) + + if "search" in stages: + print(f"\n{'=' * 60}") + print("STAGE 2: SEARCH") + print(f"{'=' * 60}\n") + await search.run( + run_name=run_name, + config_name=config_name, + data_path=data_path, + num_questions=num_questions, + top_k=top_k, + max_concurrent=max_concurrent * 2, # search is lighter than add + timeout=timeout, + resume=resume, + embedding_client=embedding_client, + llm_client=llm_client, + ) + + if "evaluate" in stages: + print(f"\n{'=' * 60}") + print("STAGE 3: EVALUATE") + print(f"{'=' * 60}\n") + await evaluate.run( + run_name=run_name, + llm_client=llm_client, + resume=resume, + ) + + total_elapsed = time.monotonic() - total_start + print(f"\n{'=' * 60}") + print(f"Done. Total time: {total_elapsed / 60:.1f} min") + print(f"{'=' * 60}") + + +def main(): + parser = argparse.ArgumentParser(description="LongMemEval Benchmark Runner") + parser.add_argument("--run-name", default="baseline", help="Name for this run") + parser.add_argument("--config", default="default", help="Config preset name") + parser.add_argument("--data-path", default=None, help="Path to dataset JSON") + parser.add_argument("--num-questions", type=int, default=None, help="Limit number of questions") + parser.add_argument("--max-concurrent", type=int, default=5, help="Max concurrent question processing") + parser.add_argument("--timeout", type=int, default=1200, help="Per-question timeout in seconds") + parser.add_argument("--top-k", type=int, default=10, help="Number of memories to retrieve") + parser.add_argument( + "--model", + default="openai:gpt-4.1-mini", + help="PydanticAI model string (e.g. google-gla:gemini-2.5-flash, groq:llama-3.3-70b-versatile)", + ) + parser.add_argument("--stages", default="add,search,evaluate", help="Comma-separated stages to run") + parser.add_argument("--no-resume", action="store_true", help="Start fresh, ignore prior checkpoints") + args = parser.parse_args() + + asyncio.run( + run( + run_name=args.run_name, + config_name=args.config, + data_path=args.data_path, + num_questions=args.num_questions, + max_concurrent=args.max_concurrent, + timeout=args.timeout, + top_k=args.top_k, + model=args.model, + stages=args.stages.split(","), + resume=not args.no_resume, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/longmemeval/search.py b/benchmarks/longmemeval/search.py new file mode 100644 index 0000000..4df2b3e --- /dev/null +++ b/benchmarks/longmemeval/search.py @@ -0,0 +1,264 @@ +"""Stage 2: Retrieve memories and generate answers for LongMemEval questions.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import time +from pathlib import Path + +from memv.memory.memory import Memory + +from ._checkpoint import RESULTS_DIR, append_jsonl, load_all_results, load_completed +from .config import get_config +from .dataset import LongMemEvalQuestion, load_dataset + +logger = logging.getLogger(__name__) + +ANSWER_PROMPT = """You are a memory assistant that retrieves accurate information from conversation memories. + +## Instructions +1. Carefully analyze all provided memories +2. Pay special attention to timestamps to determine the correct answer +3. If memories contain contradictory information, prioritize the most recent memory +4. Convert relative time references to specific dates using the question date as reference +5. The answer should be concise (less than 5-6 words) + +## Memories +{memories} + +## Question Date +{question_date} + +## Question +{question} + +Answer:""" + + +async def process_question( + question_data: LongMemEvalQuestion, + db_dir: Path, + config_name: str, + embedding_client, + llm_client, + top_k: int = 10, +) -> dict: + """Retrieve and answer a single question.""" + question_id = question_data.question_id + user_id = f"question_{question_id}" + db_path = str(db_dir / f"{question_id}.db") + + if not Path(db_path).exists(): + return { + "question_id": question_id, + "question": question_data.question, + "question_type": question_data.question_type, + "answer": question_data.answer, + "question_date": question_data.question_date, + "response": "", + "retrieval_time_s": 0, + "error": f"DB not found: {db_path}", + } + + config = get_config(config_name) + + memory = Memory( + db_path=db_path, + config=config, + embedding_client=embedding_client, + llm_client=llm_client, + ) + + start_time = time.monotonic() + + async with memory: + result = await memory.retrieve(question_data.question, user_id=user_id, top_k=top_k) + retrieval_time = time.monotonic() - start_time + + # Format memories for the answer prompt + memory_lines = [] + for k in result.retrieved_knowledge: + validity = "" + if k.valid_at: + validity = f" [valid from {k.valid_at.strftime('%Y-%m-%d')}]" + if k.invalid_at: + validity += f" [invalid after {k.invalid_at.strftime('%Y-%m-%d')}]" + memory_lines.append(f"- {k.statement}{validity}") + + memories_text = "\n".join(memory_lines) if memory_lines else "No relevant memories found." + + # Generate answer + prompt = ANSWER_PROMPT.format( + memories=memories_text, + question_date=question_data.question_date, + question=question_data.question, + ) + response = await llm_client.generate(prompt) + + return { + "question_id": question_id, + "question": question_data.question, + "question_type": question_data.question_type, + "answer": question_data.answer, + "question_date": question_data.question_date, + "response": response.strip(), + "retrieved_count": len(result.retrieved_knowledge), + "retrieval_time_s": round(retrieval_time, 3), + } + + +async def run( + run_name: str = "baseline", + config_name: str = "default", + data_path: str | None = None, + num_questions: int | None = None, + top_k: int = 10, + max_concurrent: int = 10, + timeout: int = 1200, + resume: bool = True, + embedding_client=None, + llm_client=None, +): + """Run search stage for all questions. + + Args: + run_name: Name for this benchmark run (must match add stage). + config_name: Config preset name. + data_path: Path to dataset JSON. + num_questions: Limit number of questions. + top_k: Number of memories to retrieve per question. + max_concurrent: Max concurrent question processing tasks. + timeout: Per-question timeout in seconds. + resume: Resume from checkpoint if prior results exist. + embedding_client: EmbeddingClient instance. + llm_client: LLMClient instance. + """ + if embedding_client is None or llm_client is None: + raise RuntimeError("embedding_client and llm_client are required.") + + dataset = load_dataset(data_path) + if num_questions is not None: + dataset = dataset[:num_questions] + + run_dir = RESULTS_DIR / run_name + db_dir = run_dir / "dbs" + if not db_dir.exists(): + raise FileNotFoundError(f"No DBs found at {db_dir}. Run add stage first.") + + jsonl_path = run_dir / "search.jsonl" + + # Load checkpoint + completed_ids = load_completed(jsonl_path) if resume else set() + if not resume and jsonl_path.exists(): + jsonl_path.unlink() + + remaining = [q for q in dataset if q.question_id not in completed_ids] + + print( + f"LongMemEval Search | run={run_name} config={config_name} " + f"questions={len(dataset)} remaining={len(remaining)} top_k={top_k} concurrent={max_concurrent}" + ) + if completed_ids: + print(f" Resuming: {len(completed_ids)} already completed") + + semaphore = asyncio.Semaphore(max_concurrent) + completed_count = len(completed_ids) + total_count = len(dataset) + + async def process_with_guard(question: LongMemEvalQuestion) -> dict | None: + nonlocal completed_count + async with semaphore: + try: + result = await asyncio.wait_for( + process_question(question, db_dir, config_name, embedding_client, llm_client, top_k), + timeout=timeout, + ) + except asyncio.TimeoutError: + result = { + "question_id": question.question_id, + "question": question.question, + "question_type": question.question_type, + "answer": question.answer, + "question_date": question.question_date, + "response": "", + "error": "timeout", + "retrieval_time_s": timeout, + } + except Exception as e: + logger.exception("Failed to process question %s", question.question_id) + result = { + "question_id": question.question_id, + "question": question.question, + "question_type": question.question_type, + "answer": question.answer, + "question_date": question.question_date, + "response": "", + "error": str(e), + "retrieval_time_s": 0, + } + + append_jsonl(jsonl_path, result) + completed_count += 1 + error = result.get("error") + if error: + print(f" [{completed_count}/{total_count}] {question.question_id} ERROR: {error}") + else: + print( + f" [{completed_count}/{total_count}] {question.question_id} " + f"→ {result['retrieved_count']} memories, {result['retrieval_time_s']}s" + ) + return result + + tasks = [process_with_guard(q) for q in remaining] + await asyncio.gather(*tasks) + + # Write compatibility JSON from all JSONL results + all_results = load_all_results(jsonl_path) + output_path = run_dir / "search.json" + output_path.write_text(json.dumps(all_results, indent=2, ensure_ascii=False), encoding="utf-8") + print(f"\nResults saved to {output_path}") + + return all_results + + +def _make_clients(): + from memv.embeddings.openai import OpenAIEmbedAdapter + from memv.llm.pydantic_ai import PydanticAIAdapter + + return OpenAIEmbedAdapter(), PydanticAIAdapter() + + +def main(): + parser = argparse.ArgumentParser(description="LongMemEval Stage 2: Search + Answer") + parser.add_argument("--run-name", default="baseline", help="Name for this run") + parser.add_argument("--config", default="default", help="Config preset name") + parser.add_argument("--data-path", default=None, help="Path to dataset JSON") + parser.add_argument("--num-questions", type=int, default=None, help="Limit number of questions") + parser.add_argument("--top-k", type=int, default=10, help="Number of memories to retrieve") + parser.add_argument("--max-concurrent", type=int, default=10, help="Max concurrent question processing") + parser.add_argument("--timeout", type=int, default=1200, help="Per-question timeout in seconds") + parser.add_argument("--no-resume", action="store_true", help="Start fresh, ignore prior checkpoint") + args = parser.parse_args() + + embedding_client, llm_client = _make_clients() + asyncio.run( + run( + run_name=args.run_name, + config_name=args.config, + data_path=args.data_path, + num_questions=args.num_questions, + top_k=args.top_k, + max_concurrent=args.max_concurrent, + timeout=args.timeout, + resume=not args.no_resume, + embedding_client=embedding_client, + llm_client=llm_client, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/results/.gitkeep b/benchmarks/results/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/memv/memory/_pipeline.py b/src/memv/memory/_pipeline.py index 256364d..a703298 100644 --- a/src/memv/memory/_pipeline.py +++ b/src/memv/memory/_pipeline.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from typing import TYPE_CHECKING @@ -18,6 +19,8 @@ logger = logging.getLogger(__name__) +MAX_CONCURRENT_EPISODES = 10 + class Pipeline: """Handles message processing, episode generation, and knowledge extraction.""" @@ -52,13 +55,20 @@ async def process(self, user_id: str) -> int: # Segment into episodes episodes_messages = await self._segment_messages(unprocessed) - # Process episodes sequentially to ensure each sees prior extractions - total_extracted = 0 - for messages in episodes_messages: - count = await self._process_episode(messages, user_id) - total_extracted += count + # Process episodes with concurrent LLM/embedding calls. aiosqlite serializes + # DB writes through a single thread, so KB reads/writes are sequential — the + # actual parallelism is on API I/O (segmentation, extraction, embedding). + # Trade-off: episodes see stale KB state (predict-calibrate can't suppress + # intra-batch duplicates). Dedup catches overlap post-extraction. + # This matches Nemori's approach — parallelism + dedup over sequential fidelity. + semaphore = asyncio.Semaphore(MAX_CONCURRENT_EPISODES) + + async def _guarded(msgs: list[Message]) -> int: + async with semaphore: + return await self._process_episode(msgs, user_id) - return total_extracted + counts = await asyncio.gather(*[_guarded(msgs) for msgs in episodes_messages]) + return sum(counts) async def process_messages(self, messages: list[Message], user_id: str) -> int: """ diff --git a/src/memv/processing/batch_segmenter.py b/src/memv/processing/batch_segmenter.py index 5be3084..3470807 100644 --- a/src/memv/processing/batch_segmenter.py +++ b/src/memv/processing/batch_segmenter.py @@ -5,6 +5,7 @@ handling interleaved topics and time gaps correctly. """ +import asyncio import json from datetime import timedelta @@ -12,6 +13,8 @@ from memv.processing.prompts import batch_segmentation_prompt from memv.protocols import LLMClient +MAX_CONCURRENT_SEGMENTATIONS = 10 + class BatchSegmenter: """ @@ -44,7 +47,7 @@ async def segment(self, messages: list[Message]) -> list[list[Message]]: Flow: 1. Split on time gaps first (creates independent batches) - 2. For each batch, use LLM to group by topic + 2. For each batch, use LLM to group by topic (concurrently) 3. Return all episode groups Args: @@ -62,16 +65,20 @@ async def segment(self, messages: list[Message]) -> list[list[Message]]: # Step 1: Split on time gaps time_batches = self._split_on_time_gaps(messages) - # Step 2: Segment each batch semantically - all_episodes: list[list[Message]] = [] - for batch in time_batches: + # Step 2: Segment each batch semantically (concurrently) + semaphore = asyncio.Semaphore(MAX_CONCURRENT_SEGMENTATIONS) + + async def _segment_or_passthrough(batch: list[Message]) -> list[list[Message]]: if len(batch) <= 2: - # Small batches don't need LLM segmentation - all_episodes.append(batch) - else: - # Use LLM to group by topic - episode_groups = await self._segment_batch(batch) - all_episodes.extend(episode_groups) + return [batch] + async with semaphore: + return await self._segment_batch(batch) + + batch_results = await asyncio.gather(*[_segment_or_passthrough(b) for b in time_batches]) + + all_episodes: list[list[Message]] = [] + for groups in batch_results: + all_episodes.extend(groups) return all_episodes diff --git a/src/memv/processing/prompts.py b/src/memv/processing/prompts.py index 4a609c4..65d0ecb 100644 --- a/src/memv/processing/prompts.py +++ b/src/memv/processing/prompts.py @@ -39,15 +39,18 @@ # ============================================================================= KNOWLEDGE_CATEGORIES = """ -Extract knowledge that fits these categories: - -- **Identity & Background**: Name, profession, location, education, demographics -- **Persistent Preferences**: Technology choices, communication style, work patterns -- **Technical Details**: Stack, tools, projects, codebases, technical constraints -- **Relationships**: Family, colleagues, pets, organizations they belong to -- **Goals & Plans**: Short and long-term objectives, deadlines, milestones -- **Beliefs & Values**: Opinions, priorities, decision-making criteria -- **Habits & Patterns**: Recurring behaviors, routines, typical responses +Extract knowledge ABOUT THE USER that fits these categories: + +- **Identity & Background**: User's name, profession, location, education, demographics +- **Persistent Preferences**: User's technology choices, communication style, work patterns +- **Technical Details**: User's stack, tools, projects, codebases, technical constraints +- **Relationships**: User's family, colleagues, pets, organizations they belong to +- **Goals & Plans**: User's short and long-term objectives, deadlines, milestones +- **Beliefs & Values**: User's opinions, priorities, decision-making criteria +- **Habits & Patterns**: User's recurring behaviors, routines, typical responses + +CRITICAL: Only extract facts that help understand the USER long-term. +Do NOT extract general knowledge, topic content, or information the assistant provided as educational material. """ # ============================================================================= @@ -58,12 +61,17 @@ EXCLUSIONS = """ Do NOT extract: +- **General/topical knowledge**: Facts about the world, science, history, technology, etc. + (e.g., "Radiation therapy uses ionizing radiation", "Bitcoin uses blockchain", "Python is a programming language") +- **Educational content from assistant**: Information the assistant explained or taught + (e.g., "HTTP uses TCP", "Kubernetes orchestrates containers") +- **Conversation topic summaries**: What the conversation was about, not facts about the user + (e.g., "The conversation covered cooking techniques", "They discussed radiation therapy") - Temporary emotions or reactions ("user seems frustrated") - Single conversation acknowledgments ("user said thanks") - Vague statements without specifics ("user likes food") - Context-dependent information ("user prefers this one") - Generic pleasantries or filler -- Obvious or common knowledge - Speculative or uncertain claims - Conversation events ("User asked about X", "User requested Y") - extract the FACT, not the action @@ -283,6 +291,10 @@ def cold_start_extraction_prompt(episode_title: str, original_messages: list[dic - "User moved to Berlin in 2023" (resolved, not "last year") ### BAD Extractions: +- "Radiation therapy uses ionizing radiation to kill cancer cells" (general knowledge, not about the user) +- "Bitcoin is a decentralized cryptocurrency" (topic content, not about the user) +- "A kitchen knife should be sharpened at a 15-20 degree angle" (educational content from assistant) +- "The fox-chicken-grain riddle is a classic river crossing puzzle" (general knowledge) - "I use JavaScript" (raw copy - should be "User uses JavaScript") - "He started using it yesterday" (unresolved pronoun + relative time → "User started using FastAPI on 2024-06-14") - "They moved there last year" (unresolved pronoun + relative time → "User moved to Berlin in 2023") @@ -308,7 +320,9 @@ def cold_start_extraction_prompt(episode_title: str, original_messages: list[dic - invalid_at: ISO 8601 datetime when fact stops being true, or null if still true (e.g., "2024-12-31T23:59:59Z") - confidence: 0.0-1.0 -Extract ALL concrete facts. Multiple extractions from one episode is expected.""" +Quality over quantity — fewer valuable statements about the USER are better than many generic ones. +Only extract facts that help understand the user long-term. If a conversation is about a general topic +(cooking, physics, history) but reveals nothing personal about the user, return an EMPTY list.""" def extraction_prompt_with_prediction(prediction: str, conversation: str, reference_timestamp: str | None = None) -> str: @@ -367,6 +381,8 @@ def extraction_prompt_with_prediction(prediction: str, conversation: str, refere - "User moved to Berlin in 2023" (resolved, not "last year") ### BAD Extractions: +- "Radiation therapy uses ionizing radiation" (general knowledge, not about user) +- "Bitcoin uses proof-of-work consensus" (topic content, not about user) - "He started using it yesterday" (unresolved pronoun + relative time) - "They moved there last year" (unresolved pronoun + relative time) - "User is interested in X" (too vague) @@ -387,7 +403,9 @@ def extraction_prompt_with_prediction(prediction: str, conversation: str, refere - invalid_at: ISO 8601 datetime when fact stops being true, or null if still true (e.g., "2024-12-31T23:59:59Z") - confidence: 0.0-1.0 -Return EMPTY LIST if no concrete facts found beyond the prediction.""" +Quality over quantity — fewer valuable statements about the USER are better than many generic ones. +Return EMPTY LIST if no facts about the user are found beyond the prediction. +General knowledge or topic content discussed in conversation is NOT extractable.""" # =============================================================================