From ac084fee96833004fe28ddc273295e484119f583 Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Sat, 13 Sep 2025 17:11:39 +0800 Subject: [PATCH 1/9] feat: implement memory extraction --- .../app/services/memory_extraction_service.py | 765 ++++++++++++++++++ athena/models/message.py | 1 + athena/prompts/__init__.py | 0 athena/prompts/memory_extraction.py | 265 ++++++ .../trajectory_memory_extractor_deepseek.py | 574 ++++++------- 5 files changed, 1318 insertions(+), 287 deletions(-) create mode 100644 athena/app/services/memory_extraction_service.py create mode 100644 athena/prompts/__init__.py create mode 100644 athena/prompts/memory_extraction.py diff --git a/athena/app/services/memory_extraction_service.py b/athena/app/services/memory_extraction_service.py new file mode 100644 index 0000000..90888bf --- /dev/null +++ b/athena/app/services/memory_extraction_service.py @@ -0,0 +1,765 @@ +import json +import re +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Protocol + +from datasets import load_dataset +from tqdm import tqdm + +from athena.app.services.base_service import BaseService +from athena.app.services.llm_service import LLMService +from athena.models import ( + Action, + MemoryContext, + MemoryTimestamp, + MemoryUnit, + Message, + Result, + State, + Task, +) +from athena.prompts.memory_extraction import ( + ACTION_EXTRACTION_PROMPT, + ACTION_JUDGE_PROMPT, + RESULT_EXTRACTION_PROMPT, + STATE_DONE_SUMMARY_PROMPT, + STATE_TODO_SYNTHESIS_PROMPT, + TASK_EXTRACTION_PROMPT, +) + + +class TrajectoryDataSource(Protocol): + """Protocol for different trajectory data sources.""" + + def load_trajectories(self) -> List[List[Message]]: + """Load trajectories from the data source.""" + ... + + def get_metadata(self, trajectory_id: str) -> Dict[str, Any]: + """Get metadata for a specific trajectory.""" + ... + + +class ExtractionStrategy(Protocol): + """Protocol for different extraction strategies.""" + + def extract_task(self, messages: List[Message]) -> Task: + """Extract task information from trajectory messages.""" + ... + + def extract_action(self, message: Message) -> Action: + """Extract action information from a single message.""" + ... + + def extract_result(self, messages: List[Message], action: Action) -> Result: + """Extract result information from message sequence.""" + ... + + def synthesize_state( + self, prior_units: List[MemoryUnit], current_messages: List[Message], task: Task + ) -> State: + """Synthesize state information from context.""" + ... + + +class MemoryExtractionService(BaseService): + """ + Service for extracting structured memory units from interaction trajectories. + + This service orchestrates the extraction of memory units from various data + sources including conversation logs, agent execution traces, and user + interaction histories. It provides a unified interface for converting + unstructured interaction data into structured memory units that can be + stored and queried by the Athena memory system. + + Key Features: + - Multi-source trajectory loading (datasets, files, APIs) + - Configurable extraction strategies (LLM-based, rule-based, hybrid) + - Batch processing with progress tracking + - Error handling and retry mechanisms + - Memory unit deduplication and validation + - Integration with existing memory storage systems + + Architecture: + The service follows a pipeline pattern: + 1. Data Source -> Load trajectories + 2. Extraction Strategy -> Extract components (task, action, result, state) + 3. Memory Unit Assembly -> Combine components into MemoryUnit + 4. Validation -> Ensure data quality and consistency + 5. Storage -> Persist to memory system + + Usage: + service = MemoryExtractionService(llm_service, extraction_strategy) + memory_units = service.extract_from_trajectories(trajectory_data) + """ + + def __init__( + self, + llm_service: LLMService, + extraction_strategy: Optional[ExtractionStrategy] = None, + batch_size: int = 100, + max_retries: int = 3, + enable_deduplication: bool = True, + ): + """ + Initialize the Memory Extraction Service. + + Args: + llm_service: LLM service for AI-powered extraction + extraction_strategy: Strategy for extracting components (optional) + batch_size: Number of trajectories to process in each batch + max_retries: Maximum retry attempts for failed extractions + enable_deduplication: Whether to deduplicate memory units + """ + self.llm_service = llm_service + self.extraction_strategy = extraction_strategy + self.batch_size = batch_size + self.max_retries = max_retries + self.enable_deduplication = enable_deduplication + self._extraction_cache: Dict[str, MemoryUnit] = {} + + def start(self): + """Start the memory extraction service.""" + # Initialize extraction strategy if not provided + if self.extraction_strategy is None: + self.extraction_strategy = self._create_default_strategy() + + # Initialize any required resources + pass + + def close(self): + """Close the memory extraction service and cleanup resources.""" + self._extraction_cache.clear() + + def extract_from_huggingface_trajectory_repository( + self, repo_name: str, split: str + ) -> List[MemoryUnit]: + """ + Extract memory units from a HuggingFace trajectory repository. + + Args: + repo_name: Name of the HuggingFace trajectory repository + split: Split of the HuggingFace trajectory repository + + Returns: + List of extracted memory units + + Raises: + ExtractionError: If extraction fails for critical trajectories + """ + dataset = load_dataset(repo_name, split=split) + + traj_keys = ["messages", "history", "trajectory", "chat", "conversation", "dialog"] + run_id_keys = ["run_id", "traj_id", "id"] + inst_id_keys = ["instance_id"] + model_keys = ["model", "model_name"] + resolved_keys = [ + "resolved", + "target", + ] # bool value, whether the model solved the issue in this trajectory. + + def _pick(d: Dict[str, Any], keys) -> Optional[Any]: + for k in keys: + if k in d: + return d[k] + return None + + for idx in range(len(dataset)): + row = dataset[idx] + trajectory = _pick(row, traj_keys) + messages: List[Message] = [] + for m in trajectory: + content = str( + m.get("content") + or m.get("text") + or m.get("message") + or m.get("utterance") + or "" + ) + role = str(m.get("role", "unknown")) + metadata = { + k: v + for k, v in m.items() + if k not in {"content", "role", "text", "message", "utterance"} + } + messages.append(Message(content=content, role=role, metadata=metadata)) + + run_id = str(_pick(row, run_id_keys)) or f"{idx}" + instance_id = str(_pick(row, inst_id_keys)) + model = str(_pick(row, model_keys)) + resolved = bool(_pick(row, resolved_keys)) + context = self._create_memory_context( + repo_name, + run_id, + metadata={ + **({"instance_id": instance_id} if instance_id is not None else {}), + **({"model": model} if model is not None else {}), + **({"resolved": resolved} if resolved is not None else {}), + }, + ) + memory_units = self._extract_memory_units_by_action_windows(messages, context) + self._extraction_cache.update( + {mu.context.memory_id: mu for mu in memory_units} + ) # TODO: use PostgreSqlMemoryStore to store memory units. + return list(self._extraction_cache.values()) + + def extract_from_data_source(self, data_source: TrajectoryDataSource) -> List[MemoryUnit]: + """ + Extract memory units from a configured data source. + + Args: + data_source: Data source implementing TrajectoryDataSource protocol + + Returns: + List of all extracted memory units + """ + pass + + def batch_extract( + self, + trajectory_batches: List[List[List[Message]]], + progress_callback: Optional[callable] = None, + ) -> List[MemoryUnit]: + """ + Extract memory units from multiple trajectory batches with progress tracking. + + Args: + trajectory_batches: List of trajectory batches to process + progress_callback: Optional callback for progress updates + + Returns: + List of all extracted memory units + """ + pass + + def _create_default_strategy(self) -> ExtractionStrategy: + """ + Create a default extraction strategy. + + Returns: + Default extraction strategy implementation + """ + pass + + def _create_memory_context( + self, source: str, run_id: str, metadata: Optional[Dict[str, Any]] = {} + ) -> MemoryContext: + """ + Create memory context for a trajectory. + + Args: + source: Data source identifier + run_id: Unique run identifier + metadata: Optional additional metadata + + Returns: + MemoryContext object + """ + return MemoryContext( + memory_id=str(uuid.uuid4()), + source=source, + run_id=run_id, + timestamp=MemoryTimestamp( + created_at=datetime.now(timezone.utc).isoformat(), + updated_at=None, + invalid_at=None, + ), + metadata=metadata, + ) + + def _extract_memory_units_by_action_windows( + self, + messages: List[Message], + context: MemoryContext, + ) -> List[MemoryUnit]: + ordered_memory_units: List[MemoryUnit] = [] + window_msgs: List[Message] = [] + window_first_action: Optional[Message] = None + + task = self._extract_task_from_messages(messages) + + for msg in tqdm( + messages, desc=f"Extracting memory units for {context.run_id} in {context.source}" + ): + if self._is_action_message(msg): + if window_msgs and window_first_action is not None: + mu = self._create_memory_unit( + context, + task, + window_msgs, + ordered_memory_units, + ) + if mu: + ordered_memory_units.append(mu) + + window_msgs = [msg] + window_first_action = msg + else: + if window_msgs: + window_msgs.append(msg) + else: + continue + + if window_msgs and window_first_action is not None: + mu = self._create_memory_unit( + context, + task, + window_msgs, + ordered_memory_units, + ) + if mu: + ordered_memory_units.append(mu) + + if ( + self.memory_store is not None and ordered_memory_units + ): # TODO: use PostgreSqlMemoryStore to store memory units. + self.memory_store.upsert(ordered_memory_units) + + return ordered_memory_units + + def _is_action_message(self, message: Message) -> bool: + """ + Determine if a message represents the start of an action. + + Args: + message: Message to analyze + + Returns: + True if message starts an action, False otherwise + """ + if getattr(message, "role", "") != "assistant": + return False + system_prompt = ACTION_JUDGE_PROMPT + user_prompt = ( + "Decide if the following assistant message represents a concrete ACTION.\n" + "Use ONLY the content between and as input.\n" + "Output exactly one token: True or False.\n\n" + "\n" + message.model_dump_json() + "\n" + ) + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ) + return self._normalize_llm_bool_output(raw) + + def _normalize_llm_bool_output(self, raw: str) -> bool: + """ + Normalize raw LLM output into a boolean. + Accepts variations like True/False, yes/no, 1/0, + possibly wrapped in quotes, code fences, or punctuation. + """ + text = raw.strip().lower() + + # remove common wrappers + text = text.strip(" \t\n\r\"'`") + if text.startswith("```"): + text = text.strip("`").strip() + + # direct boolean keywords + if text in {"true", "yes", "y", "1"}: + return True + if text in {"false", "no", "n", "0"}: + return False + + # fallback: fuzzy match + if text.startswith("true"): + return True + if text.startswith("false"): + return False + + return False + + def _create_memory_unit( + self, + context: MemoryContext, + task: Task, + window_msgs: List[Message], + prior_units: List[MemoryUnit], + ) -> MemoryUnit: + """ + Extract a single memory unit from a window of messages. + - Synthesize state.done from prior actions + - Synthesize state.todo by extracting intents from window_msgs + - Extract action from first_action_msg + - Extract result from the last message in window_msgs + """ + action = self._extract_action_from_message(window_msgs[0]) + state_done = self._synthesize_state_done_from_context(prior_units, task) + state_todo = self._synthesize_state_todo_from_window(window_msgs, task, state_done, action) + result = self._extract_result_from_window(window_msgs, action) + return MemoryUnit( + context=context, + task=task, + state=State( + done=state_done, + todo=state_todo, + ), + action=action, + result=result, + ) + + def _extract_task_from_messages(self, messages: List[Message]) -> Task: + """ + Extract task information from trajectory messages. + + This method identifies the first user message and extracts: + - Issue title and description + - Issue type classification + - Repository information + - Related comments or context + + Args: + messages: List of messages in the trajectory + + Returns: + Task object with extracted information + """ + # 1) find first user message + first_user_msg = next((m for m in messages if getattr(m, "role", "") == "user"), None) + if first_user_msg is None: + return Task( + issue_title="", + issue_body="", + issue_comments="", + issue_type="", + repository="", + ) + + # 2) build a schema-constrained prompt for ONLY Task fields + system_prompt = TASK_EXTRACTION_PROMPT + user_prompt = ( + "Extract the `task` object from the SINGLE user message below.\n" + "Use only the text between and .\n\n" + "\n" + first_user_msg.model_dump_json() + "\n" + ) + + # 3) call LLM service + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ) + + # 4) robust JSON parse, with a single repair attempt if needed + try: + payload = json.loads(raw) + except json.JSONDecodeError: + fixed = self.llm_service.model.invoke( + [ + {"role": "system", "content": "Fix to valid JSON object. No comments."}, + {"role": "user", "content": raw}, + ] + ) + payload = json.loads(fixed) + + # 5) normalize fields to strings (no extra heuristics) + task_obj = payload.get("task", {}) + + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) + + repo = s(task_obj.get("repository")) + if not repo: + repo = "" + + return Task( + issue_title=s(task_obj.get("issue_title")), + issue_body=s(task_obj.get("issue_body")), + issue_comments=s(task_obj.get("issue_comments")), + issue_type=s(task_obj.get("issue_type")), + repository=repo, + ) + + def _synthesize_state_from_context( + self, + prior_units: List[MemoryUnit], + task: Task, + window_msgs: List[Message], + current_action: Optional[Action] = None, + ) -> State: + """ + Synthesize state information from context. + """ + state_done = self._synthesize_state_done_from_context(prior_units, task) + state_todo = self._synthesize_state_todo_from_window( + window_msgs, task, state_done, current_action + ) + return State(done=state_done, todo=state_todo) + + def _synthesize_state_done_from_context(self, prior_units: List[MemoryUnit], task: Task) -> str: + """ + Summarize previous context into a concise `state.done` string. + Summary of what has ALREADY BEEN COMPLETED (no plans). + """ + if not prior_units: + return "(none)" + + # Keep the window modest to control prompt size + window = prior_units[-1:] # state typically depends only on the last state in an agent run + + # Pack minimal, evidence-bound context for the LLM (no heavy heuristics) + history = [] + for u in window: + history.append( + { + "state": { + "done": u.state.done, + "todo": u.state.todo, + }, + "action": { + "name": u.action.name, + "description": u.action.description, + "target": u.action.target, + "tool": u.action.tool, + }, + "result": { + "type": u.result.type, + "description": u.result.description, + "exit_code": u.result.exit_code, + }, + } + ) + + system_prompt = STATE_DONE_SUMMARY_PROMPT + user_prompt = ( + "Summarize ONLY what has ALREADY BEEN COMPLETED into `state.done`.\n" + "TASK (for deriving only):\n" + "\n" + task.model_dump_json() + "\n\n\n" + "PRIOR UNITS (used ONLY for completed work evidence):\n" + "\n" + json.dumps(history, ensure_ascii=False) + "\n\n\n" + "Return the final summary paragraph ONLY (no explanations)." + ) + + try: + summary = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ) + # Be lenient if the model wraps the text in quotes/newlines + return summary.strip().strip('"').strip() + except Exception: + # Minimal, defensive fallback (kept short); avoids complex heuristics + return window[-1].state.done + + def _synthesize_state_todo_from_window( + self, + window_msgs: List[Message], + task: Task, + prior_done_text: str = "", + current_action: Optional[Action] = None, + ) -> str: + """ + Synthesize state.todo from the window of messages. + """ + system_prompt = STATE_TODO_SYNTHESIS_PROMPT + user_prompt = ( + "TASK (overall purpose):\n" + "\n" + task.model_dump_json() + "\n\n\n" + "PRIOR_DONE (what has already been completed; avoid repeating it):\n" + "\n" + (prior_done_text or "") + "\n\n\n" + "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n" + "\n" + + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False) + + "\n\n\n" + "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n" + "\n" + current_action.model_dump_json() + "\n\n\n" + "Return ONLY the final paragraph." + ) + + try: + text = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ) + return text.strip().strip('"').replace("\n", " ").strip() + except Exception: + # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation) + parts = [] + task_short = ( + task.issue_title or task.repository or task.issue_type or "the task" + ).strip() + parts.append(f"To progress on {task_short},") + # Build an imperative clause from action fields + verb = (current_action.name or "execute action").replace("_", " ") + tgt = f" {current_action.target}" if current_action.target else "" + via = f" via {current_action.tool}" if current_action.tool else "" + desc = f" to {current_action.description}" if current_action.description else "" + parts.append(f" {verb}{tgt}{via}{desc}.") + return " ".join(parts).replace(" ", " ").strip() + + def _extract_action_from_message(self, message: Message) -> Action: + """ + Extract action information from an assistant message that starts an action call. + + This method analyzes assistant messages to identify: + - Action type (read_file, edit_file, run_test, etc.) + - Target file or resource + - Tool being used + - Action description + + Args: + message: Assistant message containing action information + + Returns: + Action object with extracted information + """ + system_prompt = ACTION_EXTRACTION_PROMPT + user_prompt = ( + "Extract the `action` from the SINGLE assistant message below.\n" + "Use ONLY the content between and .\n\n" + "\n" + message.model_dump_json() + "\n" + ) + + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ) + + # Robust JSON parse with one repair attempt + try: + payload = json.loads(raw) + except json.JSONDecodeError: + fixed = self.llm_service.model.invoke( + [ + {"role": "system", "content": "Return a valid JSON object only. No comments."}, + {"role": "user", "content": raw}, + ] + ) + payload = json.loads(fixed) + + action_obj = payload.get("action", {}) or {} + + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) + + return Action( + name=s(action_obj.get("name")), + description=s(action_obj.get("description")), + target=s(action_obj.get("target")), + tool=s(action_obj.get("tool")), + ) + + def _extract_result_from_window( + self, window_msgs: List[Message], current_action: Action + ) -> Result: + """ + Extract result information from message sequence. + + This method analyzes the outcome of an action by examining: + - Success/failure indicators + - Error messages or exceptions + - Exit codes + - Output descriptions + + Args: + messages: Messages containing result information + action: The action that was executed + + Returns: + Result object with outcome information + """ + last_msg = window_msgs[-1] + + system_prompt = RESULT_EXTRACTION_PROMPT + user_prompt = ( + "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n" + "\n" + current_action.model_dump_json() + "\n\n\n" + "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n" + "\n" + last_msg.model_dump_json() + "\n\n\n" + "Return ONLY the JSON object." + ) + + try: + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ) + # Parse JSON; if failed, try to fix it to a valid JSON + try: + payload = json.loads(raw) + except json.JSONDecodeError: + fixed = self.llm_service.model.invoke( + [ + { + "role": "system", + "content": "Return a valid JSON object only. No comments.", + }, + {"role": "user", "content": raw}, + ] + ) + payload = json.loads(fixed) + + r = (payload or {}).get("result", {}) or {} + + # Utility: convert to string + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) + + # Result type correction + r_type = s(r.get("type")).lower() + if r_type not in {"success", "failure", "partial", "unknown"}: + r_type = "unknown" + + # Text: one sentence, remove wrapping quotes + description = s(r.get("description")).strip().strip('"') + if not description: + description = "(none)" + + # exit_code: allow returning number or number string + exit_raw = r.get("exit_code") + exit_code = None + if ( + isinstance(exit_raw, (int, float)) + and str(int(exit_raw)) == str(exit_raw).split(".")[0] + ): + exit_code = int(exit_raw) + elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()): + exit_code = int(exit_raw.strip()) + + return Result(type=r_type, description=description, exit_code=exit_code) + + except Exception: + # Minimal fallback: only based on the last message (no LLM call) + src = (getattr(last_msg, "content", "") or "").strip() + lt = src.lower() + + if re.search(r"(error|exception|traceback|failed)", lt): + rtype = "failure" + elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt): + rtype = "success" + elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt): + rtype = "partial" + else: + rtype = "unknown" + + # Extract the last non-empty line as one sentence result, and truncate to ~60 words + lines = [ln.strip() for ln in src.splitlines() if ln.strip()] + summary = lines[-1] if lines else "(none)" + words = summary.split() + if len(words) > 60: + summary = " ".join(words[:60]) + + m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I) + code = int(m.group(1)) if m else None + + return Result(type=rtype, description=summary, exit_code=code) + + +class ExtractionError(Exception): + """Exception raised when memory extraction fails.""" + + def __init__(self, message: str, trajectory_id: Optional[str] = None): + self.trajectory_id = trajectory_id + super().__init__(message) diff --git a/athena/models/message.py b/athena/models/message.py index 5e4f2c4..4b66896 100644 --- a/athena/models/message.py +++ b/athena/models/message.py @@ -12,6 +12,7 @@ "sys": "system", "system_prompt": "system", "sysmsg": "system", + "human": "user", } diff --git a/athena/prompts/__init__.py b/athena/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/athena/prompts/memory_extraction.py b/athena/prompts/memory_extraction.py new file mode 100644 index 0000000..1238991 --- /dev/null +++ b/athena/prompts/memory_extraction.py @@ -0,0 +1,265 @@ +TASK_EXTRACTION_PROMPT = """ +You are an information-extraction model that outputs ONLY the `task` object +from ONE user message. + +Return STRICT JSON with EXACTLY this shape and NOTHING else: +{ +"task": { + "issue_title": string, + "issue_body": string, + "issue_comments": string, + "issue_type": "bug" | "feature" | "documentation" | "question" | "other", + "repository": string +} +} + +Extraction rules: +1) Evidence-bound: Base ALL fields SOLELY on the given user message. Do NOT invent. +2) Unknown -> empty string "" (do not write null, N/A, or placeholders). +3) Plain text; preserve key technical details. If the message contains code or traceback, +keep them in `issue_body` using triple backticks. Do NOT discard reproduction steps. +4) Field guidance: +- issue_title: ≤ 30 words; concise headline of the main problem/request. +- issue_body: main description (symptom, reproduction, expected vs. actual, relevant code/traceback). +- issue_comments: secondary remarks, greetings, environment notes, or side comments not central to the body. +- issue_type (choose the FIRST matching category in this priority): + a) "bug": mentions error/exception/failure, stack trace, "expected vs actual", or reproducible defect. + b) "feature": requests adding/changing functionality without a failure symptom. + c) "documentation": asks to fix/clarify docs/examples/tutorials only. + d) "question": pure inquiry without requesting changes. + e) otherwise "other". +- repository: extract explicit repo identifiers (e.g., "owner/repo") or a clearly named local repo + referenced as the main target. If none is clear, use "". +5) Output policy: +- EXACTLY one JSON object; no markdown fences, no comments, no extra keys. +- Valid JSON only: double quotes, escaped newlines if any, no trailing commas. +""" + + +STATE_DONE_SUMMARY_PROMPT = """ +You summarize PRIOR memory units into a concise `state.done` text that ties completed work to the overall task. + +OUTPUT FORMAT (STRICT): +- Return EXACTLY one plain-text paragraph (no bullets, no JSON, no code fences). +- 5–7 sentences, ≤300 words total. +- No line breaks; no leading/trailing quotes; no extra commentary. +- The FIRST sentence MUST begin with: To resolve , we have ... +- If nothing is verifiably completed, return exactly: (none) + +TASK_SHORT (how to derive): +- Derive a concise (≤15 words) clause from the TASK object capturing the main goal. +- Prefer `issue_title`; add minimal specificity from other fields (e.g., key API/file/repo) if helpful. +- Do NOT invent entities; evidence-bound to TASK fields; plain text only (no quotes) + +WHAT COUNTS AS COMPLETED (content scope): +- Describe completed subgoals and observable outcomes (e.g., files created/edited, tests run with pass/fail, errors reproduced/fixed, resources updated). +- ALLOWED: referencing abstract artifacts (module/class/filename level, or top-level paths), and mentioning very key tool usage/results in generic terms (e.g., “validated with pytest”, “mypy reported no errors”, “git commit applied”). +- Keep references compact and evidence-bound. +- Use past tense and neutral tone. + +WHAT TO EXCLUDE (forbidden detail level): +- Specific commands, CLI flags, test selectors, exact invocation strings, commit SHAs, line numbers, function signature changes, patch/diff hunks, stack traces beyond the error name/message. +- Plans, intentions, hypotheses, open questions, TODO/next steps, speculative language (e.g., “will”, “plan to”, “should”, “next”). +- Duplicated or conflicting earlier statements (prefer the most recent unit for the same aspect). +- Irrelevant greetings or meta commentary. + +CONFLICT & SCOPE: +- Units are ordered oldest→newest; when facts conflict, prefer the newest. +- Be evidence-bound: do NOT invent information beyond the units provided. + +QUALITY CHECK BEFORE ANSWERING: +- Contains only completed work and outcomes. +- Mentions important artifacts at module/class/filename level and—only if essential—high-level tool usage/results. +- No specific commands/flags/line numbers/function-signature details. +- Single paragraph, ≤300 words, no newlines, no quotes, no JSON. +- The FIRST sentence MUST begin with: 'To resolve , we have ...' +""" + + +STATE_TODO_SYNTHESIS_PROMPT = """ +Compose `state.todo` as ONE high-level, tool-agnostic intent paragraph that guides the next step toward resolving the task. + +OUTPUT (STRICT): +- EXACTLY one plain-text paragraph (no bullets, no JSON, no code fences). +- 2-3 sentences, ≤150 words total. +- No line breaks; no leading/trailing quotes; no extra commentary. + +SCOPE & BOUNDARY +- Evidence sources: TASK (overall goal), PRIOR_DONE (what is already completed), WINDOW_MSGS (context for inferring the intent of the next step). +- CURRENT_ACTION is ONLY for domain alignment (same component/area), NEVER for details; do NOT copy or paraphrase its name/tool/command/flags/paths/function names. + +WHAT COUNTS AS INTENT (content scope) +- State ONE high-level objective that moves the task forward (e.g., verify behavior, diagnose cause, validate fix, run regression checks, broaden coverage, document behavior, gather evidence). +- You MAY mention the relevant component/feature/module/class/filename at an abstract level (e.g., “ECS label merge logic”, “mypy type checking”). +- Keep phrasing outcome-oriented and generalizable; avoid procedural steps. + +WHAT TO EXCLUDE (forbidden detail level) +- Any tools, commands, flags, exact invocations, API endpoints, function names, precise paths/parameters, line numbers, stack traces, test selectors, commit SHAs, patch/diff hunks. +- Multi-step plans or sequences (“then/after that”); more than one objective. +- Hedging/speculation (“might/should/maybe”); meta commentary. +- Invented entities. + +CONTENT ORDER: +1) Derive a ≤15-word TASK_SHORT from TASK (prefer issue_title) and begin with a high-level objective: +"To resolve , ..." +2) If PRIOR_DONE is non-empty, briefly acknowledge it (≤12 words) without repeating details: +"having completed , ..." +3) Then state the high-level objective/intent (WHAT, not HOW). +"we need to ." + +CONSISTENCY & STYLE: +- Domain must be consistent with CURRENT_ACTION’s area/component, but MUST NOT include any of its concrete details. +- If any conflicts exist in PRIOR_DONE, prefer the most recent completion. +- Neutral, task-oriented; imperative (“run…”, “open…”, “apply…”). + +QUALITY CHECK BEFORE ANSWERING +- Single paragraph; 2–3 sentences; ≤150 words. +- Begin with the selected template; no newlines/quotes/JSON. +- Tool/action-agnostic; no concrete commands/flags/paths/function names. +- Domain-aligned with CURRENT_ACTION; evidence-bound to inputs. +""" + + +ACTION_JUDGE_PROMPT = """ +You are a STRICT binary classifier for assistant messages. Decide whether the message initiates or executes a concrete ACTION in an agent environment. + +## Decision Target +Return **True** if the message triggers, schedules, or immediately proceeds to perform an operation — even when the operation is written in natural language or embedded in a code/markdown block. + +### ACTION indicators (non-exhaustive; if ambiguous, prefer **True**) +- File/FS ops: create/open/read/write/edit/delete/rename/move/copy, save, apply patch, replace/insert, scroll_up/scroll_down, navigate. +- Search/explore ops: find_file/grep/search/locate/list (files/dirs/symbols), project-wide search. +- Execution/invocation: run/execute/launch/call/invoke a tool or command (pytest/python/bash/node/npm/pip/curl/sql/git), function_call objects, start/stop processes, HTTP requests. +- Editor/tool calls: str_replace_editor.*, execute_bash, tool:{...}, Action: , ..., OpenAI function_call objects, XML/JSON tool blocks. +- Repair/fix/change steps: "Let's fix...", "Edit the constructor/method...", "Implement changes", "Apply the patch". +- Committed immediate steps in first person or imperative: "Now, let's ...", "First, I'll create ...", "Next, run ...", "Proceed to ...". +- Fenced code/markdown blocks that contain actionable commands (even if preceded or followed by explanations), e.g.: + - ```find_file "dense_ndim_array.py"``` + - ```scroll_down``` + - ```create reproduce_bug.py``` + - ```pytest -q``` + - ```python reproduce_error.py``` + +### NOT actions (return **False**) — only if clearly none of the above: +- Pure analysis, planning, or summarization with no concrete operation. +- Pure thought process or reasoning steps without executing/triggering any operation. +- Questions, confirmations, or instructions to a human with no tool/command to be executed. +- Code shown purely for illustration (no intent to execute/apply/edit), e.g., "Here is an example snippet:" followed by code not framed as an immediate step. + +### Ambiguity policy +If the case is ambiguous, **prefer True** (maximize recall for action messages). + +## Output +Output EXACTLY one token on a single line: True or False +No punctuation, quotes, or explanations. +""" + + +ACTION_EXTRACTION_PROMPT = """ +Extract ONLY the `action` object from ONE assistant message that starts a tool/action call. + +OUTPUT (STRICT): +{ +"action": { + "name": string, + "description": string, + "target": string, + "tool": string +} +} +- Exactly one JSON object, no extra keys, no markdown/code fences, valid JSON (double quotes, no trailing commas). +- All fields are strings. If unknown, use "" (never null). + +EVIDENCE SCOPE: +- Use ONLY the text between and . Be evidence-bound; do not invent. + +PRIMARY PATTERNS (use the first that matches): +1) XML-like function call: +- Example: ... /a/b.py ... +- name := the identifier after "" +- tool := infer concise tool from name or parameters (e.g., "str_replace_editor"→"editor"; "execute_bash"→"bash"). + If unclear, use "". +- target := the primary operand (prefer parameters named path/file/target/command). Preserve the exact string. +- description := ≤30 words summarizing the intent (e.g., "view file", "replace string in file", "run command"). + +2) JSON-like/arg-style tool invocation (e.g., {"tool":"bash","command":"..."}): +- name := the operation verb-noun if present (or reuse tool name if no better choice). +- tool := explicit tool field if present, else infer from command/tool tokens; else "". +- target := primary file/path/command; keep flags/args verbatim. +- description := ≤30 words summarizing the intent. + +3) Plain imperative text (no explicit tags), e.g., "Run pytest -k t": +- name := an action verb-noun (e.g., "run_test", "open_file", "edit_file") derived from the instruction. +- tool := infer from tokens ("pytest","git","bash","editor","curl", etc.); else "". +- target := the concrete operand (path/command/resource). Preserve literal text. +- description := ≤30 words. + +MULTIPLE ITEMS: +- If multiple candidate targets exist, pick the most central as `target` and include the rest briefly in `description` (e.g., "also: X, Y"). +- If the message contains several distinct action calls, extract the FIRST action only. + +PRESERVATION & PRECISION: +- Preserve exact paths/commands/API names in `target` (no normalization). +- Prefer parameter values over paraphrases. +- Keep `description` concise and factual; no plans/speculation. + +FINAL CHECKS: +- One JSON object only; all fields strings; unknown→""; ≤30-word description; evidence-bound. +""" + + +RESULT_EXTRACTION_PROMPT = """ +You are an execution-result extractor. +Extract the definitive outcome of executing the CURRENT_ACTION from the LAST_MESSAGE only. + +OUTPUT (STRICT JSON) +{ +"result": { + "type": "success" | "failure" | "partial" | "unknown", + "description": string, // EXACTLY one sentence, ≤60 words, plain text + "exit_code": string // optional; keep digits as string if present; else "" +} +} +- Exactly one JSON object; no extra keys; valid JSON; double quotes; no trailing commas. +- The "description" MUST describe the outcome of CURRENT_ACTION (not other actions/logs). +- If has no verifiable outcome about CURRENT_ACTION, return: +{"result":{"type":"unknown","description":"(none)","exit_code":""}} + +ALIGNMENT & EVIDENCE +- Evidence source: ONLY the text in . +- Use ONLY to align/filter which outcome to report (match by tool/name/target/command). +- Do NOT invent details absent from ; preserve literals (paths, commands, API names, ARNs, exit codes). + +DECISION RULES (priority) +1) Error/exception lines (e.g., “Error”, “Exception”, “Traceback”, “failed”) → type="failure". +2) Explicit success lines (“succeeded”, “completed”, tests passed, “exit code 0”) → type="success". +3) Mixed outcomes (some pass and some fail; partial updates) → type="partial". +4) Otherwise → type="unknown". + +DECISION PROCESS (follow in order) +1) Locate lines in LAST_MESSAGE that explicitly mention CURRENT_ACTION’s tool/name/target/command. +2) If multiple candidates, select the most recent summary/conclusive line related to CURRENT_ACTION. +3) If none mention the action explicitly, use the final conclusive line (e.g., overall success/failure/exit code) that still plausibly refers to the same execution context. +4) If still no verifiable outcome → output (none). + +WHAT TO CAPTURE (priority) +- Success/failure/exception (error type/message). +- Test/command outcomes (pass/fail counts; “finished”, “succeeded”, “failed”). +- Side effects (files/resources created/edited/updated; IDs/ARNs). +- Exit status (e.g., “exit code 0/1”). +- Keep literals verbatim; avoid paraphrasing targets/commands. + +STYLE & TONE +- Neutral, past tense, concise and factual. +- Prefer mentioning the primary target/command if present. +- "description" must be ONE sentence, ≤60 words, neutral, past tense; semicolons allowed. + +SENTENCE TEMPLATE (guideline, not to be printed verbatim) +- " ; (exit code X)" +or "Execution for on with ". + +EDGE CASES +- If the message is mostly wrapper noise (prompts, timestamps), but includes an exit code or clear success/failure line, report that. +- If logs are truncated or clipped and no conclusive line exists → (none). +- Do NOT restate next steps or advice; only the observed outcome. +""" diff --git a/init_memory_base/trajectory_memory_extractor_deepseek.py b/init_memory_base/trajectory_memory_extractor_deepseek.py index 5f084dd..6bdcf12 100644 --- a/init_memory_base/trajectory_memory_extractor_deepseek.py +++ b/init_memory_base/trajectory_memory_extractor_deepseek.py @@ -232,206 +232,68 @@ def call_deepseek( # ===================================== -def extract_result_from_last_message( - window_msgs: List[TrajectoryMessage], current_action: Action -) -> Result: - """ - Use DeepSeek to extract the `result` from the LAST message of a window. - - Evidence-bound to the last message ONLY. - - Output must be ONE plain-text sentence capturing the definitive outcome. - - Minimal heuristics; a tiny fallback is used only if the API fails. - - Assumes available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - json imported - """ - if not window_msgs: - return Result(type="unknown", description="(none)") - last_msg = window_msgs[-1] - - system_prompt = RESULT_EXTRACTION_PROMPT - user_prompt = ( - "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n" - "\n" + current_action.model_dump_json() + "\n\n\n" - "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n" - "\n" + last_msg.model_dump_json() + "\n\n\n" - "Return ONLY the JSON object." - ) - - try: - raw = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, - ) - # Parse JSON; if failed, try to fix it to a valid JSON - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = call_deepseek( - [ - {"role": "system", "content": "Return a valid JSON object only. No comments."}, - {"role": "user", "content": raw}, - ], - temperature=0.0, - ) - payload = json.loads(fixed) - - r = (payload or {}).get("result", {}) or {} - - # Utility: convert to string - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - - # Result type correction - r_type = s(r.get("type")).lower() - if r_type not in {"success", "failure", "partial", "unknown"}: - r_type = "unknown" - - # Text: one sentence, remove wrapping quotes - description = s(r.get("description")).strip().strip('"') - if not description: - description = "(none)" - - # exit_code: allow returning number or number string - exit_raw = r.get("exit_code") - exit_code = None - if isinstance(exit_raw, (int, float)) and str(int(exit_raw)) == str(exit_raw).split(".")[0]: - exit_code = int(exit_raw) - elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()): - exit_code = int(exit_raw.strip()) - - return Result(type=r_type, description=description, exit_code=exit_code) - - except Exception: - # Minimal fallback: only based on the last message (no LLM call) - src = (getattr(last_msg, "content", "") or "").strip() - lt = src.lower() - - if re.search(r"(error|exception|traceback|failed)", lt): - rtype = "failure" - elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt): - rtype = "success" - elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt): - rtype = "partial" - else: - rtype = "unknown" - - # Extract the last non-empty line as one sentence result, and truncate to ~60 words - lines = [ln.strip() for ln in src.splitlines() if ln.strip()] - summary = lines[-1] if lines else "(none)" - words = summary.split() - if len(words) > 60: - summary = " ".join(words[:60]) - - m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I) - code = int(m.group(1)) if m else None - - return Result(type=rtype, description=summary, exit_code=code) +def row_to_trajectory_messages(row): + traj = row["messages"] + msgs = [] + for m in traj: + content = m.get("content", "") if isinstance(m, dict) else str(m) + role = m.get("role", "user") if isinstance(m, dict) else "user" + msgs.append(TrajectoryMessage(content=content, role=role)) + return msgs -def synthesize_state_todo_paragraph( - window_msgs: List[TrajectoryMessage], +def extract_memory_units_by_action_windows( + meta: TrajectoryMeta, task: Task, - prior_done_text: str = "", - current_action: Optional[Action] = None, -) -> str: - """ - Call DeepSeek to synthesize `state.todo` as ONE highly condensed paragraph - capturing the intent of the CURRENT message window. + trajectory: List[TrajectoryMessage], + *, + memory_base: Optional[MemoryBaseProtocol] = None, +) -> List[MemoryUnit]: + ordered_memory_units: List[MemoryUnit] = [] + window_msgs: List[TrajectoryMessage] = [] + window_first_action: Optional[TrajectoryMessage] = None - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - json imported - """ + for msg in tqdm(trajectory, desc="Extracting memory units"): + if is_action_start(msg): + if window_msgs and window_first_action is not None: + mu = finalize_window_with_llm( + window_msgs, + window_first_action, + prior_units=ordered_memory_units, + meta=meta, + task=task, + ) + if mu: + ordered_memory_units.append(mu) - system_prompt = STATE_TODO_SYNTHESIS_PROMPT - user_prompt = ( - "TASK (overall purpose):\n" - "\n" + task.model_dump_json() + "\n\n\n" - "PRIOR_DONE (what has already been completed; avoid repeating it):\n" - "\n" + (prior_done_text or "") + "\n\n\n" - "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n" - "\n" - + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False) - + "\n\n\n" - "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n" - "\n" + current_action.model_dump_json() + "\n\n\n" - "Return ONLY the final paragraph." - ) + window_msgs = [msg] + window_first_action = msg + else: + if window_msgs: + window_msgs.append(msg) + else: + continue - try: - text = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, + if window_msgs and window_first_action is not None: + mu = finalize_window_with_llm( + window_msgs, window_first_action, prior_units=ordered_memory_units, meta=meta, task=task ) - return text.strip().strip('"').replace("\n", " ").strip() - except Exception: - # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation) - parts = [] - task_short = (task.issue_title or task.repository or task.issue_type or "the task").strip() - parts.append(f"To progress on {task_short},") - # Build an imperative clause from action fields - verb = (current_action.name or "execute action").replace("_", " ") - tgt = f" {current_action.target}" if current_action.target else "" - via = f" via {current_action.tool}" if current_action.tool else "" - desc = f" to {current_action.description}" if current_action.description else "" - parts.append(f" {verb}{tgt}{via}{desc}.") - return " ".join(parts).replace(" ", " ").strip() - - -def extract_action_from_first_action_msg(first_action_msg: TrajectoryMessage) -> Action: - """ - Use DeepSeek to extract the `action` object from a SINGLE assistant message - that starts an action call. Minimal heuristics; evidence-bound to the message. - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - Action Pydantic model - - json imported - """ - system_prompt = ACTION_EXTRACTION_PROMPT - user_prompt = ( - "Extract the `action` from the SINGLE assistant message below.\n" - "Use ONLY the content between and .\n\n" - "\n" + first_action_msg.model_dump_json() + "\n" - ) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - raw = call_deepseek(messages, temperature=0.0) + if mu: + ordered_memory_units.append(mu) - # Robust JSON parse with one repair attempt - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = call_deepseek( - [ - {"role": "system", "content": "Return a valid JSON object only. No comments."}, - {"role": "user", "content": raw}, - ], - temperature=0.0, - ) - payload = json.loads(fixed) + if memory_base is not None and ordered_memory_units: + memory_base.upsert(ordered_memory_units) - action_obj = payload.get("action", {}) or {} + return ordered_memory_units - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - return Action( - name=s(action_obj.get("name")), - description=s(action_obj.get("description")), - target=s(action_obj.get("target")), - tool=s(action_obj.get("tool")), - ) +def is_action_start(msg: TrajectoryMessage) -> bool: + """Return True if this message starts an action call, else False.""" + if getattr(msg, "role", "") != "assistant": + return False + content = getattr(msg, "content", "") or "" + return bool(re.search(r"", content, flags=re.DOTALL)) + # TODO: content check, different trajectory style def finalize_window_with_llm( @@ -464,58 +326,79 @@ def finalize_window_with_llm( ) -def is_action_start(msg: TrajectoryMessage) -> bool: - """Return True if this message starts an action call, else False.""" - if getattr(msg, "role", "") != "assistant": - return False - content = getattr(msg, "content", "") or "" - return bool(re.search(r"", content, flags=re.DOTALL)) - # TODO: content check, different trajectory style +def extract_task_from_first_user_message( + trajectory: List[TrajectoryMessage], *, default_repository: str = "" +) -> Task: + """ + From a trajectory (list of TrajectoryMessage), take the first user message, + call DeepSeek to extract the Task fields, and return a Task. + Requirements: + - Uses the LLM to infer fields (minimal heuristics). + - Returns empty strings when unknown. + - If `default_repository` is provided and the model leaves repository empty, + fill it with `default_repository`. -def extract_memory_units_by_action_windows( - meta: TrajectoryMeta, - task: Task, - trajectory: List[TrajectoryMessage], - *, - memory_base: Optional[MemoryBaseProtocol] = None, -) -> List[MemoryUnit]: - ordered_memory_units: List[MemoryUnit] = [] - window_msgs: List[TrajectoryMessage] = [] - window_first_action: Optional[TrajectoryMessage] = None + Dependencies assumed available in your file: + - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str + - Task Pydantic model + - json imported + """ + # 1) find first user message + first_user_msg = next((m for m in trajectory if getattr(m, "role", "") == "user"), None) + if first_user_msg is None: + return Task( + issue_title="", + issue_body="", + issue_comments="", + issue_type="", + repository=default_repository, + ) - for msg in tqdm(trajectory, desc="Extracting memory units"): - if is_action_start(msg): - if window_msgs and window_first_action is not None: - mu = finalize_window_with_llm( - window_msgs, - window_first_action, - prior_units=ordered_memory_units, - meta=meta, - task=task, - ) - if mu: - ordered_memory_units.append(mu) + # 2) build a schema-constrained prompt for ONLY Task fields + system_prompt = TASK_EXTRACTION_PROMPT + user_prompt = ( + "Extract the `task` object from the SINGLE user message below.\n" + "Use only the text between and .\n\n" + "\n" + first_user_msg.model_dump_json() + "\n" + ) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + # 3) call DeepSeek + raw = call_deepseek(messages, temperature=0.0) + + # 4) robust JSON parse, with a single repair attempt if needed + try: + payload = json.loads(raw) + except json.JSONDecodeError: + fix_messages = [ + {"role": "system", "content": "Fix to valid JSON object. No comments."}, + {"role": "user", "content": raw}, + ] + fixed = call_deepseek(fix_messages, temperature=0.0) + payload = json.loads(fixed) - window_msgs = [msg] - window_first_action = msg - else: - if window_msgs: - window_msgs.append(msg) - else: - continue + # 5) normalize fields to strings (no extra heuristics) + task_obj = payload.get("task", {}) - if window_msgs and window_first_action is not None: - mu = finalize_window_with_llm( - window_msgs, window_first_action, prior_units=ordered_memory_units, meta=meta, task=task - ) - if mu: - ordered_memory_units.append(mu) + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) - if memory_base is not None and ordered_memory_units: - memory_base.upsert(ordered_memory_units) + repo = s(task_obj.get("repository")) + if not repo and default_repository: + repo = default_repository # optional fallback; minimal deviation - return ordered_memory_units + return Task( + issue_title=s(task_obj.get("issue_title")), + issue_body=s(task_obj.get("issue_body")), + issue_comments=s(task_obj.get("issue_comments")), + issue_type=s(task_obj.get("issue_type")), + repository=repo, + ) def summarize_context_from_units(units: List[MemoryUnit], task: Task) -> str: @@ -583,41 +466,73 @@ def summarize_context_from_units(units: List[MemoryUnit], task: Task) -> str: return window[-1].state.done -def extract_task_from_first_user_message( - trajectory: List[TrajectoryMessage], *, default_repository: str = "" -) -> Task: +def synthesize_state_todo_paragraph( + window_msgs: List[TrajectoryMessage], + task: Task, + prior_done_text: str = "", + current_action: Optional[Action] = None, +) -> str: """ - From a trajectory (list of TrajectoryMessage), take the first user message, - call DeepSeek to extract the Task fields, and return a Task. - - Requirements: - - Uses the LLM to infer fields (minimal heuristics). - - Returns empty strings when unknown. - - If `default_repository` is provided and the model leaves repository empty, - fill it with `default_repository`. + Call DeepSeek to synthesize `state.todo` as ONE highly condensed paragraph + capturing the intent of the CURRENT message window. - Dependencies assumed available in your file: + Dependencies assumed available: - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - Task Pydantic model - json imported """ - # 1) find first user message - first_user_msg = next((m for m in trajectory if getattr(m, "role", "") == "user"), None) - if first_user_msg is None: - return Task( - issue_title="", - issue_body="", - issue_comments="", - issue_type="", - repository=default_repository, + + system_prompt = STATE_TODO_SYNTHESIS_PROMPT + user_prompt = ( + "TASK (overall purpose):\n" + "\n" + task.model_dump_json() + "\n\n\n" + "PRIOR_DONE (what has already been completed; avoid repeating it):\n" + "\n" + (prior_done_text or "") + "\n\n\n" + "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n" + "\n" + + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False) + + "\n\n\n" + "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n" + "\n" + current_action.model_dump_json() + "\n\n\n" + "Return ONLY the final paragraph." + ) + + try: + text = call_deepseek( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.0, ) + return text.strip().strip('"').replace("\n", " ").strip() + except Exception: + # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation) + parts = [] + task_short = (task.issue_title or task.repository or task.issue_type or "the task").strip() + parts.append(f"To progress on {task_short},") + # Build an imperative clause from action fields + verb = (current_action.name or "execute action").replace("_", " ") + tgt = f" {current_action.target}" if current_action.target else "" + via = f" via {current_action.tool}" if current_action.tool else "" + desc = f" to {current_action.description}" if current_action.description else "" + parts.append(f" {verb}{tgt}{via}{desc}.") + return " ".join(parts).replace(" ", " ").strip() - # 2) build a schema-constrained prompt for ONLY Task fields - system_prompt = TASK_EXTRACTION_PROMPT + +def extract_action_from_first_action_msg(first_action_msg: TrajectoryMessage) -> Action: + """ + Use DeepSeek to extract the `action` object from a SINGLE assistant message + that starts an action call. Minimal heuristics; evidence-bound to the message. + Dependencies assumed available: + - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str + - Action Pydantic model + - json imported + """ + system_prompt = ACTION_EXTRACTION_PROMPT user_prompt = ( - "Extract the `task` object from the SINGLE user message below.\n" - "Use only the text between and .\n\n" - "\n" + first_user_msg.model_dump_json() + "\n" + "Extract the `action` from the SINGLE assistant message below.\n" + "Use ONLY the content between and .\n\n" + "\n" + first_action_msg.model_dump_json() + "\n" ) messages = [ @@ -625,47 +540,132 @@ def extract_task_from_first_user_message( {"role": "user", "content": user_prompt}, ] - # 3) call DeepSeek raw = call_deepseek(messages, temperature=0.0) - # 4) robust JSON parse, with a single repair attempt if needed + # Robust JSON parse with one repair attempt try: payload = json.loads(raw) except json.JSONDecodeError: - fix_messages = [ - {"role": "system", "content": "Fix to valid JSON object. No comments."}, - {"role": "user", "content": raw}, - ] - fixed = call_deepseek(fix_messages, temperature=0.0) + fixed = call_deepseek( + [ + {"role": "system", "content": "Return a valid JSON object only. No comments."}, + {"role": "user", "content": raw}, + ], + temperature=0.0, + ) payload = json.loads(fixed) - # 5) normalize fields to strings (no extra heuristics) - task_obj = payload.get("task", {}) + action_obj = payload.get("action", {}) or {} def s(x): return x if isinstance(x, str) else ("" if x is None else str(x)) - repo = s(task_obj.get("repository")) - if not repo and default_repository: - repo = default_repository # optional fallback; minimal deviation + return Action( + name=s(action_obj.get("name")), + description=s(action_obj.get("description")), + target=s(action_obj.get("target")), + tool=s(action_obj.get("tool")), + ) - return Task( - issue_title=s(task_obj.get("issue_title")), - issue_body=s(task_obj.get("issue_body")), - issue_comments=s(task_obj.get("issue_comments")), - issue_type=s(task_obj.get("issue_type")), - repository=repo, + +def extract_result_from_last_message( + window_msgs: List[TrajectoryMessage], current_action: Action +) -> Result: + """ + Use DeepSeek to extract the `result` from the LAST message of a window. + - Evidence-bound to the last message ONLY. + - Output must be ONE plain-text sentence capturing the definitive outcome. + - Minimal heuristics; a tiny fallback is used only if the API fails. + + Assumes available: + - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str + - json imported + """ + if not window_msgs: + return Result(type="unknown", description="(none)") + last_msg = window_msgs[-1] + + system_prompt = RESULT_EXTRACTION_PROMPT + user_prompt = ( + "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n" + "\n" + current_action.model_dump_json() + "\n\n\n" + "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n" + "\n" + last_msg.model_dump_json() + "\n\n\n" + "Return ONLY the JSON object." ) + try: + raw = call_deepseek( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.0, + ) + # Parse JSON; if failed, try to fix it to a valid JSON + try: + payload = json.loads(raw) + except json.JSONDecodeError: + fixed = call_deepseek( + [ + {"role": "system", "content": "Return a valid JSON object only. No comments."}, + {"role": "user", "content": raw}, + ], + temperature=0.0, + ) + payload = json.loads(fixed) + + r = (payload or {}).get("result", {}) or {} -def row_to_trajectory_messages(row): - traj = row["messages"] - msgs = [] - for m in traj: - content = m.get("content", "") if isinstance(m, dict) else str(m) - role = m.get("role", "user") if isinstance(m, dict) else "user" - msgs.append(TrajectoryMessage(content=content, role=role)) - return msgs + # Utility: convert to string + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) + + # Result type correction + r_type = s(r.get("type")).lower() + if r_type not in {"success", "failure", "partial", "unknown"}: + r_type = "unknown" + + # Text: one sentence, remove wrapping quotes + description = s(r.get("description")).strip().strip('"') + if not description: + description = "(none)" + + # exit_code: allow returning number or number string + exit_raw = r.get("exit_code") + exit_code = None + if isinstance(exit_raw, (int, float)) and str(int(exit_raw)) == str(exit_raw).split(".")[0]: + exit_code = int(exit_raw) + elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()): + exit_code = int(exit_raw.strip()) + + return Result(type=r_type, description=description, exit_code=exit_code) + + except Exception: + # Minimal fallback: only based on the last message (no LLM call) + src = (getattr(last_msg, "content", "") or "").strip() + lt = src.lower() + + if re.search(r"(error|exception|traceback|failed)", lt): + rtype = "failure" + elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt): + rtype = "success" + elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt): + rtype = "partial" + else: + rtype = "unknown" + + # Extract the last non-empty line as one sentence result, and truncate to ~60 words + lines = [ln.strip() for ln in src.splitlines() if ln.strip()] + summary = lines[-1] if lines else "(none)" + words = summary.split() + if len(words) > 60: + summary = " ".join(words[:60]) + + m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I) + code = int(m.group(1)) if m else None + + return Result(type=rtype, description=summary, exit_code=code) def main(): From 8ac2b81866c4bd278c05b7f5dee57e0fe4ad05db Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Sat, 13 Sep 2025 21:13:37 +0800 Subject: [PATCH 2/9] feat: implement memory extraction --- athena/app/services/llm_service.py | 12 + .../app/services/memory_extraction_service.py | 480 +++++++++++------- athena/models/__init__.py | 4 +- athena/models/memory.py | 139 ++--- 4 files changed, 384 insertions(+), 251 deletions(-) diff --git a/athena/app/services/llm_service.py b/athena/app/services/llm_service.py index 6c28d5a..fe39fbd 100644 --- a/athena/app/services/llm_service.py +++ b/athena/app/services/llm_service.py @@ -3,6 +3,7 @@ from langchain_anthropic import ChatAnthropic from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_google_vertexai import ChatVertexAI from athena.app.services.base_service import BaseService from athena.chat_models.custom_chat_openai import CustomChatOpenAI @@ -50,6 +51,17 @@ def get_model( max_tokens_to_sample=max_output_tokens, max_retries=3, ) + elif model_name.startswith("vertex:"): + # example: model_name="vertex:gemini-2.5-pro" + vertex_model = model_name.split("vertex:", 1)[1] + return ChatVertexAI( + model_name=vertex_model, + project="prometheus-code-agent", + location="us-central1", + temperature=temperature, + max_output_tokens=max_output_tokens, + max_retries=3, + ) elif "gemini" in model_name: return ChatGoogleGenerativeAI( model=model_name, diff --git a/athena/app/services/memory_extraction_service.py b/athena/app/services/memory_extraction_service.py index 90888bf..fdb3442 100644 --- a/athena/app/services/memory_extraction_service.py +++ b/athena/app/services/memory_extraction_service.py @@ -1,8 +1,9 @@ +import html import json import re import uuid from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional from datasets import load_dataset from tqdm import tqdm @@ -11,7 +12,7 @@ from athena.app.services.llm_service import LLMService from athena.models import ( Action, - MemoryContext, + MemorySource, MemoryTimestamp, MemoryUnit, Message, @@ -28,39 +29,16 @@ TASK_EXTRACTION_PROMPT, ) +# from athena.utils.logger_manager import get_logger -class TrajectoryDataSource(Protocol): - """Protocol for different trajectory data sources.""" - def load_trajectories(self) -> List[List[Message]]: - """Load trajectories from the data source.""" - ... - - def get_metadata(self, trajectory_id: str) -> Dict[str, Any]: - """Get metadata for a specific trajectory.""" - ... - - -class ExtractionStrategy(Protocol): - """Protocol for different extraction strategies.""" - - def extract_task(self, messages: List[Message]) -> Task: - """Extract task information from trajectory messages.""" - ... - - def extract_action(self, message: Message) -> Action: - """Extract action information from a single message.""" - ... +class ExtractionError(Exception): + """Exception raised when memory extraction fails.""" - def extract_result(self, messages: List[Message], action: Action) -> Result: - """Extract result information from message sequence.""" - ... - - def synthesize_state( - self, prior_units: List[MemoryUnit], current_messages: List[Message], task: Task - ) -> State: - """Synthesize state information from context.""" - ... + def __init__(self, source_name: str, run_id: str): + self.source_name = source_name + self.run_id = run_id + super().__init__(f"Failed to extract memory units from {run_id} in {source_name}") class MemoryExtractionService(BaseService): @@ -78,7 +56,6 @@ class MemoryExtractionService(BaseService): - Configurable extraction strategies (LLM-based, rule-based, hybrid) - Batch processing with progress tracking - Error handling and retry mechanisms - - Memory unit deduplication and validation - Integration with existing memory storage systems Architecture: @@ -90,41 +67,32 @@ class MemoryExtractionService(BaseService): 5. Storage -> Persist to memory system Usage: - service = MemoryExtractionService(llm_service, extraction_strategy) + service = MemoryExtractionService(llm_service) memory_units = service.extract_from_trajectories(trajectory_data) """ def __init__( self, llm_service: LLMService, - extraction_strategy: Optional[ExtractionStrategy] = None, batch_size: int = 100, max_retries: int = 3, - enable_deduplication: bool = True, ): """ Initialize the Memory Extraction Service. Args: llm_service: LLM service for AI-powered extraction - extraction_strategy: Strategy for extracting components (optional) batch_size: Number of trajectories to process in each batch max_retries: Maximum retry attempts for failed extractions - enable_deduplication: Whether to deduplicate memory units """ self.llm_service = llm_service - self.extraction_strategy = extraction_strategy self.batch_size = batch_size self.max_retries = max_retries - self.enable_deduplication = enable_deduplication self._extraction_cache: Dict[str, MemoryUnit] = {} + # self._logger = get_logger(__name__) def start(self): """Start the memory extraction service.""" - # Initialize extraction strategy if not provided - if self.extraction_strategy is None: - self.extraction_strategy = self._create_default_strategy() - # Initialize any required resources pass @@ -132,7 +100,7 @@ def close(self): """Close the memory extraction service and cleanup resources.""" self._extraction_cache.clear() - def extract_from_huggingface_trajectory_repository( + def extract_from_huggingface_trajectory_repository( # TODO: batch extraction self, repo_name: str, split: str ) -> List[MemoryUnit]: """ @@ -185,11 +153,13 @@ def _pick(d: Dict[str, Any], keys) -> Optional[Any]: } messages.append(Message(content=content, role=role, metadata=metadata)) - run_id = str(_pick(row, run_id_keys)) or f"{idx}" - instance_id = str(_pick(row, inst_id_keys)) - model = str(_pick(row, model_keys)) - resolved = bool(_pick(row, resolved_keys)) - context = self._create_memory_context( + run_id = _pick(row, run_id_keys) or f"{idx}" + instance_id = _pick(row, inst_id_keys) + model = _pick(row, model_keys) + resolved = _pick( + row, resolved_keys + ) # TODO: use resolved to filter memory units to successful/failed ones. + memory_source = self._extract_memory_source( repo_name, run_id, metadata={ @@ -198,55 +168,25 @@ def _pick(d: Dict[str, Any], keys) -> Optional[Any]: **({"resolved": resolved} if resolved is not None else {}), }, ) - memory_units = self._extract_memory_units_by_action_windows(messages, context) + try: + memory_units = self._extract_memory_units_by_action_windows(messages, memory_source) + print(memory_units) + assert False + except ExtractionError: + # self._logger.error(f"Failed to extract memory units from {run_id} in {repo_name}: {e}") # TODO: add logger + assert False + continue + self._extraction_cache.update( - {mu.context.memory_id: mu for mu in memory_units} + {mu.memory_id: mu for mu in memory_units} ) # TODO: use PostgreSqlMemoryStore to store memory units. return list(self._extraction_cache.values()) - def extract_from_data_source(self, data_source: TrajectoryDataSource) -> List[MemoryUnit]: + def _extract_memory_source( + self, source: str, run_id: str, metadata: Optional[Dict[str, Any]] = None + ) -> MemorySource: """ - Extract memory units from a configured data source. - - Args: - data_source: Data source implementing TrajectoryDataSource protocol - - Returns: - List of all extracted memory units - """ - pass - - def batch_extract( - self, - trajectory_batches: List[List[List[Message]]], - progress_callback: Optional[callable] = None, - ) -> List[MemoryUnit]: - """ - Extract memory units from multiple trajectory batches with progress tracking. - - Args: - trajectory_batches: List of trajectory batches to process - progress_callback: Optional callback for progress updates - - Returns: - List of all extracted memory units - """ - pass - - def _create_default_strategy(self) -> ExtractionStrategy: - """ - Create a default extraction strategy. - - Returns: - Default extraction strategy implementation - """ - pass - - def _create_memory_context( - self, source: str, run_id: str, metadata: Optional[Dict[str, Any]] = {} - ) -> MemoryContext: - """ - Create memory context for a trajectory. + Extract memory source for a trajectory. Args: source: Data source identifier @@ -254,24 +194,18 @@ def _create_memory_context( metadata: Optional additional metadata Returns: - MemoryContext object + MemorySource object """ - return MemoryContext( - memory_id=str(uuid.uuid4()), - source=source, + return MemorySource( + source_name=source, run_id=run_id, - timestamp=MemoryTimestamp( - created_at=datetime.now(timezone.utc).isoformat(), - updated_at=None, - invalid_at=None, - ), - metadata=metadata, + metadata=metadata or {}, ) def _extract_memory_units_by_action_windows( self, messages: List[Message], - context: MemoryContext, + memory_source: MemorySource, ) -> List[MemoryUnit]: ordered_memory_units: List[MemoryUnit] = [] window_msgs: List[Message] = [] @@ -280,12 +214,13 @@ def _extract_memory_units_by_action_windows( task = self._extract_task_from_messages(messages) for msg in tqdm( - messages, desc=f"Extracting memory units for {context.run_id} in {context.source}" + messages, + desc=f"Extracting memory units for {memory_source.run_id} in {memory_source.source_name}", ): if self._is_action_message(msg): if window_msgs and window_first_action is not None: mu = self._create_memory_unit( - context, + memory_source, task, window_msgs, ordered_memory_units, @@ -303,7 +238,7 @@ def _extract_memory_units_by_action_windows( if window_msgs and window_first_action is not None: mu = self._create_memory_unit( - context, + memory_source, task, window_msgs, ordered_memory_units, @@ -342,39 +277,12 @@ def _is_action_message(self, message: Message) -> bool: {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - ) + ).content return self._normalize_llm_bool_output(raw) - def _normalize_llm_bool_output(self, raw: str) -> bool: - """ - Normalize raw LLM output into a boolean. - Accepts variations like True/False, yes/no, 1/0, - possibly wrapped in quotes, code fences, or punctuation. - """ - text = raw.strip().lower() - - # remove common wrappers - text = text.strip(" \t\n\r\"'`") - if text.startswith("```"): - text = text.strip("`").strip() - - # direct boolean keywords - if text in {"true", "yes", "y", "1"}: - return True - if text in {"false", "no", "n", "0"}: - return False - - # fallback: fuzzy match - if text.startswith("true"): - return True - if text.startswith("false"): - return False - - return False - def _create_memory_unit( self, - context: MemoryContext, + source: MemorySource, task: Task, window_msgs: List[Message], prior_units: List[MemoryUnit], @@ -391,7 +299,13 @@ def _create_memory_unit( state_todo = self._synthesize_state_todo_from_window(window_msgs, task, state_done, action) result = self._extract_result_from_window(window_msgs, action) return MemoryUnit( - context=context, + memory_id=str(uuid.uuid4()), + timestamp=MemoryTimestamp( + created_at=datetime.now(timezone.utc), + updated_at=None, + invalid_at=None, + ), + source=source, task=task, state=State( done=state_done, @@ -442,19 +356,10 @@ def _extract_task_from_messages(self, messages: List[Message]) -> Task: {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - ) + ).content # 4) robust JSON parse, with a single repair attempt if needed - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = self.llm_service.model.invoke( - [ - {"role": "system", "content": "Fix to valid JSON object. No comments."}, - {"role": "user", "content": raw}, - ] - ) - payload = json.loads(fixed) + payload = self._normalize_llm_json_output(raw) # 5) normalize fields to strings (no extra heuristics) task_obj = payload.get("task", {}) @@ -488,7 +393,7 @@ def _synthesize_state_from_context( state_todo = self._synthesize_state_todo_from_window( window_msgs, task, state_done, current_action ) - return State(done=state_done, todo=state_todo) + return State(done=state_done, todo=state_todo) # TODO: open_file, working_dir def _synthesize_state_done_from_context(self, prior_units: List[MemoryUnit], task: Task) -> str: """ @@ -540,7 +445,7 @@ def _synthesize_state_done_from_context(self, prior_units: List[MemoryUnit], tas {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - ) + ).content # Be lenient if the model wraps the text in quotes/newlines return summary.strip().strip('"').strip() except Exception: @@ -578,7 +483,7 @@ def _synthesize_state_todo_from_window( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - ) + ).content return text.strip().strip('"').replace("\n", " ").strip() except Exception: # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation) @@ -623,19 +528,10 @@ def _extract_action_from_message(self, message: Message) -> Action: {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - ) + ).content # Robust JSON parse with one repair attempt - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = self.llm_service.model.invoke( - [ - {"role": "system", "content": "Return a valid JSON object only. No comments."}, - {"role": "user", "content": raw}, - ] - ) - payload = json.loads(fixed) + payload = self._normalize_llm_json_output(raw) action_obj = payload.get("action", {}) or {} @@ -685,21 +581,8 @@ def _extract_result_from_window( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - ) - # Parse JSON; if failed, try to fix it to a valid JSON - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = self.llm_service.model.invoke( - [ - { - "role": "system", - "content": "Return a valid JSON object only. No comments.", - }, - {"role": "user", "content": raw}, - ] - ) - payload = json.loads(fixed) + ).content + payload = self._normalize_llm_json_output(raw) r = (payload or {}).get("result", {}) or {} @@ -756,10 +639,241 @@ def s(x): return Result(type=rtype, description=summary, exit_code=code) + def _normalize_llm_bool_output(self, raw: str) -> bool: + """ + Normalize raw LLM output into a boolean. + Accepts variations like True/False, yes/no, 1/0, + possibly wrapped in quotes, code fences, or punctuation. + """ + text = raw.strip().lower() + + # remove common wrappers + text = text.strip(" \t\n\r\"'`") + if text.startswith("```"): + text = text.strip("`").strip() -class ExtractionError(Exception): - """Exception raised when memory extraction fails.""" + # direct boolean keywords + if text in {"true", "yes", "y", "1"}: + return True + if text in {"false", "no", "n", "0"}: + return False + + # fallback: fuzzy match + if text.startswith("true"): + return True + if text.startswith("false"): + return False + + return False + + def _normalize_llm_json_output(self, raw: str) -> Dict[str, Any]: + """ + Normalize raw LLM output into a JSON object (dict-like). + Single-argument API. All robustness heuristics are internal. + + Strategy (built-in, no extra args): + - Clean BOM/zero-width chars/HTML entities + - Strip markdown code fences (``` / ```json) + - Try direct json.loads + - If it is a quoted JSON string, unquote then parse + - Extract the first balanced JSON object/array (quote-aware) + - Apply limited trailing comma fixes + - Minify whitespace and retry + - Optionally do ONE LLM repair round if self.llm_service is available + - Always return a dict; non-dict top-level will be wrapped as {"_root": value} + """ + EXPECT_OBJECT = True # always return Dict[str, Any] + EXPECT_KEY: Optional[str] = None # no required top-level key + MAX_SCAN_CHARS = 200_000 # cap scanning for safety + ALLOW_REPAIR_WITH_LLM = True # try one LLM repair if available + + if not isinstance(raw, str): + raise ValueError("LLM output is not a string") + text = raw.strip() + if not text: + raise ValueError("Empty LLM output") + + # Hygiene + text = self._strip_bom_zwsp(text) + text = html.unescape(text) + text = self._strip_code_fences(text).strip() + + # 1) direct parse + obj = self._try_parse_direct(text, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 2) quoted JSON -> unquote then parse + obj = self._try_unquote_then_parse(text, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 3) extract first balanced JSON block (object/array), quote-aware + candidate = self._extract_first_json_balanced(text[:MAX_SCAN_CHARS]) + if candidate: + obj = self._try_parse_direct( + candidate, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY + ) + if obj is not None: + return obj + fixed = self._fix_trailing_commas(candidate) + obj = self._try_parse_direct(fixed, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 4) minify whitespace and retry + compact = self._minify_ws(candidate or text) + obj = self._try_parse_direct(compact, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 5) optional single LLM repair (if service available) + if ALLOW_REPAIR_WITH_LLM and getattr(self, "llm_service", None) is not None: + repaired = self._repair_json_with_llm(text, expect_key=EXPECT_KEY) + obj = self._try_parse_direct( + repaired, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY + ) + if obj is not None: + return obj + + snippet = (text[:800] + "...") if len(text) > 800 else text + raise ValueError(f"Failed to parse JSON from LLM output. Snippet: {snippet}") + + def _try_parse_direct( + self, s: str, *, expect_object: bool, expect_key: Optional[str] + ) -> Optional[Dict[str, Any]]: + """Try json.loads; coerce non-dict to {'_root': value}; enforce expect_key if given.""" + try: + val = json.loads(s) + except Exception: + return None + if expect_object and not isinstance(val, dict): + val = {"_root": val} + elif not isinstance(val, dict): + val = {"_root": val} + if expect_key and expect_key not in val: + return None + return val + + def _try_unquote_then_parse( + self, s: str, *, expect_object: bool, expect_key: Optional[str] + ) -> Optional[Dict[str, Any]]: + """If s is a JSON string containing JSON (e.g., '"{...}"'), unquote then parse.""" + try: + inner = json.loads(s) + except Exception: + return None + if isinstance(inner, str): + return self._try_parse_direct(inner, expect_object=expect_object, expect_key=expect_key) + if expect_object and not isinstance(inner, dict): + inner = {"_root": inner} + elif not isinstance(inner, dict): + inner = {"_root": inner} + if expect_key and expect_key not in inner: + return None + return inner + + def _strip_code_fences(self, s: str) -> str: + """Remove markdown code fences and leading 'json' language tag.""" + t = s.strip() + if t.startswith("```"): + t = re.sub(r"^\s*```(?:json)?\s*", "", t, flags=re.I) + t = re.sub(r"\s*```\s*$", "", t) + t = re.sub(r"```(?:json)?\s*([\s\S]*?)\s*```", r"\1", t, flags=re.I) + return t + + def _strip_bom_zwsp(self, s: str) -> str: + """Remove BOM and zero-width characters that can break parsers.""" + s = s.lstrip("\ufeff") + return re.sub(r"[\u200B-\u200D\u2060\uFEFF]", "", s) + + def _minify_ws(self, s: str) -> str: + """Collapse excessive newlines and surrounding whitespace (keeps spaces inside strings).""" + return re.sub(r"\s*\n\s*", " ", s).strip() + + def _fix_trailing_commas(self, s: str) -> str: + """Limited safe fix: ',}' -> '}' and ',]' -> ']'.""" + t = re.sub(r",\s*}", "}", s) + t = re.sub(r",\s*]", "]", t) + return t + + def _extract_first_json_balanced(self, text: str) -> str: + """Find the first balanced {...} or [...] block (quote/escape-aware).""" + for opener, closer in (("{", "}"), ("[", "]")): + block = self._scan_balanced(text, opener, closer) + if block: + return block + return "" + + def _scan_balanced(self, t: str, opener: str, closer: str) -> str: + """Quote/escape-aware balanced scanner to locate a JSON block.""" + start = t.find(opener) + if start == -1: + return "" + depth = 0 + in_str = False + esc = False + quote_char = "" + for i in range(start, len(t)): + ch = t[i] + if in_str: + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == quote_char: + in_str = False + continue + if ch in ('"', "'"): + in_str = True + quote_char = ch + esc = False + continue + if ch == opener: + depth += 1 + elif ch == closer: + depth -= 1 + if depth == 0: + return t[start : i + 1] + return "" + + def _repair_json_with_llm(self, raw: str, expect_key: Optional[str]) -> str: + """One-shot LLM repair to return MINIFIED valid JSON only (no prose/fences).""" + constraint = f" It MUST contain the top-level key '{expect_key}'." if expect_key else "" + try: + repaired = self.llm_service.model.invoke( + [ + { + "role": "system", + "content": ( + "You are a JSON normalizer. " + "Return ONLY a valid MINIFIED JSON object. " + "No explanations, no code fences, no comments." + constraint + ), + }, + {"role": "user", "content": raw}, + ] + ) + if isinstance(repaired, str): + return repaired.strip() + content = getattr(repaired, "content", None) + if isinstance(content, str): + return content.strip() + return str(repaired).strip() + except Exception: + return raw - def __init__(self, message: str, trajectory_id: Optional[str] = None): - self.trajectory_id = trajectory_id - super().__init__(message) + +if __name__ == "__main__": + service = MemoryExtractionService( + llm_service=LLMService( + model_name="vertex:gemini-2.5-flash", + model_temperature=0.0, + model_max_input_tokens=8192, + model_max_output_tokens=8192, + ) + ) + service.extract_from_huggingface_trajectory_repository( + repo_name="SWE-Gym/OpenHands-SFT-Trajectories", + split="train.success.oss", + ) diff --git a/athena/models/__init__.py b/athena/models/__init__.py index 7d32594..07f5255 100644 --- a/athena/models/__init__.py +++ b/athena/models/__init__.py @@ -1,6 +1,6 @@ from .memory import ( Action, - MemoryContext, + MemorySource, MemoryTimestamp, MemoryUnit, MemoryUnitDB, @@ -13,7 +13,7 @@ __all__ = [ "Message", "MemoryUnit", - "MemoryContext", + "MemorySource", "MemoryTimestamp", "Task", "State", diff --git a/athena/models/memory.py b/athena/models/memory.py index 11fe613..2d7c8c4 100644 --- a/athena/models/memory.py +++ b/athena/models/memory.py @@ -1,38 +1,38 @@ import json import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel +from pydantic import Field as PydField from sqlalchemy import Column, DateTime, Text +from sqlmodel import Field as SQLField from sqlmodel import SQLModel, UniqueConstraint class MemoryTimestamp(BaseModel): """Lifecycle timestamps for a memory unit.""" - created_at: datetime = Field(None, description="When the memory unit was first created") - updated_at: Optional[datetime] = Field( + created_at: datetime = PydField( + default_factory=lambda: datetime.now(timezone.utc), + description="When the memory unit was first created", + ) + updated_at: Optional[datetime] = PydField( None, description="When the memory was last updated/refreshed" ) - invalid_at: Optional[datetime] = Field( + invalid_at: Optional[datetime] = PydField( None, description="When the memory was invalidated or expired" ) -class MemoryContext(BaseModel): - """Context metadata for a memory unit.""" +class MemorySource(BaseModel): + """Source information for a memory unit.""" - memory_id: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique identifier for this memory unit", - ) - source: str = Field( + source_name: str = PydField( ..., description="Memory source, e.g., agent name, model name, dataset name, or file path" ) - run_id: str = Field(..., description="Unique id for agent run or trajectory dataset") - timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory") - metadata: Dict[str, Any] = Field( + run_id: str = PydField(..., description="Unique id for agent run or trajectory dataset") + metadata: Dict[str, Any] = PydField( default_factory=dict, description="Optional extension field for custom metadata" ) @@ -40,28 +40,28 @@ class MemoryContext(BaseModel): class Task(BaseModel): """Task information extracted from messages or issues.""" - issue_title: str = Field(..., description="Short title of the issue/task") - issue_body: str = Field(..., description="Detailed description of the issue/task") - issue_comments: str = Field(..., description="Relevant comments or discussions") - issue_type: str = Field( + issue_title: str = PydField(..., description="Short title of the issue/task") + issue_body: str = PydField(..., description="Detailed description of the issue/task") + issue_comments: str = PydField(..., description="Relevant comments or discussions") + issue_type: str = PydField( ..., description="Type of issue: bug | feature | documentation | question | other" ) - repository: str = Field(..., description="Repository where the issue/task belongs") + repository: str = PydField(..., description="Repository where the issue/task belongs") class State(BaseModel): """State information including completed and pending work.""" - done: str = Field(..., description="Summary of what has already been completed") - todo: str = Field(..., description="Summary of what remains to be done") - open_file: Optional[str] = Field( + done: str = PydField(..., description="Summary of what has already been completed") + todo: str = PydField(..., description="Summary of what remains to be done") + open_file: Optional[str] = PydField( None, description="Path of the currently opened or edited file, if any (e.g., /sympy__sympy/reproduce_bug.py)", ) - working_dir: Optional[str] = Field( + working_dir: Optional[str] = PydField( None, description="Path of the current working directory (e.g., /sympy__sympy)" ) - extra_environment: Dict[str, str] = Field( + extra_environment: Dict[str, str] = PydField( default_factory=dict, description="Other SWE-specific runtime state, e.g. active branch, virtualenv, configs, etc.", ) @@ -70,15 +70,15 @@ class State(BaseModel): class Action(BaseModel): """Action taken by the agent.""" - name: str = Field( + name: str = PydField( ..., description="Action type (read_file | edit_file | run_test | invoke_tool | etc.)" ) - description: str = Field(..., description="Detailed description of the action") - target: str = Field( + description: str = PydField(..., description="Detailed description of the action") + target: str = PydField( ..., description="Target of the action (e.g., utils/math.py, tests/test_math.py, pytest, git)", ) - tool: str = Field( + tool: str = PydField( ..., description="Tool used for the action (e.g., pytest, git, search, editor, bash)" ) @@ -89,11 +89,11 @@ class Action(BaseModel): class Result(BaseModel): """Execution result produced by an action.""" - type: ResultType = Field( + type: ResultType = PydField( "unknown", description="Execution outcome: success | failure | partial | unknown" ) - description: str = Field("", description="Summary of the result") - exit_code: Optional[int] = Field( + description: str = PydField("", description="Summary of the result") + exit_code: Optional[int] = PydField( None, description="Exit code if extractable from logs (optional)" ) @@ -103,14 +103,21 @@ class MemoryUnit(BaseModel): Core memory unit capturing one action of agent execution. This includes: - - The memory context (memory_id, source, run_id, timestamps, metadata) + - The memory id (memory_id) + - The memory timestamp (timestamp) + - The memory source (source_name, run_id, metadata) - The task being worked on (issue and repository details) - The current state (what's done, what's todo) - The action taken by the agent - The result of the action """ - context: MemoryContext + memory_id: str = PydField( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for this memory unit", + ) + timestamp: MemoryTimestamp = PydField(..., description="Timestamp of the memory") + source: MemorySource task: Task state: State action: Action @@ -122,64 +129,64 @@ class MemoryUnitDB(SQLModel, table=True): """Database model for persistent storage of memory units.""" __tablename__ = "memory_units" - __table_args__ = UniqueConstraint("memory_id", name="uq_memory_id") + __table_args__ = (UniqueConstraint("memory_id", name="uq_memory_id"),) - id: Optional[int] = Field(default=None, primary_key=True) + id: Optional[int] = SQLField(default=None, primary_key=True) - # Context information + # Source information memory_id: str - memory_source: str + memory_source_name: str memory_run_id: str - memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True))) - memory_updated_at: Optional[datetime] = Field( + memory_created_at: datetime = SQLField(sa_column=Column(DateTime(timezone=True))) + memory_updated_at: Optional[datetime] = SQLField( default=None, sa_column=Column(DateTime(timezone=True)) ) - memory_invalid_at: Optional[datetime] = Field( + memory_invalid_at: Optional[datetime] = SQLField( default=None, sa_column=Column(DateTime(timezone=True)) ) - memory_metadata: str = Field( + memory_metadata: str = SQLField( default="{}", sa_column=Column(Text) ) # JSON string for Dict[str, Any] # Task information - task_issue_title: str = Field(sa_column=Column(Text)) - task_issue_body: str = Field(sa_column=Column(Text)) - task_issue_comments: str = Field(sa_column=Column(Text)) + task_issue_title: str = SQLField(sa_column=Column(Text)) + task_issue_body: str = SQLField(sa_column=Column(Text)) + task_issue_comments: str = SQLField(sa_column=Column(Text)) task_issue_type: str task_repository: str # State information - state_done: str = Field(sa_column=Column(Text)) - state_todo: str = Field(sa_column=Column(Text)) + state_done: str = SQLField(sa_column=Column(Text)) + state_todo: str = SQLField(sa_column=Column(Text)) state_open_file: Optional[str] = None state_working_dir: Optional[str] = None - state_extra_environment: str = Field( + state_extra_environment: str = SQLField( default="{}", sa_column=Column(Text) ) # JSON string for Dict[str, str] # Action information action_name: str - action_description: str = Field(sa_column=Column(Text)) + action_description: str = SQLField(sa_column=Column(Text)) action_target: str action_tool: str # Result information result_type: str - result_description: str = Field(sa_column=Column(Text)) + result_description: str = SQLField(sa_column=Column(Text)) result_exit_code: Optional[int] = None @classmethod def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": """Create a database model from a MemoryUnit.""" return cls( - memory_id=memory_unit.context.memory_id, - memory_source=memory_unit.context.source, - memory_run_id=memory_unit.context.run_id, - memory_created_at=memory_unit.context.timestamp.created_at, - memory_updated_at=memory_unit.context.timestamp.updated_at, - memory_invalid_at=memory_unit.context.timestamp.invalid_at, - memory_metadata=json.dumps(memory_unit.context.metadata) - if memory_unit.context.metadata + memory_id=memory_unit.memory_id, + memory_source_name=memory_unit.source.source_name, + memory_run_id=memory_unit.source.run_id, + memory_created_at=memory_unit.timestamp.created_at, + memory_updated_at=memory_unit.timestamp.updated_at, + memory_invalid_at=memory_unit.timestamp.invalid_at, + memory_metadata=json.dumps(memory_unit.source.metadata) + if memory_unit.source.metadata else "{}", task_issue_title=memory_unit.task.issue_title, task_issue_body=memory_unit.task.issue_body, @@ -205,15 +212,15 @@ def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": def to_memory_unit(self) -> MemoryUnit: """Convert database model back to MemoryUnit.""" return MemoryUnit( - context=MemoryContext( - memory_id=self.memory_id, - source=self.memory_source, + memory_id=self.memory_id, + timestamp=MemoryTimestamp( + created_at=self.memory_created_at, + updated_at=self.memory_updated_at, + invalid_at=self.memory_invalid_at, + ), + source=MemorySource( + source_name=self.memory_source_name, run_id=self.memory_run_id, - timestamp=MemoryTimestamp( - created_at=self.memory_created_at, - updated_at=self.memory_updated_at, - invalid_at=self.memory_invalid_at, - ), metadata=json.loads(self.memory_metadata) if self.memory_metadata not in (None, "", "null") else {}, From 2014fdf62aee1dd44ce28c383a14a0bbcf6b6b08 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Sun, 14 Sep 2025 14:31:18 +0800 Subject: [PATCH 3/9] feat: add MemoryUnit and related models for memory extraction service --- athena/entity/__init__.py | 0 athena/entity/memory.py | 141 ++++++++++++++++++++++++++++ athena/models/memory.py | 190 ++++++-------------------------------- pyproject.toml | 1 + 4 files changed, 168 insertions(+), 164 deletions(-) create mode 100644 athena/entity/__init__.py create mode 100644 athena/entity/memory.py diff --git a/athena/entity/__init__.py b/athena/entity/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/athena/entity/memory.py b/athena/entity/memory.py new file mode 100644 index 0000000..7b9c87d --- /dev/null +++ b/athena/entity/memory.py @@ -0,0 +1,141 @@ +import json +from datetime import datetime +from typing import Optional + +from sqlalchemy import Column, DateTime, Text +from sqlmodel import Field, SQLModel, UniqueConstraint + +from athena.models import Action, MemorySource, MemoryTimestamp, MemoryUnit, Result, State, Task + + +# Database models for persistent storage +class MemoryUnitDB(SQLModel, table=True): + """Database model for persistent storage of memory units.""" + + __tablename__ = "memory_units" + __table_args__ = (UniqueConstraint("memory_id", name="uq_memory_id"),) + + id: Optional[int] = Field(default=None, primary_key=True) + + # Source information + memory_id: str + memory_source_name: str + memory_run_id: str + memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True))) + memory_updated_at: Optional[datetime] = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + memory_invalid_at: Optional[datetime] = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + memory_metadata: str = Field( + default="{}", sa_column=Column(Text) + ) # JSON string for Dict[str, Any] + + # Task information + task_issue_title: str = Field(sa_column=Column(Text)) + task_issue_body: str = Field(sa_column=Column(Text)) + task_issue_comments: str = Field(sa_column=Column(Text)) + task_issue_type: str + task_repository: str + + # State information + state_done: str = Field(sa_column=Column(Text)) + state_todo: str = Field(sa_column=Column(Text)) + state_open_file: Optional[str] = None + state_working_dir: Optional[str] = None + state_extra_environment: str = Field( + default="{}", sa_column=Column(Text) + ) # JSON string for Dict[str, str] + + # Action information + action_name: str + action_description: str = Field(sa_column=Column(Text)) + action_target: str + action_tool: str + + # Result information + result_type: str + result_description: str = Field(sa_column=Column(Text)) + result_exit_code: Optional[int] = None + + @classmethod + def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": + """Create a database model from a MemoryUnit.""" + return cls( + memory_id=memory_unit.memory_id, + memory_source_name=memory_unit.source.source_name, + memory_run_id=memory_unit.source.run_id, + memory_created_at=memory_unit.timestamp.created_at, + memory_updated_at=memory_unit.timestamp.updated_at, + memory_invalid_at=memory_unit.timestamp.invalid_at, + memory_metadata=json.dumps(memory_unit.source.metadata) + if memory_unit.source.metadata + else "{}", + task_issue_title=memory_unit.task.issue_title, + task_issue_body=memory_unit.task.issue_body, + task_issue_comments=memory_unit.task.issue_comments, + task_issue_type=memory_unit.task.issue_type, + task_repository=memory_unit.task.repository, + state_done=memory_unit.state.done, + state_todo=memory_unit.state.todo, + state_open_file=memory_unit.state.open_file, + state_working_dir=memory_unit.state.working_dir, + state_extra_environment=json.dumps(memory_unit.state.extra_environment) + if memory_unit.state.extra_environment + else "{}", + action_name=memory_unit.action.name, + action_description=memory_unit.action.description, + action_target=memory_unit.action.target, + action_tool=memory_unit.action.tool, + result_type=memory_unit.result.type, + result_description=memory_unit.result.description, + result_exit_code=memory_unit.result.exit_code, + ) + + def to_memory_unit(self) -> MemoryUnit: + """Convert database model back to MemoryUnit.""" + return MemoryUnit( + memory_id=self.memory_id, + timestamp=MemoryTimestamp( + created_at=self.memory_created_at, + updated_at=self.memory_updated_at, + invalid_at=self.memory_invalid_at, + ), + source=MemorySource( + source_name=self.memory_source_name, + run_id=self.memory_run_id, + metadata=json.loads(self.memory_metadata) + if self.memory_metadata not in (None, "", "null") + else {}, + ), + task=Task( + issue_title=self.task_issue_title, + issue_body=self.task_issue_body, + issue_comments=self.task_issue_comments, + issue_type=self.task_issue_type, + repository=self.task_repository, + ), + state=State( + done=self.state_done, + todo=self.state_todo, + open_file=self.state_open_file, + working_dir=self.state_working_dir, + extra_environment=json.loads(self.state_extra_environment) + if self.state_extra_environment not in (None, "", "null") + else {}, + ), + action=Action( + name=self.action_name, + description=self.action_description, + target=self.action_target, + tool=self.action_tool, + ), + result=Result( + type=self.result_type + if self.result_type in ["success", "failure", "partial", "unknown"] + else "unknown", + description=self.result_description, + exit_code=self.result_exit_code, + ), + ) diff --git a/athena/models/memory.py b/athena/models/memory.py index 2d7c8c4..4776665 100644 --- a/athena/models/memory.py +++ b/athena/models/memory.py @@ -1,26 +1,21 @@ -import json import uuid from datetime import datetime, timezone from typing import Any, Dict, Literal, Optional -from pydantic import BaseModel -from pydantic import Field as PydField -from sqlalchemy import Column, DateTime, Text -from sqlmodel import Field as SQLField -from sqlmodel import SQLModel, UniqueConstraint +from pydantic import BaseModel, Field class MemoryTimestamp(BaseModel): """Lifecycle timestamps for a memory unit.""" - created_at: datetime = PydField( + created_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), description="When the memory unit was first created", ) - updated_at: Optional[datetime] = PydField( + updated_at: Optional[datetime] = Field( None, description="When the memory was last updated/refreshed" ) - invalid_at: Optional[datetime] = PydField( + invalid_at: Optional[datetime] = Field( None, description="When the memory was invalidated or expired" ) @@ -28,11 +23,11 @@ class MemoryTimestamp(BaseModel): class MemorySource(BaseModel): """Source information for a memory unit.""" - source_name: str = PydField( + source_name: str = Field( ..., description="Memory source, e.g., agent name, model name, dataset name, or file path" ) - run_id: str = PydField(..., description="Unique id for agent run or trajectory dataset") - metadata: Dict[str, Any] = PydField( + run_id: str = Field(..., description="Unique id for agent run or trajectory dataset") + metadata: Dict[str, Any] = Field( default_factory=dict, description="Optional extension field for custom metadata" ) @@ -40,28 +35,28 @@ class MemorySource(BaseModel): class Task(BaseModel): """Task information extracted from messages or issues.""" - issue_title: str = PydField(..., description="Short title of the issue/task") - issue_body: str = PydField(..., description="Detailed description of the issue/task") - issue_comments: str = PydField(..., description="Relevant comments or discussions") - issue_type: str = PydField( + issue_title: str = Field(..., description="Short title of the issue/task") + issue_body: str = Field(..., description="Detailed description of the issue/task") + issue_comments: str = Field(..., description="Relevant comments or discussions") + issue_type: str = Field( ..., description="Type of issue: bug | feature | documentation | question | other" ) - repository: str = PydField(..., description="Repository where the issue/task belongs") + repository: str = Field(..., description="Repository where the issue/task belongs") class State(BaseModel): """State information including completed and pending work.""" - done: str = PydField(..., description="Summary of what has already been completed") - todo: str = PydField(..., description="Summary of what remains to be done") - open_file: Optional[str] = PydField( + done: str = Field(..., description="Summary of what has already been completed") + todo: str = Field(..., description="Summary of what remains to be done") + open_file: Optional[str] = Field( None, description="Path of the currently opened or edited file, if any (e.g., /sympy__sympy/reproduce_bug.py)", ) - working_dir: Optional[str] = PydField( + working_dir: Optional[str] = Field( None, description="Path of the current working directory (e.g., /sympy__sympy)" ) - extra_environment: Dict[str, str] = PydField( + extra_environment: Dict[str, str] = Field( default_factory=dict, description="Other SWE-specific runtime state, e.g. active branch, virtualenv, configs, etc.", ) @@ -70,15 +65,15 @@ class State(BaseModel): class Action(BaseModel): """Action taken by the agent.""" - name: str = PydField( + name: str = Field( ..., description="Action type (read_file | edit_file | run_test | invoke_tool | etc.)" ) - description: str = PydField(..., description="Detailed description of the action") - target: str = PydField( + description: str = Field(..., description="Detailed description of the action") + target: str = Field( ..., description="Target of the action (e.g., utils/math.py, tests/test_math.py, pytest, git)", ) - tool: str = PydField( + tool: str = Field( ..., description="Tool used for the action (e.g., pytest, git, search, editor, bash)" ) @@ -89,11 +84,11 @@ class Action(BaseModel): class Result(BaseModel): """Execution result produced by an action.""" - type: ResultType = PydField( + type: ResultType = Field( "unknown", description="Execution outcome: success | failure | partial | unknown" ) - description: str = PydField("", description="Summary of the result") - exit_code: Optional[int] = PydField( + description: str = Field("", description="Summary of the result") + exit_code: Optional[int] = Field( None, description="Exit code if extractable from logs (optional)" ) @@ -112,146 +107,13 @@ class MemoryUnit(BaseModel): - The result of the action """ - memory_id: str = PydField( + memory_id: str = Field( default_factory=lambda: str(uuid.uuid4()), description="Unique identifier for this memory unit", ) - timestamp: MemoryTimestamp = PydField(..., description="Timestamp of the memory") + timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory") source: MemorySource task: Task state: State action: Action result: Result - - -# Database models for persistent storage -class MemoryUnitDB(SQLModel, table=True): - """Database model for persistent storage of memory units.""" - - __tablename__ = "memory_units" - __table_args__ = (UniqueConstraint("memory_id", name="uq_memory_id"),) - - id: Optional[int] = SQLField(default=None, primary_key=True) - - # Source information - memory_id: str - memory_source_name: str - memory_run_id: str - memory_created_at: datetime = SQLField(sa_column=Column(DateTime(timezone=True))) - memory_updated_at: Optional[datetime] = SQLField( - default=None, sa_column=Column(DateTime(timezone=True)) - ) - memory_invalid_at: Optional[datetime] = SQLField( - default=None, sa_column=Column(DateTime(timezone=True)) - ) - memory_metadata: str = SQLField( - default="{}", sa_column=Column(Text) - ) # JSON string for Dict[str, Any] - - # Task information - task_issue_title: str = SQLField(sa_column=Column(Text)) - task_issue_body: str = SQLField(sa_column=Column(Text)) - task_issue_comments: str = SQLField(sa_column=Column(Text)) - task_issue_type: str - task_repository: str - - # State information - state_done: str = SQLField(sa_column=Column(Text)) - state_todo: str = SQLField(sa_column=Column(Text)) - state_open_file: Optional[str] = None - state_working_dir: Optional[str] = None - state_extra_environment: str = SQLField( - default="{}", sa_column=Column(Text) - ) # JSON string for Dict[str, str] - - # Action information - action_name: str - action_description: str = SQLField(sa_column=Column(Text)) - action_target: str - action_tool: str - - # Result information - result_type: str - result_description: str = SQLField(sa_column=Column(Text)) - result_exit_code: Optional[int] = None - - @classmethod - def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": - """Create a database model from a MemoryUnit.""" - return cls( - memory_id=memory_unit.memory_id, - memory_source_name=memory_unit.source.source_name, - memory_run_id=memory_unit.source.run_id, - memory_created_at=memory_unit.timestamp.created_at, - memory_updated_at=memory_unit.timestamp.updated_at, - memory_invalid_at=memory_unit.timestamp.invalid_at, - memory_metadata=json.dumps(memory_unit.source.metadata) - if memory_unit.source.metadata - else "{}", - task_issue_title=memory_unit.task.issue_title, - task_issue_body=memory_unit.task.issue_body, - task_issue_comments=memory_unit.task.issue_comments, - task_issue_type=memory_unit.task.issue_type, - task_repository=memory_unit.task.repository, - state_done=memory_unit.state.done, - state_todo=memory_unit.state.todo, - state_open_file=memory_unit.state.open_file, - state_working_dir=memory_unit.state.working_dir, - state_extra_environment=json.dumps(memory_unit.state.extra_environment) - if memory_unit.state.extra_environment - else "{}", - action_name=memory_unit.action.name, - action_description=memory_unit.action.description, - action_target=memory_unit.action.target, - action_tool=memory_unit.action.tool, - result_type=memory_unit.result.type, - result_description=memory_unit.result.description, - result_exit_code=memory_unit.result.exit_code, - ) - - def to_memory_unit(self) -> MemoryUnit: - """Convert database model back to MemoryUnit.""" - return MemoryUnit( - memory_id=self.memory_id, - timestamp=MemoryTimestamp( - created_at=self.memory_created_at, - updated_at=self.memory_updated_at, - invalid_at=self.memory_invalid_at, - ), - source=MemorySource( - source_name=self.memory_source_name, - run_id=self.memory_run_id, - metadata=json.loads(self.memory_metadata) - if self.memory_metadata not in (None, "", "null") - else {}, - ), - task=Task( - issue_title=self.task_issue_title, - issue_body=self.task_issue_body, - issue_comments=self.task_issue_comments, - issue_type=self.task_issue_type, - repository=self.task_repository, - ), - state=State( - done=self.state_done, - todo=self.state_todo, - open_file=self.state_open_file, - working_dir=self.state_working_dir, - extra_environment=json.loads(self.state_extra_environment) - if self.state_extra_environment not in (None, "", "null") - else {}, - ), - action=Action( - name=self.action_name, - description=self.action_description, - target=self.action_target, - tool=self.action_tool, - ), - result=Result( - type=self.result_type - if self.result_type in ["success", "failure", "partial", "unknown"] - else "unknown", - description=self.result_description, - exit_code=self.result_exit_code, - ), - ) diff --git a/pyproject.toml b/pyproject.toml index 9634060..89610b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "langchain-openai==0.2.8", "langchain-google-genai==2.0.4", "langchain_community==0.3.2", + "langchain_google_vertexai==2.1.0" ] requires-python = ">= 3.11" From a6d7375da443d312c22fbe254986df4ce1e0b272 Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 12:04:10 +0800 Subject: [PATCH 4/9] feat: implement memory extraction --- athena/models/memory.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/athena/models/memory.py b/athena/models/memory.py index 2d7c8c4..5be9c58 100644 --- a/athena/models/memory.py +++ b/athena/models/memory.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from typing import Any, Dict, Literal, Optional +from pgvector.sqlalchemy import Vector from pydantic import BaseModel from pydantic import Field as PydField from sqlalchemy import Column, DateTime, Text @@ -175,6 +176,11 @@ class MemoryUnitDB(SQLModel, table=True): result_description: str = SQLField(sa_column=Column(Text)) result_exit_code: Optional[int] = None + # Embeddings for semantic retrieval (optional, pgvector recommended). Stored as float arrays + task_embedding: Optional[list[float]] = SQLField(default=None, sa_column=Column(Vector())) + state_embedding: Optional[list[float]] = SQLField(default=None, sa_column=Column(Vector())) + task_state_embedding: Optional[list[float]] = SQLField(default=None, sa_column=Column(Vector())) + @classmethod def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": """Create a database model from a MemoryUnit.""" From 5232bf7e9a39fe37da589160fc5fdca5e578dda6 Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 12:05:24 +0800 Subject: [PATCH 5/9] feat: implement memory extraction --- athena/configuration/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/athena/configuration/config.py b/athena/configuration/config.py index 62366ed..7c169ac 100644 --- a/athena/configuration/config.py +++ b/athena/configuration/config.py @@ -40,5 +40,11 @@ class Settings(BaseSettings): MEMORY_MAX_STORED_UNITS: int = 1000 # Maximum number of memory units to store MEMORY_SEARCH_LIMIT: int = 10 # Default search result limit + # Embeddings (OpenAI-format, works with Codestral embed via a compatible gateway) + EMBEDDINGS_MODEL: Optional[str] = None + EMBEDDINGS_API_KEY: Optional[str] = None + EMBEDDINGS_BASE_URL: Optional[str] = None + EMBEDDINGS_DIM: Optional[int] = None + settings = Settings() From 27baa8e18fe51c25fbd7bd45b268169913305c97 Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 12:11:02 +0800 Subject: [PATCH 6/9] feat: implement memory extraction --- .../app/services/memory_extraction_service.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/athena/app/services/memory_extraction_service.py b/athena/app/services/memory_extraction_service.py index fdb3442..6cba676 100644 --- a/athena/app/services/memory_extraction_service.py +++ b/athena/app/services/memory_extraction_service.py @@ -1,3 +1,4 @@ +import asyncio import html import json import re @@ -10,6 +11,7 @@ from athena.app.services.base_service import BaseService from athena.app.services.llm_service import LLMService +from athena.app.services.postgres_memory_store import PostgreSqlMemoryStore from athena.models import ( Action, MemorySource, @@ -76,6 +78,7 @@ def __init__( llm_service: LLMService, batch_size: int = 100, max_retries: int = 3, + memory_store: Optional[PostgreSqlMemoryStore] = None, ): """ Initialize the Memory Extraction Service. @@ -89,6 +92,7 @@ def __init__( self.batch_size = batch_size self.max_retries = max_retries self._extraction_cache: Dict[str, MemoryUnit] = {} + self.memory_store = memory_store # self._logger = get_logger(__name__) def start(self): @@ -170,11 +174,8 @@ def _pick(d: Dict[str, Any], keys) -> Optional[Any]: ) try: memory_units = self._extract_memory_units_by_action_windows(messages, memory_source) - print(memory_units) - assert False except ExtractionError: # self._logger.error(f"Failed to extract memory units from {run_id} in {repo_name}: {e}") # TODO: add logger - assert False continue self._extraction_cache.update( @@ -246,10 +247,12 @@ def _extract_memory_units_by_action_windows( if mu: ordered_memory_units.append(mu) - if ( - self.memory_store is not None and ordered_memory_units - ): # TODO: use PostgreSqlMemoryStore to store memory units. - self.memory_store.upsert(ordered_memory_units) + if self.memory_store is not None and ordered_memory_units: # TODO: refine memory store + # Fire-and-forget is acceptable in most flows; caller may await if desired + if asyncio.iscoroutinefunction(self.memory_store.upsert): + asyncio.create_task(self.memory_store.upsert(ordered_memory_units)) + else: + self.memory_store.upsert(ordered_memory_units) # type: ignore[arg-type] return ordered_memory_units From b683266e8e7357029e836496e0bdd9abbd7bff8d Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 12:16:05 +0800 Subject: [PATCH 7/9] feat: implement memory extraction --- init_memory_base/dataset_pre.py | 10 - init_memory_base/prompt.py | 230 ------ init_memory_base/prompts/code_skeleton.txt | 31 - init_memory_base/prompts/text_summary.txt | 31 - .../trajectory_memory_extractor_deepseek.py | 692 ------------------ 5 files changed, 994 deletions(-) delete mode 100644 init_memory_base/dataset_pre.py delete mode 100644 init_memory_base/prompt.py delete mode 100644 init_memory_base/prompts/code_skeleton.txt delete mode 100644 init_memory_base/prompts/text_summary.txt delete mode 100644 init_memory_base/trajectory_memory_extractor_deepseek.py diff --git a/init_memory_base/dataset_pre.py b/init_memory_base/dataset_pre.py deleted file mode 100644 index 858608c..0000000 --- a/init_memory_base/dataset_pre.py +++ /dev/null @@ -1,10 +0,0 @@ -from datasets import load_dataset - -dataset = load_dataset("SWE-Gym/OpenHands-SFT-Trajectories", split="train.success.oss") -dataset = load_dataset("ZhengyanShi/Poutine-OpenHands-SFT-Trajectories", split="train") -dataset = load_dataset("SWE-Gym/OpenHands-Sampled-Trajectories", split="train.raw") -dataset = load_dataset("SWE-bench/SWE-smith-trajectories") -dataset = load_dataset("nebius/SWE-agent-trajectories", split="train") -dataset = load_dataset("zai-org/SWE-Dev-train") -dataset = load_dataset("JetBrains-Research/swe-traj-complete", split="train") -# git clone https://huggingface.co/datasets/ByteDance-Seed/Multi-SWE-bench_trajs diff --git a/init_memory_base/prompt.py b/init_memory_base/prompt.py deleted file mode 100644 index 3ce2a80..0000000 --- a/init_memory_base/prompt.py +++ /dev/null @@ -1,230 +0,0 @@ -TASK_EXTRACTION_PROMPT = """ -You are an information-extraction model that outputs ONLY the `task` object -from ONE user message. - -Return STRICT JSON with EXACTLY this shape and NOTHING else: -{ -"task": { - "issue_title": string, - "issue_body": string, - "issue_comments": string, - "issue_type": "bug" | "feature" | "documentation" | "question" | "other", - "repository": string -} -} - -Extraction rules: -1) Evidence-bound: Base ALL fields SOLELY on the given user message. Do NOT invent. -2) Unknown -> empty string "" (do not write null, N/A, or placeholders). -3) Plain text; preserve key technical details. If the message contains code or traceback, -keep them in `issue_body` using triple backticks. Do NOT discard reproduction steps. -4) Field guidance: -- issue_title: ≤ 30 words; concise headline of the main problem/request. -- issue_body: main description (symptom, reproduction, expected vs. actual, relevant code/traceback). -- issue_comments: secondary remarks, greetings, environment notes, or side comments not central to the body. -- issue_type (choose the FIRST matching category in this priority): - a) "bug": mentions error/exception/failure, stack trace, "expected vs actual", or reproducible defect. - b) "feature": requests adding/changing functionality without a failure symptom. - c) "documentation": asks to fix/clarify docs/examples/tutorials only. - d) "question": pure inquiry without requesting changes. - e) otherwise "other". -- repository: extract explicit repo identifiers (e.g., "owner/repo") or a clearly named local repo - referenced as the main target. If none is clear, use "". -5) Output policy: -- EXACTLY one JSON object; no markdown fences, no comments, no extra keys. -- Valid JSON only: double quotes, escaped newlines if any, no trailing commas. -""" - - -STATE_DONE_SUMMARY_PROMPT = """ -You summarize PRIOR memory units into a concise `state.done` text that ties completed work to the overall task. - -OUTPUT FORMAT (STRICT): -- Return EXACTLY one plain-text paragraph (no bullets, no JSON, no code fences). -- 5–7 sentences, ≤300 words total. -- No line breaks; no leading/trailing quotes; no extra commentary. -- The FIRST sentence MUST begin with: To resolve , we have ... -- If nothing is verifiably completed, return exactly: (none) - -TASK_SHORT (how to derive): -- Derive a concise (≤15 words) clause from the TASK object capturing the main goal. -- Prefer `issue_title`; add minimal specificity from other fields (e.g., key API/file/repo) if helpful. -- Do NOT invent entities; evidence-bound to TASK fields; plain text only (no quotes) - -WHAT COUNTS AS COMPLETED (content scope): -- Describe completed subgoals and observable outcomes (e.g., files created/edited, tests run with pass/fail, errors reproduced/fixed, resources updated). -- ALLOWED: referencing abstract artifacts (module/class/filename level, or top-level paths), and mentioning very key tool usage/results in generic terms (e.g., “validated with pytest”, “mypy reported no errors”, “git commit applied”). -- Keep references compact and evidence-bound. -- Use past tense and neutral tone. - -WHAT TO EXCLUDE (forbidden detail level): -- Specific commands, CLI flags, test selectors, exact invocation strings, commit SHAs, line numbers, function signature changes, patch/diff hunks, stack traces beyond the error name/message. -- Plans, intentions, hypotheses, open questions, TODO/next steps, speculative language (e.g., “will”, “plan to”, “should”, “next”). -- Duplicated or conflicting earlier statements (prefer the most recent unit for the same aspect). -- Irrelevant greetings or meta commentary. - -CONFLICT & SCOPE: -- Units are ordered oldest→newest; when facts conflict, prefer the newest. -- Be evidence-bound: do NOT invent information beyond the units provided. - -QUALITY CHECK BEFORE ANSWERING: -- Contains only completed work and outcomes. -- Mentions important artifacts at module/class/filename level and—only if essential—high-level tool usage/results. -- No specific commands/flags/line numbers/function-signature details. -- Single paragraph, ≤300 words, no newlines, no quotes, no JSON. -- The FIRST sentence MUST begin with: 'To resolve , we have ...' -""" - - -STATE_TODO_SYNTHESIS_PROMPT = """ -Compose `state.todo` as ONE high-level, tool-agnostic intent paragraph that guides the next step toward resolving the task. - -OUTPUT (STRICT): -- EXACTLY one plain-text paragraph (no bullets, no JSON, no code fences). -- 2-3 sentences, ≤150 words total. -- No line breaks; no leading/trailing quotes; no extra commentary. - -SCOPE & BOUNDARY -- Evidence sources: TASK (overall goal), PRIOR_DONE (what is already completed), WINDOW_MSGS (context for inferring the intent of the next step). -- CURRENT_ACTION is ONLY for domain alignment (same component/area), NEVER for details; do NOT copy or paraphrase its name/tool/command/flags/paths/function names. - -WHAT COUNTS AS INTENT (content scope) -- State ONE high-level objective that moves the task forward (e.g., verify behavior, diagnose cause, validate fix, run regression checks, broaden coverage, document behavior, gather evidence). -- You MAY mention the relevant component/feature/module/class/filename at an abstract level (e.g., “ECS label merge logic”, “mypy type checking”). -- Keep phrasing outcome-oriented and generalizable; avoid procedural steps. - -WHAT TO EXCLUDE (forbidden detail level) -- Any tools, commands, flags, exact invocations, API endpoints, function names, precise paths/parameters, line numbers, stack traces, test selectors, commit SHAs, patch/diff hunks. -- Multi-step plans or sequences (“then/after that”); more than one objective. -- Hedging/speculation (“might/should/maybe”); meta commentary. -- Invented entities. - -CONTENT ORDER: -1) Derive a ≤15-word TASK_SHORT from TASK (prefer issue_title) and begin with a high-level objective: -"To resolve , ..." -2) If PRIOR_DONE is non-empty, briefly acknowledge it (≤12 words) without repeating details: -"having completed , ..." -3) Then state the high-level objective/intent (WHAT, not HOW). -"we need to ." - -CONSISTENCY & STYLE: -- Domain must be consistent with CURRENT_ACTION’s area/component, but MUST NOT include any of its concrete details. -- If any conflicts exist in PRIOR_DONE, prefer the most recent completion. -- Neutral, task-oriented; imperative (“run…”, “open…”, “apply…”). - -QUALITY CHECK BEFORE ANSWERING -- Single paragraph; 2–3 sentences; ≤150 words. -- Begin with the selected template; no newlines/quotes/JSON. -- Tool/action-agnostic; no concrete commands/flags/paths/function names. -- Domain-aligned with CURRENT_ACTION; evidence-bound to inputs. -""" - - -ACTION_EXTRACTION_PROMPT = """ -Extract ONLY the `action` object from ONE assistant message that starts a tool/action call. - -OUTPUT (STRICT): -{ -"action": { - "name": string, - "description": string, - "target": string, - "tool": string -} -} -- Exactly one JSON object, no extra keys, no markdown/code fences, valid JSON (double quotes, no trailing commas). -- All fields are strings. If unknown, use "" (never null). - -EVIDENCE SCOPE: -- Use ONLY the text between and . Be evidence-bound; do not invent. - -PRIMARY PATTERNS (use the first that matches): -1) XML-like function call: -- Example: ... /a/b.py ... -- name := the identifier after "" -- tool := infer concise tool from name or parameters (e.g., "str_replace_editor"→"editor"; "execute_bash"→"bash"). - If unclear, use "". -- target := the primary operand (prefer parameters named path/file/target/command). Preserve the exact string. -- description := ≤30 words summarizing the intent (e.g., "view file", "replace string in file", "run command"). - -2) JSON-like/arg-style tool invocation (e.g., {"tool":"bash","command":"..."}): -- name := the operation verb-noun if present (or reuse tool name if no better choice). -- tool := explicit tool field if present, else infer from command/tool tokens; else "". -- target := primary file/path/command; keep flags/args verbatim. -- description := ≤30 words summarizing the intent. - -3) Plain imperative text (no explicit tags), e.g., "Run pytest -k t": -- name := an action verb-noun (e.g., "run_test", "open_file", "edit_file") derived from the instruction. -- tool := infer from tokens ("pytest","git","bash","editor","curl", etc.); else "". -- target := the concrete operand (path/command/resource). Preserve literal text. -- description := ≤30 words. - -MULTIPLE ITEMS: -- If multiple candidate targets exist, pick the most central as `target` and include the rest briefly in `description` (e.g., "also: X, Y"). -- If the message contains several distinct action calls, extract the FIRST action only. - -PRESERVATION & PRECISION: -- Preserve exact paths/commands/API names in `target` (no normalization). -- Prefer parameter values over paraphrases. -- Keep `description` concise and factual; no plans/speculation. - -FINAL CHECKS: -- One JSON object only; all fields strings; unknown→""; ≤30-word description; evidence-bound. -""" - - -RESULT_EXTRACTION_PROMPT = """ -You are an execution-result extractor. -Extract the definitive outcome of executing the CURRENT_ACTION from the LAST_MESSAGE only. - -OUTPUT (STRICT JSON) -{ -"result": { - "type": "success" | "failure" | "partial" | "unknown", - "description": string, // EXACTLY one sentence, ≤60 words, plain text - "exit_code": string // optional; keep digits as string if present; else "" -} -} -- Exactly one JSON object; no extra keys; valid JSON; double quotes; no trailing commas. -- The "description" MUST describe the outcome of CURRENT_ACTION (not other actions/logs). -- If has no verifiable outcome about CURRENT_ACTION, return: -{"result":{"type":"unknown","description":"(none)","exit_code":""}} - -ALIGNMENT & EVIDENCE -- Evidence source: ONLY the text in . -- Use ONLY to align/filter which outcome to report (match by tool/name/target/command). -- Do NOT invent details absent from ; preserve literals (paths, commands, API names, ARNs, exit codes). - -DECISION RULES (priority) -1) Error/exception lines (e.g., “Error”, “Exception”, “Traceback”, “failed”) → type="failure". -2) Explicit success lines (“succeeded”, “completed”, tests passed, “exit code 0”) → type="success". -3) Mixed outcomes (some pass and some fail; partial updates) → type="partial". -4) Otherwise → type="unknown". - -DECISION PROCESS (follow in order) -1) Locate lines in LAST_MESSAGE that explicitly mention CURRENT_ACTION’s tool/name/target/command. -2) If multiple candidates, select the most recent summary/conclusive line related to CURRENT_ACTION. -3) If none mention the action explicitly, use the final conclusive line (e.g., overall success/failure/exit code) that still plausibly refers to the same execution context. -4) If still no verifiable outcome → output (none). - -WHAT TO CAPTURE (priority) -- Success/failure/exception (error type/message). -- Test/command outcomes (pass/fail counts; “finished”, “succeeded”, “failed”). -- Side effects (files/resources created/edited/updated; IDs/ARNs). -- Exit status (e.g., “exit code 0/1”). -- Keep literals verbatim; avoid paraphrasing targets/commands. - -STYLE & TONE -- Neutral, past tense, concise and factual. -- Prefer mentioning the primary target/command if present. -- "description" must be ONE sentence, ≤60 words, neutral, past tense; semicolons allowed. - -SENTENCE TEMPLATE (guideline, not to be printed verbatim) -- " ; (exit code X)" -or "Execution for on with ". - -EDGE CASES -- If the message is mostly wrapper noise (prompts, timestamps), but includes an exit code or clear success/failure line, report that. -- If logs are truncated or clipped and no conclusive line exists → (none). -- Do NOT restate next steps or advice; only the observed outcome. -""" diff --git a/init_memory_base/prompts/code_skeleton.txt b/init_memory_base/prompts/code_skeleton.txt deleted file mode 100644 index 1b9b25a..0000000 --- a/init_memory_base/prompts/code_skeleton.txt +++ /dev/null @@ -1,31 +0,0 @@ -You are an expert software engineering assistant. -Given an **observation** about a coding task and the **source code** of a function/class, -your goal is to extract a **code skeleton** that highlights only the focused regions relevant to the observation. - -### Instructions -1. Keep only the **main import statements** at the top of the file: - - Keep imports that are directly used in the focused regions. - - Keep imports that define critical dependencies (e.g., standard libraries like `os`, `sys`, or key frameworks like `torch`, `pandas`). - - Collapse all other imports into a single line: `... # other imports omitted`. -2. Keep class and function headers **with full signatures**: - - Include parameter names and types (if available). - - Include return type or return value description (infer if not explicit; keep concise, e.g., "-> bool"). -3. Preserve necessary surrounding structure for readability (class/def blocks, braces/indentation). -4. Inside classes/functions: - - Keep only lines relevant to the current observation (**focused logic** and **key invocations**). - - Replace non-relevant lines with a single line containing `...`. -5. Provide ±3 lines of context around each focused region, if available. -6. Do **not** add region markers or any textual summary. Output only the code skeleton. -7. Do not introduce new identifiers, logic, control flow, or API calls that are not present in the source. -8. Do not infer types/return types that are not explicit; if unknown, use -> Unknown or omit. -9. Output one fenced code block with correct language tag. No prose outside. - -### Input -Observation: - - -Source Code: - - -### Output (Code Skeleton Only) - diff --git a/init_memory_base/prompts/text_summary.txt b/init_memory_base/prompts/text_summary.txt deleted file mode 100644 index f37456a..0000000 --- a/init_memory_base/prompts/text_summary.txt +++ /dev/null @@ -1,31 +0,0 @@ -You are an expert software engineering assistant. -Given an **observation** about a coding task and the **source code** of a function/class, -your goal is to extract a **concise textual summary** that captures the essential intent, behavior, and role of the code in relation to the observation. - -### Instructions -1. **Focus on relevance**: Summaries must highlight aspects of the code most related to the observation (e.g., functionality, data flow, critical logic). -2. **Content to include**: - - The **purpose** of the function/class (what it does, high-level intent). - - The **inputs/outputs** (parameters, return values, side effects). - - The **key operations or algorithms** performed (control flow, API/library usage). - - Any **dependencies** or notable relationships (calls to other functions/classes). -3. **Conciseness**: - - Write in **2–4 sentences**. - - Avoid unnecessary details (variable names, trivial operations). - - Use plain, precise technical language. -4. **No speculation**: - - If the code’s intent is unclear, state it as *“uncertain”* rather than guessing. - - Do not introduce functionality that is not evident in the code. -5. **Output format**: - - One concise **paragraph of plain text only** (no code, no bullet points, no extra markup). - - Do not repeat the observation verbatim. - -### Input -Observation: - - -Source Code: - - -### Output (Textual Summary Only) - diff --git a/init_memory_base/trajectory_memory_extractor_deepseek.py b/init_memory_base/trajectory_memory_extractor_deepseek.py deleted file mode 100644 index 6bdcf12..0000000 --- a/init_memory_base/trajectory_memory_extractor_deepseek.py +++ /dev/null @@ -1,692 +0,0 @@ -import hashlib -import json -import os -import random -import re -import time -from typing import Dict, Iterable, List, Literal, Optional, Protocol - -import requests -from datasets import load_dataset -from prompt import ( - ACTION_EXTRACTION_PROMPT, - RESULT_EXTRACTION_PROMPT, - STATE_DONE_SUMMARY_PROMPT, - STATE_TODO_SYNTHESIS_PROMPT, - TASK_EXTRACTION_PROMPT, -) -from pydantic import BaseModel, Field -from tqdm import tqdm - -# ============================= -# 1) Pydantic data structures -# ============================= - -# memory_unit_schema = { -# "trajectory": { -# "source": str, -# "id": str, -# }, -# "task": { -# "issue_title": str, -# "issue_body": str, -# "issue_comments": str, -# "issue_type": str, # e.g., bug, feature, documentation, question -# "repository": str, -# }, -# "state": { -# "done": str, # summaries of the completed work -# "todo": str, # the work to be done, decided on the next action -# # "accessed_code": str, -# }, -# "action": { -# "name": str, # e.g., read_file, edit_file, run_test, invoke_tool -# "description": str, -# "target": str, # e.g., utils/math.py, tests/test_math.py, pytest, git -# "tool": str, # e.g., pytest, git, search -# }, -# "result": { -# "type": str, # success | failure | partial | unknown -# "description": str, -# "exit_code": str, -# }, -# } - - -class TrajectoryMessage(BaseModel): - content: str - role: str # "system" | "user" | "assistant" | tool etc. - # TODO: chech style of trajectory message - - -class TrajectoryMeta(BaseModel): - source: str = Field(..., description="e.g., dataset name, run id, or file path") - id: str = Field(..., description="unique id for the trajectory") - - -class Task(BaseModel): - issue_title: str - issue_body: str - issue_comments: str - issue_type: str # bug | feature | documentation | question | other - repository: str - - -class State(BaseModel): - done: str - todo: str - - -class Action(BaseModel): - name: str # read_file | edit_file | run_test | invoke_tool | etc. - description: str - target: str # utils/math.py | tests/test_math.py | pytest | git - tool: str # pytest | git | search | editor | bash - - -ResultType = Literal["success", "failure", "partial", "unknown"] - - -class Result(BaseModel): - type: ResultType = "unknown" # success | failure | partial | unknown - description: str = "" # one-sentence result summary (for human) - exit_code: Optional[int] = None # if extractable from logs - - -class MemoryUnit(BaseModel): - trajectory: TrajectoryMeta - task: Task - state: State - action: Action - result: Result - - def key(self) -> str: - # Canonical key for dedup (task+state) as requested - payload = json.dumps( - { - "task": self.task.model_dump(), - "state": self.state.model_dump(), - }, - sort_keys=True, - ) - return hashlib.sha256(payload.encode("utf-8")).hexdigest() - - -# Optional: In-memory base and protocol you can replace with your own DB/Vector store -class MemoryBaseProtocol(Protocol): - def upsert(self, units: Iterable[MemoryUnit]) -> None: ... - - -class InMemoryMemoryBase: - def __init__(self): - self._store: Dict[str, MemoryUnit] = {} - - def upsert(self, units: Iterable[MemoryUnit]) -> None: - for u in units: - self._store[u.key()] = u - - def all(self) -> List[MemoryUnit]: - return list(self._store.values()) - - -# ============================================= -# 2) DeepSeek (OpenAI-compatible) API utilities -# ============================================= - -DEEPSEEK_API_BASE = os.environ.get("DEEPSEEK_API_BASE", "https://api.deepseek.com") -DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "") -DEEPSEEK_MODEL = os.environ.get("DEEPSEEK_MODEL", "deepseek-chat") - -HEADERS = { - "Authorization": f"Bearer {DEEPSEEK_API_KEY}", - "Content-Type": "application/json", -} - - -class RateLimit(Exception): - pass - - -_SESSION: Optional[requests.Session] = None - - -def _get_session() -> requests.Session: - global _SESSION - if _SESSION is None: - s = requests.Session() - adapter = requests.adapters.HTTPAdapter( - pool_connections=10, pool_maxsize=50, max_retries=0, pool_block=True - ) - s.mount("https://", adapter) - s.mount("http://", adapter) - _SESSION = s - return _SESSION - - -def _post_with_retries( - session: requests.Session, - url: str, - payload: dict, - *, - max_retries: int = 5, - retry_base: float = 1.25, - timeout=(10, 60), -) -> str: - if not DEEPSEEK_API_KEY: - raise RuntimeError("DEEPSEEK_API_KEY is not set.") - - for attempt in range(max_retries): - try: - headers = { - "Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY', '')}", - "Content-Type": "application/json", - } - resp = session.post(url, headers=headers, json=payload, timeout=timeout) - if resp.status_code == 429: - raise RateLimit("Rate limited") - resp.raise_for_status() - data = resp.json() - content = data.get("choices", [{}])[0].get("message", {}).get("content", "") - if not content: - raise ValueError("Empty content from API") - return content - - except (requests.HTTPError, requests.ConnectionError, requests.Timeout, RateLimit) as e: - if attempt == max_retries - 1: - raise - # Optional: Rebuild Session on connection errors - if isinstance(e, (requests.ConnectionError,)): - global _SESSION - _SESSION = None - session = _get_session() - time.sleep((retry_base**attempt) + random.uniform(0, 0.25)) - - -def call_deepseek( - messages: List[Dict[str, str]], - *, - temperature: float = 0.0, - max_retries: int = 5, - retry_base: float = 1.25, - timeout: int = 60, -) -> str: - """Call DeepSeek Chat Completions endpoint and return assistant text.""" - url = f"{DEEPSEEK_API_BASE.rstrip('/')}/v1/chat/completions" - payload = { - "model": DEEPSEEK_MODEL, - "messages": messages, - "temperature": temperature, - "response_format": {"type": "json_object"}, - # Optional: Limit return length to further reduce latency - # "max_tokens": 512, - # "n": 1, # number of completions to generate - } - session = _get_session() - return _post_with_retries( - session, url, payload, max_retries=max_retries, retry_base=retry_base, timeout=timeout - ) - - -# ===================================== -# 3) Orchestration over a trajectory -# ===================================== - - -def row_to_trajectory_messages(row): - traj = row["messages"] - msgs = [] - for m in traj: - content = m.get("content", "") if isinstance(m, dict) else str(m) - role = m.get("role", "user") if isinstance(m, dict) else "user" - msgs.append(TrajectoryMessage(content=content, role=role)) - return msgs - - -def extract_memory_units_by_action_windows( - meta: TrajectoryMeta, - task: Task, - trajectory: List[TrajectoryMessage], - *, - memory_base: Optional[MemoryBaseProtocol] = None, -) -> List[MemoryUnit]: - ordered_memory_units: List[MemoryUnit] = [] - window_msgs: List[TrajectoryMessage] = [] - window_first_action: Optional[TrajectoryMessage] = None - - for msg in tqdm(trajectory, desc="Extracting memory units"): - if is_action_start(msg): - if window_msgs and window_first_action is not None: - mu = finalize_window_with_llm( - window_msgs, - window_first_action, - prior_units=ordered_memory_units, - meta=meta, - task=task, - ) - if mu: - ordered_memory_units.append(mu) - - window_msgs = [msg] - window_first_action = msg - else: - if window_msgs: - window_msgs.append(msg) - else: - continue - - if window_msgs and window_first_action is not None: - mu = finalize_window_with_llm( - window_msgs, window_first_action, prior_units=ordered_memory_units, meta=meta, task=task - ) - if mu: - ordered_memory_units.append(mu) - - if memory_base is not None and ordered_memory_units: - memory_base.upsert(ordered_memory_units) - - return ordered_memory_units - - -def is_action_start(msg: TrajectoryMessage) -> bool: - """Return True if this message starts an action call, else False.""" - if getattr(msg, "role", "") != "assistant": - return False - content = getattr(msg, "content", "") or "" - return bool(re.search(r"", content, flags=re.DOTALL)) - # TODO: content check, different trajectory style - - -def finalize_window_with_llm( - window_msgs: List[TrajectoryMessage], - first_action_msg: TrajectoryMessage, - prior_units: List[MemoryUnit], - meta: TrajectoryMeta, - task: Task, -) -> Optional[MemoryUnit]: - """ - Extract a single memory unit from a window of messages. - - Synthesize state.done from prior actions - - Synthesize state.todo by extracting intents from window_msgs - - Extract action from first_action_msg - - Extract result from the last message in window_msgs - """ - action = extract_action_from_first_action_msg(first_action_msg) - state_done = summarize_context_from_units(prior_units, task) - state_todo = synthesize_state_todo_paragraph(window_msgs, task, state_done, action) - result = extract_result_from_last_message(window_msgs, action) - return MemoryUnit( - trajectory=meta, - task=task, - state=State( - done=state_done, - todo=state_todo, - ), - action=action, - result=result, - ) - - -def extract_task_from_first_user_message( - trajectory: List[TrajectoryMessage], *, default_repository: str = "" -) -> Task: - """ - From a trajectory (list of TrajectoryMessage), take the first user message, - call DeepSeek to extract the Task fields, and return a Task. - - Requirements: - - Uses the LLM to infer fields (minimal heuristics). - - Returns empty strings when unknown. - - If `default_repository` is provided and the model leaves repository empty, - fill it with `default_repository`. - - Dependencies assumed available in your file: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - Task Pydantic model - - json imported - """ - # 1) find first user message - first_user_msg = next((m for m in trajectory if getattr(m, "role", "") == "user"), None) - if first_user_msg is None: - return Task( - issue_title="", - issue_body="", - issue_comments="", - issue_type="", - repository=default_repository, - ) - - # 2) build a schema-constrained prompt for ONLY Task fields - system_prompt = TASK_EXTRACTION_PROMPT - user_prompt = ( - "Extract the `task` object from the SINGLE user message below.\n" - "Use only the text between and .\n\n" - "\n" + first_user_msg.model_dump_json() + "\n" - ) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - # 3) call DeepSeek - raw = call_deepseek(messages, temperature=0.0) - - # 4) robust JSON parse, with a single repair attempt if needed - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fix_messages = [ - {"role": "system", "content": "Fix to valid JSON object. No comments."}, - {"role": "user", "content": raw}, - ] - fixed = call_deepseek(fix_messages, temperature=0.0) - payload = json.loads(fixed) - - # 5) normalize fields to strings (no extra heuristics) - task_obj = payload.get("task", {}) - - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - - repo = s(task_obj.get("repository")) - if not repo and default_repository: - repo = default_repository # optional fallback; minimal deviation - - return Task( - issue_title=s(task_obj.get("issue_title")), - issue_body=s(task_obj.get("issue_body")), - issue_comments=s(task_obj.get("issue_comments")), - issue_type=s(task_obj.get("issue_type")), - repository=repo, - ) - - -def summarize_context_from_units(units: List[MemoryUnit], task: Task) -> str: - """ - Summarize previous context into a concise `state.done` string by calling DeepSeek. - Uses only the last 10 memory units and asks the model to produce a single, plain-text - summary of what has ALREADY BEEN COMPLETED (no plans). - - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - MemoryUnit Pydantic model - - json imported - """ - if not units: - return "(none)" - - # Keep the window modest to control prompt size - window = units[-10:] - - # Pack minimal, evidence-bound context for the LLM (no heavy heuristics) - history = [] - for u in window: - history.append( - { - "state": { - "done": u.state.done, - "todo": u.state.todo, - }, - "action": { - "name": u.action.name, - "description": u.action.description, - "target": u.action.target, - "tool": u.action.tool, - }, - "result": { - "type": u.result.type, - "description": u.result.description, - "exit_code": u.result.exit_code, - }, - } - ) - - system_prompt = STATE_DONE_SUMMARY_PROMPT - user_prompt = ( - "Summarize ONLY what has ALREADY BEEN COMPLETED into `state.done`.\n" - "TASK (for deriving only):\n" - "\n" + task.model_dump_json() + "\n\n\n" - "PRIOR UNITS (used ONLY for completed work evidence):\n" - "\n" + json.dumps(history, ensure_ascii=False) + "\n\n\n" - "Return the final summary paragraph ONLY (no explanations)." - ) - - try: - summary = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, - ) - # Be lenient if the model wraps the text in quotes/newlines - return summary.strip().strip('"').strip() - except Exception: - # Minimal, defensive fallback (kept short); avoids complex heuristics - return window[-1].state.done - - -def synthesize_state_todo_paragraph( - window_msgs: List[TrajectoryMessage], - task: Task, - prior_done_text: str = "", - current_action: Optional[Action] = None, -) -> str: - """ - Call DeepSeek to synthesize `state.todo` as ONE highly condensed paragraph - capturing the intent of the CURRENT message window. - - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - json imported - """ - - system_prompt = STATE_TODO_SYNTHESIS_PROMPT - user_prompt = ( - "TASK (overall purpose):\n" - "\n" + task.model_dump_json() + "\n\n\n" - "PRIOR_DONE (what has already been completed; avoid repeating it):\n" - "\n" + (prior_done_text or "") + "\n\n\n" - "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n" - "\n" - + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False) - + "\n\n\n" - "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n" - "\n" + current_action.model_dump_json() + "\n\n\n" - "Return ONLY the final paragraph." - ) - - try: - text = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, - ) - return text.strip().strip('"').replace("\n", " ").strip() - except Exception: - # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation) - parts = [] - task_short = (task.issue_title or task.repository or task.issue_type or "the task").strip() - parts.append(f"To progress on {task_short},") - # Build an imperative clause from action fields - verb = (current_action.name or "execute action").replace("_", " ") - tgt = f" {current_action.target}" if current_action.target else "" - via = f" via {current_action.tool}" if current_action.tool else "" - desc = f" to {current_action.description}" if current_action.description else "" - parts.append(f" {verb}{tgt}{via}{desc}.") - return " ".join(parts).replace(" ", " ").strip() - - -def extract_action_from_first_action_msg(first_action_msg: TrajectoryMessage) -> Action: - """ - Use DeepSeek to extract the `action` object from a SINGLE assistant message - that starts an action call. Minimal heuristics; evidence-bound to the message. - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - Action Pydantic model - - json imported - """ - system_prompt = ACTION_EXTRACTION_PROMPT - user_prompt = ( - "Extract the `action` from the SINGLE assistant message below.\n" - "Use ONLY the content between and .\n\n" - "\n" + first_action_msg.model_dump_json() + "\n" - ) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - raw = call_deepseek(messages, temperature=0.0) - - # Robust JSON parse with one repair attempt - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = call_deepseek( - [ - {"role": "system", "content": "Return a valid JSON object only. No comments."}, - {"role": "user", "content": raw}, - ], - temperature=0.0, - ) - payload = json.loads(fixed) - - action_obj = payload.get("action", {}) or {} - - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - - return Action( - name=s(action_obj.get("name")), - description=s(action_obj.get("description")), - target=s(action_obj.get("target")), - tool=s(action_obj.get("tool")), - ) - - -def extract_result_from_last_message( - window_msgs: List[TrajectoryMessage], current_action: Action -) -> Result: - """ - Use DeepSeek to extract the `result` from the LAST message of a window. - - Evidence-bound to the last message ONLY. - - Output must be ONE plain-text sentence capturing the definitive outcome. - - Minimal heuristics; a tiny fallback is used only if the API fails. - - Assumes available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - json imported - """ - if not window_msgs: - return Result(type="unknown", description="(none)") - last_msg = window_msgs[-1] - - system_prompt = RESULT_EXTRACTION_PROMPT - user_prompt = ( - "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n" - "\n" + current_action.model_dump_json() + "\n\n\n" - "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n" - "\n" + last_msg.model_dump_json() + "\n\n\n" - "Return ONLY the JSON object." - ) - - try: - raw = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, - ) - # Parse JSON; if failed, try to fix it to a valid JSON - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = call_deepseek( - [ - {"role": "system", "content": "Return a valid JSON object only. No comments."}, - {"role": "user", "content": raw}, - ], - temperature=0.0, - ) - payload = json.loads(fixed) - - r = (payload or {}).get("result", {}) or {} - - # Utility: convert to string - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - - # Result type correction - r_type = s(r.get("type")).lower() - if r_type not in {"success", "failure", "partial", "unknown"}: - r_type = "unknown" - - # Text: one sentence, remove wrapping quotes - description = s(r.get("description")).strip().strip('"') - if not description: - description = "(none)" - - # exit_code: allow returning number or number string - exit_raw = r.get("exit_code") - exit_code = None - if isinstance(exit_raw, (int, float)) and str(int(exit_raw)) == str(exit_raw).split(".")[0]: - exit_code = int(exit_raw) - elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()): - exit_code = int(exit_raw.strip()) - - return Result(type=r_type, description=description, exit_code=exit_code) - - except Exception: - # Minimal fallback: only based on the last message (no LLM call) - src = (getattr(last_msg, "content", "") or "").strip() - lt = src.lower() - - if re.search(r"(error|exception|traceback|failed)", lt): - rtype = "failure" - elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt): - rtype = "success" - elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt): - rtype = "partial" - else: - rtype = "unknown" - - # Extract the last non-empty line as one sentence result, and truncate to ~60 words - lines = [ln.strip() for ln in src.splitlines() if ln.strip()] - summary = lines[-1] if lines else "(none)" - words = summary.split() - if len(words) > 60: - summary = " ".join(words[:60]) - - m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I) - code = int(m.group(1)) if m else None - - return Result(type=rtype, description=summary, exit_code=code) - - -def main(): - trajectory_repo_name = "SWE-Gym/OpenHands-SFT-Trajectories" - trajectory_split = "train.success.oss" - dataset = load_dataset(trajectory_repo_name, split=trajectory_split) - trajectories = [row_to_trajectory_messages(row) for row in dataset] - memory_base = InMemoryMemoryBase() - print(f"\nExtracting memory units for {len(trajectories)} trajectories...\n") - - for idx, trajectory in enumerate(trajectories): - print(f"\nProcessing trajectory {idx + 1} of {len(trajectories)}...\n") - meta = TrajectoryMeta(source=trajectory_repo_name, id=f"{idx}") - task = extract_task_from_first_user_message(trajectory) - units = extract_memory_units_by_action_windows( - meta, task, trajectory, memory_base=memory_base - ) - print(f"\nExtracted {len(units)} memory units. Showing JSON:\n") - print(json.dumps([u.model_dump() for u in units], ensure_ascii=False, indent=2)) - assert False - - -if __name__ == "__main__": - main() From 845c7e75a07c3b349f7c10a04521e50f977da4ac Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 17:12:05 +0800 Subject: [PATCH 8/9] feat: implement memory extraction --- athena/app/services/memory_extraction_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/athena/app/services/memory_extraction_service.py b/athena/app/services/memory_extraction_service.py index 6cba676..b6a5071 100644 --- a/athena/app/services/memory_extraction_service.py +++ b/athena/app/services/memory_extraction_service.py @@ -11,7 +11,7 @@ from athena.app.services.base_service import BaseService from athena.app.services.llm_service import LLMService -from athena.app.services.postgres_memory_store import PostgreSqlMemoryStore +from athena.app.services.memory_storage_service import MemoryStorageService from athena.models import ( Action, MemorySource, @@ -78,7 +78,7 @@ def __init__( llm_service: LLMService, batch_size: int = 100, max_retries: int = 3, - memory_store: Optional[PostgreSqlMemoryStore] = None, + memory_store: Optional[MemoryStorageService] = None, ): """ Initialize the Memory Extraction Service. From 9d17039a91b1cee7c3df4572dd3f59ffd790380f Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 17:50:26 +0800 Subject: [PATCH 9/9] feat: add embedding in MemoryUnitDB --- athena/entity/__init__.py | 3 +++ athena/entity/memory.py | 6 ++++++ athena/models/__init__.py | 2 -- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/athena/entity/__init__.py b/athena/entity/__init__.py index e69de29..f5b71d1 100644 --- a/athena/entity/__init__.py +++ b/athena/entity/__init__.py @@ -0,0 +1,3 @@ +from .memory import MemoryUnitDB + +__all__ = ["MemoryUnitDB"] diff --git a/athena/entity/memory.py b/athena/entity/memory.py index 7b9c87d..a816010 100644 --- a/athena/entity/memory.py +++ b/athena/entity/memory.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Optional +from pgvector.sqlalchemy import Vector from sqlalchemy import Column, DateTime, Text from sqlmodel import Field, SQLModel, UniqueConstraint @@ -59,6 +60,11 @@ class MemoryUnitDB(SQLModel, table=True): result_description: str = Field(sa_column=Column(Text)) result_exit_code: Optional[int] = None + # Embeddings for semantic retrieval (optional, pgvector recommended). Stored as float arrays + task_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) + state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) + task_state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) + @classmethod def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": """Create a database model from a MemoryUnit.""" diff --git a/athena/models/__init__.py b/athena/models/__init__.py index 07f5255..5228135 100644 --- a/athena/models/__init__.py +++ b/athena/models/__init__.py @@ -3,7 +3,6 @@ MemorySource, MemoryTimestamp, MemoryUnit, - MemoryUnitDB, Result, State, Task, @@ -19,5 +18,4 @@ "State", "Action", "Result", - "MemoryUnitDB", ]