diff --git a/app_backend/src/metta/app_backend/metta_repo.py b/app_backend/src/metta/app_backend/metta_repo.py index d448f97bd05..d0eee154c71 100644 --- a/app_backend/src/metta/app_backend/metta_repo.py +++ b/app_backend/src/metta/app_backend/metta_repo.py @@ -1,73 +1,17 @@ import logging import uuid -from collections import defaultdict from contextlib import asynccontextmanager from datetime import datetime -from typing import Any, Literal from psycopg import Connection from psycopg.rows import class_row -from psycopg.types.json import Jsonb from psycopg_pool import AsyncConnectionPool, PoolTimeout -from pydantic import BaseModel, Field +from pydantic import BaseModel from metta.app_backend.config import settings from metta.app_backend.migrations import MIGRATIONS from metta.app_backend.schema_manager import run_migrations -TaskStatus = Literal["unprocessed", "running", "canceled", "done", "error", "system_error"] -FinishedTaskStatus = Literal["done", "error", "canceled", "system_error"] - - -class TaskStatusUpdate(BaseModel): - status: TaskStatus - clear_assignee: bool = False - status_details: dict[str, Any] = Field(default_factory=dict) - - -class EvalTaskRow(BaseModel): - """Row model that matches the eval_tasks table with latest attempt data.""" - - model_config = {"from_attributes": True} - - id: int - command: str - data_uri: str | None - git_hash: str | None - attributes: dict[str, Any] - user_id: str - created_at: datetime - is_finished: bool - latest_attempt_id: int | None - - # Fields from the latest attempt (populated via JOIN) - # Note: attempt_number will be 0 for new tasks, status will be 'unprocessed' - attempt_number: int | None = 0 - status: TaskStatus = "unprocessed" - status_details: dict[str, Any] | None = None - assigned_at: datetime | None = None - assignee: str | None = None - started_at: datetime | None = None - finished_at: datetime | None = None - output_log_path: str | None = None - - -class TaskAttemptRow(BaseModel): - """Row model for task_attempts table.""" - - model_config = {"from_attributes": True} - - id: int - task_id: int - attempt_number: int - assigned_at: datetime | None - assignee: str | None - started_at: datetime | None - finished_at: datetime | None - output_log_path: str | None - status: TaskStatus - status_details: dict[str, Any] | None - class SweepRow(BaseModel): id: uuid.UUID @@ -122,197 +66,6 @@ async def close(self) -> None: # Event loop might be closed, ignore pass - async def create_eval_task( - self, - command: str, - user_id: str, - attributes: dict[str, Any], - git_hash: str | None = None, - data_uri: str | None = None, - ) -> EvalTaskRow: - async with self.connect() as con: - # Insert the task - result = await con.execute( - """ - INSERT INTO eval_tasks (command, data_uri, git_hash, attributes, user_id) - VALUES (%s, %s, %s, %s, %s) - RETURNING id - """, - (command, data_uri, git_hash, Jsonb(attributes), user_id), - ) - row = await result.fetchone() - if row is None: - raise RuntimeError("Failed to create eval task") - task_id = row[0] - - # Create the first attempt - result2 = await con.execute( - """ - INSERT INTO task_attempts (task_id, attempt_number, status) - VALUES (%s, 0, 'unprocessed') - RETURNING id - """, - (task_id,), - ) - row2 = await result2.fetchone() - if row2 is None: - raise RuntimeError("Failed to create first attempt") - attempt_id = row2[0] - - # Update the task with the latest_attempt_id - await con.execute( - """ - UPDATE eval_tasks - SET latest_attempt_id = %s - WHERE id = %s - """, - (attempt_id, task_id), - ) - - # Fetch and return the complete task directly within the same transaction - async with con.cursor(row_factory=class_row(EvalTaskRow)) as cur: - await cur.execute("SELECT * FROM eval_tasks_view WHERE id = %s", (task_id,)) - task = await cur.fetchone() - if task is None: - raise RuntimeError("Failed to retrieve created task") - return task - - async def get_available_tasks(self, limit: int = 200) -> list[EvalTaskRow]: - async with self.connect() as con: - async with con.cursor(row_factory=class_row(EvalTaskRow)) as cur: - await cur.execute( - """ - SELECT * FROM eval_tasks_view - WHERE status = 'unprocessed' AND assignee IS NULL AND is_finished = FALSE - ORDER BY created_at ASC - LIMIT %s - """, - (limit,), - ) - return await cur.fetchall() - - async def claim_tasks( - self, - task_ids: list[int], - assignee: str, - ) -> list[int]: - if not task_ids: - return [] - - async with self.connect() as con: - # Update the latest attempt for each task - result = await con.execute( - """ - UPDATE task_attempts - SET assignee = %s, assigned_at = NOW() - WHERE id IN ( - SELECT latest_attempt_id FROM eval_tasks - WHERE id = ANY(%s) AND is_finished = FALSE - ) - AND status = 'unprocessed' - AND assignee IS NULL - RETURNING task_id - """, - (assignee, task_ids), - ) - rows = await result.fetchall() - return [row[0] for row in rows] - - async def get_claimed_tasks(self, assignee: str | None = None) -> list[EvalTaskRow]: - async with self.connect() as con: - async with con.cursor(row_factory=class_row(EvalTaskRow)) as cur: - if assignee is not None: - await cur.execute( - """ - SELECT * FROM eval_tasks_view - WHERE assignee = %s AND is_finished = FALSE - ORDER BY created_at ASC - """, - (assignee,), - ) - else: - await cur.execute( - """ - SELECT * FROM eval_tasks_view - WHERE assignee IS NOT NULL AND is_finished = FALSE - ORDER BY created_at ASC - """ - ) - return await cur.fetchall() - - async def start_task(self, task_id: int) -> None: - async with self.connect() as con: - await con.execute( - """ - UPDATE task_attempts - SET status = 'running', started_at = NOW() - WHERE id = (SELECT latest_attempt_id FROM eval_tasks WHERE id = %s) - """, - (task_id,), - ) - - async def finish_task( - self, task_id: int, status: FinishedTaskStatus, status_details: dict[str, Any], log_path: str | None = None - ) -> None: - async with self.connect() as con: - # Update the current attempt - await con.execute( - """ - UPDATE task_attempts - SET status = %s, finished_at = NOW(), status_details = %s, output_log_path = %s - WHERE id = (SELECT latest_attempt_id FROM eval_tasks WHERE id = %s) - """, - (status, Jsonb(status_details), log_path, task_id), - ) - - # Get the current attempt number - result = await con.execute( - """ - SELECT attempt_number FROM task_attempts - WHERE id = (SELECT latest_attempt_id FROM eval_tasks WHERE id = %s) - """, - (task_id,), - ) - row = await result.fetchone() - if row is None: - raise RuntimeError(f"Failed to get attempt number for task {task_id}") - current_attempt = row[0] - - # Check if we should mark the task as finished or create a new attempt - should_finish = status != "system_error" or current_attempt >= 2 # 0, 1, 2 = 3 attempts - - if should_finish: - # Mark the task as finished - await con.execute( - """ - UPDATE eval_tasks - SET is_finished = TRUE - WHERE id = %s - """, - (task_id,), - ) - else: - # Create a new attempt for retry - await con.execute( - """ - INSERT INTO task_attempts (task_id, attempt_number, status) - VALUES (%s, %s, 'unprocessed') - """, - (task_id, current_attempt + 1), - ) - - # Update latest_attempt_id - await con.execute( - """ - UPDATE eval_tasks - SET latest_attempt_id = ( - SELECT id FROM task_attempts WHERE task_id = %s ORDER BY attempt_number DESC LIMIT 1 - ) - WHERE id = %s - """, - (task_id, task_id), - ) - async def create_sweep(self, name: str, project: str, entity: str, wandb_sweep_id: str, user_id: str) -> uuid.UUID: """Create a new sweep.""" async with self.connect() as con: @@ -361,164 +114,3 @@ async def get_next_sweep_run_counter(self, sweep_id: uuid.UUID) -> int: if row is None: raise ValueError(f"Sweep {sweep_id} not found") return row[0] - - async def get_latest_assigned_task_for_worker(self, assignee: str) -> EvalTaskRow | None: - async with self.connect() as con: - async with con.cursor(row_factory=class_row(EvalTaskRow)) as cur: - await cur.execute( - """ - SELECT * FROM eval_tasks_view - WHERE assignee = %s AND assigned_at IS NOT NULL - ORDER BY assigned_at DESC - LIMIT 1 - """, - (assignee,), - ) - return await cur.fetchone() - - async def get_task_by_id(self, task_id: int) -> EvalTaskRow | None: - async with self.connect() as con: - async with con.cursor(row_factory=class_row(EvalTaskRow)) as cur: - await cur.execute("SELECT * FROM eval_tasks_view WHERE id = %s", (task_id,)) - return await cur.fetchone() - - async def get_task_attempts(self, task_id: int) -> list[TaskAttemptRow]: - """Get all attempts for a task, ordered by attempt_number.""" - async with self.connect() as con: - async with con.cursor(row_factory=class_row(TaskAttemptRow)) as cur: - await cur.execute( - """ - SELECT * FROM task_attempts - WHERE task_id = %s - ORDER BY attempt_number ASC - """, - (task_id,), - ) - return await cur.fetchall() - - async def get_all_tasks( - self, - limit: int = 500, - statuses: list[TaskStatus] | None = None, - git_hash: str | None = None, - ) -> list[EvalTaskRow]: - async with self.connect() as con: - # Build the WHERE clause dynamically - where_conditions = [] - params = [] - - if statuses: - placeholders = ", ".join(["%s"] * len(statuses)) - where_conditions.append(f"status IN ({placeholders})") - params.extend(statuses) - - if git_hash: - where_conditions.append("git_hash = %s") - params.append(git_hash) - - where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - params.append(limit) - - async with con.cursor(row_factory=class_row(EvalTaskRow)) as cur: - await cur.execute( - f""" - SELECT * FROM eval_tasks_view - WHERE {where_clause} - ORDER BY created_at DESC - LIMIT %s - """, - params, - ) - return await cur.fetchall() - - async def get_tasks_paginated( - self, - page: int = 1, - page_size: int = 50, - status: str | None = None, - assignee: str | None = None, - user_id: str | None = None, - command: str | None = None, - created_at: str | None = None, - assigned_at: str | None = None, - ) -> tuple[list[EvalTaskRow], int]: - async with self.connect() as con: - where_conditions = [] - params = [] - - # Add text-based filters using ILIKE for case-insensitive substring search - if status: - # Use exact match for status since it's an enum-like field - where_conditions.append("status = %s") - params.append(status) - - if assignee: - where_conditions.append("assignee ILIKE %s") - params.append(f"%{assignee}%") - - if user_id: - where_conditions.append("user_id ILIKE %s") - params.append(f"%{user_id}%") - - if command: - where_conditions.append("command ILIKE %s") - params.append(f"%{command}%") - - if created_at: - where_conditions.append("CAST(created_at AS TEXT) ILIKE %s") - params.append(f"%{created_at}%") - - if assigned_at: - where_conditions.append("CAST(assigned_at AS TEXT) ILIKE %s") - params.append(f"%{assigned_at}%") - - where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - - # Get total count - count_query = f""" - SELECT COUNT(*) - FROM eval_tasks_view - WHERE {where_clause} - """ - count_result = await con.execute(count_query, params) - result_row = await count_result.fetchone() - total_count: int = result_row[0] if result_row else 0 - - # Get paginated results - offset = (page - 1) * page_size - params.extend([page_size, offset]) - - async with con.cursor(row_factory=class_row(EvalTaskRow)) as cur: - await cur.execute( - f""" - SELECT * FROM eval_tasks_view - WHERE {where_clause} - ORDER BY created_at DESC - LIMIT %s OFFSET %s - """, - params, - ) - tasks = await cur.fetchall() - - return tasks, total_count - - async def get_git_hashes_for_workers(self, assignees: list[str]) -> dict[str, list[str]]: - async with self.connect() as con: - if not assignees: - return {} - - # Use ANY() for proper list handling in PostgreSQL - queryRes = await con.execute( - """ - SELECT DISTINCT assignee, git_hash - FROM eval_tasks_view - WHERE assignee = ANY(%s) - """, - (assignees,), - ) - rows = await queryRes.fetchall() - res: dict[str, list[str]] = defaultdict(list) - for row in rows: - if row[1]: # Only add non-null git hashes - res[row[0]].append(row[1]) - return res diff --git a/app_backend/src/metta/app_backend/models/eval_task.py b/app_backend/src/metta/app_backend/models/eval_task.py new file mode 100644 index 00000000000..4d8e5300b72 --- /dev/null +++ b/app_backend/src/metta/app_backend/models/eval_task.py @@ -0,0 +1,50 @@ +from datetime import UTC, datetime +from typing import Any, Literal + +from sqlalchemy import Column, text +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field, Relationship, SQLModel + +TaskStatus = Literal["unprocessed", "running", "canceled", "done", "error", "system_error"] +FinishedTaskStatus = Literal["done", "error", "canceled", "system_error"] + + +class EvalTask(SQLModel, table=True): + __tablename__ = "eval_tasks" # type: ignore[assignment] + + id: int | None = Field(default=None, primary_key=True) + command: str + data_uri: str | None = None + git_hash: str | None = None + attributes: dict[str, Any] | None = Field(default=None, sa_column=Column(JSONB)) + user_id: str + is_finished: bool = Field(default=False) + latest_attempt_id: int | None = Field(default=None, foreign_key="task_attempts.id") + created_at: datetime = Field( + default_factory=lambda: datetime.now(UTC), sa_column_kwargs={"server_default": text("now()")} + ) + + attempts: list["TaskAttempt"] = Relationship( + back_populates="task", + sa_relationship_kwargs={"foreign_keys": "[TaskAttempt.task_id]"}, + ) + + +class TaskAttempt(SQLModel, table=True): + __tablename__ = "task_attempts" # type: ignore[assignment] + + id: int | None = Field(default=None, primary_key=True) + task_id: int = Field(foreign_key="eval_tasks.id") + attempt_number: int = Field(default=0) + status: str = Field(default="unprocessed") + status_details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSONB)) + assignee: str | None = None + assigned_at: datetime | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + output_log_path: str | None = None + + task: EvalTask = Relationship( + back_populates="attempts", + sa_relationship_kwargs={"foreign_keys": "[TaskAttempt.task_id]"}, + ) diff --git a/app_backend/src/metta/app_backend/queries/eval_task_queries.py b/app_backend/src/metta/app_backend/queries/eval_task_queries.py new file mode 100644 index 00000000000..bf3a9c45b1e --- /dev/null +++ b/app_backend/src/metta/app_backend/queries/eval_task_queries.py @@ -0,0 +1,388 @@ +# pyright: reportArgumentType=false + +from collections import defaultdict +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field +from sqlalchemy import text + +from metta.app_backend.database import get_db, with_db +from metta.app_backend.models.eval_task import EvalTask, FinishedTaskStatus, TaskAttempt, TaskStatus + + +class EvalTaskRow(BaseModel): + model_config = {"from_attributes": True} + + id: int + command: str + data_uri: str | None + git_hash: str | None + attributes: dict[str, Any] + user_id: str + created_at: datetime + is_finished: bool + latest_attempt_id: int | None + attempt_number: int | None = 0 + status: TaskStatus = "unprocessed" + status_details: dict[str, Any] | None = None + assigned_at: datetime | None = None + assignee: str | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + output_log_path: str | None = None + + @classmethod + def from_task_and_attempt(cls, task: EvalTask, attempt: TaskAttempt | None) -> "EvalTaskRow": + return cls( + id=task.id, # type: ignore[arg-type] + command=task.command, + data_uri=task.data_uri, + git_hash=task.git_hash, + attributes=task.attributes or {}, + user_id=task.user_id, + created_at=task.created_at, + is_finished=task.is_finished, + latest_attempt_id=task.latest_attempt_id, + attempt_number=attempt.attempt_number if attempt else 0, + status=attempt.status if attempt else "unprocessed", # type: ignore[arg-type] + status_details=attempt.status_details if attempt else None, + assigned_at=attempt.assigned_at if attempt else None, + assignee=attempt.assignee if attempt else None, + started_at=attempt.started_at if attempt else None, + finished_at=attempt.finished_at if attempt else None, + output_log_path=attempt.output_log_path if attempt else None, + ) + + +class TaskAttemptRow(BaseModel): + model_config = {"from_attributes": True} + + id: int + task_id: int + attempt_number: int + assigned_at: datetime | None + assignee: str | None + started_at: datetime | None + finished_at: datetime | None + output_log_path: str | None + status: TaskStatus + status_details: dict[str, Any] | None + + @classmethod + def from_model(cls, attempt: TaskAttempt) -> "TaskAttemptRow": + return cls( + id=attempt.id, # type: ignore[arg-type] + task_id=attempt.task_id, + attempt_number=attempt.attempt_number, + assigned_at=attempt.assigned_at, + assignee=attempt.assignee, + started_at=attempt.started_at, + finished_at=attempt.finished_at, + output_log_path=attempt.output_log_path, + status=attempt.status, # type: ignore[arg-type] + status_details=attempt.status_details, + ) + + +class TaskStatusUpdate(BaseModel): + status: TaskStatus + clear_assignee: bool = False + status_details: dict[str, Any] = Field(default_factory=dict) + + +@with_db +async def create_eval_task( + command: str, + user_id: str, + attributes: dict[str, Any], + git_hash: str | None = None, + data_uri: str | None = None, +) -> EvalTaskRow: + session = get_db() + + task = EvalTask( + command=command, + data_uri=data_uri, + git_hash=git_hash, + attributes=attributes, + user_id=user_id, + ) + session.add(task) + await session.flush() + + attempt = TaskAttempt(task_id=task.id, attempt_number=0, status="unprocessed") # type: ignore[arg-type] + session.add(attempt) + await session.flush() + + task.latest_attempt_id = attempt.id + await session.flush() + + return EvalTaskRow.from_task_and_attempt(task, attempt) + + +@with_db +async def get_available_tasks(limit: int = 200) -> list[EvalTaskRow]: + session = get_db() + query = text(""" + SELECT * FROM eval_tasks_view + WHERE status = 'unprocessed' AND assignee IS NULL AND is_finished = FALSE + ORDER BY created_at ASC + LIMIT :limit + """) + result = await session.execute(query, {"limit": limit}) + rows = result.mappings().all() + return [EvalTaskRow(**dict(row)) for row in rows] + + +@with_db +async def claim_tasks(task_ids: list[int], assignee: str) -> list[int]: + if not task_ids: + return [] + + session = get_db() + query = text(""" + UPDATE task_attempts + SET assignee = :assignee, assigned_at = NOW() + WHERE id IN ( + SELECT latest_attempt_id FROM eval_tasks + WHERE id = ANY(:task_ids) AND is_finished = FALSE + ) + AND status = 'unprocessed' + AND assignee IS NULL + RETURNING task_id + """) + result = await session.execute(query, {"assignee": assignee, "task_ids": task_ids}) + rows = result.fetchall() + return [row[0] for row in rows] + + +@with_db +async def get_claimed_tasks(assignee: str | None = None) -> list[EvalTaskRow]: + session = get_db() + if assignee is not None: + query = text(""" + SELECT * FROM eval_tasks_view + WHERE assignee = :assignee AND is_finished = FALSE + ORDER BY created_at ASC + """) + result = await session.execute(query, {"assignee": assignee}) + else: + query = text(""" + SELECT * FROM eval_tasks_view + WHERE assignee IS NOT NULL AND is_finished = FALSE + ORDER BY created_at ASC + """) + result = await session.execute(query) + rows = result.mappings().all() + return [EvalTaskRow(**dict(row)) for row in rows] + + +@with_db +async def start_task(task_id: int) -> None: + session = get_db() + query = text(""" + UPDATE task_attempts + SET status = 'running', started_at = NOW() + WHERE id = (SELECT latest_attempt_id FROM eval_tasks WHERE id = :task_id) + """) + await session.execute(query, {"task_id": task_id}) + + +@with_db +async def finish_task( + task_id: int, status: FinishedTaskStatus, status_details: dict[str, Any], log_path: str | None = None +) -> None: + session = get_db() + + update_attempt = text(""" + UPDATE task_attempts + SET status = :status, finished_at = NOW(), status_details = :status_details, output_log_path = :log_path + WHERE id = (SELECT latest_attempt_id FROM eval_tasks WHERE id = :task_id) + """) + await session.execute( + update_attempt, + {"status": status, "status_details": status_details, "log_path": log_path, "task_id": task_id}, + ) + + get_attempt = text(""" + SELECT attempt_number FROM task_attempts + WHERE id = (SELECT latest_attempt_id FROM eval_tasks WHERE id = :task_id) + """) + result = await session.execute(get_attempt, {"task_id": task_id}) + row = result.fetchone() + if row is None: + raise RuntimeError(f"Failed to get attempt number for task {task_id}") + current_attempt = row[0] + + should_finish = status != "system_error" or current_attempt >= 2 + + if should_finish: + finish_task_query = text("UPDATE eval_tasks SET is_finished = TRUE WHERE id = :task_id") + await session.execute(finish_task_query, {"task_id": task_id}) + else: + create_attempt = text(""" + INSERT INTO task_attempts (task_id, attempt_number, status) + VALUES (:task_id, :attempt_number, 'unprocessed') + """) + await session.execute(create_attempt, {"task_id": task_id, "attempt_number": current_attempt + 1}) + + update_latest = text(""" + UPDATE eval_tasks + SET latest_attempt_id = ( + SELECT id FROM task_attempts WHERE task_id = :task_id ORDER BY attempt_number DESC LIMIT 1 + ) + WHERE id = :task_id + """) + await session.execute(update_latest, {"task_id": task_id}) + + +@with_db +async def get_latest_assigned_task_for_worker(assignee: str) -> EvalTaskRow | None: + session = get_db() + query = text(""" + SELECT * FROM eval_tasks_view + WHERE assignee = :assignee AND assigned_at IS NOT NULL + ORDER BY assigned_at DESC + LIMIT 1 + """) + result = await session.execute(query, {"assignee": assignee}) + row = result.mappings().first() + return EvalTaskRow(**dict(row)) if row else None + + +@with_db +async def get_task_by_id(task_id: int) -> EvalTaskRow | None: + session = get_db() + query = text("SELECT * FROM eval_tasks_view WHERE id = :task_id") + result = await session.execute(query, {"task_id": task_id}) + row = result.mappings().first() + return EvalTaskRow(**dict(row)) if row else None + + +@with_db +async def get_task_attempts(task_id: int) -> list[TaskAttemptRow]: + session = get_db() + query = text(""" + SELECT * FROM task_attempts + WHERE task_id = :task_id + ORDER BY attempt_number ASC + """) + result = await session.execute(query, {"task_id": task_id}) + rows = result.mappings().all() + return [TaskAttemptRow(**dict(row)) for row in rows] + + +@with_db +async def get_all_tasks( + limit: int = 500, + statuses: list[TaskStatus] | None = None, + git_hash: str | None = None, +) -> list[EvalTaskRow]: + session = get_db() + + where_parts = [] + params: dict[str, Any] = {"limit": limit} + + if statuses: + where_parts.append("status = ANY(:statuses)") + params["statuses"] = statuses + + if git_hash: + where_parts.append("git_hash = :git_hash") + params["git_hash"] = git_hash + + where_clause = " AND ".join(where_parts) if where_parts else "1=1" + + query = text(f""" + SELECT * FROM eval_tasks_view + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT :limit + """) + result = await session.execute(query, params) + rows = result.mappings().all() + return [EvalTaskRow(**dict(row)) for row in rows] + + +@with_db +async def get_tasks_paginated( + page: int = 1, + page_size: int = 50, + status: str | None = None, + assignee: str | None = None, + user_id: str | None = None, + command: str | None = None, + created_at: str | None = None, + assigned_at: str | None = None, +) -> tuple[list[EvalTaskRow], int]: + session = get_db() + + where_parts = [] + params: dict[str, Any] = {} + + if status: + where_parts.append("status = :status") + params["status"] = status + + if assignee: + where_parts.append("assignee ILIKE :assignee") + params["assignee"] = f"%{assignee}%" + + if user_id: + where_parts.append("user_id ILIKE :user_id") + params["user_id"] = f"%{user_id}%" + + if command: + where_parts.append("command ILIKE :command") + params["command"] = f"%{command}%" + + if created_at: + where_parts.append("CAST(created_at AS TEXT) ILIKE :created_at") + params["created_at"] = f"%{created_at}%" + + if assigned_at: + where_parts.append("CAST(assigned_at AS TEXT) ILIKE :assigned_at") + params["assigned_at"] = f"%{assigned_at}%" + + where_clause = " AND ".join(where_parts) if where_parts else "1=1" + + count_query = text(f"SELECT COUNT(*) FROM eval_tasks_view WHERE {where_clause}") + count_result = await session.execute(count_query, params) + total_count: int = count_result.scalar_one() + + offset = (page - 1) * page_size + params["limit"] = page_size + params["offset"] = offset + + query = text(f""" + SELECT * FROM eval_tasks_view + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT :limit OFFSET :offset + """) + result = await session.execute(query, params) + rows = result.mappings().all() + + return [EvalTaskRow(**dict(row)) for row in rows], total_count + + +@with_db +async def get_git_hashes_for_workers(assignees: list[str]) -> dict[str, list[str]]: + if not assignees: + return {} + + session = get_db() + query = text(""" + SELECT DISTINCT assignee, git_hash + FROM eval_tasks_view + WHERE assignee = ANY(:assignees) + """) + result = await session.execute(query, {"assignees": assignees}) + rows = result.fetchall() + + res: dict[str, list[str]] = defaultdict(list) + for row in rows: + if row[1]: + res[row[0]].append(row[1]) + return dict(res) diff --git a/app_backend/src/metta/app_backend/routes/eval_task_routes.py b/app_backend/src/metta/app_backend/routes/eval_task_routes.py index c72bc33f66a..63f1ba10b86 100644 --- a/app_backend/src/metta/app_backend/routes/eval_task_routes.py +++ b/app_backend/src/metta/app_backend/routes/eval_task_routes.py @@ -13,7 +13,9 @@ import gitta as git from metta.app_backend.auth import CheckSoftmaxUser -from metta.app_backend.metta_repo import EvalTaskRow, FinishedTaskStatus, MettaRepo, TaskAttemptRow, TaskStatus +from metta.app_backend.models.eval_task import FinishedTaskStatus, TaskStatus +from metta.app_backend.queries import eval_task_queries +from metta.app_backend.queries.eval_task_queries import EvalTaskRow, TaskAttemptRow from metta.app_backend.route_logger import timed_http_handler from metta.common.util.git_repo import REPO_SLUG @@ -83,7 +85,7 @@ class TaskAttemptsResponse(BaseModel): attempts: list[TaskAttemptRow] -def create_eval_task_router(stats_repo: MettaRepo) -> APIRouter: +def create_eval_task_router() -> APIRouter: router = APIRouter(prefix="/tasks", tags=["eval_tasks"]) # Cache for latest commit @@ -121,7 +123,7 @@ async def create_task(request: TaskCreateRequest, user: CheckSoftmaxUser) -> Eva os.remove(file_path) - task = await stats_repo.create_eval_task( + task = await eval_task_queries.create_eval_task( command=request.command, git_hash=request.git_hash or await get_cached_latest_commit(), attributes=request.attributes, @@ -133,7 +135,7 @@ async def create_task(request: TaskCreateRequest, user: CheckSoftmaxUser) -> Eva @router.get("/latest", response_model=EvalTaskRow) @timed_http_handler async def get_latest_assigned_task_for_worker(assignee: str) -> EvalTaskRow | None: - task = await stats_repo.get_latest_assigned_task_for_worker(assignee=assignee) + task = await eval_task_queries.get_latest_assigned_task_for_worker(assignee=assignee) return task @router.get("/available") @@ -141,13 +143,13 @@ async def get_latest_assigned_task_for_worker(assignee: str) -> EvalTaskRow | No async def get_available_tasks( limit: int = Query(default=200, ge=1, le=1000), ) -> TasksResponse: - tasks = await stats_repo.get_available_tasks(limit=limit) + tasks = await eval_task_queries.get_available_tasks(limit=limit) return TasksResponse(tasks=tasks) @router.post("/claim") @timed_http_handler async def claim_tasks(request: TaskClaimRequest) -> TaskClaimResponse: - claimed_ids = await stats_repo.claim_tasks( + claimed_ids = await eval_task_queries.claim_tasks( task_ids=request.tasks, assignee=request.assignee, ) @@ -156,13 +158,13 @@ async def claim_tasks(request: TaskClaimRequest) -> TaskClaimResponse: @router.get("/claimed") @timed_http_handler async def get_claimed_tasks(assignee: str | None = Query(None)) -> TasksResponse: - tasks = await stats_repo.get_claimed_tasks(assignee=assignee) + tasks = await eval_task_queries.get_claimed_tasks(assignee=assignee) return TasksResponse(tasks=tasks) @router.post("/git-hashes") @timed_http_handler async def get_git_hashes_for_workers(request: GitHashesRequest) -> GitHashesResponse: - git_hashes = await stats_repo.get_git_hashes_for_workers(assignees=request.assignees) + git_hashes = await eval_task_queries.get_git_hashes_for_workers(assignees=request.assignees) return GitHashesResponse(git_hashes=git_hashes) @router.get("/all") @@ -172,7 +174,7 @@ async def get_all_tasks( statuses: list[TaskStatus] | None = Query(default=None), git_hash: str | None = Query(default=None), ) -> TasksResponse: - tasks = await stats_repo.get_all_tasks( + tasks = await eval_task_queries.get_all_tasks( limit=limit, statuses=statuses, git_hash=git_hash, @@ -191,7 +193,7 @@ async def get_tasks_paginated( created_at: str | None = Query(default=None), assigned_at: str | None = Query(default=None), ) -> PaginatedTasksResponse: - tasks, total_count = await stats_repo.get_tasks_paginated( + tasks, total_count = await eval_task_queries.get_tasks_paginated( page=page, page_size=page_size, status=status, @@ -213,13 +215,13 @@ async def get_tasks_paginated( @router.post("/{task_id}/start") @timed_http_handler async def start_task(task_id: int) -> TaskIdResponse: - await stats_repo.start_task(task_id=task_id) + await eval_task_queries.start_task(task_id=task_id) return TaskIdResponse(task_id=task_id) @router.post("/{task_id}/finish") @timed_http_handler async def finish_task(task_id: int, request: TaskFinishRequest) -> TaskIdResponse: - await stats_repo.finish_task( + await eval_task_queries.finish_task( task_id=task_id, status=request.status, status_details=request.status_details, log_path=request.log_path ) return TaskIdResponse(task_id=task_id) @@ -228,7 +230,7 @@ async def finish_task(task_id: int, request: TaskFinishRequest) -> TaskIdRespons @timed_http_handler async def get_task(task_id: int) -> EvalTaskRow: """Get a single task by ID with full details including attributes.""" - task = await stats_repo.get_task_by_id(task_id) + task = await eval_task_queries.get_task_by_id(task_id) if not task: raise HTTPException(status_code=404, detail=f"Task {task_id} not found") return task @@ -237,7 +239,7 @@ async def get_task(task_id: int) -> EvalTaskRow: @timed_http_handler async def get_task_attempts(task_id: int) -> TaskAttemptsResponse: """Get all attempts for a specific task.""" - attempts = await stats_repo.get_task_attempts(task_id) + attempts = await eval_task_queries.get_task_attempts(task_id) return TaskAttemptsResponse(attempts=attempts) @router.get("/{task_id}/logs/{log_type}") @@ -256,7 +258,7 @@ async def get_task_logs(task_id: int, log_type: str): raise HTTPException(status_code=400, detail="log_type must be 'stdout' or 'stderr' or 'output'") # Get the task to retrieve the log path from attributes - task = await stats_repo.get_task_by_id(task_id) + task = await eval_task_queries.get_task_by_id(task_id) if not task: raise HTTPException(status_code=404, detail=f"Task {task_id} not found") diff --git a/app_backend/src/metta/app_backend/server.py b/app_backend/src/metta/app_backend/server.py index 3db887f2c1d..f9d3effbf83 100755 --- a/app_backend/src/metta/app_backend/server.py +++ b/app_backend/src/metta/app_backend/server.py @@ -121,7 +121,7 @@ def create_app(stats_repo: MettaRepo) -> fastapi.FastAPI: ) # Create routers with the provided StatsRepo - eval_task_router = eval_task_routes.create_eval_task_router(stats_repo) + eval_task_router = eval_task_routes.create_eval_task_router() sql_router = sql_routes.create_sql_router(stats_repo) stats_router = stats_routes.create_stats_router() sweep_router = sweep_routes.create_sweep_router(stats_repo)