diff --git a/athena/app/services/llm_service.py b/athena/app/services/llm_service.py index 6c28d5a..fe39fbd 100644 --- a/athena/app/services/llm_service.py +++ b/athena/app/services/llm_service.py @@ -3,6 +3,7 @@ from langchain_anthropic import ChatAnthropic from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_google_vertexai import ChatVertexAI from athena.app.services.base_service import BaseService from athena.chat_models.custom_chat_openai import CustomChatOpenAI @@ -50,6 +51,17 @@ def get_model( max_tokens_to_sample=max_output_tokens, max_retries=3, ) + elif model_name.startswith("vertex:"): + # example: model_name="vertex:gemini-2.5-pro" + vertex_model = model_name.split("vertex:", 1)[1] + return ChatVertexAI( + model_name=vertex_model, + project="prometheus-code-agent", + location="us-central1", + temperature=temperature, + max_output_tokens=max_output_tokens, + max_retries=3, + ) elif "gemini" in model_name: return ChatGoogleGenerativeAI( model=model_name, diff --git a/athena/app/services/memory_extraction_service.py b/athena/app/services/memory_extraction_service.py new file mode 100644 index 0000000..b6a5071 --- /dev/null +++ b/athena/app/services/memory_extraction_service.py @@ -0,0 +1,882 @@ +import asyncio +import html +import json +import re +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from datasets import load_dataset +from tqdm import tqdm + +from athena.app.services.base_service import BaseService +from athena.app.services.llm_service import LLMService +from athena.app.services.memory_storage_service import MemoryStorageService +from athena.models import ( + Action, + MemorySource, + MemoryTimestamp, + MemoryUnit, + Message, + Result, + State, + Task, +) +from athena.prompts.memory_extraction import ( + ACTION_EXTRACTION_PROMPT, + ACTION_JUDGE_PROMPT, + RESULT_EXTRACTION_PROMPT, + STATE_DONE_SUMMARY_PROMPT, + STATE_TODO_SYNTHESIS_PROMPT, + TASK_EXTRACTION_PROMPT, +) + +# from athena.utils.logger_manager import get_logger + + +class ExtractionError(Exception): + """Exception raised when memory extraction fails.""" + + def __init__(self, source_name: str, run_id: str): + self.source_name = source_name + self.run_id = run_id + super().__init__(f"Failed to extract memory units from {run_id} in {source_name}") + + +class MemoryExtractionService(BaseService): + """ + Service for extracting structured memory units from interaction trajectories. + + This service orchestrates the extraction of memory units from various data + sources including conversation logs, agent execution traces, and user + interaction histories. It provides a unified interface for converting + unstructured interaction data into structured memory units that can be + stored and queried by the Athena memory system. + + Key Features: + - Multi-source trajectory loading (datasets, files, APIs) + - Configurable extraction strategies (LLM-based, rule-based, hybrid) + - Batch processing with progress tracking + - Error handling and retry mechanisms + - Integration with existing memory storage systems + + Architecture: + The service follows a pipeline pattern: + 1. Data Source -> Load trajectories + 2. Extraction Strategy -> Extract components (task, action, result, state) + 3. Memory Unit Assembly -> Combine components into MemoryUnit + 4. Validation -> Ensure data quality and consistency + 5. Storage -> Persist to memory system + + Usage: + service = MemoryExtractionService(llm_service) + memory_units = service.extract_from_trajectories(trajectory_data) + """ + + def __init__( + self, + llm_service: LLMService, + batch_size: int = 100, + max_retries: int = 3, + memory_store: Optional[MemoryStorageService] = None, + ): + """ + Initialize the Memory Extraction Service. + + Args: + llm_service: LLM service for AI-powered extraction + batch_size: Number of trajectories to process in each batch + max_retries: Maximum retry attempts for failed extractions + """ + self.llm_service = llm_service + self.batch_size = batch_size + self.max_retries = max_retries + self._extraction_cache: Dict[str, MemoryUnit] = {} + self.memory_store = memory_store + # self._logger = get_logger(__name__) + + def start(self): + """Start the memory extraction service.""" + # Initialize any required resources + pass + + def close(self): + """Close the memory extraction service and cleanup resources.""" + self._extraction_cache.clear() + + def extract_from_huggingface_trajectory_repository( # TODO: batch extraction + self, repo_name: str, split: str + ) -> List[MemoryUnit]: + """ + Extract memory units from a HuggingFace trajectory repository. + + Args: + repo_name: Name of the HuggingFace trajectory repository + split: Split of the HuggingFace trajectory repository + + Returns: + List of extracted memory units + + Raises: + ExtractionError: If extraction fails for critical trajectories + """ + dataset = load_dataset(repo_name, split=split) + + traj_keys = ["messages", "history", "trajectory", "chat", "conversation", "dialog"] + run_id_keys = ["run_id", "traj_id", "id"] + inst_id_keys = ["instance_id"] + model_keys = ["model", "model_name"] + resolved_keys = [ + "resolved", + "target", + ] # bool value, whether the model solved the issue in this trajectory. + + def _pick(d: Dict[str, Any], keys) -> Optional[Any]: + for k in keys: + if k in d: + return d[k] + return None + + for idx in range(len(dataset)): + row = dataset[idx] + trajectory = _pick(row, traj_keys) + messages: List[Message] = [] + for m in trajectory: + content = str( + m.get("content") + or m.get("text") + or m.get("message") + or m.get("utterance") + or "" + ) + role = str(m.get("role", "unknown")) + metadata = { + k: v + for k, v in m.items() + if k not in {"content", "role", "text", "message", "utterance"} + } + messages.append(Message(content=content, role=role, metadata=metadata)) + + run_id = _pick(row, run_id_keys) or f"{idx}" + instance_id = _pick(row, inst_id_keys) + model = _pick(row, model_keys) + resolved = _pick( + row, resolved_keys + ) # TODO: use resolved to filter memory units to successful/failed ones. + memory_source = self._extract_memory_source( + repo_name, + run_id, + metadata={ + **({"instance_id": instance_id} if instance_id is not None else {}), + **({"model": model} if model is not None else {}), + **({"resolved": resolved} if resolved is not None else {}), + }, + ) + try: + memory_units = self._extract_memory_units_by_action_windows(messages, memory_source) + except ExtractionError: + # self._logger.error(f"Failed to extract memory units from {run_id} in {repo_name}: {e}") # TODO: add logger + continue + + self._extraction_cache.update( + {mu.memory_id: mu for mu in memory_units} + ) # TODO: use PostgreSqlMemoryStore to store memory units. + return list(self._extraction_cache.values()) + + def _extract_memory_source( + self, source: str, run_id: str, metadata: Optional[Dict[str, Any]] = None + ) -> MemorySource: + """ + Extract memory source for a trajectory. + + Args: + source: Data source identifier + run_id: Unique run identifier + metadata: Optional additional metadata + + Returns: + MemorySource object + """ + return MemorySource( + source_name=source, + run_id=run_id, + metadata=metadata or {}, + ) + + def _extract_memory_units_by_action_windows( + self, + messages: List[Message], + memory_source: MemorySource, + ) -> List[MemoryUnit]: + ordered_memory_units: List[MemoryUnit] = [] + window_msgs: List[Message] = [] + window_first_action: Optional[Message] = None + + task = self._extract_task_from_messages(messages) + + for msg in tqdm( + messages, + desc=f"Extracting memory units for {memory_source.run_id} in {memory_source.source_name}", + ): + if self._is_action_message(msg): + if window_msgs and window_first_action is not None: + mu = self._create_memory_unit( + memory_source, + task, + window_msgs, + ordered_memory_units, + ) + if mu: + ordered_memory_units.append(mu) + + window_msgs = [msg] + window_first_action = msg + else: + if window_msgs: + window_msgs.append(msg) + else: + continue + + if window_msgs and window_first_action is not None: + mu = self._create_memory_unit( + memory_source, + task, + window_msgs, + ordered_memory_units, + ) + if mu: + ordered_memory_units.append(mu) + + if self.memory_store is not None and ordered_memory_units: # TODO: refine memory store + # Fire-and-forget is acceptable in most flows; caller may await if desired + if asyncio.iscoroutinefunction(self.memory_store.upsert): + asyncio.create_task(self.memory_store.upsert(ordered_memory_units)) + else: + self.memory_store.upsert(ordered_memory_units) # type: ignore[arg-type] + + return ordered_memory_units + + def _is_action_message(self, message: Message) -> bool: + """ + Determine if a message represents the start of an action. + + Args: + message: Message to analyze + + Returns: + True if message starts an action, False otherwise + """ + if getattr(message, "role", "") != "assistant": + return False + system_prompt = ACTION_JUDGE_PROMPT + user_prompt = ( + "Decide if the following assistant message represents a concrete ACTION.\n" + "Use ONLY the content between and as input.\n" + "Output exactly one token: True or False.\n\n" + "\n" + message.model_dump_json() + "\n" + ) + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ).content + return self._normalize_llm_bool_output(raw) + + def _create_memory_unit( + self, + source: MemorySource, + task: Task, + window_msgs: List[Message], + prior_units: List[MemoryUnit], + ) -> MemoryUnit: + """ + Extract a single memory unit from a window of messages. + - Synthesize state.done from prior actions + - Synthesize state.todo by extracting intents from window_msgs + - Extract action from first_action_msg + - Extract result from the last message in window_msgs + """ + action = self._extract_action_from_message(window_msgs[0]) + state_done = self._synthesize_state_done_from_context(prior_units, task) + state_todo = self._synthesize_state_todo_from_window(window_msgs, task, state_done, action) + result = self._extract_result_from_window(window_msgs, action) + return MemoryUnit( + memory_id=str(uuid.uuid4()), + timestamp=MemoryTimestamp( + created_at=datetime.now(timezone.utc), + updated_at=None, + invalid_at=None, + ), + source=source, + task=task, + state=State( + done=state_done, + todo=state_todo, + ), + action=action, + result=result, + ) + + def _extract_task_from_messages(self, messages: List[Message]) -> Task: + """ + Extract task information from trajectory messages. + + This method identifies the first user message and extracts: + - Issue title and description + - Issue type classification + - Repository information + - Related comments or context + + Args: + messages: List of messages in the trajectory + + Returns: + Task object with extracted information + """ + # 1) find first user message + first_user_msg = next((m for m in messages if getattr(m, "role", "") == "user"), None) + if first_user_msg is None: + return Task( + issue_title="", + issue_body="", + issue_comments="", + issue_type="", + repository="", + ) + + # 2) build a schema-constrained prompt for ONLY Task fields + system_prompt = TASK_EXTRACTION_PROMPT + user_prompt = ( + "Extract the `task` object from the SINGLE user message below.\n" + "Use only the text between and .\n\n" + "\n" + first_user_msg.model_dump_json() + "\n" + ) + + # 3) call LLM service + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ).content + + # 4) robust JSON parse, with a single repair attempt if needed + payload = self._normalize_llm_json_output(raw) + + # 5) normalize fields to strings (no extra heuristics) + task_obj = payload.get("task", {}) + + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) + + repo = s(task_obj.get("repository")) + if not repo: + repo = "" + + return Task( + issue_title=s(task_obj.get("issue_title")), + issue_body=s(task_obj.get("issue_body")), + issue_comments=s(task_obj.get("issue_comments")), + issue_type=s(task_obj.get("issue_type")), + repository=repo, + ) + + def _synthesize_state_from_context( + self, + prior_units: List[MemoryUnit], + task: Task, + window_msgs: List[Message], + current_action: Optional[Action] = None, + ) -> State: + """ + Synthesize state information from context. + """ + state_done = self._synthesize_state_done_from_context(prior_units, task) + state_todo = self._synthesize_state_todo_from_window( + window_msgs, task, state_done, current_action + ) + return State(done=state_done, todo=state_todo) # TODO: open_file, working_dir + + def _synthesize_state_done_from_context(self, prior_units: List[MemoryUnit], task: Task) -> str: + """ + Summarize previous context into a concise `state.done` string. + Summary of what has ALREADY BEEN COMPLETED (no plans). + """ + if not prior_units: + return "(none)" + + # Keep the window modest to control prompt size + window = prior_units[-1:] # state typically depends only on the last state in an agent run + + # Pack minimal, evidence-bound context for the LLM (no heavy heuristics) + history = [] + for u in window: + history.append( + { + "state": { + "done": u.state.done, + "todo": u.state.todo, + }, + "action": { + "name": u.action.name, + "description": u.action.description, + "target": u.action.target, + "tool": u.action.tool, + }, + "result": { + "type": u.result.type, + "description": u.result.description, + "exit_code": u.result.exit_code, + }, + } + ) + + system_prompt = STATE_DONE_SUMMARY_PROMPT + user_prompt = ( + "Summarize ONLY what has ALREADY BEEN COMPLETED into `state.done`.\n" + "TASK (for deriving only):\n" + "\n" + task.model_dump_json() + "\n\n\n" + "PRIOR UNITS (used ONLY for completed work evidence):\n" + "\n" + json.dumps(history, ensure_ascii=False) + "\n\n\n" + "Return the final summary paragraph ONLY (no explanations)." + ) + + try: + summary = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ).content + # Be lenient if the model wraps the text in quotes/newlines + return summary.strip().strip('"').strip() + except Exception: + # Minimal, defensive fallback (kept short); avoids complex heuristics + return window[-1].state.done + + def _synthesize_state_todo_from_window( + self, + window_msgs: List[Message], + task: Task, + prior_done_text: str = "", + current_action: Optional[Action] = None, + ) -> str: + """ + Synthesize state.todo from the window of messages. + """ + system_prompt = STATE_TODO_SYNTHESIS_PROMPT + user_prompt = ( + "TASK (overall purpose):\n" + "\n" + task.model_dump_json() + "\n\n\n" + "PRIOR_DONE (what has already been completed; avoid repeating it):\n" + "\n" + (prior_done_text or "") + "\n\n\n" + "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n" + "\n" + + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False) + + "\n\n\n" + "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n" + "\n" + current_action.model_dump_json() + "\n\n\n" + "Return ONLY the final paragraph." + ) + + try: + text = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ).content + return text.strip().strip('"').replace("\n", " ").strip() + except Exception: + # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation) + parts = [] + task_short = ( + task.issue_title or task.repository or task.issue_type or "the task" + ).strip() + parts.append(f"To progress on {task_short},") + # Build an imperative clause from action fields + verb = (current_action.name or "execute action").replace("_", " ") + tgt = f" {current_action.target}" if current_action.target else "" + via = f" via {current_action.tool}" if current_action.tool else "" + desc = f" to {current_action.description}" if current_action.description else "" + parts.append(f" {verb}{tgt}{via}{desc}.") + return " ".join(parts).replace(" ", " ").strip() + + def _extract_action_from_message(self, message: Message) -> Action: + """ + Extract action information from an assistant message that starts an action call. + + This method analyzes assistant messages to identify: + - Action type (read_file, edit_file, run_test, etc.) + - Target file or resource + - Tool being used + - Action description + + Args: + message: Assistant message containing action information + + Returns: + Action object with extracted information + """ + system_prompt = ACTION_EXTRACTION_PROMPT + user_prompt = ( + "Extract the `action` from the SINGLE assistant message below.\n" + "Use ONLY the content between and .\n\n" + "\n" + message.model_dump_json() + "\n" + ) + + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ).content + + # Robust JSON parse with one repair attempt + payload = self._normalize_llm_json_output(raw) + + action_obj = payload.get("action", {}) or {} + + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) + + return Action( + name=s(action_obj.get("name")), + description=s(action_obj.get("description")), + target=s(action_obj.get("target")), + tool=s(action_obj.get("tool")), + ) + + def _extract_result_from_window( + self, window_msgs: List[Message], current_action: Action + ) -> Result: + """ + Extract result information from message sequence. + + This method analyzes the outcome of an action by examining: + - Success/failure indicators + - Error messages or exceptions + - Exit codes + - Output descriptions + + Args: + messages: Messages containing result information + action: The action that was executed + + Returns: + Result object with outcome information + """ + last_msg = window_msgs[-1] + + system_prompt = RESULT_EXTRACTION_PROMPT + user_prompt = ( + "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n" + "\n" + current_action.model_dump_json() + "\n\n\n" + "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n" + "\n" + last_msg.model_dump_json() + "\n\n\n" + "Return ONLY the JSON object." + ) + + try: + raw = self.llm_service.model.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ).content + payload = self._normalize_llm_json_output(raw) + + r = (payload or {}).get("result", {}) or {} + + # Utility: convert to string + def s(x): + return x if isinstance(x, str) else ("" if x is None else str(x)) + + # Result type correction + r_type = s(r.get("type")).lower() + if r_type not in {"success", "failure", "partial", "unknown"}: + r_type = "unknown" + + # Text: one sentence, remove wrapping quotes + description = s(r.get("description")).strip().strip('"') + if not description: + description = "(none)" + + # exit_code: allow returning number or number string + exit_raw = r.get("exit_code") + exit_code = None + if ( + isinstance(exit_raw, (int, float)) + and str(int(exit_raw)) == str(exit_raw).split(".")[0] + ): + exit_code = int(exit_raw) + elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()): + exit_code = int(exit_raw.strip()) + + return Result(type=r_type, description=description, exit_code=exit_code) + + except Exception: + # Minimal fallback: only based on the last message (no LLM call) + src = (getattr(last_msg, "content", "") or "").strip() + lt = src.lower() + + if re.search(r"(error|exception|traceback|failed)", lt): + rtype = "failure" + elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt): + rtype = "success" + elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt): + rtype = "partial" + else: + rtype = "unknown" + + # Extract the last non-empty line as one sentence result, and truncate to ~60 words + lines = [ln.strip() for ln in src.splitlines() if ln.strip()] + summary = lines[-1] if lines else "(none)" + words = summary.split() + if len(words) > 60: + summary = " ".join(words[:60]) + + m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I) + code = int(m.group(1)) if m else None + + return Result(type=rtype, description=summary, exit_code=code) + + def _normalize_llm_bool_output(self, raw: str) -> bool: + """ + Normalize raw LLM output into a boolean. + Accepts variations like True/False, yes/no, 1/0, + possibly wrapped in quotes, code fences, or punctuation. + """ + text = raw.strip().lower() + + # remove common wrappers + text = text.strip(" \t\n\r\"'`") + if text.startswith("```"): + text = text.strip("`").strip() + + # direct boolean keywords + if text in {"true", "yes", "y", "1"}: + return True + if text in {"false", "no", "n", "0"}: + return False + + # fallback: fuzzy match + if text.startswith("true"): + return True + if text.startswith("false"): + return False + + return False + + def _normalize_llm_json_output(self, raw: str) -> Dict[str, Any]: + """ + Normalize raw LLM output into a JSON object (dict-like). + Single-argument API. All robustness heuristics are internal. + + Strategy (built-in, no extra args): + - Clean BOM/zero-width chars/HTML entities + - Strip markdown code fences (``` / ```json) + - Try direct json.loads + - If it is a quoted JSON string, unquote then parse + - Extract the first balanced JSON object/array (quote-aware) + - Apply limited trailing comma fixes + - Minify whitespace and retry + - Optionally do ONE LLM repair round if self.llm_service is available + - Always return a dict; non-dict top-level will be wrapped as {"_root": value} + """ + EXPECT_OBJECT = True # always return Dict[str, Any] + EXPECT_KEY: Optional[str] = None # no required top-level key + MAX_SCAN_CHARS = 200_000 # cap scanning for safety + ALLOW_REPAIR_WITH_LLM = True # try one LLM repair if available + + if not isinstance(raw, str): + raise ValueError("LLM output is not a string") + text = raw.strip() + if not text: + raise ValueError("Empty LLM output") + + # Hygiene + text = self._strip_bom_zwsp(text) + text = html.unescape(text) + text = self._strip_code_fences(text).strip() + + # 1) direct parse + obj = self._try_parse_direct(text, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 2) quoted JSON -> unquote then parse + obj = self._try_unquote_then_parse(text, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 3) extract first balanced JSON block (object/array), quote-aware + candidate = self._extract_first_json_balanced(text[:MAX_SCAN_CHARS]) + if candidate: + obj = self._try_parse_direct( + candidate, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY + ) + if obj is not None: + return obj + fixed = self._fix_trailing_commas(candidate) + obj = self._try_parse_direct(fixed, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 4) minify whitespace and retry + compact = self._minify_ws(candidate or text) + obj = self._try_parse_direct(compact, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY) + if obj is not None: + return obj + + # 5) optional single LLM repair (if service available) + if ALLOW_REPAIR_WITH_LLM and getattr(self, "llm_service", None) is not None: + repaired = self._repair_json_with_llm(text, expect_key=EXPECT_KEY) + obj = self._try_parse_direct( + repaired, expect_object=EXPECT_OBJECT, expect_key=EXPECT_KEY + ) + if obj is not None: + return obj + + snippet = (text[:800] + "...") if len(text) > 800 else text + raise ValueError(f"Failed to parse JSON from LLM output. Snippet: {snippet}") + + def _try_parse_direct( + self, s: str, *, expect_object: bool, expect_key: Optional[str] + ) -> Optional[Dict[str, Any]]: + """Try json.loads; coerce non-dict to {'_root': value}; enforce expect_key if given.""" + try: + val = json.loads(s) + except Exception: + return None + if expect_object and not isinstance(val, dict): + val = {"_root": val} + elif not isinstance(val, dict): + val = {"_root": val} + if expect_key and expect_key not in val: + return None + return val + + def _try_unquote_then_parse( + self, s: str, *, expect_object: bool, expect_key: Optional[str] + ) -> Optional[Dict[str, Any]]: + """If s is a JSON string containing JSON (e.g., '"{...}"'), unquote then parse.""" + try: + inner = json.loads(s) + except Exception: + return None + if isinstance(inner, str): + return self._try_parse_direct(inner, expect_object=expect_object, expect_key=expect_key) + if expect_object and not isinstance(inner, dict): + inner = {"_root": inner} + elif not isinstance(inner, dict): + inner = {"_root": inner} + if expect_key and expect_key not in inner: + return None + return inner + + def _strip_code_fences(self, s: str) -> str: + """Remove markdown code fences and leading 'json' language tag.""" + t = s.strip() + if t.startswith("```"): + t = re.sub(r"^\s*```(?:json)?\s*", "", t, flags=re.I) + t = re.sub(r"\s*```\s*$", "", t) + t = re.sub(r"```(?:json)?\s*([\s\S]*?)\s*```", r"\1", t, flags=re.I) + return t + + def _strip_bom_zwsp(self, s: str) -> str: + """Remove BOM and zero-width characters that can break parsers.""" + s = s.lstrip("\ufeff") + return re.sub(r"[\u200B-\u200D\u2060\uFEFF]", "", s) + + def _minify_ws(self, s: str) -> str: + """Collapse excessive newlines and surrounding whitespace (keeps spaces inside strings).""" + return re.sub(r"\s*\n\s*", " ", s).strip() + + def _fix_trailing_commas(self, s: str) -> str: + """Limited safe fix: ',}' -> '}' and ',]' -> ']'.""" + t = re.sub(r",\s*}", "}", s) + t = re.sub(r",\s*]", "]", t) + return t + + def _extract_first_json_balanced(self, text: str) -> str: + """Find the first balanced {...} or [...] block (quote/escape-aware).""" + for opener, closer in (("{", "}"), ("[", "]")): + block = self._scan_balanced(text, opener, closer) + if block: + return block + return "" + + def _scan_balanced(self, t: str, opener: str, closer: str) -> str: + """Quote/escape-aware balanced scanner to locate a JSON block.""" + start = t.find(opener) + if start == -1: + return "" + depth = 0 + in_str = False + esc = False + quote_char = "" + for i in range(start, len(t)): + ch = t[i] + if in_str: + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == quote_char: + in_str = False + continue + if ch in ('"', "'"): + in_str = True + quote_char = ch + esc = False + continue + if ch == opener: + depth += 1 + elif ch == closer: + depth -= 1 + if depth == 0: + return t[start : i + 1] + return "" + + def _repair_json_with_llm(self, raw: str, expect_key: Optional[str]) -> str: + """One-shot LLM repair to return MINIFIED valid JSON only (no prose/fences).""" + constraint = f" It MUST contain the top-level key '{expect_key}'." if expect_key else "" + try: + repaired = self.llm_service.model.invoke( + [ + { + "role": "system", + "content": ( + "You are a JSON normalizer. " + "Return ONLY a valid MINIFIED JSON object. " + "No explanations, no code fences, no comments." + constraint + ), + }, + {"role": "user", "content": raw}, + ] + ) + if isinstance(repaired, str): + return repaired.strip() + content = getattr(repaired, "content", None) + if isinstance(content, str): + return content.strip() + return str(repaired).strip() + except Exception: + return raw + + +if __name__ == "__main__": + service = MemoryExtractionService( + llm_service=LLMService( + model_name="vertex:gemini-2.5-flash", + model_temperature=0.0, + model_max_input_tokens=8192, + model_max_output_tokens=8192, + ) + ) + service.extract_from_huggingface_trajectory_repository( + repo_name="SWE-Gym/OpenHands-SFT-Trajectories", + split="train.success.oss", + ) diff --git a/athena/configuration/config.py b/athena/configuration/config.py index 62366ed..7c169ac 100644 --- a/athena/configuration/config.py +++ b/athena/configuration/config.py @@ -40,5 +40,11 @@ class Settings(BaseSettings): MEMORY_MAX_STORED_UNITS: int = 1000 # Maximum number of memory units to store MEMORY_SEARCH_LIMIT: int = 10 # Default search result limit + # Embeddings (OpenAI-format, works with Codestral embed via a compatible gateway) + EMBEDDINGS_MODEL: Optional[str] = None + EMBEDDINGS_API_KEY: Optional[str] = None + EMBEDDINGS_BASE_URL: Optional[str] = None + EMBEDDINGS_DIM: Optional[int] = None + settings = Settings() diff --git a/athena/entity/__init__.py b/athena/entity/__init__.py new file mode 100644 index 0000000..f5b71d1 --- /dev/null +++ b/athena/entity/__init__.py @@ -0,0 +1,3 @@ +from .memory import MemoryUnitDB + +__all__ = ["MemoryUnitDB"] diff --git a/athena/entity/memory.py b/athena/entity/memory.py new file mode 100644 index 0000000..a816010 --- /dev/null +++ b/athena/entity/memory.py @@ -0,0 +1,147 @@ +import json +from datetime import datetime +from typing import Optional + +from pgvector.sqlalchemy import Vector +from sqlalchemy import Column, DateTime, Text +from sqlmodel import Field, SQLModel, UniqueConstraint + +from athena.models import Action, MemorySource, MemoryTimestamp, MemoryUnit, Result, State, Task + + +# Database models for persistent storage +class MemoryUnitDB(SQLModel, table=True): + """Database model for persistent storage of memory units.""" + + __tablename__ = "memory_units" + __table_args__ = (UniqueConstraint("memory_id", name="uq_memory_id"),) + + id: Optional[int] = Field(default=None, primary_key=True) + + # Source information + memory_id: str + memory_source_name: str + memory_run_id: str + memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True))) + memory_updated_at: Optional[datetime] = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + memory_invalid_at: Optional[datetime] = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + memory_metadata: str = Field( + default="{}", sa_column=Column(Text) + ) # JSON string for Dict[str, Any] + + # Task information + task_issue_title: str = Field(sa_column=Column(Text)) + task_issue_body: str = Field(sa_column=Column(Text)) + task_issue_comments: str = Field(sa_column=Column(Text)) + task_issue_type: str + task_repository: str + + # State information + state_done: str = Field(sa_column=Column(Text)) + state_todo: str = Field(sa_column=Column(Text)) + state_open_file: Optional[str] = None + state_working_dir: Optional[str] = None + state_extra_environment: str = Field( + default="{}", sa_column=Column(Text) + ) # JSON string for Dict[str, str] + + # Action information + action_name: str + action_description: str = Field(sa_column=Column(Text)) + action_target: str + action_tool: str + + # Result information + result_type: str + result_description: str = Field(sa_column=Column(Text)) + result_exit_code: Optional[int] = None + + # Embeddings for semantic retrieval (optional, pgvector recommended). Stored as float arrays + task_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) + state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) + task_state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) + + @classmethod + def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": + """Create a database model from a MemoryUnit.""" + return cls( + memory_id=memory_unit.memory_id, + memory_source_name=memory_unit.source.source_name, + memory_run_id=memory_unit.source.run_id, + memory_created_at=memory_unit.timestamp.created_at, + memory_updated_at=memory_unit.timestamp.updated_at, + memory_invalid_at=memory_unit.timestamp.invalid_at, + memory_metadata=json.dumps(memory_unit.source.metadata) + if memory_unit.source.metadata + else "{}", + task_issue_title=memory_unit.task.issue_title, + task_issue_body=memory_unit.task.issue_body, + task_issue_comments=memory_unit.task.issue_comments, + task_issue_type=memory_unit.task.issue_type, + task_repository=memory_unit.task.repository, + state_done=memory_unit.state.done, + state_todo=memory_unit.state.todo, + state_open_file=memory_unit.state.open_file, + state_working_dir=memory_unit.state.working_dir, + state_extra_environment=json.dumps(memory_unit.state.extra_environment) + if memory_unit.state.extra_environment + else "{}", + action_name=memory_unit.action.name, + action_description=memory_unit.action.description, + action_target=memory_unit.action.target, + action_tool=memory_unit.action.tool, + result_type=memory_unit.result.type, + result_description=memory_unit.result.description, + result_exit_code=memory_unit.result.exit_code, + ) + + def to_memory_unit(self) -> MemoryUnit: + """Convert database model back to MemoryUnit.""" + return MemoryUnit( + memory_id=self.memory_id, + timestamp=MemoryTimestamp( + created_at=self.memory_created_at, + updated_at=self.memory_updated_at, + invalid_at=self.memory_invalid_at, + ), + source=MemorySource( + source_name=self.memory_source_name, + run_id=self.memory_run_id, + metadata=json.loads(self.memory_metadata) + if self.memory_metadata not in (None, "", "null") + else {}, + ), + task=Task( + issue_title=self.task_issue_title, + issue_body=self.task_issue_body, + issue_comments=self.task_issue_comments, + issue_type=self.task_issue_type, + repository=self.task_repository, + ), + state=State( + done=self.state_done, + todo=self.state_todo, + open_file=self.state_open_file, + working_dir=self.state_working_dir, + extra_environment=json.loads(self.state_extra_environment) + if self.state_extra_environment not in (None, "", "null") + else {}, + ), + action=Action( + name=self.action_name, + description=self.action_description, + target=self.action_target, + tool=self.action_tool, + ), + result=Result( + type=self.result_type + if self.result_type in ["success", "failure", "partial", "unknown"] + else "unknown", + description=self.result_description, + exit_code=self.result_exit_code, + ), + ) diff --git a/athena/models/__init__.py b/athena/models/__init__.py index 7d32594..5228135 100644 --- a/athena/models/__init__.py +++ b/athena/models/__init__.py @@ -1,9 +1,8 @@ from .memory import ( Action, - MemoryContext, + MemorySource, MemoryTimestamp, MemoryUnit, - MemoryUnitDB, Result, State, Task, @@ -13,11 +12,10 @@ __all__ = [ "Message", "MemoryUnit", - "MemoryContext", + "MemorySource", "MemoryTimestamp", "Task", "State", "Action", "Result", - "MemoryUnitDB", ] diff --git a/athena/models/memory.py b/athena/models/memory.py index 11fe613..4776665 100644 --- a/athena/models/memory.py +++ b/athena/models/memory.py @@ -1,17 +1,17 @@ -import json import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, Literal, Optional from pydantic import BaseModel, Field -from sqlalchemy import Column, DateTime, Text -from sqlmodel import SQLModel, UniqueConstraint class MemoryTimestamp(BaseModel): """Lifecycle timestamps for a memory unit.""" - created_at: datetime = Field(None, description="When the memory unit was first created") + created_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="When the memory unit was first created", + ) updated_at: Optional[datetime] = Field( None, description="When the memory was last updated/refreshed" ) @@ -20,18 +20,13 @@ class MemoryTimestamp(BaseModel): ) -class MemoryContext(BaseModel): - """Context metadata for a memory unit.""" +class MemorySource(BaseModel): + """Source information for a memory unit.""" - memory_id: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique identifier for this memory unit", - ) - source: str = Field( + source_name: str = Field( ..., description="Memory source, e.g., agent name, model name, dataset name, or file path" ) run_id: str = Field(..., description="Unique id for agent run or trajectory dataset") - timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory") metadata: Dict[str, Any] = Field( default_factory=dict, description="Optional extension field for custom metadata" ) @@ -103,148 +98,22 @@ class MemoryUnit(BaseModel): Core memory unit capturing one action of agent execution. This includes: - - The memory context (memory_id, source, run_id, timestamps, metadata) + - The memory id (memory_id) + - The memory timestamp (timestamp) + - The memory source (source_name, run_id, metadata) - The task being worked on (issue and repository details) - The current state (what's done, what's todo) - The action taken by the agent - The result of the action """ - context: MemoryContext + memory_id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for this memory unit", + ) + timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory") + source: MemorySource task: Task state: State action: Action result: Result - - -# Database models for persistent storage -class MemoryUnitDB(SQLModel, table=True): - """Database model for persistent storage of memory units.""" - - __tablename__ = "memory_units" - __table_args__ = UniqueConstraint("memory_id", name="uq_memory_id") - - id: Optional[int] = Field(default=None, primary_key=True) - - # Context information - memory_id: str - memory_source: str - memory_run_id: str - memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True))) - memory_updated_at: Optional[datetime] = Field( - default=None, sa_column=Column(DateTime(timezone=True)) - ) - memory_invalid_at: Optional[datetime] = Field( - default=None, sa_column=Column(DateTime(timezone=True)) - ) - memory_metadata: str = Field( - default="{}", sa_column=Column(Text) - ) # JSON string for Dict[str, Any] - - # Task information - task_issue_title: str = Field(sa_column=Column(Text)) - task_issue_body: str = Field(sa_column=Column(Text)) - task_issue_comments: str = Field(sa_column=Column(Text)) - task_issue_type: str - task_repository: str - - # State information - state_done: str = Field(sa_column=Column(Text)) - state_todo: str = Field(sa_column=Column(Text)) - state_open_file: Optional[str] = None - state_working_dir: Optional[str] = None - state_extra_environment: str = Field( - default="{}", sa_column=Column(Text) - ) # JSON string for Dict[str, str] - - # Action information - action_name: str - action_description: str = Field(sa_column=Column(Text)) - action_target: str - action_tool: str - - # Result information - result_type: str - result_description: str = Field(sa_column=Column(Text)) - result_exit_code: Optional[int] = None - - @classmethod - def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": - """Create a database model from a MemoryUnit.""" - return cls( - memory_id=memory_unit.context.memory_id, - memory_source=memory_unit.context.source, - memory_run_id=memory_unit.context.run_id, - memory_created_at=memory_unit.context.timestamp.created_at, - memory_updated_at=memory_unit.context.timestamp.updated_at, - memory_invalid_at=memory_unit.context.timestamp.invalid_at, - memory_metadata=json.dumps(memory_unit.context.metadata) - if memory_unit.context.metadata - else "{}", - task_issue_title=memory_unit.task.issue_title, - task_issue_body=memory_unit.task.issue_body, - task_issue_comments=memory_unit.task.issue_comments, - task_issue_type=memory_unit.task.issue_type, - task_repository=memory_unit.task.repository, - state_done=memory_unit.state.done, - state_todo=memory_unit.state.todo, - state_open_file=memory_unit.state.open_file, - state_working_dir=memory_unit.state.working_dir, - state_extra_environment=json.dumps(memory_unit.state.extra_environment) - if memory_unit.state.extra_environment - else "{}", - action_name=memory_unit.action.name, - action_description=memory_unit.action.description, - action_target=memory_unit.action.target, - action_tool=memory_unit.action.tool, - result_type=memory_unit.result.type, - result_description=memory_unit.result.description, - result_exit_code=memory_unit.result.exit_code, - ) - - def to_memory_unit(self) -> MemoryUnit: - """Convert database model back to MemoryUnit.""" - return MemoryUnit( - context=MemoryContext( - memory_id=self.memory_id, - source=self.memory_source, - run_id=self.memory_run_id, - timestamp=MemoryTimestamp( - created_at=self.memory_created_at, - updated_at=self.memory_updated_at, - invalid_at=self.memory_invalid_at, - ), - metadata=json.loads(self.memory_metadata) - if self.memory_metadata not in (None, "", "null") - else {}, - ), - task=Task( - issue_title=self.task_issue_title, - issue_body=self.task_issue_body, - issue_comments=self.task_issue_comments, - issue_type=self.task_issue_type, - repository=self.task_repository, - ), - state=State( - done=self.state_done, - todo=self.state_todo, - open_file=self.state_open_file, - working_dir=self.state_working_dir, - extra_environment=json.loads(self.state_extra_environment) - if self.state_extra_environment not in (None, "", "null") - else {}, - ), - action=Action( - name=self.action_name, - description=self.action_description, - target=self.action_target, - tool=self.action_tool, - ), - result=Result( - type=self.result_type - if self.result_type in ["success", "failure", "partial", "unknown"] - else "unknown", - description=self.result_description, - exit_code=self.result_exit_code, - ), - ) diff --git a/athena/models/message.py b/athena/models/message.py index 5e4f2c4..4b66896 100644 --- a/athena/models/message.py +++ b/athena/models/message.py @@ -12,6 +12,7 @@ "sys": "system", "system_prompt": "system", "sysmsg": "system", + "human": "user", } diff --git a/athena/prompts/__init__.py b/athena/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/init_memory_base/prompt.py b/athena/prompts/memory_extraction.py similarity index 84% rename from init_memory_base/prompt.py rename to athena/prompts/memory_extraction.py index 3ce2a80..1238991 100644 --- a/init_memory_base/prompt.py +++ b/athena/prompts/memory_extraction.py @@ -120,6 +120,41 @@ """ +ACTION_JUDGE_PROMPT = """ +You are a STRICT binary classifier for assistant messages. Decide whether the message initiates or executes a concrete ACTION in an agent environment. + +## Decision Target +Return **True** if the message triggers, schedules, or immediately proceeds to perform an operation — even when the operation is written in natural language or embedded in a code/markdown block. + +### ACTION indicators (non-exhaustive; if ambiguous, prefer **True**) +- File/FS ops: create/open/read/write/edit/delete/rename/move/copy, save, apply patch, replace/insert, scroll_up/scroll_down, navigate. +- Search/explore ops: find_file/grep/search/locate/list (files/dirs/symbols), project-wide search. +- Execution/invocation: run/execute/launch/call/invoke a tool or command (pytest/python/bash/node/npm/pip/curl/sql/git), function_call objects, start/stop processes, HTTP requests. +- Editor/tool calls: str_replace_editor.*, execute_bash, tool:{...}, Action: , ..., OpenAI function_call objects, XML/JSON tool blocks. +- Repair/fix/change steps: "Let's fix...", "Edit the constructor/method...", "Implement changes", "Apply the patch". +- Committed immediate steps in first person or imperative: "Now, let's ...", "First, I'll create ...", "Next, run ...", "Proceed to ...". +- Fenced code/markdown blocks that contain actionable commands (even if preceded or followed by explanations), e.g.: + - ```find_file "dense_ndim_array.py"``` + - ```scroll_down``` + - ```create reproduce_bug.py``` + - ```pytest -q``` + - ```python reproduce_error.py``` + +### NOT actions (return **False**) — only if clearly none of the above: +- Pure analysis, planning, or summarization with no concrete operation. +- Pure thought process or reasoning steps without executing/triggering any operation. +- Questions, confirmations, or instructions to a human with no tool/command to be executed. +- Code shown purely for illustration (no intent to execute/apply/edit), e.g., "Here is an example snippet:" followed by code not framed as an immediate step. + +### Ambiguity policy +If the case is ambiguous, **prefer True** (maximize recall for action messages). + +## Output +Output EXACTLY one token on a single line: True or False +No punctuation, quotes, or explanations. +""" + + ACTION_EXTRACTION_PROMPT = """ Extract ONLY the `action` object from ONE assistant message that starts a tool/action call. diff --git a/init_memory_base/dataset_pre.py b/init_memory_base/dataset_pre.py deleted file mode 100644 index 858608c..0000000 --- a/init_memory_base/dataset_pre.py +++ /dev/null @@ -1,10 +0,0 @@ -from datasets import load_dataset - -dataset = load_dataset("SWE-Gym/OpenHands-SFT-Trajectories", split="train.success.oss") -dataset = load_dataset("ZhengyanShi/Poutine-OpenHands-SFT-Trajectories", split="train") -dataset = load_dataset("SWE-Gym/OpenHands-Sampled-Trajectories", split="train.raw") -dataset = load_dataset("SWE-bench/SWE-smith-trajectories") -dataset = load_dataset("nebius/SWE-agent-trajectories", split="train") -dataset = load_dataset("zai-org/SWE-Dev-train") -dataset = load_dataset("JetBrains-Research/swe-traj-complete", split="train") -# git clone https://huggingface.co/datasets/ByteDance-Seed/Multi-SWE-bench_trajs diff --git a/init_memory_base/prompts/code_skeleton.txt b/init_memory_base/prompts/code_skeleton.txt deleted file mode 100644 index 1b9b25a..0000000 --- a/init_memory_base/prompts/code_skeleton.txt +++ /dev/null @@ -1,31 +0,0 @@ -You are an expert software engineering assistant. -Given an **observation** about a coding task and the **source code** of a function/class, -your goal is to extract a **code skeleton** that highlights only the focused regions relevant to the observation. - -### Instructions -1. Keep only the **main import statements** at the top of the file: - - Keep imports that are directly used in the focused regions. - - Keep imports that define critical dependencies (e.g., standard libraries like `os`, `sys`, or key frameworks like `torch`, `pandas`). - - Collapse all other imports into a single line: `... # other imports omitted`. -2. Keep class and function headers **with full signatures**: - - Include parameter names and types (if available). - - Include return type or return value description (infer if not explicit; keep concise, e.g., "-> bool"). -3. Preserve necessary surrounding structure for readability (class/def blocks, braces/indentation). -4. Inside classes/functions: - - Keep only lines relevant to the current observation (**focused logic** and **key invocations**). - - Replace non-relevant lines with a single line containing `...`. -5. Provide ±3 lines of context around each focused region, if available. -6. Do **not** add region markers or any textual summary. Output only the code skeleton. -7. Do not introduce new identifiers, logic, control flow, or API calls that are not present in the source. -8. Do not infer types/return types that are not explicit; if unknown, use -> Unknown or omit. -9. Output one fenced code block with correct language tag. No prose outside. - -### Input -Observation: - - -Source Code: - - -### Output (Code Skeleton Only) - diff --git a/init_memory_base/prompts/text_summary.txt b/init_memory_base/prompts/text_summary.txt deleted file mode 100644 index f37456a..0000000 --- a/init_memory_base/prompts/text_summary.txt +++ /dev/null @@ -1,31 +0,0 @@ -You are an expert software engineering assistant. -Given an **observation** about a coding task and the **source code** of a function/class, -your goal is to extract a **concise textual summary** that captures the essential intent, behavior, and role of the code in relation to the observation. - -### Instructions -1. **Focus on relevance**: Summaries must highlight aspects of the code most related to the observation (e.g., functionality, data flow, critical logic). -2. **Content to include**: - - The **purpose** of the function/class (what it does, high-level intent). - - The **inputs/outputs** (parameters, return values, side effects). - - The **key operations or algorithms** performed (control flow, API/library usage). - - Any **dependencies** or notable relationships (calls to other functions/classes). -3. **Conciseness**: - - Write in **2–4 sentences**. - - Avoid unnecessary details (variable names, trivial operations). - - Use plain, precise technical language. -4. **No speculation**: - - If the code’s intent is unclear, state it as *“uncertain”* rather than guessing. - - Do not introduce functionality that is not evident in the code. -5. **Output format**: - - One concise **paragraph of plain text only** (no code, no bullet points, no extra markup). - - Do not repeat the observation verbatim. - -### Input -Observation: - - -Source Code: - - -### Output (Textual Summary Only) - diff --git a/init_memory_base/trajectory_memory_extractor_deepseek.py b/init_memory_base/trajectory_memory_extractor_deepseek.py deleted file mode 100644 index 5f084dd..0000000 --- a/init_memory_base/trajectory_memory_extractor_deepseek.py +++ /dev/null @@ -1,692 +0,0 @@ -import hashlib -import json -import os -import random -import re -import time -from typing import Dict, Iterable, List, Literal, Optional, Protocol - -import requests -from datasets import load_dataset -from prompt import ( - ACTION_EXTRACTION_PROMPT, - RESULT_EXTRACTION_PROMPT, - STATE_DONE_SUMMARY_PROMPT, - STATE_TODO_SYNTHESIS_PROMPT, - TASK_EXTRACTION_PROMPT, -) -from pydantic import BaseModel, Field -from tqdm import tqdm - -# ============================= -# 1) Pydantic data structures -# ============================= - -# memory_unit_schema = { -# "trajectory": { -# "source": str, -# "id": str, -# }, -# "task": { -# "issue_title": str, -# "issue_body": str, -# "issue_comments": str, -# "issue_type": str, # e.g., bug, feature, documentation, question -# "repository": str, -# }, -# "state": { -# "done": str, # summaries of the completed work -# "todo": str, # the work to be done, decided on the next action -# # "accessed_code": str, -# }, -# "action": { -# "name": str, # e.g., read_file, edit_file, run_test, invoke_tool -# "description": str, -# "target": str, # e.g., utils/math.py, tests/test_math.py, pytest, git -# "tool": str, # e.g., pytest, git, search -# }, -# "result": { -# "type": str, # success | failure | partial | unknown -# "description": str, -# "exit_code": str, -# }, -# } - - -class TrajectoryMessage(BaseModel): - content: str - role: str # "system" | "user" | "assistant" | tool etc. - # TODO: chech style of trajectory message - - -class TrajectoryMeta(BaseModel): - source: str = Field(..., description="e.g., dataset name, run id, or file path") - id: str = Field(..., description="unique id for the trajectory") - - -class Task(BaseModel): - issue_title: str - issue_body: str - issue_comments: str - issue_type: str # bug | feature | documentation | question | other - repository: str - - -class State(BaseModel): - done: str - todo: str - - -class Action(BaseModel): - name: str # read_file | edit_file | run_test | invoke_tool | etc. - description: str - target: str # utils/math.py | tests/test_math.py | pytest | git - tool: str # pytest | git | search | editor | bash - - -ResultType = Literal["success", "failure", "partial", "unknown"] - - -class Result(BaseModel): - type: ResultType = "unknown" # success | failure | partial | unknown - description: str = "" # one-sentence result summary (for human) - exit_code: Optional[int] = None # if extractable from logs - - -class MemoryUnit(BaseModel): - trajectory: TrajectoryMeta - task: Task - state: State - action: Action - result: Result - - def key(self) -> str: - # Canonical key for dedup (task+state) as requested - payload = json.dumps( - { - "task": self.task.model_dump(), - "state": self.state.model_dump(), - }, - sort_keys=True, - ) - return hashlib.sha256(payload.encode("utf-8")).hexdigest() - - -# Optional: In-memory base and protocol you can replace with your own DB/Vector store -class MemoryBaseProtocol(Protocol): - def upsert(self, units: Iterable[MemoryUnit]) -> None: ... - - -class InMemoryMemoryBase: - def __init__(self): - self._store: Dict[str, MemoryUnit] = {} - - def upsert(self, units: Iterable[MemoryUnit]) -> None: - for u in units: - self._store[u.key()] = u - - def all(self) -> List[MemoryUnit]: - return list(self._store.values()) - - -# ============================================= -# 2) DeepSeek (OpenAI-compatible) API utilities -# ============================================= - -DEEPSEEK_API_BASE = os.environ.get("DEEPSEEK_API_BASE", "https://api.deepseek.com") -DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "") -DEEPSEEK_MODEL = os.environ.get("DEEPSEEK_MODEL", "deepseek-chat") - -HEADERS = { - "Authorization": f"Bearer {DEEPSEEK_API_KEY}", - "Content-Type": "application/json", -} - - -class RateLimit(Exception): - pass - - -_SESSION: Optional[requests.Session] = None - - -def _get_session() -> requests.Session: - global _SESSION - if _SESSION is None: - s = requests.Session() - adapter = requests.adapters.HTTPAdapter( - pool_connections=10, pool_maxsize=50, max_retries=0, pool_block=True - ) - s.mount("https://", adapter) - s.mount("http://", adapter) - _SESSION = s - return _SESSION - - -def _post_with_retries( - session: requests.Session, - url: str, - payload: dict, - *, - max_retries: int = 5, - retry_base: float = 1.25, - timeout=(10, 60), -) -> str: - if not DEEPSEEK_API_KEY: - raise RuntimeError("DEEPSEEK_API_KEY is not set.") - - for attempt in range(max_retries): - try: - headers = { - "Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY', '')}", - "Content-Type": "application/json", - } - resp = session.post(url, headers=headers, json=payload, timeout=timeout) - if resp.status_code == 429: - raise RateLimit("Rate limited") - resp.raise_for_status() - data = resp.json() - content = data.get("choices", [{}])[0].get("message", {}).get("content", "") - if not content: - raise ValueError("Empty content from API") - return content - - except (requests.HTTPError, requests.ConnectionError, requests.Timeout, RateLimit) as e: - if attempt == max_retries - 1: - raise - # Optional: Rebuild Session on connection errors - if isinstance(e, (requests.ConnectionError,)): - global _SESSION - _SESSION = None - session = _get_session() - time.sleep((retry_base**attempt) + random.uniform(0, 0.25)) - - -def call_deepseek( - messages: List[Dict[str, str]], - *, - temperature: float = 0.0, - max_retries: int = 5, - retry_base: float = 1.25, - timeout: int = 60, -) -> str: - """Call DeepSeek Chat Completions endpoint and return assistant text.""" - url = f"{DEEPSEEK_API_BASE.rstrip('/')}/v1/chat/completions" - payload = { - "model": DEEPSEEK_MODEL, - "messages": messages, - "temperature": temperature, - "response_format": {"type": "json_object"}, - # Optional: Limit return length to further reduce latency - # "max_tokens": 512, - # "n": 1, # number of completions to generate - } - session = _get_session() - return _post_with_retries( - session, url, payload, max_retries=max_retries, retry_base=retry_base, timeout=timeout - ) - - -# ===================================== -# 3) Orchestration over a trajectory -# ===================================== - - -def extract_result_from_last_message( - window_msgs: List[TrajectoryMessage], current_action: Action -) -> Result: - """ - Use DeepSeek to extract the `result` from the LAST message of a window. - - Evidence-bound to the last message ONLY. - - Output must be ONE plain-text sentence capturing the definitive outcome. - - Minimal heuristics; a tiny fallback is used only if the API fails. - - Assumes available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - json imported - """ - if not window_msgs: - return Result(type="unknown", description="(none)") - last_msg = window_msgs[-1] - - system_prompt = RESULT_EXTRACTION_PROMPT - user_prompt = ( - "CURRENT_ACTION (reference for alignment; do NOT invent beyond LAST_MESSAGE):\n" - "\n" + current_action.model_dump_json() + "\n\n\n" - "LAST_MESSAGE (extract the definitive outcome from here ONLY):\n" - "\n" + last_msg.model_dump_json() + "\n\n\n" - "Return ONLY the JSON object." - ) - - try: - raw = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, - ) - # Parse JSON; if failed, try to fix it to a valid JSON - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = call_deepseek( - [ - {"role": "system", "content": "Return a valid JSON object only. No comments."}, - {"role": "user", "content": raw}, - ], - temperature=0.0, - ) - payload = json.loads(fixed) - - r = (payload or {}).get("result", {}) or {} - - # Utility: convert to string - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - - # Result type correction - r_type = s(r.get("type")).lower() - if r_type not in {"success", "failure", "partial", "unknown"}: - r_type = "unknown" - - # Text: one sentence, remove wrapping quotes - description = s(r.get("description")).strip().strip('"') - if not description: - description = "(none)" - - # exit_code: allow returning number or number string - exit_raw = r.get("exit_code") - exit_code = None - if isinstance(exit_raw, (int, float)) and str(int(exit_raw)) == str(exit_raw).split(".")[0]: - exit_code = int(exit_raw) - elif isinstance(exit_raw, str) and re.fullmatch(r"-?\d+", exit_raw.strip()): - exit_code = int(exit_raw.strip()) - - return Result(type=r_type, description=description, exit_code=exit_code) - - except Exception: - # Minimal fallback: only based on the last message (no LLM call) - src = (getattr(last_msg, "content", "") or "").strip() - lt = src.lower() - - if re.search(r"(error|exception|traceback|failed)", lt): - rtype = "failure" - elif re.search(r"(exit code\s*0|succeeded|completed|all tests passed|passed\b)", lt): - rtype = "success" - elif re.search(r"(some tests failed|failures?:\s*[1-9]|errors?:\s*[1-9])", lt): - rtype = "partial" - else: - rtype = "unknown" - - # Extract the last non-empty line as one sentence result, and truncate to ~60 words - lines = [ln.strip() for ln in src.splitlines() if ln.strip()] - summary = lines[-1] if lines else "(none)" - words = summary.split() - if len(words) > 60: - summary = " ".join(words[:60]) - - m = re.search(r"exit code\s+(-?\d+)", src, flags=re.I) - code = int(m.group(1)) if m else None - - return Result(type=rtype, description=summary, exit_code=code) - - -def synthesize_state_todo_paragraph( - window_msgs: List[TrajectoryMessage], - task: Task, - prior_done_text: str = "", - current_action: Optional[Action] = None, -) -> str: - """ - Call DeepSeek to synthesize `state.todo` as ONE highly condensed paragraph - capturing the intent of the CURRENT message window. - - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - json imported - """ - - system_prompt = STATE_TODO_SYNTHESIS_PROMPT - user_prompt = ( - "TASK (overall purpose):\n" - "\n" + task.model_dump_json() + "\n\n\n" - "PRIOR_DONE (what has already been completed; avoid repeating it):\n" - "\n" + (prior_done_text or "") + "\n\n\n" - "WINDOW_MSGS (current window to derive the immediate intent of the next step; do NOT output tools/commands/paths):\n" - "\n" - + json.dumps([msg.model_dump() for msg in window_msgs], ensure_ascii=False) - + "\n\n\n" - "CURRENT_ACTION (use ONLY to ensure the intent refers to the same component/area; do NOT include or paraphrase any of its details):\n" - "\n" + current_action.model_dump_json() + "\n\n\n" - "Return ONLY the final paragraph." - ) - - try: - text = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, - ) - return text.strip().strip('"').replace("\n", " ").strip() - except Exception: - # Minimal deterministic fallback using available CURRENT_ACTION fields (no speculation) - parts = [] - task_short = (task.issue_title or task.repository or task.issue_type or "the task").strip() - parts.append(f"To progress on {task_short},") - # Build an imperative clause from action fields - verb = (current_action.name or "execute action").replace("_", " ") - tgt = f" {current_action.target}" if current_action.target else "" - via = f" via {current_action.tool}" if current_action.tool else "" - desc = f" to {current_action.description}" if current_action.description else "" - parts.append(f" {verb}{tgt}{via}{desc}.") - return " ".join(parts).replace(" ", " ").strip() - - -def extract_action_from_first_action_msg(first_action_msg: TrajectoryMessage) -> Action: - """ - Use DeepSeek to extract the `action` object from a SINGLE assistant message - that starts an action call. Minimal heuristics; evidence-bound to the message. - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - Action Pydantic model - - json imported - """ - system_prompt = ACTION_EXTRACTION_PROMPT - user_prompt = ( - "Extract the `action` from the SINGLE assistant message below.\n" - "Use ONLY the content between and .\n\n" - "\n" + first_action_msg.model_dump_json() + "\n" - ) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - raw = call_deepseek(messages, temperature=0.0) - - # Robust JSON parse with one repair attempt - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fixed = call_deepseek( - [ - {"role": "system", "content": "Return a valid JSON object only. No comments."}, - {"role": "user", "content": raw}, - ], - temperature=0.0, - ) - payload = json.loads(fixed) - - action_obj = payload.get("action", {}) or {} - - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - - return Action( - name=s(action_obj.get("name")), - description=s(action_obj.get("description")), - target=s(action_obj.get("target")), - tool=s(action_obj.get("tool")), - ) - - -def finalize_window_with_llm( - window_msgs: List[TrajectoryMessage], - first_action_msg: TrajectoryMessage, - prior_units: List[MemoryUnit], - meta: TrajectoryMeta, - task: Task, -) -> Optional[MemoryUnit]: - """ - Extract a single memory unit from a window of messages. - - Synthesize state.done from prior actions - - Synthesize state.todo by extracting intents from window_msgs - - Extract action from first_action_msg - - Extract result from the last message in window_msgs - """ - action = extract_action_from_first_action_msg(first_action_msg) - state_done = summarize_context_from_units(prior_units, task) - state_todo = synthesize_state_todo_paragraph(window_msgs, task, state_done, action) - result = extract_result_from_last_message(window_msgs, action) - return MemoryUnit( - trajectory=meta, - task=task, - state=State( - done=state_done, - todo=state_todo, - ), - action=action, - result=result, - ) - - -def is_action_start(msg: TrajectoryMessage) -> bool: - """Return True if this message starts an action call, else False.""" - if getattr(msg, "role", "") != "assistant": - return False - content = getattr(msg, "content", "") or "" - return bool(re.search(r"", content, flags=re.DOTALL)) - # TODO: content check, different trajectory style - - -def extract_memory_units_by_action_windows( - meta: TrajectoryMeta, - task: Task, - trajectory: List[TrajectoryMessage], - *, - memory_base: Optional[MemoryBaseProtocol] = None, -) -> List[MemoryUnit]: - ordered_memory_units: List[MemoryUnit] = [] - window_msgs: List[TrajectoryMessage] = [] - window_first_action: Optional[TrajectoryMessage] = None - - for msg in tqdm(trajectory, desc="Extracting memory units"): - if is_action_start(msg): - if window_msgs and window_first_action is not None: - mu = finalize_window_with_llm( - window_msgs, - window_first_action, - prior_units=ordered_memory_units, - meta=meta, - task=task, - ) - if mu: - ordered_memory_units.append(mu) - - window_msgs = [msg] - window_first_action = msg - else: - if window_msgs: - window_msgs.append(msg) - else: - continue - - if window_msgs and window_first_action is not None: - mu = finalize_window_with_llm( - window_msgs, window_first_action, prior_units=ordered_memory_units, meta=meta, task=task - ) - if mu: - ordered_memory_units.append(mu) - - if memory_base is not None and ordered_memory_units: - memory_base.upsert(ordered_memory_units) - - return ordered_memory_units - - -def summarize_context_from_units(units: List[MemoryUnit], task: Task) -> str: - """ - Summarize previous context into a concise `state.done` string by calling DeepSeek. - Uses only the last 10 memory units and asks the model to produce a single, plain-text - summary of what has ALREADY BEEN COMPLETED (no plans). - - Dependencies assumed available: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - MemoryUnit Pydantic model - - json imported - """ - if not units: - return "(none)" - - # Keep the window modest to control prompt size - window = units[-10:] - - # Pack minimal, evidence-bound context for the LLM (no heavy heuristics) - history = [] - for u in window: - history.append( - { - "state": { - "done": u.state.done, - "todo": u.state.todo, - }, - "action": { - "name": u.action.name, - "description": u.action.description, - "target": u.action.target, - "tool": u.action.tool, - }, - "result": { - "type": u.result.type, - "description": u.result.description, - "exit_code": u.result.exit_code, - }, - } - ) - - system_prompt = STATE_DONE_SUMMARY_PROMPT - user_prompt = ( - "Summarize ONLY what has ALREADY BEEN COMPLETED into `state.done`.\n" - "TASK (for deriving only):\n" - "\n" + task.model_dump_json() + "\n\n\n" - "PRIOR UNITS (used ONLY for completed work evidence):\n" - "\n" + json.dumps(history, ensure_ascii=False) + "\n\n\n" - "Return the final summary paragraph ONLY (no explanations)." - ) - - try: - summary = call_deepseek( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.0, - ) - # Be lenient if the model wraps the text in quotes/newlines - return summary.strip().strip('"').strip() - except Exception: - # Minimal, defensive fallback (kept short); avoids complex heuristics - return window[-1].state.done - - -def extract_task_from_first_user_message( - trajectory: List[TrajectoryMessage], *, default_repository: str = "" -) -> Task: - """ - From a trajectory (list of TrajectoryMessage), take the first user message, - call DeepSeek to extract the Task fields, and return a Task. - - Requirements: - - Uses the LLM to infer fields (minimal heuristics). - - Returns empty strings when unknown. - - If `default_repository` is provided and the model leaves repository empty, - fill it with `default_repository`. - - Dependencies assumed available in your file: - - call_deepseek(messages: List[Dict[str, str]], temperature=0.0) -> str - - Task Pydantic model - - json imported - """ - # 1) find first user message - first_user_msg = next((m for m in trajectory if getattr(m, "role", "") == "user"), None) - if first_user_msg is None: - return Task( - issue_title="", - issue_body="", - issue_comments="", - issue_type="", - repository=default_repository, - ) - - # 2) build a schema-constrained prompt for ONLY Task fields - system_prompt = TASK_EXTRACTION_PROMPT - user_prompt = ( - "Extract the `task` object from the SINGLE user message below.\n" - "Use only the text between and .\n\n" - "\n" + first_user_msg.model_dump_json() + "\n" - ) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - # 3) call DeepSeek - raw = call_deepseek(messages, temperature=0.0) - - # 4) robust JSON parse, with a single repair attempt if needed - try: - payload = json.loads(raw) - except json.JSONDecodeError: - fix_messages = [ - {"role": "system", "content": "Fix to valid JSON object. No comments."}, - {"role": "user", "content": raw}, - ] - fixed = call_deepseek(fix_messages, temperature=0.0) - payload = json.loads(fixed) - - # 5) normalize fields to strings (no extra heuristics) - task_obj = payload.get("task", {}) - - def s(x): - return x if isinstance(x, str) else ("" if x is None else str(x)) - - repo = s(task_obj.get("repository")) - if not repo and default_repository: - repo = default_repository # optional fallback; minimal deviation - - return Task( - issue_title=s(task_obj.get("issue_title")), - issue_body=s(task_obj.get("issue_body")), - issue_comments=s(task_obj.get("issue_comments")), - issue_type=s(task_obj.get("issue_type")), - repository=repo, - ) - - -def row_to_trajectory_messages(row): - traj = row["messages"] - msgs = [] - for m in traj: - content = m.get("content", "") if isinstance(m, dict) else str(m) - role = m.get("role", "user") if isinstance(m, dict) else "user" - msgs.append(TrajectoryMessage(content=content, role=role)) - return msgs - - -def main(): - trajectory_repo_name = "SWE-Gym/OpenHands-SFT-Trajectories" - trajectory_split = "train.success.oss" - dataset = load_dataset(trajectory_repo_name, split=trajectory_split) - trajectories = [row_to_trajectory_messages(row) for row in dataset] - memory_base = InMemoryMemoryBase() - print(f"\nExtracting memory units for {len(trajectories)} trajectories...\n") - - for idx, trajectory in enumerate(trajectories): - print(f"\nProcessing trajectory {idx + 1} of {len(trajectories)}...\n") - meta = TrajectoryMeta(source=trajectory_repo_name, id=f"{idx}") - task = extract_task_from_first_user_message(trajectory) - units = extract_memory_units_by_action_windows( - meta, task, trajectory, memory_base=memory_base - ) - print(f"\nExtracted {len(units)} memory units. Showing JSON:\n") - print(json.dumps([u.model_dump() for u in units], ensure_ascii=False, indent=2)) - assert False - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 9634060..89610b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "langchain-openai==0.2.8", "langchain-google-genai==2.0.4", "langchain_community==0.3.2", + "langchain_google_vertexai==2.1.0" ] requires-python = ">= 3.11"