Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## Tooling

- Package manager: `uv`
- Run tests: `uv run pytest`
4 changes: 4 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## Tooling

- Package manager: `uv`
- Run tests: `uv run pytest`
24 changes: 24 additions & 0 deletions dani/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import re

TRANSIENT_CAPACITY_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"Selected model is at capacity", re.IGNORECASE),
re.compile(r"model is currently overloaded", re.IGNORECASE),
]


class TransientCapacityError(Exception):
"""Raised when an OMX session fails due to a transient model-capacity issue."""

def __init__(self, message: str, pattern: str) -> None:
super().__init__(message)
self.pattern = pattern


def check_transient_capacity_error(stderr_text: str) -> None:
"""Raise ``TransientCapacityError`` if *stderr_text* contains a known capacity pattern."""
for pattern in TRANSIENT_CAPACITY_PATTERNS:
match = pattern.search(stderr_text)
if match:
raise TransientCapacityError(match.group(0), pattern.pattern)
10 changes: 10 additions & 0 deletions dani/omx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Protocol, TextIO
from uuid import uuid4

from dani.errors import check_transient_capacity_error
from dani.models import JobRecord, SessionRecord


Expand Down Expand Up @@ -135,6 +136,15 @@ def wait(self, runtime_handle: str, *, poll_interval: float = 0.5, timeout_secon
msg = f"omx exec process did not exit before timeout: {runtime_handle}"
raise TimeoutError(msg) from exc

self._check_stderr_for_transient_error(runtime_handle)

def _check_stderr_for_transient_error(self, runtime_handle: str) -> None:
stderr_path = self.run_dir / runtime_handle / "stderr.log"
if not stderr_path.exists():
return
stderr_text = stderr_path.read_text(encoding="utf-8", errors="replace")
check_transient_capacity_error(stderr_text)

def close_session(self, runtime_handle: str) -> None:
with self._lock:
entry = self._processes.pop(runtime_handle, None)
Expand Down
130 changes: 116 additions & 14 deletions dani/service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import contextlib
import logging
import re
import time
from pathlib import Path
from typing import Any

from dani.errors import TransientCapacityError
from dani.git_sync import DevSyncConflictError, GitDevSyncer
from dani.github import GitHubCLI, MergeConflictError
from dani.models import DaniConfig, JobRecord, NormalizedEvent, RepoConfig, SessionRecord, utc_now
Expand All @@ -15,6 +19,10 @@

ISSUE_REF_PATTERN = re.compile(r"#(?P<number>\d+)")

RETRY_BACKOFF_SECONDS: list[int] = [60, 180, 600]

logger = logging.getLogger(__name__)


class DaniService:
def __init__(
Expand Down Expand Up @@ -267,23 +275,117 @@ def _run_job(self, job: JobRecord) -> None:
self._run_dev_sync_job(repo, job)
return

session = None
try:
prompt = self._build_prompt(repo, job)
if job.stage == "issue_followup":
session = self.omx_runner.resume(Path(repo.local_path), job, prompt, self._omx_session_id_for(job))
retry_history: list[dict[str, str]] = list(job.metadata.get("retry_history", []))
max_attempts = len(RETRY_BACKOFF_SECONDS) + 1

for attempt in range(1, max_attempts + 1):
try:
self._run_job_attempt(repo, job)
except TransientCapacityError as exc:
if self._handle_transient_failure(repo, job, exc, attempt, max_attempts, retry_history):
return
except Exception as exc:
self._handle_job_failure(job, exc, attempt, retry_history)
return
else:
session = self.omx_runner.launch(Path(repo.local_path), job, prompt)
self.storage.create_session(session)
self.storage.update_job(job.id, status="launched", session_id=session.id)
self.storage.update_job(
job.id,
status="completed",
metadata={**job.metadata, "retry_attempts": attempt - 1, "retry_history": retry_history},
)
return

def _run_job_attempt(self, repo: RepoConfig, job: JobRecord) -> None:
prompt = self._build_prompt(repo, job)
if job.stage == "issue_followup":
session = self.omx_runner.resume(Path(repo.local_path), job, prompt, self._omx_session_id_for(job))
else:
session = self.omx_runner.launch(Path(repo.local_path), job, prompt)
self.storage.create_session(session)
self.storage.update_job(job.id, status="launched", session_id=session.id)
try:
self.omx_runner.wait(session.runtime_handle)
self._verify_side_effect(repo, job)
self._finalize_session(session, status="completed", termination_reason="completed")
self.storage.update_job(job.id, status="completed")
except Exception as exc:
if session is not None:
self._finalize_session(session, status="failed", termination_reason=type(exc).__name__)
self.storage.update_job(job.id, status="failed", metadata={**job.metadata, "error": str(exc)})
except Exception:
self._finalize_session(session, status="failed", termination_reason="failed")
raise
self._finalize_session(session, status="completed", termination_reason="completed")

def _handle_transient_failure(
self,
repo: RepoConfig,
job: JobRecord,
exc: TransientCapacityError,
attempt: int,
max_attempts: int,
retry_history: list[dict[str, str]],
) -> bool:
"""Handle a transient capacity error. Return True if the job is terminal (done or exhausted)."""
# If the side effect was already posted before the capacity error,
# treat the job as completed to avoid duplicate GitHub comments/PRs.
side_effect_exists = False
with contextlib.suppress(Exception):
self._verify_side_effect(repo, job)
side_effect_exists = True
if side_effect_exists:
self.storage.update_job(
job.id,
status="completed",
metadata={
**job.metadata,
"retry_attempts": attempt - 1,
"retry_history": retry_history,
"note": "side_effect_already_posted",
},
)
return True
retry_history.append({
"attempt": str(attempt),
"reason": "transient_capacity",
"pattern": exc.pattern,
"at": utc_now(),
})
if attempt >= max_attempts:
logger.warning("Job %s exhausted all %d retry attempts", job.id, max_attempts)
self.storage.update_job(
job.id,
status="failed",
metadata={
**job.metadata,
"error": f"retry_exhausted: {exc}",
"retry_exhausted": True,
"retry_attempts": attempt,
"retry_history": retry_history,
},
)
return True
backoff = RETRY_BACKOFF_SECONDS[attempt - 1]
logger.info("Job %s attempt %d hit transient capacity error, retrying in %ds", job.id, attempt, backoff)
self.storage.update_job(
job.id,
status="retrying",
metadata={**job.metadata, "retry_attempts": attempt, "retry_history": retry_history},
)
time.sleep(backoff)
return False

def _handle_job_failure(
self,
job: JobRecord,
exc: Exception,
attempt: int,
retry_history: list[dict[str, str]],
) -> None:
self.storage.update_job(
job.id,
status="failed",
metadata={
**job.metadata,
"error": str(exc),
"retry_attempts": attempt - 1,
"retry_history": retry_history,
},
)

def _run_dev_sync_job(self, repo: RepoConfig, job: JobRecord) -> None:
session = None
Expand Down
70 changes: 39 additions & 31 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from pathlib import Path
from typing import Any, TypedDict

from dani.errors import TransientCapacityError
from dani.git_sync import DevSyncConflictError, DevSyncContext, DevSyncOutcome
from dani.github import MergeConflictError
from dani.models import JobRecord, SessionRecord
from dani.signatures import build_signature, parse_signature

_CAPACITY_MSG = "capacity"


class FakeGitHubCLI:
def __init__(self) -> None:
Expand Down Expand Up @@ -118,13 +121,36 @@ def __init__(self, github: FakeGitHubCLI) -> None:
self.launches: list[LaunchRecord] = []
self.resumes: list[ResumeRecord] = []
self.closed_sessions: list[str] = []
self._transient_failures_remaining: int = 0

def set_transient_failures(self, count: int) -> None:
"""Configure the runner to raise TransientCapacityError for the next *count* wait() calls."""
self._transient_failures_remaining = count

def launch(self, repo_path: Path, job: JobRecord, prompt: str) -> SessionRecord:
repo_full_name = job.repo_full_name
matches = re.findall(r"<!--\s*dani:([^>]+)\s*-->", prompt)
signature = None
if matches:
signature = parse_signature(f"<!-- dani:{matches[-1]} -->")
signature = parse_signature(f"<!-- dani:{matches[-1]} -->") if matches else None
# If next wait() will raise TransientCapacityError, skip posting side effects
# to simulate the OMX session failing before it could post anything.
if self._transient_failures_remaining == 0:
self._post_side_effect(repo_full_name, job, signature)
self.launches.append({"repo_path": str(repo_path), "job": job, "prompt": prompt})
return SessionRecord(
repo_full_name=repo_full_name,
stage=job.stage,
runtime_handle=f"runtime-{job.id}",
prompt_path=str(repo_path / "prompt.txt"),
script_path=str(repo_path / "run.sh"),
worktree_path=str(repo_path),
job_id=job.id,
issue_number=job.issue_number,
pr_number=job.pr_number,
review_round=job.review_round,
omx_session_id=f"omx-{job.id}",
)

def _post_side_effect(self, repo_full_name: str, job: JobRecord, signature: dict[str, str] | None) -> None:
if job.stage == "issue_request":
issue_number = int((signature or {}).get("issue", job.issue_number or 0))
self.github.add_issue_signature(
Expand All @@ -136,19 +162,15 @@ def launch(self, repo_path: Path, job: JobRecord, prompt: str) -> SessionRecord:
issue_number = int((signature or {}).get("issue", job.issue_number or 0))
pr_number = int((signature or {}).get("pr", job.pr_number or 0))
if pr_number:
signature_fields: dict[str, Any] = {"stage": "implementation", "job": job.id, "pr": pr_number}
fields: dict[str, Any] = {"stage": "implementation", "job": job.id, "pr": pr_number}
if issue_number:
signature_fields["issue"] = issue_number
self.github.add_pr_signature(
repo_full_name,
pr_number,
build_signature(**signature_fields),
)
fields["issue"] = issue_number
self.github.add_pr_signature(repo_full_name, pr_number, build_signature(**fields))
else:
signature_fields: dict[str, Any] = {"stage": "implementation", "job": job.id}
fields = {"stage": "implementation", "job": job.id}
if issue_number:
signature_fields["issue"] = issue_number
self.github.add_pull_request(repo_full_name, 101, build_signature(**signature_fields))
fields["issue"] = str(issue_number)
self.github.add_pull_request(repo_full_name, 101, build_signature(**fields))
elif job.stage == "review_round":
pr_number = int((signature or {}).get("pr", job.pr_number or 0))
self.github.add_pr_signature(
Expand All @@ -163,30 +185,13 @@ def launch(self, repo_path: Path, job: JobRecord, prompt: str) -> SessionRecord:
pr_number,
build_signature(stage="merge_conflict_resolution", job=job.id, pr=pr_number),
)
elif job.stage == "dev_sync":
pass
else:
elif job.stage != "dev_sync":
self.github.add_pr_signature(
repo_full_name,
job.pr_number or 0,
build_signature(stage="final_verdict", job=job.id, pr=job.pr_number or 0, verdict="APPROVE"),
)

self.launches.append({"repo_path": str(repo_path), "job": job, "prompt": prompt})
return SessionRecord(
repo_full_name=repo_full_name,
stage=job.stage,
runtime_handle=f"runtime-{job.id}",
prompt_path=str(repo_path / "prompt.txt"),
script_path=str(repo_path / "run.sh"),
worktree_path=str(repo_path),
job_id=job.id,
issue_number=job.issue_number,
pr_number=job.pr_number,
review_round=job.review_round,
omx_session_id=f"omx-{job.id}",
)

def resume(self, repo_path: Path, job: JobRecord, prompt: str, omx_session_id: str) -> SessionRecord:
issue_number = job.issue_number or 0
self.github.add_issue_signature(
Expand Down Expand Up @@ -215,6 +220,9 @@ def resume(self, repo_path: Path, job: JobRecord, prompt: str, omx_session_id: s
)

def wait(self, runtime_handle: str, *, poll_interval: float = 0.5, timeout_seconds: float = 1800) -> None:
if self._transient_failures_remaining > 0:
self._transient_failures_remaining -= 1
raise TransientCapacityError(_CAPACITY_MSG, _CAPACITY_MSG)
return None

def close_session(self, runtime_handle: str) -> None:
Expand Down
Loading
Loading