diff --git a/athena/app/services/llm_service.py b/athena/app/services/llm_service.py
index 6c28d5a..fe39fbd 100644
--- a/athena/app/services/llm_service.py
+++ b/athena/app/services/llm_service.py
@@ -3,6 +3,7 @@
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
+from langchain_google_vertexai import ChatVertexAI
from athena.app.services.base_service import BaseService
from athena.chat_models.custom_chat_openai import CustomChatOpenAI
@@ -50,6 +51,17 @@ def get_model(
max_tokens_to_sample=max_output_tokens,
max_retries=3,
)
+ elif model_name.startswith("vertex:"):
+ # example: model_name="vertex:gemini-2.5-pro"
+ vertex_model = model_name.split("vertex:", 1)[1]
+ return ChatVertexAI(
+ model_name=vertex_model,
+ project="prometheus-code-agent",
+ location="us-central1",
+ temperature=temperature,
+ max_output_tokens=max_output_tokens,
+ max_retries=3,
+ )
elif "gemini" in model_name:
return ChatGoogleGenerativeAI(
model=model_name,
diff --git a/athena/app/services/memory_extraction_service.py b/athena/app/services/memory_extraction_service.py
new file mode 100644
index 0000000..b6a5071
--- /dev/null
+++ b/athena/app/services/memory_extraction_service.py
@@ -0,0 +1,882 @@
+import asyncio
+import html
+import json
+import re
+import uuid
+from datetime import datetime, timezone
+from typing import Any, Dict, List, Optional
+
+from datasets import load_dataset
+from tqdm import tqdm
+
+from athena.app.services.base_service import BaseService
+from athena.app.services.llm_service import LLMService
+from athena.app.services.memory_storage_service import MemoryStorageService
+from athena.models import (
+ Action,
+ MemorySource,
+ MemoryTimestamp,
+ MemoryUnit,
+ Message,
+ Result,
+ State,
+ Task,
+)
+from athena.prompts.memory_extraction import (
+ ACTION_EXTRACTION_PROMPT,
+ ACTION_JUDGE_PROMPT,
+ RESULT_EXTRACTION_PROMPT,
+ STATE_DONE_SUMMARY_PROMPT,
+ STATE_TODO_SYNTHESIS_PROMPT,
+ TASK_EXTRACTION_PROMPT,
+)
+
+# from athena.utils.logger_manager import get_logger
+
+
+class ExtractionError(Exception):
+ """Exception raised when memory extraction fails."""
+
+ def __init__(self, source_name: str, run_id: str):
+ self.source_name = source_name
+ self.run_id = run_id
+ super().__init__(f"Failed to extract memory units from {run_id} in {source_name}")
+
+
+class MemoryExtractionService(BaseService):
+ """
+ Service for extracting structured memory units from interaction trajectories.
+
+ This service orchestrates the extraction of memory units from various data
+ sources including conversation logs, agent execution traces, and user
+ interaction histories. It provides a unified interface for converting
+ unstructured interaction data into structured memory units that can be
+ stored and queried by the Athena memory system.
+
+ Key Features:
+ - Multi-source trajectory loading (datasets, files, APIs)
+ - Configurable extraction strategies (LLM-based, rule-based, hybrid)
+ - Batch processing with progress tracking
+ - Error handling and retry mechanisms
+ - Integration with existing memory storage systems
+
+ Architecture:
+ The service follows a pipeline pattern:
+ 1. Data Source -> Load trajectories
+ 2. Extraction Strategy -> Extract components (task, action, result, state)
+ 3. Memory Unit Assembly -> Combine components into MemoryUnit
+ 4. Validation -> Ensure data quality and consistency
+ 5. Storage -> Persist to memory system
+
+ Usage:
+ service = MemoryExtractionService(llm_service)
+ memory_units = service.extract_from_trajectories(trajectory_data)
+ """
+
+ def __init__(
+ self,
+ llm_service: LLMService,
+ batch_size: int = 100,
+ max_retries: int = 3,
+ memory_store: Optional[MemoryStorageService] = None,
+ ):
+ """
+ Initialize the Memory Extraction Service.
+
+ Args:
+ llm_service: LLM service for AI-powered extraction
+ batch_size: Number of trajectories to process in each batch
+ max_retries: Maximum retry attempts for failed extractions
+ """
+ self.llm_service = llm_service
+ self.batch_size = batch_size
+ self.max_retries = max_retries
+ self._extraction_cache: Dict[str, MemoryUnit] = {}
+ self.memory_store = memory_store
+ # self._logger = get_logger(__name__)
+
+ def start(self):
+ """Start the memory extraction service."""
+ # Initialize any required resources
+ pass
+
+ def close(self):
+ """Close the memory extraction service and cleanup resources."""
+ self._extraction_cache.clear()
+
+ def extract_from_huggingface_trajectory_repository( # TODO: batch extraction
+ self, repo_name: str, split: str
+ ) -> List[MemoryUnit]:
+ """
+ Extract memory units from a HuggingFace trajectory repository.
+
+ Args:
+ repo_name: Name of the HuggingFace trajectory repository
+ split: Split of the HuggingFace trajectory repository
+
+ Returns:
+ List of extracted memory units
+
+ Raises:
+ ExtractionError: If extraction fails for critical trajectories
+ """
+ dataset = load_dataset(repo_name, split=split)
+
+ traj_keys = ["messages", "history", "trajectory", "chat", "conversation", "dialog"]
+ run_id_keys = ["run_id", "traj_id", "id"]
+ inst_id_keys = ["instance_id"]
+ model_keys = ["model", "model_name"]
+ resolved_keys = [
+ "resolved",
+ "target",
+ ] # bool value, whether the model solved the issue in this trajectory.
+
+ def _pick(d: Dict[str, Any], keys) -> Optional[Any]:
+ for k in keys:
+ if k in d:
+ return d[k]
+ return None
+
+ for idx in range(len(dataset)):
+ row = dataset[idx]
+ trajectory = _pick(row, traj_keys)
+ messages: List[Message] = []
+ for m in trajectory:
+ content = str(
+ m.get("content")
+ or m.get("text")
+ or m.get("message")
+ or m.get("utterance")
+ or ""
+ )
+ role = str(m.get("role", "unknown"))
+ metadata = {
+ k: v
+ for k, v in m.items()
+ if k not in {"content", "role", "text", "message", "utterance"}
+ }
+ messages.append(Message(content=content, role=role, metadata=metadata))
+
+ run_id = _pick(row, run_id_keys) or f"{idx}"
+ instance_id = _pick(row, inst_id_keys)
+ model = _pick(row, model_keys)
+ resolved = _pick(
+ row, resolved_keys
+ ) # TODO: use resolved to filter memory units to successful/failed ones.
+ memory_source = self._extract_memory_source(
+ repo_name,
+ run_id,
+ metadata={
+ **({"instance_id": instance_id} if instance_id is not None else {}),
+ **({"model": model} if model is not None else {}),
+ **({"resolved": resolved} if resolved is not None else {}),
+ },
+ )
+ try:
+ memory_units = self._extract_memory_units_by_action_windows(messages, memory_source)
+ except ExtractionError:
+ # self._logger.error(f"Failed to extract memory units from {run_id} in {repo_name}: {e}") # TODO: add logger
+ continue
+
+ self._extraction_cache.update(
+ {mu.memory_id: mu for mu in memory_units}
+ ) # TODO: use PostgreSqlMemoryStore to store memory units.
+ return list(self._extraction_cache.values())
+
+ def _extract_memory_source(
+ self, source: str, run_id: str, metadata: Optional[Dict[str, Any]] = None
+ ) -> MemorySource:
+ """
+ Extract memory source for a trajectory.
+
+ Args:
+ source: Data source identifier
+ run_id: Unique run identifier
+ metadata: Optional additional metadata
+
+ Returns:
+ MemorySource object
+ """
+ return MemorySource(
+ source_name=source,
+ run_id=run_id,
+ metadata=metadata or {},
+ )
+
+ def _extract_memory_units_by_action_windows(
+ self,
+ messages: List[Message],
+ memory_source: MemorySource,
+ ) -> List[MemoryUnit]:
+ ordered_memory_units: List[MemoryUnit] = []
+ window_msgs: List[Message] = []
+ window_first_action: Optional[Message] = None
+
+ task = self._extract_task_from_messages(messages)
+
+ for msg in tqdm(
+ messages,
+ desc=f"Extracting memory units for {memory_source.run_id} in {memory_source.source_name}",
+ ):
+ if self._is_action_message(msg):
+ if window_msgs and window_first_action is not None:
+ mu = self._create_memory_unit(
+ memory_source,
+ task,
+ window_msgs,
+ ordered_memory_units,
+ )
+ if mu:
+ ordered_memory_units.append(mu)
+
+ window_msgs = [msg]
+ window_first_action = msg
+ else:
+ if window_msgs:
+ window_msgs.append(msg)
+ else:
+ continue
+
+ if window_msgs and window_first_action is not None:
+ mu = self._create_memory_unit(
+ memory_source,
+ task,
+ window_msgs,
+ ordered_memory_units,
+ )
+ if mu:
+ ordered_memory_units.append(mu)
+
+ if self.memory_store is not None and ordered_memory_units: # TODO: refine memory store
+ # Fire-and-forget is acceptable in most flows; caller may await if desired
+ if asyncio.iscoroutinefunction(self.memory_store.upsert):
+ asyncio.create_task(self.memory_store.upsert(ordered_memory_units))
+ else:
+ self.memory_store.upsert(ordered_memory_units) # type: ignore[arg-type]
+
+ return ordered_memory_units
+
+ def _is_action_message(self, message: Message) -> bool:
+ """
+ Determine if a message represents the start of an action.
+
+ Args:
+ message: Message to analyze
+
+ Returns:
+ True if message starts an action, False otherwise
+ """
+ if getattr(message, "role", "") != "assistant":
+ return False
+ system_prompt = ACTION_JUDGE_PROMPT
+ user_prompt = (
+ "Decide if the following assistant message represents a concrete ACTION.\n"
+ "Use ONLY the content between and as input.\n"
+ "Output exactly one token: True or False.\n\n"
+ "\n" + message.model_dump_json() + "\n"
+ )
+ raw = self.llm_service.model.invoke(
+ [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
+ ).content
+ return self._normalize_llm_bool_output(raw)
+
+ def _create_memory_unit(
+ self,
+ source: MemorySource,
+ task: Task,
+ window_msgs: List[Message],
+ prior_units: List[MemoryUnit],
+ ) -> MemoryUnit:
+ """
+ Extract a single memory unit from a window of messages.
+ - Synthesize state.done from prior actions
+ - Synthesize state.todo by extracting intents from window_msgs
+ - Extract action from first_action_msg
+ - Extract result from the last message in window_msgs
+ """
+ action = self._extract_action_from_message(window_msgs[0])
+ state_done = self._synthesize_state_done_from_context(prior_units, task)
+ state_todo = self._synthesize_state_todo_from_window(window_msgs, task, state_done, action)
+ result = self._extract_result_from_window(window_msgs, action)
+ return MemoryUnit(
+ memory_id=str(uuid.uuid4()),
+ timestamp=MemoryTimestamp(
+ created_at=datetime.now(timezone.utc),
+ updated_at=None,
+ invalid_at=None,
+ ),
+ source=source,
+ task=task,
+ state=State(
+ done=state_done,
+ todo=state_todo,
+ ),
+ action=action,
+ result=result,
+ )
+
+ def _extract_task_from_messages(self, messages: List[Message]) -> Task:
+ """
+ Extract task information from trajectory messages.
+
+ This method identifies the first user message and extracts:
+ - Issue title and description
+ - Issue type classification
+ - Repository information
+ - Related comments or context
+
+ Args:
+ messages: List of messages in the trajectory
+
+ Returns:
+ Task object with extracted information
+ """
+ # 1) find first user message
+ first_user_msg = next((m for m in messages if getattr(m, "role", "") == "user"), None)
+ if first_user_msg is None:
+ return Task(
+ issue_title="",
+ issue_body="",
+ issue_comments="",
+ issue_type="",
+ repository="",
+ )
+
+ # 2) build a schema-constrained prompt for ONLY Task fields
+ system_prompt = TASK_EXTRACTION_PROMPT
+ user_prompt = (
+ "Extract the `task` object from the SINGLE user message below.\n"
+ "Use only the text between and .\n\n"
+ "\n" + first_user_msg.model_dump_json() + "\n"
+ )
+
+ # 3) call LLM service
+ raw = self.llm_service.model.invoke(
+ [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
+ ).content
+
+ # 4) robust JSON parse, with a single repair attempt if needed
+ payload = self._normalize_llm_json_output(raw)
+
+ # 5) normalize fields to strings (no extra heuristics)
+ task_obj = payload.get("task", {})
+
+ def s(x):
+ return x if isinstance(x, str) else ("" if x is None else str(x))
+
+ repo = s(task_obj.get("repository"))
+ if not repo:
+ repo = ""
+
+ return Task(
+ issue_title=s(task_obj.get("issue_title")),
+ issue_body=s(task_obj.get("issue_body")),
+ issue_comments=s(task_obj.get("issue_comments")),
+ issue_type=s(task_obj.get("issue_type")),
+ repository=repo,
+ )
+
+ def _synthesize_state_from_context(
+ self,
+ prior_units: List[MemoryUnit],
+ task: Task,
+ window_msgs: List[Message],
+ current_action: Optional[Action] = None,
+ ) -> State:
+ """
+ Synthesize state information from context.
+ """
+ state_done = self._synthesize_state_done_from_context(prior_units, task)
+ state_todo = self._synthesize_state_todo_from_window(
+ window_msgs, task, state_done, current_action
+ )
+ return State(done=state_done, todo=state_todo) # TODO: open_file, working_dir
+
+ def _synthesize_state_done_from_context(self, prior_units: List[MemoryUnit], task: Task) -> str:
+ """
+ Summarize previous context into a concise `state.done` string.
+ Summary of what has ALREADY BEEN COMPLETED (no plans).
+ """
+ if not prior_units:
+ return "(none)"
+
+ # Keep the window modest to control prompt size
+ window = prior_units[-1:] # state typically depends only on the last state in an agent run
+
+ # Pack minimal, evidence-bound context for the LLM (no heavy heuristics)
+ history = []
+ for u in window:
+ history.append(
+ {
+ "state": {
+ "done": u.state.done,
+ "todo": u.state.todo,
+ },
+ "action": {
+ "name": u.action.name,
+ "description": u.action.description,
+ "target": u.action.target,
+ "tool": u.action.tool,
+ },
+ "result": {
+ "type": u.result.type,
+ "description": u.result.description,
+ "exit_code": u.result.exit_code,
+ },
+ }
+ )
+
+ system_prompt = STATE_DONE_SUMMARY_PROMPT
+ user_prompt = (
+ "Summarize ONLY what has ALREADY BEEN COMPLETED into `state.done`.\n"
+ "TASK (for deriving only):\n"
+ "\n" + task.model_dump_json() + "\n\n\n"
+ "PRIOR UNITS (used ONLY for completed work evidence):\n"
+ "\n" + json.dumps(history, ensure_ascii=False) + "\n\n\n"
+ "Return the final summary paragraph ONLY (no explanations)."
+ )
+
+ try:
+ summary = self.llm_service.model.invoke(
+ [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
+ ).content
+ # Be lenient if the model wraps the text in quotes/newlines
+ return summary.strip().strip('"').strip()
+ except Exception:
+ # Minimal, defensive fallback (kept short); avoids complex heuristics
+ return window[-1].state.done
+
+ def _synthesize_state_todo_from_window(
+ self,
+ window_msgs: List[Message],
+ task: Task,
+ prior_done_text: str = "",
+ current_action: Optional[Action] = None,
+ ) -> str:
+ """
+ Synthesize state.todo from the window of messages.
+ """
+ system_prompt = STATE_TODO_SYNTHESIS_PROMPT
+ user_prompt = (
+ "TASK (overall purpose):\n"
+ "\n" + task.model_dump_json() + "\n\n\n"
+ "PRIOR_DONE (what has already been completed; avoid repeating it):\n"
+ "\n" + (prior_done_text or "") + "\n\n\n"
+ "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n"
+ "\n"
+ + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False)
+ + "\n\n\n"
+ "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n"
+ "\n" + current_action.model_dump_json() + "\n\n\n"
+ "Return ONLY the final paragraph."
+ )
+
+ try:
+ text = self.llm_service.model.invoke(
+ [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
+ ).content
+ return text.strip().strip('"').replace("\n", " ").strip()
+ except Exception:
+ # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation)
+ parts = []
+ task_short = (
+ task.issue_title or task.repository or task.issue_type or "the task"
+ ).strip()
+ parts.append(f"To progress on {task_short},")
+ # Build an imperative clause from action fields
+ verb = (current_action.name or "execute action").replace("_", " ")
+ tgt = f" {current_action.target}" if current_action.target else ""
+ via = f" via {current_action.tool}" if current_action.tool else ""
+ desc = f" to {current_action.description}" if current_action.description else ""
+ parts.append(f" {verb}{tgt}{via}{desc}.")
+ return " ".join(parts).replace(" ", " ").strip()
+
+ def _extract_action_from_message(self, message: Message) -> Action:
+ """
+ Extract action information from an assistant message that starts an action call.
+
+ This method analyzes assistant messages to identify:
+ - Action type (read_file, edit_file, run_test, etc.)
+ - Target file or resource
+ - Tool being used
+ - Action description
+
+ Args:
+ message: Assistant message containing action information
+
+ Returns:
+ Action object with extracted information
+ """
+ system_prompt = ACTION_EXTRACTION_PROMPT
+ user_prompt = (
+ "Extract the `action` from the SINGLE assistant message below.\n"
+ "Use ONLY the content between and .\n\n"
+ "\n" + message.model_dump_json() + "\n"
+ )
+
+ raw = self.llm_service.model.invoke(
+ [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
+ ).content
+
+ # Robust JSON parse with one repair attempt
+ payload = self._normalize_llm_json_output(raw)
+
+ action_obj = payload.get("action", {}) or {}
+
+ def s(x):
+ return x if isinstance(x, str) else ("" if x is None else str(x))
+
+ return Action(
+ name=s(action_obj.get("name")),
+ description=s(action_obj.get("description")),
+ target=s(action_obj.get("target")),
+ tool=s(action_obj.get("tool")),
+ )
+
+ def _extract_result_from_window(
+ self, window_msgs: List[Message], current_action: Action
+ ) -> Result:
+ """
+ Extract result information from message sequence.
+
+ This method analyzes the outcome of an action by examining:
+ - Success/failure indicators
+ - Error messages or exceptions
+ - Exit codes
+ - Output descriptions
+
+ Args:
+ messages: Messages containing result information
+ action: The action that was executed
+
+ Returns:
+ Result object with outcome information
+ """
+ last_msg = window_msgs[-1]
+
+ system_prompt = RESULT_EXTRACTION_PROMPT
+ user_prompt = (
+ "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n"
+ "\n" + current_action.model_dump_json() + "\n\n\n"
+ "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n"
+ "\n" + last_msg.model_dump_json() + "\n\n\n"
+ "Return ONLY the JSON object."
+ )
+
+ try:
+ raw = self.llm_service.model.invoke(
+ [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
+ ).content
+ payload = self._normalize_llm_json_output(raw)
+
+ r = (payload or {}).get("result", {}) or {}
+
+ # Utility: convert to string
+ def s(x):
+ return x if isinstance(x, str) else ("" if x is None else str(x))
+
+ # Result type correction
+ r_type = s(r.get("type")).lower()
+ if r_type not in {"success", "failure", "partial", "unknown"}:
+ r_type = "unknown"
+
+ # Text: one sentence, remove wrapping quotes
+ description = s(r.get("description")).strip().strip('"')
+ if not description:
+ description = "(none)"
+
+ # exit_code: allow returning number or number string
+ exit_raw = r.get("exit_code")
+ exit_code = None
+ if (
+ isinstance(exit_raw, (int, float))
+ and str(int(exit_raw)) == str(exit_raw).split(".")[0]
+ ):
+ exit_code = int(exit_raw)
+ elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()):
+ exit_code = int(exit_raw.strip())
+
+ return Result(type=r_type, description=description, exit_code=exit_code)
+
+ except Exception:
+ # Minimal fallback: only based on the last message (no LLM call)
+ src = (getattr(last_msg, "content", "") or "").strip()
+ lt = src.lower()
+
+ if re.search(r"(error|exception|traceback|failed)", lt):
+ rtype = "failure"
+ elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt):
+ rtype = "success"
+ elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt):
+ rtype = "partial"
+ else:
+ rtype = "unknown"
+
+ # Extract the last non-empty line as one sentence result, and truncate to ~60 words
+ lines = [ln.strip() for ln in src.splitlines() if ln.strip()]
+ summary = lines[-1] if lines else "(none)"
+ words = summary.split()
+ if len(words) > 60:
+ summary = " ".join(words[:60])
+
+ m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I)
+ code = int(m.group(1)) if m else None
+
+ return Result(type=rtype, description=summary, exit_code=code)
+
+ def _normalize_llm_bool_output(self, raw: str) -> bool:
+ """
+ Normalize raw LLM output into a boolean.
+ Accepts variations like True/False, yes/no, 1/0,
+ possibly wrapped in quotes, code fences, or punctuation.
+ """
+ text = raw.strip().lower()
+
+ # remove common wrappers
+ text = text.strip(" \t\n\r\"'`")
+ if text.startswith("```"):
+ text = text.strip("`").strip()
+
+ # direct boolean keywords
+ if text in {"true", "yes", "y", "1"}:
+ return True
+ if text in {"false", "no", "n", "0"}:
+ return False
+
+ # fallback: fuzzy match
+ if text.startswith("true"):
+ return True
+ if text.startswith("false"):
+ return False
+
+ return False
+
+ def _normalize_llm_json_output(self, raw: str) -> Dict[str, Any]:
+ """
+ Normalize raw LLM output into a JSON object (dict-like).
+ Single-argument API. All robustness heuristics are internal.
+
+ Strategy (built-in, no extra args):
+ - Clean BOM/zero-width chars/HTML entities
+ - Strip markdown code fences (``` / ```json)
+ - Try direct json.loads
+ - If it is a quoted JSON string, unquote then parse
+ - Extract the first balanced JSON object/array (quote-aware)
+ - Apply limited trailing comma fixes
+ - Minify whitespace and retry
+ - Optionally do ONE LLM repair round if self.llm_service is available
+ - Always return a dict; non-dict top-level will be wrapped as {"_root": value}
+ """
+ EXPECT_OBJECT = True # always return Dict[str, Any]
+ EXPECT_KEY: Optional[str] = None # no required top-level key
+ MAX_SCAN_CHARS = 200_000 # cap scanning for safety
+ ALLOW_REPAIR_WITH_LLM = True # try one LLM repair if available
+
+ if not isinstance(raw, str):
+ raise ValueError("LLM output is not a string")
+ text = raw.strip()
+ if not text:
+ raise ValueError("Empty LLM output")
+
+ # Hygiene
+ text = self._strip_bom_zwsp(text)
+ text = html.unescape(text)
+ text = self._strip_code_fences(text).strip()
+
+ # 1) direct parse
+ obj = self._try_parse_direct(text, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY)
+ if obj is not None:
+ return obj
+
+ # 2) quoted JSON -> unquote then parse
+ obj = self._try_unquote_then_parse(text, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY)
+ if obj is not None:
+ return obj
+
+ # 3) extract first balanced JSON block (object/array), quote-aware
+ candidate = self._extract_first_json_balanced(text[:MAX_SCAN_CHARS])
+ if candidate:
+ obj = self._try_parse_direct(
+ candidate, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY
+ )
+ if obj is not None:
+ return obj
+ fixed = self._fix_trailing_commas(candidate)
+ obj = self._try_parse_direct(fixed, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY)
+ if obj is not None:
+ return obj
+
+ # 4) minify whitespace and retry
+ compact = self._minify_ws(candidate or text)
+ obj = self._try_parse_direct(compact, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY)
+ if obj is not None:
+ return obj
+
+ # 5) optional single LLM repair (if service available)
+ if ALLOW_REPAIR_WITH_LLM and getattr(self, "llm_service", None) is not None:
+ repaired = self._repair_json_with_llm(text, expect_key=EXPECT_KEY)
+ obj = self._try_parse_direct(
+ repaired, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY
+ )
+ if obj is not None:
+ return obj
+
+ snippet = (text[:800] + "...") if len(text) > 800 else text
+ raise ValueError(f"Failed to parse JSON from LLM output. Snippet: {snippet}")
+
+ def _try_parse_direct(
+ self, s: str, *, expect_object: bool, expect_key: Optional[str]
+ ) -> Optional[Dict[str, Any]]:
+ """Try json.loads; coerce non-dict to {'_root': value}; enforce expect_key if given."""
+ try:
+ val = json.loads(s)
+ except Exception:
+ return None
+ if expect_object and not isinstance(val, dict):
+ val = {"_root": val}
+ elif not isinstance(val, dict):
+ val = {"_root": val}
+ if expect_key and expect_key not in val:
+ return None
+ return val
+
+ def _try_unquote_then_parse(
+ self, s: str, *, expect_object: bool, expect_key: Optional[str]
+ ) -> Optional[Dict[str, Any]]:
+ """If s is a JSON string containing JSON (e.g., '"{...}"'), unquote then parse."""
+ try:
+ inner = json.loads(s)
+ except Exception:
+ return None
+ if isinstance(inner, str):
+ return self._try_parse_direct(inner, expect_object=expect_object, expect_key=expect_key)
+ if expect_object and not isinstance(inner, dict):
+ inner = {"_root": inner}
+ elif not isinstance(inner, dict):
+ inner = {"_root": inner}
+ if expect_key and expect_key not in inner:
+ return None
+ return inner
+
+ def _strip_code_fences(self, s: str) -> str:
+ """Remove markdown code fences and leading 'json' language tag."""
+ t = s.strip()
+ if t.startswith("```"):
+ t = re.sub(r"^\s*```(?:json)?\s*", "", t, flags=re.I)
+ t = re.sub(r"\s*```\s*$", "", t)
+ t = re.sub(r"```(?:json)?\s*([\s\S]*?)\s*```", r"\1", t, flags=re.I)
+ return t
+
+ def _strip_bom_zwsp(self, s: str) -> str:
+ """Remove BOM and zero-width characters that can break parsers."""
+ s = s.lstrip("\ufeff")
+ return re.sub(r"[\u200B-\u200D\u2060\uFEFF]", "", s)
+
+ def _minify_ws(self, s: str) -> str:
+ """Collapse excessive newlines and surrounding whitespace (keeps spaces inside strings)."""
+ return re.sub(r"\s*\n\s*", " ", s).strip()
+
+ def _fix_trailing_commas(self, s: str) -> str:
+ """Limited safe fix: ',}' -> '}' and ',]' -> ']'."""
+ t = re.sub(r",\s*}", "}", s)
+ t = re.sub(r",\s*]", "]", t)
+ return t
+
+ def _extract_first_json_balanced(self, text: str) -> str:
+ """Find the first balanced {...} or [...] block (quote/escape-aware)."""
+ for opener, closer in (("{", "}"), ("[", "]")):
+ block = self._scan_balanced(text, opener, closer)
+ if block:
+ return block
+ return ""
+
+ def _scan_balanced(self, t: str, opener: str, closer: str) -> str:
+ """Quote/escape-aware balanced scanner to locate a JSON block."""
+ start = t.find(opener)
+ if start == -1:
+ return ""
+ depth = 0
+ in_str = False
+ esc = False
+ quote_char = ""
+ for i in range(start, len(t)):
+ ch = t[i]
+ if in_str:
+ if esc:
+ esc = False
+ elif ch == "\\":
+ esc = True
+ elif ch == quote_char:
+ in_str = False
+ continue
+ if ch in ('"', "'"):
+ in_str = True
+ quote_char = ch
+ esc = False
+ continue
+ if ch == opener:
+ depth += 1
+ elif ch == closer:
+ depth -= 1
+ if depth == 0:
+ return t[start : i + 1]
+ return ""
+
+ def _repair_json_with_llm(self, raw: str, expect_key: Optional[str]) -> str:
+ """One-shot LLM repair to return MINIFIED valid JSON only (no prose/fences)."""
+ constraint = f" It MUST contain the top-level key '{expect_key}'." if expect_key else ""
+ try:
+ repaired = self.llm_service.model.invoke(
+ [
+ {
+ "role": "system",
+ "content": (
+ "You are a JSON normalizer. "
+ "Return ONLY a valid MINIFIED JSON object. "
+ "No explanations, no code fences, no comments." + constraint
+ ),
+ },
+ {"role": "user", "content": raw},
+ ]
+ )
+ if isinstance(repaired, str):
+ return repaired.strip()
+ content = getattr(repaired, "content", None)
+ if isinstance(content, str):
+ return content.strip()
+ return str(repaired).strip()
+ except Exception:
+ return raw
+
+
+if __name__ == "__main__":
+ service = MemoryExtractionService(
+ llm_service=LLMService(
+ model_name="vertex:gemini-2.5-flash",
+ model_temperature=0.0,
+ model_max_input_tokens=8192,
+ model_max_output_tokens=8192,
+ )
+ )
+ service.extract_from_huggingface_trajectory_repository(
+ repo_name="SWE-Gym/OpenHands-SFT-Trajectories",
+ split="train.success.oss",
+ )
diff --git a/athena/configuration/config.py b/athena/configuration/config.py
index 62366ed..7c169ac 100644
--- a/athena/configuration/config.py
+++ b/athena/configuration/config.py
@@ -40,5 +40,11 @@ class Settings(BaseSettings):
MEMORY_MAX_STORED_UNITS: int = 1000 # Maximum number of memory units to store
MEMORY_SEARCH_LIMIT: int = 10 # Default search result limit
+ # Embeddings (OpenAI-format, works with Codestral embed via a compatible gateway)
+ EMBEDDINGS_MODEL: Optional[str] = None
+ EMBEDDINGS_API_KEY: Optional[str] = None
+ EMBEDDINGS_BASE_URL: Optional[str] = None
+ EMBEDDINGS_DIM: Optional[int] = None
+
settings = Settings()
diff --git a/athena/entity/__init__.py b/athena/entity/__init__.py
new file mode 100644
index 0000000..f5b71d1
--- /dev/null
+++ b/athena/entity/__init__.py
@@ -0,0 +1,3 @@
+from .memory import MemoryUnitDB
+
+__all__ = ["MemoryUnitDB"]
diff --git a/athena/entity/memory.py b/athena/entity/memory.py
new file mode 100644
index 0000000..a816010
--- /dev/null
+++ b/athena/entity/memory.py
@@ -0,0 +1,147 @@
+import json
+from datetime import datetime
+from typing import Optional
+
+from pgvector.sqlalchemy import Vector
+from sqlalchemy import Column, DateTime, Text
+from sqlmodel import Field, SQLModel, UniqueConstraint
+
+from athena.models import Action, MemorySource, MemoryTimestamp, MemoryUnit, Result, State, Task
+
+
+# Database models for persistent storage
+class MemoryUnitDB(SQLModel, table=True):
+ """Database model for persistent storage of memory units."""
+
+ __tablename__ = "memory_units"
+ __table_args__ = (UniqueConstraint("memory_id", name="uq_memory_id"),)
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+ # Source information
+ memory_id: str
+ memory_source_name: str
+ memory_run_id: str
+ memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
+ memory_updated_at: Optional[datetime] = Field(
+ default=None, sa_column=Column(DateTime(timezone=True))
+ )
+ memory_invalid_at: Optional[datetime] = Field(
+ default=None, sa_column=Column(DateTime(timezone=True))
+ )
+ memory_metadata: str = Field(
+ default="{}", sa_column=Column(Text)
+ ) # JSON string for Dict[str, Any]
+
+ # Task information
+ task_issue_title: str = Field(sa_column=Column(Text))
+ task_issue_body: str = Field(sa_column=Column(Text))
+ task_issue_comments: str = Field(sa_column=Column(Text))
+ task_issue_type: str
+ task_repository: str
+
+ # State information
+ state_done: str = Field(sa_column=Column(Text))
+ state_todo: str = Field(sa_column=Column(Text))
+ state_open_file: Optional[str] = None
+ state_working_dir: Optional[str] = None
+ state_extra_environment: str = Field(
+ default="{}", sa_column=Column(Text)
+ ) # JSON string for Dict[str, str]
+
+ # Action information
+ action_name: str
+ action_description: str = Field(sa_column=Column(Text))
+ action_target: str
+ action_tool: str
+
+ # Result information
+ result_type: str
+ result_description: str = Field(sa_column=Column(Text))
+ result_exit_code: Optional[int] = None
+
+ # Embeddings for semantic retrieval (optional, pgvector recommended). Stored as float arrays
+ task_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector()))
+ state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector()))
+ task_state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector()))
+
+ @classmethod
+ def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB":
+ """Create a database model from a MemoryUnit."""
+ return cls(
+ memory_id=memory_unit.memory_id,
+ memory_source_name=memory_unit.source.source_name,
+ memory_run_id=memory_unit.source.run_id,
+ memory_created_at=memory_unit.timestamp.created_at,
+ memory_updated_at=memory_unit.timestamp.updated_at,
+ memory_invalid_at=memory_unit.timestamp.invalid_at,
+ memory_metadata=json.dumps(memory_unit.source.metadata)
+ if memory_unit.source.metadata
+ else "{}",
+ task_issue_title=memory_unit.task.issue_title,
+ task_issue_body=memory_unit.task.issue_body,
+ task_issue_comments=memory_unit.task.issue_comments,
+ task_issue_type=memory_unit.task.issue_type,
+ task_repository=memory_unit.task.repository,
+ state_done=memory_unit.state.done,
+ state_todo=memory_unit.state.todo,
+ state_open_file=memory_unit.state.open_file,
+ state_working_dir=memory_unit.state.working_dir,
+ state_extra_environment=json.dumps(memory_unit.state.extra_environment)
+ if memory_unit.state.extra_environment
+ else "{}",
+ action_name=memory_unit.action.name,
+ action_description=memory_unit.action.description,
+ action_target=memory_unit.action.target,
+ action_tool=memory_unit.action.tool,
+ result_type=memory_unit.result.type,
+ result_description=memory_unit.result.description,
+ result_exit_code=memory_unit.result.exit_code,
+ )
+
+ def to_memory_unit(self) -> MemoryUnit:
+ """Convert database model back to MemoryUnit."""
+ return MemoryUnit(
+ memory_id=self.memory_id,
+ timestamp=MemoryTimestamp(
+ created_at=self.memory_created_at,
+ updated_at=self.memory_updated_at,
+ invalid_at=self.memory_invalid_at,
+ ),
+ source=MemorySource(
+ source_name=self.memory_source_name,
+ run_id=self.memory_run_id,
+ metadata=json.loads(self.memory_metadata)
+ if self.memory_metadata not in (None, "", "null")
+ else {},
+ ),
+ task=Task(
+ issue_title=self.task_issue_title,
+ issue_body=self.task_issue_body,
+ issue_comments=self.task_issue_comments,
+ issue_type=self.task_issue_type,
+ repository=self.task_repository,
+ ),
+ state=State(
+ done=self.state_done,
+ todo=self.state_todo,
+ open_file=self.state_open_file,
+ working_dir=self.state_working_dir,
+ extra_environment=json.loads(self.state_extra_environment)
+ if self.state_extra_environment not in (None, "", "null")
+ else {},
+ ),
+ action=Action(
+ name=self.action_name,
+ description=self.action_description,
+ target=self.action_target,
+ tool=self.action_tool,
+ ),
+ result=Result(
+ type=self.result_type
+ if self.result_type in ["success", "failure", "partial", "unknown"]
+ else "unknown",
+ description=self.result_description,
+ exit_code=self.result_exit_code,
+ ),
+ )
diff --git a/athena/models/__init__.py b/athena/models/__init__.py
index 7d32594..5228135 100644
--- a/athena/models/__init__.py
+++ b/athena/models/__init__.py
@@ -1,9 +1,8 @@
from .memory import (
Action,
- MemoryContext,
+ MemorySource,
MemoryTimestamp,
MemoryUnit,
- MemoryUnitDB,
Result,
State,
Task,
@@ -13,11 +12,10 @@
__all__ = [
"Message",
"MemoryUnit",
- "MemoryContext",
+ "MemorySource",
"MemoryTimestamp",
"Task",
"State",
"Action",
"Result",
- "MemoryUnitDB",
]
diff --git a/athena/models/memory.py b/athena/models/memory.py
index 11fe613..4776665 100644
--- a/athena/models/memory.py
+++ b/athena/models/memory.py
@@ -1,17 +1,17 @@
-import json
import uuid
-from datetime import datetime
+from datetime import datetime, timezone
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, Field
-from sqlalchemy import Column, DateTime, Text
-from sqlmodel import SQLModel, UniqueConstraint
class MemoryTimestamp(BaseModel):
"""Lifecycle timestamps for a memory unit."""
- created_at: datetime = Field(None, description="When the memory unit was first created")
+ created_at: datetime = Field(
+ default_factory=lambda: datetime.now(timezone.utc),
+ description="When the memory unit was first created",
+ )
updated_at: Optional[datetime] = Field(
None, description="When the memory was last updated/refreshed"
)
@@ -20,18 +20,13 @@ class MemoryTimestamp(BaseModel):
)
-class MemoryContext(BaseModel):
- """Context metadata for a memory unit."""
+class MemorySource(BaseModel):
+ """Source information for a memory unit."""
- memory_id: str = Field(
- default_factory=lambda: str(uuid.uuid4()),
- description="Unique identifier for this memory unit",
- )
- source: str = Field(
+ source_name: str = Field(
..., description="Memory source, e.g., agent name, model name, dataset name, or file path"
)
run_id: str = Field(..., description="Unique id for agent run or trajectory dataset")
- timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Optional extension field for custom metadata"
)
@@ -103,148 +98,22 @@ class MemoryUnit(BaseModel):
Core memory unit capturing one action of agent execution.
This includes:
- - The memory context (memory_id, source, run_id, timestamps, metadata)
+ - The memory id (memory_id)
+ - The memory timestamp (timestamp)
+ - The memory source (source_name, run_id, metadata)
- The task being worked on (issue and repository details)
- The current state (what's done, what's todo)
- The action taken by the agent
- The result of the action
"""
- context: MemoryContext
+ memory_id: str = Field(
+ default_factory=lambda: str(uuid.uuid4()),
+ description="Unique identifier for this memory unit",
+ )
+ timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory")
+ source: MemorySource
task: Task
state: State
action: Action
result: Result
-
-
-# Database models for persistent storage
-class MemoryUnitDB(SQLModel, table=True):
- """Database model for persistent storage of memory units."""
-
- __tablename__ = "memory_units"
- __table_args__ = UniqueConstraint("memory_id", name="uq_memory_id")
-
- id: Optional[int] = Field(default=None, primary_key=True)
-
- # Context information
- memory_id: str
- memory_source: str
- memory_run_id: str
- memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
- memory_updated_at: Optional[datetime] = Field(
- default=None, sa_column=Column(DateTime(timezone=True))
- )
- memory_invalid_at: Optional[datetime] = Field(
- default=None, sa_column=Column(DateTime(timezone=True))
- )
- memory_metadata: str = Field(
- default="{}", sa_column=Column(Text)
- ) # JSON string for Dict[str, Any]
-
- # Task information
- task_issue_title: str = Field(sa_column=Column(Text))
- task_issue_body: str = Field(sa_column=Column(Text))
- task_issue_comments: str = Field(sa_column=Column(Text))
- task_issue_type: str
- task_repository: str
-
- # State information
- state_done: str = Field(sa_column=Column(Text))
- state_todo: str = Field(sa_column=Column(Text))
- state_open_file: Optional[str] = None
- state_working_dir: Optional[str] = None
- state_extra_environment: str = Field(
- default="{}", sa_column=Column(Text)
- ) # JSON string for Dict[str, str]
-
- # Action information
- action_name: str
- action_description: str = Field(sa_column=Column(Text))
- action_target: str
- action_tool: str
-
- # Result information
- result_type: str
- result_description: str = Field(sa_column=Column(Text))
- result_exit_code: Optional[int] = None
-
- @classmethod
- def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB":
- """Create a database model from a MemoryUnit."""
- return cls(
- memory_id=memory_unit.context.memory_id,
- memory_source=memory_unit.context.source,
- memory_run_id=memory_unit.context.run_id,
- memory_created_at=memory_unit.context.timestamp.created_at,
- memory_updated_at=memory_unit.context.timestamp.updated_at,
- memory_invalid_at=memory_unit.context.timestamp.invalid_at,
- memory_metadata=json.dumps(memory_unit.context.metadata)
- if memory_unit.context.metadata
- else "{}",
- task_issue_title=memory_unit.task.issue_title,
- task_issue_body=memory_unit.task.issue_body,
- task_issue_comments=memory_unit.task.issue_comments,
- task_issue_type=memory_unit.task.issue_type,
- task_repository=memory_unit.task.repository,
- state_done=memory_unit.state.done,
- state_todo=memory_unit.state.todo,
- state_open_file=memory_unit.state.open_file,
- state_working_dir=memory_unit.state.working_dir,
- state_extra_environment=json.dumps(memory_unit.state.extra_environment)
- if memory_unit.state.extra_environment
- else "{}",
- action_name=memory_unit.action.name,
- action_description=memory_unit.action.description,
- action_target=memory_unit.action.target,
- action_tool=memory_unit.action.tool,
- result_type=memory_unit.result.type,
- result_description=memory_unit.result.description,
- result_exit_code=memory_unit.result.exit_code,
- )
-
- def to_memory_unit(self) -> MemoryUnit:
- """Convert database model back to MemoryUnit."""
- return MemoryUnit(
- context=MemoryContext(
- memory_id=self.memory_id,
- source=self.memory_source,
- run_id=self.memory_run_id,
- timestamp=MemoryTimestamp(
- created_at=self.memory_created_at,
- updated_at=self.memory_updated_at,
- invalid_at=self.memory_invalid_at,
- ),
- metadata=json.loads(self.memory_metadata)
- if self.memory_metadata not in (None, "", "null")
- else {},
- ),
- task=Task(
- issue_title=self.task_issue_title,
- issue_body=self.task_issue_body,
- issue_comments=self.task_issue_comments,
- issue_type=self.task_issue_type,
- repository=self.task_repository,
- ),
- state=State(
- done=self.state_done,
- todo=self.state_todo,
- open_file=self.state_open_file,
- working_dir=self.state_working_dir,
- extra_environment=json.loads(self.state_extra_environment)
- if self.state_extra_environment not in (None, "", "null")
- else {},
- ),
- action=Action(
- name=self.action_name,
- description=self.action_description,
- target=self.action_target,
- tool=self.action_tool,
- ),
- result=Result(
- type=self.result_type
- if self.result_type in ["success", "failure", "partial", "unknown"]
- else "unknown",
- description=self.result_description,
- exit_code=self.result_exit_code,
- ),
- )
diff --git a/athena/models/message.py b/athena/models/message.py
index 5e4f2c4..4b66896 100644
--- a/athena/models/message.py
+++ b/athena/models/message.py
@@ -12,6 +12,7 @@
"sys": "system",
"system_prompt": "system",
"sysmsg": "system",
+ "human": "user",
}
diff --git a/athena/prompts/__init__.py b/athena/prompts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/init_memory_base/prompt.py b/athena/prompts/memory_extraction.py
similarity index 84%
rename from init_memory_base/prompt.py
rename to athena/prompts/memory_extraction.py
index 3ce2a80..1238991 100644
--- a/init_memory_base/prompt.py
+++ b/athena/prompts/memory_extraction.py
@@ -120,6 +120,41 @@
"""
+ACTION_JUDGE_PROMPT = """
+You are a STRICT binary classifier for assistant messages. Decide whether the message initiates or executes a concrete ACTION in an agent environment.
+
+## Decision Target
+Return **True** if the message triggers, schedules, or immediately proceeds to perform an operation — even when the operation is written in natural language or embedded in a code/markdown block.
+
+### ACTION indicators (non-exhaustive; if ambiguous, prefer **True**)
+- File/FS ops: create/open/read/write/edit/delete/rename/move/copy, save, apply patch, replace/insert, scroll_up/scroll_down, navigate.
+- Search/explore ops: find_file/grep/search/locate/list (files/dirs/symbols), project-wide search.
+- Execution/invocation: run/execute/launch/call/invoke a tool or command (pytest/python/bash/node/npm/pip/curl/sql/git), function_call objects, start/stop processes, HTTP requests.
+- Editor/tool calls: str_replace_editor.*, execute_bash, tool:{...}, Action: , ..., OpenAI function_call objects, XML/JSON tool blocks.
+- Repair/fix/change steps: "Let's fix...", "Edit the constructor/method...", "Implement changes", "Apply the patch".
+- Committed immediate steps in first person or imperative: "Now, let's ...", "First, I'll create ...", "Next, run ...", "Proceed to ...".
+- Fenced code/markdown blocks that contain actionable commands (even if preceded or followed by explanations), e.g.:
+ - ```find_file "dense_ndim_array.py"```
+ - ```scroll_down```
+ - ```create reproduce_bug.py```
+ - ```pytest -q```
+ - ```python reproduce_error.py```
+
+### NOT actions (return **False**) — only if clearly none of the above:
+- Pure analysis, planning, or summarization with no concrete operation.
+- Pure thought process or reasoning steps without executing/triggering any operation.
+- Questions, confirmations, or instructions to a human with no tool/command to be executed.
+- Code shown purely for illustration (no intent to execute/apply/edit), e.g., "Here is an example snippet:" followed by code not framed as an immediate step.
+
+### Ambiguity policy
+If the case is ambiguous, **prefer True** (maximize recall for action messages).
+
+## Output
+Output EXACTLY one token on a single line: True or False
+No punctuation, quotes, or explanations.
+"""
+
+
ACTION_EXTRACTION_PROMPT = """
Extract ONLY the `action` object from ONE assistant message that starts a tool/action call.
diff --git a/init_memory_base/dataset_pre.py b/init_memory_base/dataset_pre.py
deleted file mode 100644
index 858608c..0000000
--- a/init_memory_base/dataset_pre.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from datasets import load_dataset
-
-dataset = load_dataset("SWE-Gym/OpenHands-SFT-Trajectories", split="train.success.oss")
-dataset = load_dataset("ZhengyanShi/Poutine-OpenHands-SFT-Trajectories", split="train")
-dataset = load_dataset("SWE-Gym/OpenHands-Sampled-Trajectories", split="train.raw")
-dataset = load_dataset("SWE-bench/SWE-smith-trajectories")
-dataset = load_dataset("nebius/SWE-agent-trajectories", split="train")
-dataset = load_dataset("zai-org/SWE-Dev-train")
-dataset = load_dataset("JetBrains-Research/swe-traj-complete", split="train")
-# git clone https://huggingface.co/datasets/ByteDance-Seed/Multi-SWE-bench_trajs
diff --git a/init_memory_base/prompts/code_skeleton.txt b/init_memory_base/prompts/code_skeleton.txt
deleted file mode 100644
index 1b9b25a..0000000
--- a/init_memory_base/prompts/code_skeleton.txt
+++ /dev/null
@@ -1,31 +0,0 @@
-You are an expert software engineering assistant.
-Given an **observation** about a coding task and the **source code** of a function/class,
-your goal is to extract a **code skeleton** that highlights only the focused regions relevant to the observation.
-
-### Instructions
-1. Keep only the **main import statements** at the top of the file:
- - Keep imports that are directly used in the focused regions.
- - Keep imports that define critical dependencies (e.g., standard libraries like `os`, `sys`, or key frameworks like `torch`, `pandas`).
- - Collapse all other imports into a single line: `... # other imports omitted`.
-2. Keep class and function headers **with full signatures**:
- - Include parameter names and types (if available).
- - Include return type or return value description (infer if not explicit; keep concise, e.g., "-> bool").
-3. Preserve necessary surrounding structure for readability (class/def blocks, braces/indentation).
-4. Inside classes/functions:
- - Keep only lines relevant to the current observation (**focused logic** and **key invocations**).
- - Replace non-relevant lines with a single line containing `...`.
-5. Provide ±3 lines of context around each focused region, if available.
-6. Do **not** add region markers or any textual summary. Output only the code skeleton.
-7. Do not introduce new identifiers, logic, control flow, or API calls that are not present in the source.
-8. Do not infer types/return types that are not explicit; if unknown, use -> Unknown or omit.
-9. Output one fenced code block with correct language tag. No prose outside.
-
-### Input
-Observation:
-
-
-Source Code:
-
-
-### Output (Code Skeleton Only)
-
diff --git a/init_memory_base/prompts/text_summary.txt b/init_memory_base/prompts/text_summary.txt
deleted file mode 100644
index f37456a..0000000
--- a/init_memory_base/prompts/text_summary.txt
+++ /dev/null
@@ -1,31 +0,0 @@
-You are an expert software engineering assistant.
-Given an **observation** about a coding task and the **source code** of a function/class,
-your goal is to extract a **concise textual summary** that captures the essential intent, behavior, and role of the code in relation to the observation.
-
-### Instructions
-1. **Focus on relevance**: Summaries must highlight aspects of the code most related to the observation (e.g., functionality, data flow, critical logic).
-2. **Content to include**:
- - The **purpose** of the function/class (what it does, high-level intent).
- - The **inputs/outputs** (parameters, return values, side effects).
- - The **key operations or algorithms** performed (control flow, API/library usage).
- - Any **dependencies** or notable relationships (calls to other functions/classes).
-3. **Conciseness**:
- - Write in **2–4 sentences**.
- - Avoid unnecessary details (variable names, trivial operations).
- - Use plain, precise technical language.
-4. **No speculation**:
- - If the code’s intent is unclear, state it as *“uncertain”* rather than guessing.
- - Do not introduce functionality that is not evident in the code.
-5. **Output format**:
- - One concise **paragraph of plain text only** (no code, no bullet points, no extra markup).
- - Do not repeat the observation verbatim.
-
-### Input
-Observation:
-
-
-Source Code:
-
-
-### Output (Textual Summary Only)
-
diff --git a/init_memory_base/trajectory_memory_extractor_deepseek.py b/init_memory_base/trajectory_memory_extractor_deepseek.py
deleted file mode 100644
index 5f084dd..0000000
--- a/init_memory_base/trajectory_memory_extractor_deepseek.py
+++ /dev/null
@@ -1,692 +0,0 @@
-import hashlib
-import json
-import os
-import random
-import re
-import time
-from typing import Dict, Iterable, List, Literal, Optional, Protocol
-
-import requests
-from datasets import load_dataset
-from prompt import (
- ACTION_EXTRACTION_PROMPT,
- RESULT_EXTRACTION_PROMPT,
- STATE_DONE_SUMMARY_PROMPT,
- STATE_TODO_SYNTHESIS_PROMPT,
- TASK_EXTRACTION_PROMPT,
-)
-from pydantic import BaseModel, Field
-from tqdm import tqdm
-
-# =============================
-# 1) Pydantic data structures
-# =============================
-
-# memory_unit_schema = {
-# "trajectory": {
-# "source": str,
-# "id": str,
-# },
-# "task": {
-# "issue_title": str,
-# "issue_body": str,
-# "issue_comments": str,
-# "issue_type": str, # e.g., bug, feature, documentation, question
-# "repository": str,
-# },
-# "state": {
-# "done": str, # summaries of the completed work
-# "todo": str, # the work to be done, decided on the next action
-# # "accessed_code": str,
-# },
-# "action": {
-# "name": str, # e.g., read_file, edit_file, run_test, invoke_tool
-# "description": str,
-# "target": str, # e.g., utils/math.py, tests/test_math.py, pytest, git
-# "tool": str, # e.g., pytest, git, search
-# },
-# "result": {
-# "type": str, # success | failure | partial | unknown
-# "description": str,
-# "exit_code": str,
-# },
-# }
-
-
-class TrajectoryMessage(BaseModel):
- content: str
- role: str # "system" | "user" | "assistant" | tool etc.
- # TODO: chech style of trajectory message
-
-
-class TrajectoryMeta(BaseModel):
- source: str = Field(..., description="e.g., dataset name, run id, or file path")
- id: str = Field(..., description="unique id for the trajectory")
-
-
-class Task(BaseModel):
- issue_title: str
- issue_body: str
- issue_comments: str
- issue_type: str # bug | feature | documentation | question | other
- repository: str
-
-
-class State(BaseModel):
- done: str
- todo: str
-
-
-class Action(BaseModel):
- name: str # read_file | edit_file | run_test | invoke_tool | etc.
- description: str
- target: str # utils/math.py | tests/test_math.py | pytest | git
- tool: str # pytest | git | search | editor | bash
-
-
-ResultType = Literal["success", "failure", "partial", "unknown"]
-
-
-class Result(BaseModel):
- type: ResultType = "unknown" # success | failure | partial | unknown
- description: str = "" # one-sentence result summary (for human)
- exit_code: Optional[int] = None # if extractable from logs
-
-
-class MemoryUnit(BaseModel):
- trajectory: TrajectoryMeta
- task: Task
- state: State
- action: Action
- result: Result
-
- def key(self) -> str:
- # Canonical key for dedup (task+state) as requested
- payload = json.dumps(
- {
- "task": self.task.model_dump(),
- "state": self.state.model_dump(),
- },
- sort_keys=True,
- )
- return hashlib.sha256(payload.encode("utf-8")).hexdigest()
-
-
-# Optional: In-memory base and protocol you can replace with your own DB/Vector store
-class MemoryBaseProtocol(Protocol):
- def upsert(self, units: Iterable[MemoryUnit]) -> None: ...
-
-
-class InMemoryMemoryBase:
- def __init__(self):
- self._store: Dict[str, MemoryUnit] = {}
-
- def upsert(self, units: Iterable[MemoryUnit]) -> None:
- for u in units:
- self._store[u.key()] = u
-
- def all(self) -> List[MemoryUnit]:
- return list(self._store.values())
-
-
-# =============================================
-# 2) DeepSeek (OpenAI-compatible) API utilities
-# =============================================
-
-DEEPSEEK_API_BASE = os.environ.get("DEEPSEEK_API_BASE", "https://api.deepseek.com")
-DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "")
-DEEPSEEK_MODEL = os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
-
-HEADERS = {
- "Authorization": f"Bearer {DEEPSEEK_API_KEY}",
- "Content-Type": "application/json",
-}
-
-
-class RateLimit(Exception):
- pass
-
-
-_SESSION: Optional[requests.Session] = None
-
-
-def _get_session() -> requests.Session:
- global _SESSION
- if _SESSION is None:
- s = requests.Session()
- adapter = requests.adapters.HTTPAdapter(
- pool_connections=10, pool_maxsize=50, max_retries=0, pool_block=True
- )
- s.mount("https://", adapter)
- s.mount("http://", adapter)
- _SESSION = s
- return _SESSION
-
-
-def _post_with_retries(
- session: requests.Session,
- url: str,
- payload: dict,
- *,
- max_retries: int = 5,
- retry_base: float = 1.25,
- timeout=(10, 60),
-) -> str:
- if not DEEPSEEK_API_KEY:
- raise RuntimeError("DEEPSEEK_API_KEY is not set.")
-
- for attempt in range(max_retries):
- try:
- headers = {
- "Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY', '')}",
- "Content-Type": "application/json",
- }
- resp = session.post(url, headers=headers, json=payload, timeout=timeout)
- if resp.status_code == 429:
- raise RateLimit("Rate limited")
- resp.raise_for_status()
- data = resp.json()
- content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
- if not content:
- raise ValueError("Empty content from API")
- return content
-
- except (requests.HTTPError, requests.ConnectionError, requests.Timeout, RateLimit) as e:
- if attempt == max_retries - 1:
- raise
- # Optional: Rebuild Session on connection errors
- if isinstance(e, (requests.ConnectionError,)):
- global _SESSION
- _SESSION = None
- session = _get_session()
- time.sleep((retry_base**attempt) + random.uniform(0, 0.25))
-
-
-def call_deepseek(
- messages: List[Dict[str, str]],
- *,
- temperature: float = 0.0,
- max_retries: int = 5,
- retry_base: float = 1.25,
- timeout: int = 60,
-) -> str:
- """Call DeepSeek Chat Completions endpoint and return assistant text."""
- url = f"{DEEPSEEK_API_BASE.rstrip('/')}/v1/chat/completions"
- payload = {
- "model": DEEPSEEK_MODEL,
- "messages": messages,
- "temperature": temperature,
- "response_format": {"type": "json_object"},
- # Optional: Limit return length to further reduce latency
- # "max_tokens": 512,
- # "n": 1, # number of completions to generate
- }
- session = _get_session()
- return _post_with_retries(
- session, url, payload, max_retries=max_retries, retry_base=retry_base, timeout=timeout
- )
-
-
-# =====================================
-# 3) Orchestration over a trajectory
-# =====================================
-
-
-def extract_result_from_last_message(
- window_msgs: List[TrajectoryMessage], current_action: Action
-) -> Result:
- """
- Use DeepSeek to extract the `result` from the LAST message of a window.
- - Evidence-bound to the last message ONLY.
- - Output must be ONE plain-text sentence capturing the definitive outcome.
- - Minimal heuristics; a tiny fallback is used only if the API fails.
-
- Assumes available:
- - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str
- - json imported
- """
- if not window_msgs:
- return Result(type="unknown", description="(none)")
- last_msg = window_msgs[-1]
-
- system_prompt = RESULT_EXTRACTION_PROMPT
- user_prompt = (
- "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n"
- "\n" + current_action.model_dump_json() + "\n\n\n"
- "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n"
- "\n" + last_msg.model_dump_json() + "\n\n\n"
- "Return ONLY the JSON object."
- )
-
- try:
- raw = call_deepseek(
- [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- temperature=0.0,
- )
- # Parse JSON; if failed, try to fix it to a valid JSON
- try:
- payload = json.loads(raw)
- except json.JSONDecodeError:
- fixed = call_deepseek(
- [
- {"role": "system", "content": "Return a valid JSON object only. No comments."},
- {"role": "user", "content": raw},
- ],
- temperature=0.0,
- )
- payload = json.loads(fixed)
-
- r = (payload or {}).get("result", {}) or {}
-
- # Utility: convert to string
- def s(x):
- return x if isinstance(x, str) else ("" if x is None else str(x))
-
- # Result type correction
- r_type = s(r.get("type")).lower()
- if r_type not in {"success", "failure", "partial", "unknown"}:
- r_type = "unknown"
-
- # Text: one sentence, remove wrapping quotes
- description = s(r.get("description")).strip().strip('"')
- if not description:
- description = "(none)"
-
- # exit_code: allow returning number or number string
- exit_raw = r.get("exit_code")
- exit_code = None
- if isinstance(exit_raw, (int, float)) and str(int(exit_raw)) == str(exit_raw).split(".")[0]:
- exit_code = int(exit_raw)
- elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()):
- exit_code = int(exit_raw.strip())
-
- return Result(type=r_type, description=description, exit_code=exit_code)
-
- except Exception:
- # Minimal fallback: only based on the last message (no LLM call)
- src = (getattr(last_msg, "content", "") or "").strip()
- lt = src.lower()
-
- if re.search(r"(error|exception|traceback|failed)", lt):
- rtype = "failure"
- elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt):
- rtype = "success"
- elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt):
- rtype = "partial"
- else:
- rtype = "unknown"
-
- # Extract the last non-empty line as one sentence result, and truncate to ~60 words
- lines = [ln.strip() for ln in src.splitlines() if ln.strip()]
- summary = lines[-1] if lines else "(none)"
- words = summary.split()
- if len(words) > 60:
- summary = " ".join(words[:60])
-
- m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I)
- code = int(m.group(1)) if m else None
-
- return Result(type=rtype, description=summary, exit_code=code)
-
-
-def synthesize_state_todo_paragraph(
- window_msgs: List[TrajectoryMessage],
- task: Task,
- prior_done_text: str = "",
- current_action: Optional[Action] = None,
-) -> str:
- """
- Call DeepSeek to synthesize `state.todo` as ONE highly condensed paragraph
- capturing the intent of the CURRENT message window.
-
- Dependencies assumed available:
- - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str
- - json imported
- """
-
- system_prompt = STATE_TODO_SYNTHESIS_PROMPT
- user_prompt = (
- "TASK (overall purpose):\n"
- "\n" + task.model_dump_json() + "\n\n\n"
- "PRIOR_DONE (what has already been completed; avoid repeating it):\n"
- "\n" + (prior_done_text or "") + "\n\n\n"
- "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n"
- "\n"
- + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False)
- + "\n\n\n"
- "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n"
- "\n" + current_action.model_dump_json() + "\n\n\n"
- "Return ONLY the final paragraph."
- )
-
- try:
- text = call_deepseek(
- [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- temperature=0.0,
- )
- return text.strip().strip('"').replace("\n", " ").strip()
- except Exception:
- # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation)
- parts = []
- task_short = (task.issue_title or task.repository or task.issue_type or "the task").strip()
- parts.append(f"To progress on {task_short},")
- # Build an imperative clause from action fields
- verb = (current_action.name or "execute action").replace("_", " ")
- tgt = f" {current_action.target}" if current_action.target else ""
- via = f" via {current_action.tool}" if current_action.tool else ""
- desc = f" to {current_action.description}" if current_action.description else ""
- parts.append(f" {verb}{tgt}{via}{desc}.")
- return " ".join(parts).replace(" ", " ").strip()
-
-
-def extract_action_from_first_action_msg(first_action_msg: TrajectoryMessage) -> Action:
- """
- Use DeepSeek to extract the `action` object from a SINGLE assistant message
- that starts an action call. Minimal heuristics; evidence-bound to the message.
- Dependencies assumed available:
- - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str
- - Action Pydantic model
- - json imported
- """
- system_prompt = ACTION_EXTRACTION_PROMPT
- user_prompt = (
- "Extract the `action` from the SINGLE assistant message below.\n"
- "Use ONLY the content between and .\n\n"
- "\n" + first_action_msg.model_dump_json() + "\n"
- )
-
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ]
-
- raw = call_deepseek(messages, temperature=0.0)
-
- # Robust JSON parse with one repair attempt
- try:
- payload = json.loads(raw)
- except json.JSONDecodeError:
- fixed = call_deepseek(
- [
- {"role": "system", "content": "Return a valid JSON object only. No comments."},
- {"role": "user", "content": raw},
- ],
- temperature=0.0,
- )
- payload = json.loads(fixed)
-
- action_obj = payload.get("action", {}) or {}
-
- def s(x):
- return x if isinstance(x, str) else ("" if x is None else str(x))
-
- return Action(
- name=s(action_obj.get("name")),
- description=s(action_obj.get("description")),
- target=s(action_obj.get("target")),
- tool=s(action_obj.get("tool")),
- )
-
-
-def finalize_window_with_llm(
- window_msgs: List[TrajectoryMessage],
- first_action_msg: TrajectoryMessage,
- prior_units: List[MemoryUnit],
- meta: TrajectoryMeta,
- task: Task,
-) -> Optional[MemoryUnit]:
- """
- Extract a single memory unit from a window of messages.
- - Synthesize state.done from prior actions
- - Synthesize state.todo by extracting intents from window_msgs
- - Extract action from first_action_msg
- - Extract result from the last message in window_msgs
- """
- action = extract_action_from_first_action_msg(first_action_msg)
- state_done = summarize_context_from_units(prior_units, task)
- state_todo = synthesize_state_todo_paragraph(window_msgs, task, state_done, action)
- result = extract_result_from_last_message(window_msgs, action)
- return MemoryUnit(
- trajectory=meta,
- task=task,
- state=State(
- done=state_done,
- todo=state_todo,
- ),
- action=action,
- result=result,
- )
-
-
-def is_action_start(msg: TrajectoryMessage) -> bool:
- """Return True if this message starts an action call, else False."""
- if getattr(msg, "role", "") != "assistant":
- return False
- content = getattr(msg, "content", "") or ""
- return bool(re.search(r"", content, flags=re.DOTALL))
- # TODO: content check, different trajectory style
-
-
-def extract_memory_units_by_action_windows(
- meta: TrajectoryMeta,
- task: Task,
- trajectory: List[TrajectoryMessage],
- *,
- memory_base: Optional[MemoryBaseProtocol] = None,
-) -> List[MemoryUnit]:
- ordered_memory_units: List[MemoryUnit] = []
- window_msgs: List[TrajectoryMessage] = []
- window_first_action: Optional[TrajectoryMessage] = None
-
- for msg in tqdm(trajectory, desc="Extracting memory units"):
- if is_action_start(msg):
- if window_msgs and window_first_action is not None:
- mu = finalize_window_with_llm(
- window_msgs,
- window_first_action,
- prior_units=ordered_memory_units,
- meta=meta,
- task=task,
- )
- if mu:
- ordered_memory_units.append(mu)
-
- window_msgs = [msg]
- window_first_action = msg
- else:
- if window_msgs:
- window_msgs.append(msg)
- else:
- continue
-
- if window_msgs and window_first_action is not None:
- mu = finalize_window_with_llm(
- window_msgs, window_first_action, prior_units=ordered_memory_units, meta=meta, task=task
- )
- if mu:
- ordered_memory_units.append(mu)
-
- if memory_base is not None and ordered_memory_units:
- memory_base.upsert(ordered_memory_units)
-
- return ordered_memory_units
-
-
-def summarize_context_from_units(units: List[MemoryUnit], task: Task) -> str:
- """
- Summarize previous context into a concise `state.done` string by calling DeepSeek.
- Uses only the last 10 memory units and asks the model to produce a single, plain-text
- summary of what has ALREADY BEEN COMPLETED (no plans).
-
- Dependencies assumed available:
- - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str
- - MemoryUnit Pydantic model
- - json imported
- """
- if not units:
- return "(none)"
-
- # Keep the window modest to control prompt size
- window = units[-10:]
-
- # Pack minimal, evidence-bound context for the LLM (no heavy heuristics)
- history = []
- for u in window:
- history.append(
- {
- "state": {
- "done": u.state.done,
- "todo": u.state.todo,
- },
- "action": {
- "name": u.action.name,
- "description": u.action.description,
- "target": u.action.target,
- "tool": u.action.tool,
- },
- "result": {
- "type": u.result.type,
- "description": u.result.description,
- "exit_code": u.result.exit_code,
- },
- }
- )
-
- system_prompt = STATE_DONE_SUMMARY_PROMPT
- user_prompt = (
- "Summarize ONLY what has ALREADY BEEN COMPLETED into `state.done`.\n"
- "TASK (for deriving only):\n"
- "\n" + task.model_dump_json() + "\n\n\n"
- "PRIOR UNITS (used ONLY for completed work evidence):\n"
- "\n" + json.dumps(history, ensure_ascii=False) + "\n\n\n"
- "Return the final summary paragraph ONLY (no explanations)."
- )
-
- try:
- summary = call_deepseek(
- [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- temperature=0.0,
- )
- # Be lenient if the model wraps the text in quotes/newlines
- return summary.strip().strip('"').strip()
- except Exception:
- # Minimal, defensive fallback (kept short); avoids complex heuristics
- return window[-1].state.done
-
-
-def extract_task_from_first_user_message(
- trajectory: List[TrajectoryMessage], *, default_repository: str = ""
-) -> Task:
- """
- From a trajectory (list of TrajectoryMessage), take the first user message,
- call DeepSeek to extract the Task fields, and return a Task.
-
- Requirements:
- - Uses the LLM to infer fields (minimal heuristics).
- - Returns empty strings when unknown.
- - If `default_repository` is provided and the model leaves repository empty,
- fill it with `default_repository`.
-
- Dependencies assumed available in your file:
- - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str
- - Task Pydantic model
- - json imported
- """
- # 1) find first user message
- first_user_msg = next((m for m in trajectory if getattr(m, "role", "") == "user"), None)
- if first_user_msg is None:
- return Task(
- issue_title="",
- issue_body="",
- issue_comments="",
- issue_type="",
- repository=default_repository,
- )
-
- # 2) build a schema-constrained prompt for ONLY Task fields
- system_prompt = TASK_EXTRACTION_PROMPT
- user_prompt = (
- "Extract the `task` object from the SINGLE user message below.\n"
- "Use only the text between and .\n\n"
- "\n" + first_user_msg.model_dump_json() + "\n"
- )
-
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ]
-
- # 3) call DeepSeek
- raw = call_deepseek(messages, temperature=0.0)
-
- # 4) robust JSON parse, with a single repair attempt if needed
- try:
- payload = json.loads(raw)
- except json.JSONDecodeError:
- fix_messages = [
- {"role": "system", "content": "Fix to valid JSON object. No comments."},
- {"role": "user", "content": raw},
- ]
- fixed = call_deepseek(fix_messages, temperature=0.0)
- payload = json.loads(fixed)
-
- # 5) normalize fields to strings (no extra heuristics)
- task_obj = payload.get("task", {})
-
- def s(x):
- return x if isinstance(x, str) else ("" if x is None else str(x))
-
- repo = s(task_obj.get("repository"))
- if not repo and default_repository:
- repo = default_repository # optional fallback; minimal deviation
-
- return Task(
- issue_title=s(task_obj.get("issue_title")),
- issue_body=s(task_obj.get("issue_body")),
- issue_comments=s(task_obj.get("issue_comments")),
- issue_type=s(task_obj.get("issue_type")),
- repository=repo,
- )
-
-
-def row_to_trajectory_messages(row):
- traj = row["messages"]
- msgs = []
- for m in traj:
- content = m.get("content", "") if isinstance(m, dict) else str(m)
- role = m.get("role", "user") if isinstance(m, dict) else "user"
- msgs.append(TrajectoryMessage(content=content, role=role))
- return msgs
-
-
-def main():
- trajectory_repo_name = "SWE-Gym/OpenHands-SFT-Trajectories"
- trajectory_split = "train.success.oss"
- dataset = load_dataset(trajectory_repo_name, split=trajectory_split)
- trajectories = [row_to_trajectory_messages(row) for row in dataset]
- memory_base = InMemoryMemoryBase()
- print(f"\nExtracting memory units for {len(trajectories)} trajectories...\n")
-
- for idx, trajectory in enumerate(trajectories):
- print(f"\nProcessing trajectory {idx + 1} of {len(trajectories)}...\n")
- meta = TrajectoryMeta(source=trajectory_repo_name, id=f"{idx}")
- task = extract_task_from_first_user_message(trajectory)
- units = extract_memory_units_by_action_windows(
- meta, task, trajectory, memory_base=memory_base
- )
- print(f"\nExtracted {len(units)} memory units. Showing JSON:\n")
- print(json.dumps([u.model_dump() for u in units], ensure_ascii=False, indent=2))
- assert False
-
-
-if __name__ == "__main__":
- main()
diff --git a/pyproject.toml b/pyproject.toml
index 9634060..89610b0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,6 +15,7 @@ dependencies = [
"langchain-openai==0.2.8",
"langchain-google-genai==2.0.4",
"langchain_community==0.3.2",
+ "langchain_google_vertexai==2.1.0"
]
requires-python = ">= 3.11"