diff --git a/api/database.py b/api/database.py index 90dc49af..49b19fa8 100644 --- a/api/database.py +++ b/api/database.py @@ -8,7 +8,7 @@ import sys from datetime import datetime, timezone from pathlib import Path -from typing import Optional +from typing import Generator, Optional def _utc_now() -> datetime: @@ -26,13 +26,16 @@ def _utc_now() -> datetime: String, Text, create_engine, + event, text, ) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, relationship, sessionmaker +from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker from sqlalchemy.types import JSON -Base = declarative_base() + +class Base(DeclarativeBase): + """SQLAlchemy 2.0 style declarative base.""" + pass class Feature(Base): @@ -307,11 +310,11 @@ def _migrate_add_schedules_tables(engine) -> None: # Create schedules table if missing if "schedules" not in existing_tables: - Schedule.__table__.create(bind=engine) + Schedule.__table__.create(bind=engine) # type: ignore[attr-defined] # Create schedule_overrides table if missing if "schedule_overrides" not in existing_tables: - ScheduleOverride.__table__.create(bind=engine) + ScheduleOverride.__table__.create(bind=engine) # type: ignore[attr-defined] # Add crash_count column if missing (for upgrades) if "schedules" in existing_tables: @@ -332,6 +335,41 @@ def _migrate_add_schedules_tables(engine) -> None: conn.commit() +def _configure_sqlite_immediate_transactions(engine) -> None: + """Configure engine for IMMEDIATE transactions via event hooks. + + Per SQLAlchemy docs: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html + + This replaces fragile pysqlite implicit transaction handling with explicit + BEGIN IMMEDIATE at transaction start. Benefits: + - Acquires write lock immediately, preventing stale reads + - Works correctly regardless of prior ORM operations + - Future-proof: won't break when pysqlite legacy mode is removed in Python 3.16 + + Note: We only use IMMEDIATE for user transactions, not for PRAGMA statements. + The do_begin hook only fires when SQLAlchemy starts a transaction, which + doesn't happen for PRAGMA commands when using conn.exec_driver_sql() directly. + """ + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # Disable pysqlite's implicit transaction handling + dbapi_connection.isolation_level = None + + # Execute PRAGMAs immediately on raw connection before any transactions + # This is safe because isolation_level=None means no implicit transactions + cursor = dbapi_connection.cursor() + try: + # These PRAGMAs need to run outside of any transaction + cursor.execute("PRAGMA busy_timeout=30000") + finally: + cursor.close() + + @event.listens_for(engine, "begin") + def do_begin(conn): + # Use IMMEDIATE for all transactions to prevent stale reads + conn.exec_driver_sql("BEGIN IMMEDIATE") + + def create_database(project_dir: Path) -> tuple: """ Create database and return engine + session maker. @@ -351,21 +389,37 @@ def create_database(project_dir: Path) -> tuple: return _engine_cache[cache_key] db_url = get_database_url(project_dir) - engine = create_engine(db_url, connect_args={ - "check_same_thread": False, - "timeout": 30 # Wait up to 30s for locks - }) - Base.metadata.create_all(bind=engine) # Choose journal mode based on filesystem type # WAL mode doesn't work reliably on network filesystems and can cause corruption is_network = _is_network_path(project_dir) journal_mode = "DELETE" if is_network else "WAL" + engine = create_engine(db_url, connect_args={ + "check_same_thread": False, + "timeout": 30 # Wait up to 30s for locks + }) + + # Set journal mode BEFORE configuring event hooks + # PRAGMA journal_mode must run outside of a transaction, and our event hooks + # start a transaction with BEGIN IMMEDIATE on every operation with engine.connect() as conn: - conn.execute(text(f"PRAGMA journal_mode={journal_mode}")) - conn.execute(text("PRAGMA busy_timeout=30000")) - conn.commit() + # Get raw DBAPI connection to execute PRAGMA outside transaction + raw_conn = conn.connection.dbapi_connection + if raw_conn is None: + raise RuntimeError("Failed to get raw DBAPI connection") + cursor = raw_conn.cursor() + try: + cursor.execute(f"PRAGMA journal_mode={journal_mode}") + cursor.execute("PRAGMA busy_timeout=30000") + finally: + cursor.close() + + # Configure IMMEDIATE transactions via event hooks AFTER setting PRAGMAs + # This must happen before create_all() and migrations run + _configure_sqlite_immediate_transactions(engine) + + Base.metadata.create_all(bind=engine) # Migrate existing databases _migrate_add_in_progress_column(engine) @@ -417,7 +471,7 @@ def set_session_maker(session_maker: sessionmaker) -> None: _session_maker = session_maker -def get_db() -> Session: +def get_db() -> Generator[Session, None, None]: """ Dependency for FastAPI to get database session. @@ -431,3 +485,74 @@ def get_db() -> Session: yield db finally: db.close() + + +# ============================================================================= +# Atomic Transaction Helpers for Parallel Mode +# ============================================================================= +# These helpers prevent database corruption when multiple processes access the +# same SQLite database concurrently. They use IMMEDIATE transactions which +# acquire write locks at the start (preventing stale reads) and atomic +# UPDATE ... WHERE clauses (preventing check-then-modify races). + + +from contextlib import contextmanager + + +@contextmanager +def atomic_transaction(session_maker, isolation_level: str = "IMMEDIATE"): + """Context manager for atomic SQLite transactions. + + Acquires a write lock immediately via BEGIN IMMEDIATE, preventing + stale reads in read-modify-write patterns. This is essential for + preventing race conditions in parallel mode. + + Note: The engine is configured via _configure_sqlite_immediate_transactions() + to use BEGIN IMMEDIATE for all transactions. The isolation_level parameter + is kept for backwards compatibility and for EXCLUSIVE transactions when + blocking readers is required. + + Args: + session_maker: SQLAlchemy sessionmaker + isolation_level: "IMMEDIATE" (default) or "EXCLUSIVE" + - IMMEDIATE: Acquires write lock at transaction start (default via event hooks) + - EXCLUSIVE: Also blocks other readers (requires explicit BEGIN EXCLUSIVE) + + Yields: + SQLAlchemy session with automatic commit/rollback + + Example: + with atomic_transaction(session_maker) as session: + # All reads in this block are protected by write lock + feature = session.query(Feature).filter(...).first() + feature.priority = new_priority + # Commit happens automatically on exit + """ + session = session_maker() + try: + # For EXCLUSIVE mode, override the default IMMEDIATE from event hooks + # For IMMEDIATE mode, the event hooks handle BEGIN IMMEDIATE automatically + if isolation_level == "EXCLUSIVE": + session.execute(text("BEGIN EXCLUSIVE")) + # Note: For IMMEDIATE, we don't issue BEGIN here - the event hook handles it + # This prevents the fragile "BEGIN on already-begun transaction" issue + yield session + session.commit() + except Exception: + try: + session.rollback() + except Exception: + pass # Don't let rollback failure mask original error + raise + finally: + session.close() + + +# Note: The following functions were removed as dead code (never imported/called): +# - atomic_claim_feature() +# - atomic_mark_passing() +# - atomic_update_priority_to_end() +# - atomic_get_next_priority() +# +# The MCP server reimplements this logic inline with proper atomic UPDATE WHERE +# clauses. See mcp_server/feature_mcp.py for the actual implementation. diff --git a/mcp_server/feature_mcp.py b/mcp_server/feature_mcp.py index a394f1e9..3dac260b 100755 --- a/mcp_server/feature_mcp.py +++ b/mcp_server/feature_mcp.py @@ -30,18 +30,18 @@ import json import os import sys -import threading from contextlib import asynccontextmanager from pathlib import Path from typing import Annotated from mcp.server.fastmcp import FastMCP from pydantic import BaseModel, Field +from sqlalchemy import text # Add parent directory to path so we can import from api module sys.path.insert(0, str(Path(__file__).parent.parent)) -from api.database import Feature, create_database +from api.database import Feature, atomic_transaction, create_database from api.dependency_resolver import ( MAX_DEPENDENCIES_PER_FEATURE, compute_scheduling_scores, @@ -96,8 +96,9 @@ class BulkCreateInput(BaseModel): _session_maker = None _engine = None -# Lock for priority assignment to prevent race conditions -_priority_lock = threading.Lock() +# NOTE: The old threading.Lock() was removed because it only worked per-process, +# not cross-process. In parallel mode, multiple MCP servers run in separate +# processes, so the lock was useless. We now use atomic SQL operations instead. @asynccontextmanager @@ -235,6 +236,8 @@ def feature_mark_passing( Updates the feature's passes field to true and clears the in_progress flag. Use this after you have implemented the feature and verified it works correctly. + Uses atomic SQL UPDATE for parallel safety. + Args: feature_id: The ID of the feature to mark as passing @@ -243,16 +246,22 @@ def feature_mark_passing( """ session = get_session() try: + # First get the feature name for the response feature = session.query(Feature).filter(Feature.id == feature_id).first() - if feature is None: return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - feature.passes = True - feature.in_progress = False + name = feature.name + + # Atomic update - prevents race conditions in parallel mode + session.execute(text(""" + UPDATE features + SET passes = 1, in_progress = 0 + WHERE id = :id + """), {"id": feature_id}) session.commit() - return json.dumps({"success": True, "feature_id": feature_id, "name": feature.name}) + return json.dumps({"success": True, "feature_id": feature_id, "name": name}) except Exception as e: session.rollback() return json.dumps({"error": f"Failed to mark feature passing: {str(e)}"}) @@ -270,6 +279,8 @@ def feature_mark_failing( Use this when a testing agent discovers that a previously-passing feature no longer works correctly (regression detected). + Uses atomic SQL UPDATE for parallel safety. + After marking as failing, you should: 1. Investigate the root cause 2. Fix the regression @@ -284,14 +295,20 @@ def feature_mark_failing( """ session = get_session() try: + # Check if feature exists first feature = session.query(Feature).filter(Feature.id == feature_id).first() - if feature is None: return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - feature.passes = False - feature.in_progress = False + # Atomic update - prevents race conditions in parallel mode + session.execute(text(""" + UPDATE features + SET passes = 0, in_progress = 0 + WHERE id = :id + """), {"id": feature_id}) session.commit() + + # Refresh to get updated state session.refresh(feature) return json.dumps({ @@ -320,6 +337,8 @@ def feature_skip( worked on after all other pending features. Also clears the in_progress flag so the feature returns to "pending" status. + Uses atomic SQL UPDATE with subquery for parallel safety. + Args: feature_id: The ID of the feature to skip @@ -337,25 +356,28 @@ def feature_skip( return json.dumps({"error": "Cannot skip a feature that is already passing"}) old_priority = feature.priority + name = feature.name + + # Atomic update: set priority to max+1 in a single statement + # This prevents race conditions where two features get the same priority + session.execute(text(""" + UPDATE features + SET priority = (SELECT COALESCE(MAX(priority), 0) + 1 FROM features), + in_progress = 0 + WHERE id = :id + """), {"id": feature_id}) + session.commit() - # Use lock to prevent race condition in priority assignment - with _priority_lock: - # Get max priority and set this feature to max + 1 - max_priority_result = session.query(Feature.priority).order_by(Feature.priority.desc()).first() - new_priority = (max_priority_result[0] + 1) if max_priority_result else 1 - - feature.priority = new_priority - feature.in_progress = False - session.commit() - + # Refresh to get new priority session.refresh(feature) + new_priority = feature.priority return json.dumps({ - "id": feature.id, - "name": feature.name, + "id": feature_id, + "name": name, "old_priority": old_priority, "new_priority": new_priority, - "message": f"Feature '{feature.name}' moved to end of queue" + "message": f"Feature '{name}' moved to end of queue" }) except Exception as e: session.rollback() @@ -373,6 +395,9 @@ def feature_mark_in_progress( This prevents other agent sessions from working on the same feature. Call this after getting your assigned feature details with feature_get_by_id. + Uses atomic UPDATE WHERE for parallel safety - prevents two agents from + claiming the same feature simultaneously. + Args: feature_id: The ID of the feature to mark as in-progress @@ -381,21 +406,27 @@ def feature_mark_in_progress( """ session = get_session() try: - feature = session.query(Feature).filter(Feature.id == feature_id).first() - - if feature is None: - return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - - if feature.passes: - return json.dumps({"error": f"Feature with ID {feature_id} is already passing"}) - - if feature.in_progress: - return json.dumps({"error": f"Feature with ID {feature_id} is already in-progress"}) - - feature.in_progress = True + # Atomic claim: only succeeds if feature is not already claimed or passing + result = session.execute(text(""" + UPDATE features + SET in_progress = 1 + WHERE id = :id AND passes = 0 AND in_progress = 0 + """), {"id": feature_id}) session.commit() - session.refresh(feature) + if result.rowcount == 0: + # Check why the claim failed + feature = session.query(Feature).filter(Feature.id == feature_id).first() + if feature is None: + return json.dumps({"error": f"Feature with ID {feature_id} not found"}) + if feature.passes: + return json.dumps({"error": f"Feature with ID {feature_id} is already passing"}) + if feature.in_progress: + return json.dumps({"error": f"Feature with ID {feature_id} is already in-progress"}) + return json.dumps({"error": "Failed to mark feature in-progress for unknown reason"}) + + # Fetch the claimed feature + feature = session.query(Feature).filter(Feature.id == feature_id).first() return json.dumps(feature.to_dict()) except Exception as e: session.rollback() @@ -413,6 +444,8 @@ def feature_claim_and_get( Combines feature_mark_in_progress + feature_get_by_id into a single operation. If already in-progress, still returns the feature details (idempotent). + Uses atomic UPDATE WHERE for parallel safety. + Args: feature_id: The ID of the feature to claim and retrieve @@ -421,24 +454,35 @@ def feature_claim_and_get( """ session = get_session() try: + # First check if feature exists and get initial state feature = session.query(Feature).filter(Feature.id == feature_id).first() - if feature is None: return json.dumps({"error": f"Feature with ID {feature_id} not found"}) if feature.passes: return json.dumps({"error": f"Feature with ID {feature_id} is already passing"}) - # Idempotent: if already in-progress, just return details - already_claimed = feature.in_progress - if not already_claimed: - feature.in_progress = True - session.commit() + # Try atomic claim: only succeeds if not already claimed + result = session.execute(text(""" + UPDATE features + SET in_progress = 1 + WHERE id = :id AND passes = 0 AND in_progress = 0 + """), {"id": feature_id}) + session.commit() + + # Determine if we claimed it or it was already claimed + already_claimed = result.rowcount == 0 + if already_claimed: + # Verify it's in_progress (not some other failure condition) session.refresh(feature) + if not feature.in_progress: + return json.dumps({"error": f"Failed to claim feature {feature_id} for unknown reason"}) - result = feature.to_dict() - result["already_claimed"] = already_claimed - return json.dumps(result) + # Refresh to get current state + session.refresh(feature) + result_dict = feature.to_dict() + result_dict["already_claimed"] = already_claimed + return json.dumps(result_dict) except Exception as e: session.rollback() return json.dumps({"error": f"Failed to claim feature: {str(e)}"}) @@ -455,6 +499,8 @@ def feature_clear_in_progress( Use this when abandoning a feature or manually unsticking a stuck feature. The feature will return to the pending queue. + Uses atomic SQL UPDATE for parallel safety. + Args: feature_id: The ID of the feature to clear in-progress status @@ -463,15 +509,20 @@ def feature_clear_in_progress( """ session = get_session() try: + # Check if feature exists feature = session.query(Feature).filter(Feature.id == feature_id).first() - if feature is None: return json.dumps({"error": f"Feature with ID {feature_id} not found"}) - feature.in_progress = False + # Atomic update - idempotent, safe in parallel mode + session.execute(text(""" + UPDATE features + SET in_progress = 0 + WHERE id = :id + """), {"id": feature_id}) session.commit() - session.refresh(feature) + session.refresh(feature) return json.dumps(feature.to_dict()) except Exception as e: session.rollback() @@ -492,6 +543,8 @@ def feature_create_bulk( This is typically used by the initializer agent to set up the initial feature list from the app specification. + Uses EXCLUSIVE transaction to prevent priority collisions in parallel mode. + Args: features: List of features to create, each with: - category (str): Feature category @@ -506,13 +559,14 @@ def feature_create_bulk( Returns: JSON with: created (int) - number of features created, with_dependencies (int) """ - session = get_session() try: - # Use lock to prevent race condition in priority assignment - with _priority_lock: - # Get the starting priority - max_priority_result = session.query(Feature.priority).order_by(Feature.priority.desc()).first() - start_priority = (max_priority_result[0] + 1) if max_priority_result else 1 + # Use EXCLUSIVE transaction for bulk inserts to prevent conflicts + with atomic_transaction(_session_maker, "EXCLUSIVE") as session: + # Get the starting priority atomically within the transaction + result = session.execute(text(""" + SELECT COALESCE(MAX(priority), 0) FROM features + """)).fetchone() + start_priority = (result[0] or 0) + 1 # First pass: validate all features and their index-based dependencies for i, feature_data in enumerate(features): @@ -546,11 +600,11 @@ def feature_create_bulk( "error": f"Feature at index {i} cannot depend on feature at index {idx} (forward reference not allowed)" }) - # Second pass: create all features + # Second pass: create all features with reserved priorities created_features: list[Feature] = [] for i, feature_data in enumerate(features): db_feature = Feature( - priority=start_priority + i, + priority=start_priority + i, # Guaranteed unique within EXCLUSIVE transaction category=feature_data["category"], name=feature_data["name"], description=feature_data["description"], @@ -574,17 +628,13 @@ def feature_create_bulk( created_features[i].dependencies = sorted(dep_ids) deps_count += 1 - session.commit() - - return json.dumps({ - "created": len(created_features), - "with_dependencies": deps_count - }) + # Commit happens automatically on context manager exit + return json.dumps({ + "created": len(created_features), + "with_dependencies": deps_count + }) except Exception as e: - session.rollback() return json.dumps({"error": str(e)}) - finally: - session.close() @mcp.tool() @@ -599,6 +649,8 @@ def feature_create( Use this when the user asks to add a new feature, capability, or test case. The feature will be added with the next available priority number. + Uses IMMEDIATE transaction for parallel safety. + Args: category: Feature category for grouping (e.g., 'Authentication', 'API', 'UI') name: Descriptive name for the feature @@ -608,13 +660,14 @@ def feature_create( Returns: JSON with the created feature details including its ID """ - session = get_session() try: - # Use lock to prevent race condition in priority assignment - with _priority_lock: - # Get the next priority - max_priority_result = session.query(Feature.priority).order_by(Feature.priority.desc()).first() - next_priority = (max_priority_result[0] + 1) if max_priority_result else 1 + # Use IMMEDIATE transaction to prevent priority collisions + with atomic_transaction(_session_maker, "IMMEDIATE") as session: + # Get the next priority atomically within the transaction + result = session.execute(text(""" + SELECT COALESCE(MAX(priority), 0) + 1 FROM features + """)).fetchone() + next_priority = result[0] db_feature = Feature( priority=next_priority, @@ -626,20 +679,18 @@ def feature_create( in_progress=False, ) session.add(db_feature) - session.commit() + session.flush() # Get the ID - session.refresh(db_feature) + feature_dict = db_feature.to_dict() + # Commit happens automatically on context manager exit return json.dumps({ "success": True, "message": f"Created feature: {name}", - "feature": db_feature.to_dict() + "feature": feature_dict }) except Exception as e: - session.rollback() return json.dumps({"error": str(e)}) - finally: - session.close() @mcp.tool() @@ -652,6 +703,8 @@ def feature_add_dependency( The dependency_id feature must be completed before feature_id can be started. Validates: self-reference, existence, circular dependencies, max limit. + Uses IMMEDIATE transaction to prevent stale reads during cycle detection. + Args: feature_id: The ID of the feature that will depend on another feature dependency_id: The ID of the feature that must be completed first @@ -659,52 +712,49 @@ def feature_add_dependency( Returns: JSON with success status and updated dependencies list, or error message """ - session = get_session() try: - # Security: Self-reference check + # Security: Self-reference check (can do before transaction) if feature_id == dependency_id: return json.dumps({"error": "A feature cannot depend on itself"}) - feature = session.query(Feature).filter(Feature.id == feature_id).first() - dependency = session.query(Feature).filter(Feature.id == dependency_id).first() - - if not feature: - return json.dumps({"error": f"Feature {feature_id} not found"}) - if not dependency: - return json.dumps({"error": f"Dependency feature {dependency_id} not found"}) - - current_deps = feature.dependencies or [] - - # Security: Max dependencies limit - if len(current_deps) >= MAX_DEPENDENCIES_PER_FEATURE: - return json.dumps({"error": f"Maximum {MAX_DEPENDENCIES_PER_FEATURE} dependencies allowed per feature"}) - - # Check if already exists - if dependency_id in current_deps: - return json.dumps({"error": "Dependency already exists"}) - - # Security: Circular dependency check - # would_create_circular_dependency(features, source_id, target_id) - # source_id = feature gaining the dependency, target_id = feature being depended upon - all_features = [f.to_dict() for f in session.query(Feature).all()] - if would_create_circular_dependency(all_features, feature_id, dependency_id): - return json.dumps({"error": "Cannot add: would create circular dependency"}) - - # Add dependency - current_deps.append(dependency_id) - feature.dependencies = sorted(current_deps) - session.commit() - - return json.dumps({ - "success": True, - "feature_id": feature_id, - "dependencies": feature.dependencies - }) + # Use IMMEDIATE transaction for consistent cycle detection + with atomic_transaction(_session_maker, "IMMEDIATE") as session: + feature = session.query(Feature).filter(Feature.id == feature_id).first() + dependency = session.query(Feature).filter(Feature.id == dependency_id).first() + + if not feature: + return json.dumps({"error": f"Feature {feature_id} not found"}) + if not dependency: + return json.dumps({"error": f"Dependency feature {dependency_id} not found"}) + + current_deps = feature.dependencies or [] + + # Security: Max dependencies limit + if len(current_deps) >= MAX_DEPENDENCIES_PER_FEATURE: + return json.dumps({"error": f"Maximum {MAX_DEPENDENCIES_PER_FEATURE} dependencies allowed per feature"}) + + # Check if already exists + if dependency_id in current_deps: + return json.dumps({"error": "Dependency already exists"}) + + # Security: Circular dependency check + # Within IMMEDIATE transaction, snapshot is protected by write lock + all_features = [f.to_dict() for f in session.query(Feature).all()] + if would_create_circular_dependency(all_features, feature_id, dependency_id): + return json.dumps({"error": "Cannot add: would create circular dependency"}) + + # Add dependency atomically + new_deps = sorted(current_deps + [dependency_id]) + feature.dependencies = new_deps + # Commit happens automatically on context manager exit + + return json.dumps({ + "success": True, + "feature_id": feature_id, + "dependencies": new_deps + }) except Exception as e: - session.rollback() return json.dumps({"error": f"Failed to add dependency: {str(e)}"}) - finally: - session.close() @mcp.tool() @@ -714,6 +764,8 @@ def feature_remove_dependency( ) -> str: """Remove a dependency from a feature. + Uses IMMEDIATE transaction for parallel safety. + Args: feature_id: The ID of the feature to remove a dependency from dependency_id: The ID of the dependency to remove @@ -721,30 +773,29 @@ def feature_remove_dependency( Returns: JSON with success status and updated dependencies list, or error message """ - session = get_session() try: - feature = session.query(Feature).filter(Feature.id == feature_id).first() - if not feature: - return json.dumps({"error": f"Feature {feature_id} not found"}) - - current_deps = feature.dependencies or [] - if dependency_id not in current_deps: - return json.dumps({"error": "Dependency does not exist"}) - - current_deps.remove(dependency_id) - feature.dependencies = current_deps if current_deps else None - session.commit() - - return json.dumps({ - "success": True, - "feature_id": feature_id, - "dependencies": feature.dependencies or [] - }) + # Use IMMEDIATE transaction for consistent read-modify-write + with atomic_transaction(_session_maker, "IMMEDIATE") as session: + feature = session.query(Feature).filter(Feature.id == feature_id).first() + if not feature: + return json.dumps({"error": f"Feature {feature_id} not found"}) + + current_deps = feature.dependencies or [] + if dependency_id not in current_deps: + return json.dumps({"error": "Dependency does not exist"}) + + # Remove dependency atomically + new_deps = [d for d in current_deps if d != dependency_id] + feature.dependencies = new_deps if new_deps else None + # Commit happens automatically on context manager exit + + return json.dumps({ + "success": True, + "feature_id": feature_id, + "dependencies": new_deps + }) except Exception as e: - session.rollback() return json.dumps({"error": f"Failed to remove dependency: {str(e)}"}) - finally: - session.close() @mcp.tool() @@ -890,6 +941,8 @@ def feature_set_dependencies( Validates: self-reference, existence of all dependencies, circular dependencies, max limit. + Uses IMMEDIATE transaction to prevent stale reads during cycle detection. + Args: feature_id: The ID of the feature to set dependencies for dependency_ids: List of feature IDs that must be completed first @@ -897,9 +950,8 @@ def feature_set_dependencies( Returns: JSON with success status and updated dependencies list, or error message """ - session = get_session() try: - # Security: Self-reference check + # Security: Self-reference check (can do before transaction) if feature_id in dependency_ids: return json.dumps({"error": "A feature cannot depend on itself"}) @@ -911,45 +963,45 @@ def feature_set_dependencies( if len(dependency_ids) != len(set(dependency_ids)): return json.dumps({"error": "Duplicate dependencies not allowed"}) - feature = session.query(Feature).filter(Feature.id == feature_id).first() - if not feature: - return json.dumps({"error": f"Feature {feature_id} not found"}) - - # Validate all dependencies exist - all_feature_ids = {f.id for f in session.query(Feature).all()} - missing = [d for d in dependency_ids if d not in all_feature_ids] - if missing: - return json.dumps({"error": f"Dependencies not found: {missing}"}) - - # Check for circular dependencies - all_features = [f.to_dict() for f in session.query(Feature).all()] - # Temporarily update the feature's dependencies for cycle check - test_features = [] - for f in all_features: - if f["id"] == feature_id: - test_features.append({**f, "dependencies": dependency_ids}) - else: - test_features.append(f) - - for dep_id in dependency_ids: - # source_id = feature_id (gaining dep), target_id = dep_id (being depended upon) - if would_create_circular_dependency(test_features, feature_id, dep_id): - return json.dumps({"error": f"Cannot add dependency {dep_id}: would create circular dependency"}) - - # Set dependencies - feature.dependencies = sorted(dependency_ids) if dependency_ids else None - session.commit() - - return json.dumps({ - "success": True, - "feature_id": feature_id, - "dependencies": feature.dependencies or [] - }) + # Use IMMEDIATE transaction for consistent cycle detection + with atomic_transaction(_session_maker, "IMMEDIATE") as session: + feature = session.query(Feature).filter(Feature.id == feature_id).first() + if not feature: + return json.dumps({"error": f"Feature {feature_id} not found"}) + + # Validate all dependencies exist + all_feature_ids = {f.id for f in session.query(Feature).all()} + missing = [d for d in dependency_ids if d not in all_feature_ids] + if missing: + return json.dumps({"error": f"Dependencies not found: {missing}"}) + + # Check for circular dependencies + # Within IMMEDIATE transaction, snapshot is protected by write lock + all_features = [f.to_dict() for f in session.query(Feature).all()] + # Temporarily update the feature's dependencies for cycle check + test_features = [] + for f in all_features: + if f["id"] == feature_id: + test_features.append({**f, "dependencies": dependency_ids}) + else: + test_features.append(f) + + for dep_id in dependency_ids: + if would_create_circular_dependency(test_features, feature_id, dep_id): + return json.dumps({"error": f"Cannot add dependency {dep_id}: would create circular dependency"}) + + # Set dependencies atomically + sorted_deps = sorted(dependency_ids) if dependency_ids else None + feature.dependencies = sorted_deps + # Commit happens automatically on context manager exit + + return json.dumps({ + "success": True, + "feature_id": feature_id, + "dependencies": sorted_deps or [] + }) except Exception as e: - session.rollback() return json.dumps({"error": f"Failed to set dependencies: {str(e)}"}) - finally: - session.close() if __name__ == "__main__": diff --git a/parallel_orchestrator.py b/parallel_orchestrator.py index 574cbd2c..5eb55cb0 100644 --- a/parallel_orchestrator.py +++ b/parallel_orchestrator.py @@ -19,7 +19,9 @@ """ import asyncio +import atexit import os +import signal import subprocess import sys import threading @@ -27,6 +29,8 @@ from pathlib import Path from typing import Callable, Literal +from sqlalchemy import text + from api.database import Feature, create_database from api.dependency_resolver import are_dependencies_satisfied, compute_scheduling_scores from progress import has_features @@ -139,11 +143,11 @@ def __init__( self, project_dir: Path, max_concurrency: int = DEFAULT_CONCURRENCY, - model: str = None, + model: str | None = None, yolo_mode: bool = False, testing_agent_ratio: int = 1, - on_output: Callable[[int, str], None] = None, - on_status: Callable[[int, str], None] = None, + on_output: Callable[[int, str], None] | None = None, + on_status: Callable[[int, str], None] | None = None, ): """Initialize the orchestrator. @@ -182,14 +186,18 @@ def __init__( # Track feature failures to prevent infinite retry loops self._failure_counts: dict[int, int] = {} + # Shutdown flag for async-safe signal handling + # Signal handlers only set this flag; cleanup happens in the main loop + self._shutdown_requested = False + # Session tracking for logging/debugging - self.session_start_time: datetime = None + self.session_start_time: datetime | None = None # Event signaled when any agent completes, allowing the main loop to wake # immediately instead of waiting for the full POLL_INTERVAL timeout. # This reduces latency when spawning the next feature after completion. - self._agent_completed_event: asyncio.Event = None # Created in run_loop - self._event_loop: asyncio.AbstractEventLoop = None # Stored for thread-safe signaling + self._agent_completed_event: asyncio.Event | None = None # Created in run_loop + self._event_loop: asyncio.AbstractEventLoop | None = None # Stored for thread-safe signaling # Database session for this orchestrator self._engine, self._session_maker = create_database(project_dir) @@ -375,7 +383,8 @@ def get_passing_count(self) -> int: session = self.get_session() try: session.expire_all() - return session.query(Feature).filter(Feature.passes == True).count() + count: int = session.query(Feature).filter(Feature.passes == True).count() + return count finally: session.close() @@ -511,11 +520,14 @@ def _spawn_coding_agent(self, feature_id: int) -> tuple[bool, str]: try: # CREATE_NO_WINDOW on Windows prevents console window pop-ups # stdin=DEVNULL prevents blocking on stdin reads + # encoding="utf-8" and errors="replace" fix Windows CP1252 issues (#138) popen_kwargs = { "stdin": subprocess.DEVNULL, "stdout": subprocess.PIPE, "stderr": subprocess.STDOUT, "text": True, + "encoding": "utf-8", + "errors": "replace", "cwd": str(AUTOCODER_ROOT), # Run from autocoder root for proper imports "env": {**os.environ, "PYTHONUNBUFFERED": "1"}, } @@ -546,7 +558,7 @@ def _spawn_coding_agent(self, feature_id: int) -> tuple[bool, str]: daemon=True ).start() - if self.on_status: + if self.on_status is not None: self.on_status(feature_id, "running") print(f"Started coding agent for feature #{feature_id}", flush=True) @@ -600,11 +612,14 @@ def _spawn_testing_agent(self) -> tuple[bool, str]: try: # CREATE_NO_WINDOW on Windows prevents console window pop-ups # stdin=DEVNULL prevents blocking on stdin reads + # encoding="utf-8" and errors="replace" fix Windows CP1252 issues (#138) popen_kwargs = { "stdin": subprocess.DEVNULL, "stdout": subprocess.PIPE, "stderr": subprocess.STDOUT, "text": True, + "encoding": "utf-8", + "errors": "replace", "cwd": str(AUTOCODER_ROOT), "env": {**os.environ, "PYTHONUNBUFFERED": "1"}, } @@ -658,11 +673,14 @@ async def _run_initializer(self) -> bool: # CREATE_NO_WINDOW on Windows prevents console window pop-ups # stdin=DEVNULL prevents blocking on stdin reads + # encoding="utf-8" and errors="replace" fix Windows CP1252 issues (#138) popen_kwargs = { "stdin": subprocess.DEVNULL, "stdout": subprocess.PIPE, "stderr": subprocess.STDOUT, "text": True, + "encoding": "utf-8", + "errors": "replace", "cwd": str(AUTOCODER_ROOT), "env": {**os.environ, "PYTHONUNBUFFERED": "1"}, } @@ -716,11 +734,14 @@ def _read_output( ): """Read output from subprocess and emit events.""" try: + if proc.stdout is None: + proc.wait() + return for line in proc.stdout: if abort.is_set(): break line = line.rstrip() - if self.on_output: + if self.on_output is not None: self.on_output(feature_id or 0, line) else: # Both coding and testing agents now use [Feature #X] format @@ -815,6 +836,9 @@ def _on_agent_complete( return # Coding agent completion + # feature_id is required for coding agents (always passed from start_feature) + assert feature_id is not None, "feature_id must not be None for coding agents" + debug_log.log("COMPLETE", f"Coding agent for feature #{feature_id} finished", return_code=return_code, status="success" if return_code == 0 else "failed") @@ -855,7 +879,7 @@ def _on_agent_complete( failure_count=failure_count) status = "completed" if return_code == 0 else "failed" - if self.on_status: + if self.on_status is not None: self.on_status(feature_id, status) # CRITICAL: This print triggers the WebSocket to emit agent_update with state='error' or 'success' print(f"Feature #{feature_id} {status}", flush=True) @@ -1014,7 +1038,7 @@ async def run_loop(self): debug_log.section("FEATURE LOOP STARTING") loop_iteration = 0 - while self.is_running: + while self.is_running and not self._shutdown_requested: loop_iteration += 1 if loop_iteration <= 3: print(f"[DEBUG] === Loop iteration {loop_iteration} ===", flush=True) @@ -1163,11 +1187,46 @@ def get_status(self) -> dict: "yolo_mode": self.yolo_mode, } + def cleanup(self) -> None: + """Clean up database resources. Safe to call multiple times. + + CRITICAL: Must be called when orchestrator exits to prevent database corruption. + - Forces WAL checkpoint to flush pending writes to main database file + - Disposes engine to close all connections + + This prevents stale cache issues when the orchestrator restarts. + + Idempotency: Sets _engine=None FIRST to prevent re-entry, then performs + cleanup operations. This is important because cleanup() can be called + from multiple paths (signal handler flag, finally block, atexit handler). + """ + # Atomically grab and clear the engine reference to prevent re-entry + engine = self._engine + self._engine = None + + if engine is None: + return # Already cleaned up + + try: + debug_log.log("CLEANUP", "Forcing WAL checkpoint before dispose") + with engine.connect() as conn: + conn.execute(text("PRAGMA wal_checkpoint(FULL)")) + conn.commit() + debug_log.log("CLEANUP", "WAL checkpoint completed, disposing engine") + except Exception as e: + debug_log.log("CLEANUP", f"WAL checkpoint failed (non-fatal): {e}") + + try: + engine.dispose() + debug_log.log("CLEANUP", "Engine disposed successfully") + except Exception as e: + debug_log.log("CLEANUP", f"Engine dispose failed: {e}") + async def run_parallel_orchestrator( project_dir: Path, max_concurrency: int = DEFAULT_CONCURRENCY, - model: str = None, + model: str | None = None, yolo_mode: bool = False, testing_agent_ratio: int = 1, ) -> None: @@ -1189,11 +1248,42 @@ async def run_parallel_orchestrator( testing_agent_ratio=testing_agent_ratio, ) + # Set up cleanup to run on exit (handles normal exit, exceptions, signals) + def cleanup_handler(): + debug_log.log("CLEANUP", "atexit cleanup handler invoked") + orchestrator.cleanup() + + atexit.register(cleanup_handler) + + # Set up async-safe signal handler for graceful shutdown + # IMPORTANT: Signal handlers run in the main thread's context and must be async-safe. + # Only setting flags is safe; file I/O, locks, and subprocess operations are NOT safe. + def signal_handler(signum, frame): + # Only set flags - everything else is unsafe in signal context + # The main loop checks _shutdown_requested and handles cleanup on safe code path + orchestrator._shutdown_requested = True + orchestrator.is_running = False + # Note: Don't call stop_all(), cleanup(), or sys.exit() here - those are unsafe + # The finally block and atexit handler will perform cleanup + + # Register SIGTERM handler for process termination signals + # Note: On Windows, SIGTERM works but subprocess termination behavior differs. + # Windows uses CTRL_C_EVENT/CTRL_BREAK_EVENT instead of Unix signals. + # The kill_process_tree() in process_utils.py handles this via psutil. + signal.signal(signal.SIGTERM, signal_handler) + + # Note: We intentionally do NOT register SIGINT handler + # Let Python raise KeyboardInterrupt naturally so the except block works + try: await orchestrator.run_loop() except KeyboardInterrupt: print("\n\nInterrupted by user. Stopping agents...", flush=True) orchestrator.stop_all() + finally: + # CRITICAL: Always clean up database resources on exit + # This forces WAL checkpoint and disposes connections + orchestrator.cleanup() def main(): diff --git a/progress.py b/progress.py index 0821c90a..2da4fea0 100644 --- a/progress.py +++ b/progress.py @@ -10,12 +10,34 @@ import os import sqlite3 import urllib.request +from contextlib import closing from datetime import datetime, timezone from pathlib import Path WEBHOOK_URL = os.environ.get("PROGRESS_N8N_WEBHOOK_URL") PROGRESS_CACHE_FILE = ".progress_cache" +# SQLite connection settings for parallel mode safety +SQLITE_TIMEOUT = 30 # seconds to wait for locks +SQLITE_BUSY_TIMEOUT_MS = 30000 # milliseconds for PRAGMA busy_timeout + + +def _get_connection(db_file: Path) -> sqlite3.Connection: + """Get a SQLite connection with proper timeout settings. + + Uses timeout=30s and PRAGMA busy_timeout=30000 for safe operation + in parallel mode where multiple processes access the same database. + + Args: + db_file: Path to the SQLite database file + + Returns: + sqlite3.Connection with proper timeout settings + """ + conn = sqlite3.connect(db_file, timeout=SQLITE_TIMEOUT) + conn.execute(f"PRAGMA busy_timeout={SQLITE_BUSY_TIMEOUT_MS}") + return conn + def has_features(project_dir: Path) -> bool: """ @@ -31,8 +53,6 @@ def has_features(project_dir: Path) -> bool: Returns False if no features exist (initializer needs to run). """ - import sqlite3 - # Check legacy JSON file first json_file = project_dir / "feature_list.json" if json_file.exists(): @@ -44,12 +64,11 @@ def has_features(project_dir: Path) -> bool: return False try: - conn = sqlite3.connect(db_file) - cursor = conn.cursor() - cursor.execute("SELECT COUNT(*) FROM features") - count = cursor.fetchone()[0] - conn.close() - return count > 0 + with closing(_get_connection(db_file)) as conn: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM features") + count: int = cursor.fetchone()[0] + return bool(count > 0) except Exception: # Database exists but can't be read or has no features table return False @@ -59,6 +78,8 @@ def count_passing_tests(project_dir: Path) -> tuple[int, int, int]: """ Count passing, in_progress, and total tests via direct database access. + Uses connection with proper timeout settings for parallel mode safety. + Args: project_dir: Directory containing the project @@ -70,36 +91,35 @@ def count_passing_tests(project_dir: Path) -> tuple[int, int, int]: return 0, 0, 0 try: - conn = sqlite3.connect(db_file) - cursor = conn.cursor() - # Single aggregate query instead of 3 separate COUNT queries - # Handle case where in_progress column doesn't exist yet (legacy DBs) - try: - cursor.execute(""" - SELECT - COUNT(*) as total, - SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing, - SUM(CASE WHEN in_progress = 1 THEN 1 ELSE 0 END) as in_progress - FROM features - """) - row = cursor.fetchone() - total = row[0] or 0 - passing = row[1] or 0 - in_progress = row[2] or 0 - except sqlite3.OperationalError: - # Fallback for databases without in_progress column - cursor.execute(""" - SELECT - COUNT(*) as total, - SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing - FROM features - """) - row = cursor.fetchone() - total = row[0] or 0 - passing = row[1] or 0 - in_progress = 0 - conn.close() - return passing, in_progress, total + with closing(_get_connection(db_file)) as conn: + cursor = conn.cursor() + # Single aggregate query instead of 3 separate COUNT queries + # Handle case where in_progress column doesn't exist yet (legacy DBs) + try: + cursor.execute(""" + SELECT + COUNT(*) as total, + SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing, + SUM(CASE WHEN in_progress = 1 THEN 1 ELSE 0 END) as in_progress + FROM features + """) + row = cursor.fetchone() + total = row[0] or 0 + passing = row[1] or 0 + in_progress = row[2] or 0 + except sqlite3.OperationalError: + # Fallback for databases without in_progress column + cursor.execute(""" + SELECT + COUNT(*) as total, + SUM(CASE WHEN passes = 1 THEN 1 ELSE 0 END) as passing + FROM features + """) + row = cursor.fetchone() + total = row[0] or 0 + passing = row[1] or 0 + in_progress = 0 + return passing, in_progress, total except Exception as e: print(f"[Database error in count_passing_tests: {e}]") return 0, 0, 0 @@ -109,6 +129,8 @@ def get_all_passing_features(project_dir: Path) -> list[dict]: """ Get all passing features for webhook notifications. + Uses connection with proper timeout settings for parallel mode safety. + Args: project_dir: Directory containing the project @@ -120,17 +142,16 @@ def get_all_passing_features(project_dir: Path) -> list[dict]: return [] try: - conn = sqlite3.connect(db_file) - cursor = conn.cursor() - cursor.execute( - "SELECT id, category, name FROM features WHERE passes = 1 ORDER BY priority ASC" - ) - features = [ - {"id": row[0], "category": row[1], "name": row[2]} - for row in cursor.fetchall() - ] - conn.close() - return features + with closing(_get_connection(db_file)) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id, category, name FROM features WHERE passes = 1 ORDER BY priority ASC" + ) + features = [ + {"id": row[0], "category": row[1], "name": row[2]} + for row in cursor.fetchall() + ] + return features except Exception: return [] diff --git a/registry.py b/registry.py index f84803e8..7d0c2afb 100644 --- a/registry.py +++ b/registry.py @@ -17,8 +17,7 @@ from typing import Any from sqlalchemy import Column, DateTime, Integer, String, create_engine, text -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import DeclarativeBase, sessionmaker # Module logger logger = logging.getLogger(__name__) @@ -75,7 +74,9 @@ class RegistryPermissionDenied(RegistryError): # SQLAlchemy Model # ============================================================================= -Base = declarative_base() +class Base(DeclarativeBase): + """SQLAlchemy 2.0 style declarative base.""" + pass class Project(Base): diff --git a/server/services/assistant_database.py b/server/services/assistant_database.py index f2ade75c..21192ff3 100644 --- a/server/services/assistant_database.py +++ b/server/services/assistant_database.py @@ -7,6 +7,7 @@ """ import logging +import threading from datetime import datetime, timezone from pathlib import Path from typing import Optional @@ -22,6 +23,10 @@ # Key: project directory path (as posix string), Value: SQLAlchemy engine _engine_cache: dict[str, object] = {} +# Lock for thread-safe access to the engine cache +# Prevents race conditions when multiple threads create engines simultaneously +_cache_lock = threading.Lock() + def _utc_now() -> datetime: """Return current UTC time. Replacement for deprecated datetime.utcnow().""" @@ -64,17 +69,33 @@ def get_engine(project_dir: Path): Uses a cache to avoid creating new engines for each request, which improves performance by reusing database connections. + + Thread-safe: Uses a lock to prevent race conditions when multiple threads + try to create engines simultaneously for the same project. """ cache_key = project_dir.as_posix() - if cache_key not in _engine_cache: - db_path = get_db_path(project_dir) - # Use as_posix() for cross-platform compatibility with SQLite connection strings - db_url = f"sqlite:///{db_path.as_posix()}" - engine = create_engine(db_url, echo=False) - Base.metadata.create_all(engine) - _engine_cache[cache_key] = engine - logger.debug(f"Created new database engine for {cache_key}") + # Double-checked locking for thread safety and performance + if cache_key in _engine_cache: + return _engine_cache[cache_key] + + with _cache_lock: + # Check again inside the lock in case another thread created it + if cache_key not in _engine_cache: + db_path = get_db_path(project_dir) + # Use as_posix() for cross-platform compatibility with SQLite connection strings + db_url = f"sqlite:///{db_path.as_posix()}" + engine = create_engine( + db_url, + echo=False, + connect_args={ + "check_same_thread": False, + "timeout": 30, # Wait up to 30s for locks + } + ) + Base.metadata.create_all(engine) + _engine_cache[cache_key] = engine + logger.debug(f"Created new database engine for {cache_key}") return _engine_cache[cache_key] diff --git a/test_atomic_operations.py b/test_atomic_operations.py new file mode 100644 index 00000000..01efb15c --- /dev/null +++ b/test_atomic_operations.py @@ -0,0 +1,368 @@ +""" +Concurrency Tests for Atomic Operations +======================================== + +Tests to verify that SQLite atomic operations work correctly under concurrent access. +These tests validate the fixes for PR #108 (fix/sqlite-parallel-corruption). + +Test cases: +1. test_atomic_claim_single_winner - verify exactly 1 thread succeeds claiming a feature +2. test_atomic_priority_no_duplicates - verify no duplicate priorities when skipping +3. test_cleanup_idempotent - verify multiple cleanup() calls don't error +4. test_atomic_transaction_isolation - verify IMMEDIATE prevents stale reads +5. test_event_hooks_applied - verify event hooks are configured on all connections + +Run with: python -m pytest test_atomic_operations.py -v +""" + +import tempfile +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pytest +from sqlalchemy import event, text + +from api.database import ( + Feature, + atomic_transaction, + create_database, +) + + +@pytest.fixture +def temp_db(): + """Create a temporary database for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + engine, session_maker = create_database(project_dir) + + # Create some test features + session = session_maker() + try: + for i in range(5): + feature = Feature( + priority=i + 1, + category="test", + name=f"Feature {i + 1}", + description=f"Test feature {i + 1}", + steps=["step 1", "step 2"], + passes=False, + in_progress=False, + ) + session.add(feature) + session.commit() + finally: + session.close() + + yield engine, session_maker + + # Cleanup + engine.dispose() + + +class TestAtomicClaimSingleWinner: + """Test that only one thread can claim a feature.""" + + def test_concurrent_claim_single_winner(self, temp_db): + """Spawn N threads calling atomic UPDATE WHERE, verify exactly 1 succeeds.""" + engine, session_maker = temp_db + num_threads = 10 + feature_id = 1 + + # Track results + results = {"claimed": 0, "failed": 0} + results_lock = threading.Lock() + barrier = threading.Barrier(num_threads) + + def try_claim(): + # Wait for all threads to be ready + barrier.wait() + + session = session_maker() + try: + # Atomic claim using UPDATE WHERE (same pattern as MCP server) + result = session.execute( + text(""" + UPDATE features + SET in_progress = 1 + WHERE id = :id AND passes = 0 AND in_progress = 0 + """), + {"id": feature_id}, + ) + session.commit() + + with results_lock: + if result.rowcount == 1: + results["claimed"] += 1 + else: + results["failed"] += 1 + except Exception: + with results_lock: + results["failed"] += 1 + finally: + session.close() + + # Run concurrent claims + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(try_claim) for _ in range(num_threads)] + for f in as_completed(futures): + f.result() # Raise any exceptions + + # Verify exactly 1 thread claimed the feature + assert results["claimed"] == 1, f"Expected 1 claim, got {results['claimed']}" + assert ( + results["failed"] == num_threads - 1 + ), f"Expected {num_threads - 1} failures, got {results['failed']}" + + # Verify database state + session = session_maker() + try: + feature = session.query(Feature).filter(Feature.id == feature_id).first() + assert feature.in_progress is True + finally: + session.close() + + +class TestAtomicPriorityNoDuplicates: + """Test that concurrent feature_skip operations don't create duplicate priorities.""" + + def test_concurrent_skip_no_duplicates(self, temp_db): + """Multiple threads skipping features simultaneously, verify no duplicate priorities.""" + engine, session_maker = temp_db + num_threads = 5 + + # Get feature IDs to skip + session = session_maker() + try: + feature_ids = [f.id for f in session.query(Feature).all()] + finally: + session.close() + + barrier = threading.Barrier(num_threads) + errors = [] + errors_lock = threading.Lock() + + def skip_feature(feature_id): + # Wait for all threads to be ready + barrier.wait() + + session = session_maker() + try: + # Atomic skip using MAX subquery (same pattern as MCP server) + session.execute( + text(""" + UPDATE features + SET priority = (SELECT COALESCE(MAX(priority), 0) + 1 FROM features), + in_progress = 0 + WHERE id = :id + """), + {"id": feature_id}, + ) + session.commit() + except Exception as e: + with errors_lock: + errors.append(str(e)) + finally: + session.close() + + # Run concurrent skips + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(skip_feature, fid) for fid in feature_ids[:num_threads] + ] + for f in as_completed(futures): + f.result() + + # Verify no errors + assert len(errors) == 0, f"Errors during skip: {errors}" + + # Verify no duplicate priorities + session = session_maker() + try: + priorities = [f.priority for f in session.query(Feature).all()] + unique_priorities = set(priorities) + assert len(priorities) == len( + unique_priorities + ), f"Duplicate priorities found: {priorities}" + finally: + session.close() + + +class TestCleanupIdempotent: + """Test that cleanup() can be called multiple times without errors.""" + + def test_cleanup_multiple_calls(self): + """Call cleanup() multiple times on ParallelOrchestrator, verify no errors.""" + from parallel_orchestrator import ParallelOrchestrator + + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + # Create empty features.db so orchestrator doesn't fail + create_database(project_dir) + + orchestrator = ParallelOrchestrator( + project_dir=project_dir, + max_concurrency=1, + ) + + # Call cleanup multiple times - should not raise + orchestrator.cleanup() + orchestrator.cleanup() + orchestrator.cleanup() + + # Verify engine is None after cleanup + assert orchestrator._engine is None + + +class TestAtomicTransactionIsolation: + """Test that atomic_transaction with IMMEDIATE prevents stale reads.""" + + def test_read_modify_write_isolation(self, temp_db): + """Verify IMMEDIATE transaction prevents stale read in read-modify-write.""" + engine, session_maker = temp_db + + # This test verifies that two concurrent read-modify-write operations + # don't both read the same value and create a conflict + + barrier = threading.Barrier(2) + + def increment_priority(feature_id): + barrier.wait() + + with atomic_transaction(session_maker) as session: + # Read current priority + feature = ( + session.query(Feature).filter(Feature.id == feature_id).first() + ) + old_priority = feature.priority + + # Modify + new_priority = old_priority + 100 + + # Write + session.execute( + text("UPDATE features SET priority = :new WHERE id = :id"), + {"new": new_priority, "id": feature_id}, + ) + + # Run two concurrent increments on the same feature + feature_id = 1 + + # Get initial priority + session = session_maker() + try: + initial_priority = ( + session.query(Feature).filter(Feature.id == feature_id).first().priority + ) + finally: + session.close() + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [executor.submit(increment_priority, feature_id) for _ in range(2)] + for f in as_completed(futures): + f.result() + + # Verify final priority - with proper isolation, each increment should see + # the other's write, so final = initial + 200 + # Without isolation, both might read initial and we'd get initial + 100 + session = session_maker() + try: + final_priority = ( + session.query(Feature).filter(Feature.id == feature_id).first().priority + ) + # With IMMEDIATE transactions and proper isolation, we expect initial + 200 + expected = initial_priority + 200 + assert ( + final_priority == expected + ), f"Expected {expected}, got {final_priority}. Lost update detected!" + finally: + session.close() + + +class TestEventHooksApplied: + """Test that SQLAlchemy event hooks are properly configured.""" + + def test_begin_immediate_hook_active(self, temp_db): + """Verify that the BEGIN IMMEDIATE event hook is active on connections.""" + engine, session_maker = temp_db + + # Track if our hook fired + hook_fired = {"begin": False, "connect": False} + + # Add test listeners to verify the existing hooks are working + @event.listens_for(engine, "begin") + def track_begin(conn): + hook_fired["begin"] = True + + # Create a new connection and start a transaction + session = session_maker() + try: + # This should trigger the begin hook via our event listener + session.execute(text("SELECT 1")) + session.commit() + finally: + session.close() + + # The begin hook should have fired + assert hook_fired["begin"], "BEGIN event hook did not fire" + + def test_isolation_level_none_on_connect(self, temp_db): + """Verify that pysqlite's implicit transaction is disabled.""" + engine, session_maker = temp_db + + # Get a raw DBAPI connection and check isolation_level + with engine.connect() as conn: + raw_conn = conn.connection.dbapi_connection + # Our do_connect hook sets isolation_level = None + # Note: In some pysqlite versions, None becomes "" (empty string) + # Both indicate autocommit mode (no implicit transactions) + assert raw_conn.isolation_level in ( + None, + "", + ), f"Expected isolation_level=None or '', got {raw_conn.isolation_level!r}" + + +class TestAtomicTransactionRollback: + """Test that atomic_transaction properly handles exceptions.""" + + def test_rollback_on_exception(self, temp_db): + """Verify that changes are rolled back on exception.""" + engine, session_maker = temp_db + feature_id = 1 + + # Get initial priority + session = session_maker() + try: + initial_priority = ( + session.query(Feature).filter(Feature.id == feature_id).first().priority + ) + finally: + session.close() + + # Try to modify and raise exception + try: + with atomic_transaction(session_maker) as session: + session.execute( + text("UPDATE features SET priority = 999 WHERE id = :id"), + {"id": feature_id}, + ) + raise ValueError("Intentional error") + except ValueError: + pass # Expected + + # Verify priority was rolled back + session = session_maker() + try: + final_priority = ( + session.query(Feature).filter(Feature.id == feature_id).first().priority + ) + assert ( + final_priority == initial_priority + ), f"Expected rollback to {initial_priority}, got {final_priority}" + finally: + session.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])