diff --git a/merlin/dag/__init__.py b/merlin/dag/__init__.py new file mode 100644 index 00000000..3232b50b --- /dev/null +++ b/merlin/dag/__init__.py @@ -0,0 +1,5 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## diff --git a/merlin/dag/dag.py b/merlin/dag/dag.py new file mode 100644 index 00000000..428ae70b --- /dev/null +++ b/merlin/dag/dag.py @@ -0,0 +1,183 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Holds the Merlin Directed Acyclic Graph (DAG) class. +""" +from collections import OrderedDict +from typing import Dict, List + +from merlin.dag.models import ExecutionLevel, ExecutionPlan, TaskChain +from merlin.study.step import Step + + +# TODO make this an interface, separate from Maestro. +class DAG: + """ + Refactored DAG class with cleaner data structures. + + This class provides methods on a task graph that Merlin needs for staging + tasks in Celery. It is initialized from a Maestro `ExecutionGraph`, and the + major entry point is the group_tasks method, which provides an ExecutionPlan + with clear structure instead of nested lists. + + Attributes: + backwards_adjacency (Dict): A dictionary mapping each task to its parent tasks. + column_labels (List[str]): A list of column labels provided in the spec file. + maestro_adjacency_table (OrderedDict): An ordered dict showing adjacency of nodes. + maestro_values (OrderedDict): An ordered dict of the values at each node. + parameter_info (Dict): A dict containing information about parameters in the study. + study_name (str): The name of the study. + """ + + def __init__( + self, + maestro_adjacency_table: OrderedDict, + maestro_values: OrderedDict, + column_labels: List[str], + study_name: str, + parameter_info: Dict, + ): + self.maestro_adjacency_table: OrderedDict = maestro_adjacency_table + self.maestro_values: OrderedDict = maestro_values + self.column_labels: List[str] = column_labels + self.study_name: str = study_name + self.parameter_info: Dict = parameter_info + self.backwards_adjacency: Dict = {} + self.calc_backwards_adjacency() + + def step(self, task_name: str): + """Return a Step object for the given task name.""" + return Step(self.maestro_values[task_name], self.study_name, self.parameter_info) + + def calc_depth(self, node: str, depths: Dict, current_depth: int = 0): + """Calculate the depth of the given node and its children.""" + if node not in depths: + depths[node] = current_depth + else: + depths[node] = max(depths[node], current_depth) + + for child in self.children(node): + self.calc_depth(child, depths, current_depth=depths[node] + 1) + + def group_by_depth(self, depths: Dict) -> List[ExecutionLevel]: + """ + Group DAG tasks by depth, returning ExecutionLevels instead of nested lists. + + Each task starts as its own single-task chain. The find_independent_chains + method will later coalesce compatible tasks into longer chains. + """ + # Group tasks by depth + depth_groups = {} + for node, depth in depths.items(): + if depth not in depth_groups: + depth_groups[depth] = [] + depth_groups[depth].append(node) + + # Create ExecutionLevels with each task as its own chain initially + levels = [] + for depth in sorted(depth_groups.keys()): + chains = [TaskChain([task], depth) for task in depth_groups[depth]] + level = ExecutionLevel(depth, chains) + levels.append(level) + + return levels + + def children(self, task_name: str) -> List: + """Return the children of the task.""" + return self.maestro_adjacency_table[task_name] + + def num_children(self, task_name: str) -> int: + """Find the number of children for the given task.""" + return len(self.children(task_name)) + + def parents(self, task_name: str) -> List: + """Return the parents of the task.""" + return self.backwards_adjacency.get(task_name, []) + + def num_parents(self, task_name: str) -> int: + """Find the number of parents for the given task.""" + return len(self.parents(task_name)) + + def find_chain_containing_task(self, task_name: str, levels: List[ExecutionLevel]) -> TaskChain: + """Find the chain containing the given task.""" + for level in levels: + chain = level.find_chain_containing_task(task_name) + if chain: + return chain + return None + + def calc_backwards_adjacency(self): + """Initialize the backwards adjacency table.""" + self.backwards_adjacency = {} + for parent in self.maestro_adjacency_table: + for task_name in self.maestro_adjacency_table[parent]: + if task_name in self.backwards_adjacency: + self.backwards_adjacency[task_name].append(parent) + else: + self.backwards_adjacency[task_name] = [parent] + + def compatible_merlin_expansion(self, task1: str, task2: str) -> bool: + """Check if two tasks are compatible for Merlin expansion.""" + step1 = self.step(task1) + step2 = self.step(task2) + return step1.check_if_expansion_needed(self.column_labels) == step2.check_if_expansion_needed(self.column_labels) + + def find_independent_chains(self, levels: List[ExecutionLevel]) -> List[ExecutionLevel]: + """ + Finds independent chains and coalesces them to maximize parallelism. + + This is much cleaner than the original nested list manipulation! + """ + for level in levels: + for chain in level.parallel_chains: + # Process each task in the chain (need to iterate safely since we're modifying) + tasks_to_process = chain.tasks.copy() + for task_name in tasks_to_process: + # Skip if task is not in chain anymore (may have been moved) + if task_name not in chain.tasks: + continue + + # Check if this task can be coalesced with its child + if self.num_children(task_name) == 1 and task_name != "_source": + child = self.children(task_name)[0] + + if self.num_parents(child) == 1 and self.compatible_merlin_expansion(child, task_name): + + # Find the child's chain and remove it from there + child_chain = self.find_chain_containing_task(child, levels) + if child_chain and child_chain != chain: + child_chain.remove_task(child) + # Add child to current chain + chain.add_task(child) + + # Clean up empty chains + for level in levels: + level.remove_empty_chains() + + # Remove empty levels + non_empty_levels = [level for level in levels if level.parallel_chains] + + return non_empty_levels + + def group_tasks(self, source_node: str) -> ExecutionPlan: + """ + Group independent tasks in a DAG, returning a clean ExecutionPlan. + + This is much more intuitive than the original nested list approach! + """ + # Calculate depths + depths = {} + self.calc_depth(source_node, depths) + + # Group by depth into ExecutionLevels + levels = self.group_by_depth(depths) + + # Find and coalesce independent chains + optimized_levels = self.find_independent_chains(levels) + + # Return a clean ExecutionPlan + return ExecutionPlan(optimized_levels) diff --git a/merlin/dag/models.py b/merlin/dag/models.py new file mode 100644 index 00000000..8d4fcb38 --- /dev/null +++ b/merlin/dag/models.py @@ -0,0 +1,511 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +from dataclasses import dataclass +from typing import List + + +@dataclass +class TaskChain: + """ + Represents a sequence of tasks that must execute sequentially within a workflow. + + A `TaskChain` defines a linear dependency chain where each task must complete + before the next task can begin. All tasks within a chain execute on the same + execution level (depth) but maintain strict ordering amongst themselves. + + This is a lightweight data structure that stores task names as strings rather + than full Step objects for efficient serialization in distributed environments + like Celery. + + Note: + Tasks within a chain have sequential dependencies, but different chains + at the same depth can execute in parallel. Use [`ExecutionLevel`][dag.models.ExecutionLevel] + to group parallel chains. + + Attributes: + tasks (List[str]): Ordered list of task names that comprise this chain. + Tasks execute in the order they appear in this list. + depth (int): The execution depth/level of this chain within the overall + workflow. All tasks in this chain execute at the same depth. + + Example: + ```python + # Create a chain of preprocessing tasks + preprocess_chain = TaskChain( + tasks=["validate_input", "clean_data", "normalize"], + depth=1 + ) + + print(preprocess_chain) # "validate_input -> clean_data -> normalize" + print(len(preprocess_chain)) # 3 + + # Check if a specific task is in this chain + if preprocess_chain.contains_task("clean_data"): + print("Data cleaning is part of preprocessing") + ``` + """ + + tasks: List[str] + depth: int + + def __str__(self) -> str: + """ + Return a human-readable representation of the task chain. + + Returns: + Tasks joined by arrows (->) to show execution flow. + """ + return " -> ".join(self.tasks) + + def __len__(self) -> int: + """ + Return the number of tasks in this chain. + + Returns: + Count of tasks in the chain. + """ + return len(self.tasks) + + def add_task(self, task: str): + """ + Add a task to the end of this chain. + + The new task will execute after all existing tasks in the chain + have completed successfully. + + Args: + task (str): Name of the task to append to the chain. + + Example: + ```python + chain = TaskChain(tasks=["setup"], depth=0) + chain.add_task("configure") + print(chain) # "setup -> configure" + ``` + """ + self.tasks.append(task) + + def remove_task(self, task: str): + """ + Remove the first occurrence of a task from this chain. + + If the task appears multiple times in the chain, only the first + occurrence is removed. If the task is not found, no action is taken. + + Warning: + Removing tasks from the middle of a chain may break dependencies. + Ensure that removing a task doesn't create gaps in the workflow logic. + + Args: + task (str): Name of the task to remove from the chain. + + Example: + ```python + chain = TaskChain(tasks=["a", "b", "c"], depth=1) + chain.remove_task("b") + print(chain) # "a -> c" + ``` + """ + if task in self.tasks: + self.tasks.remove(task) + + def contains_task(self, task: str) -> bool: + """ + Check if this chain contains the specified task. + + Args: + task (str): Name of the task to search for. + + Returns: + bool: True if the task is found in this chain, False otherwise. + + Example: + ```python + chain = TaskChain(tasks=["init", "process", "cleanup"], depth=2) + assert chain.contains_task("process") == True + assert chain.contains_task("missing") == False + ``` + """ + return task in self.tasks + + +@dataclass +class ExecutionLevel: + """ + Represents all task chains that can execute in parallel at a given workflow depth. + + An `ExecutionLevel` groups [`TaskChains`][dag.models.TaskChain] that have no + dependencies between them, allowing them to execute concurrently. All chains within + a level must complete before the workflow can proceed to the next depth level. + + This structure enables efficient parallel execution while maintaining proper + dependency ordering across the overall workflow. + + Note: + While chains within a level execute in parallel, tasks within each + individual chain still execute sequentially according to their `TaskChain` ordering. + + Attributes: + depth (int): The execution depth of this level within the workflow. + Lower depths execute before higher depths. + parallel_chains (List[TaskChain]): List of task chains that can execute + concurrently at this depth level. + + Example: + ```python + # Create parallel analysis chains at depth 2 + analysis_level = ExecutionLevel( + depth=2, + parallel_chains=[ + TaskChain(tasks=["analyze_cpu"], depth=2), + TaskChain(tasks=["analyze_memory"], depth=2), + TaskChain(tasks=["analyze_disk"], depth=2) + ] + ) + + print(analysis_level) + # Level 2: ['analyze_cpu', 'analyze_memory', 'analyze_disk'] + + # All three analysis tasks can run simultaneously + all_tasks = analysis_level.get_all_tasks() + print(f"Parallel tasks: {all_tasks}") + ``` + """ + + depth: int + parallel_chains: List[TaskChain] + + def __str__(self) -> str: + """ + Return a human-readable representation of this execution level. + + Returns: + Formatted string showing depth and all parallel chains. + """ + chain_strs = [str(chain) for chain in self.parallel_chains] + return f"Level {self.depth}: {chain_strs}" + + def add_chain(self, chain: TaskChain): + """ + Add a parallel task chain to this execution level. + + The added chain will execute in parallel with all other chains + at this level, but its internal tasks will still execute sequentially. + + Warning: + The chain's depth should match this level's depth for consistency, + though this is not enforced by this method. + + Args: + chain (TaskChain): The task chain to add to this level's parallel execution. + + Example: + ```python + level = ExecutionLevel(depth=1, parallel_chains=[]) + + # Add independent processing chains + level.add_chain(TaskChain(tasks=["process_images"], depth=1)) + level.add_chain(TaskChain(tasks=["process_text"], depth=1)) + + # Both chains will now execute in parallel + ``` + """ + self.parallel_chains.append(chain) + + def get_all_tasks(self) -> List[str]: + """ + Retrieve all task names from all chains in this execution level. + + This flattens the parallel chain structure to return a single list + containing every task that will execute at this level. + + Note: + The order of tasks in the returned list reflects the order of chains + and tasks within chains, but does not imply execution order since + chains execute in parallel. + + Returns: + List[str]: Flat list of all task names across all parallel chains. + + Example: + ```python + level = ExecutionLevel( + depth=1, + parallel_chains=[ + TaskChain(tasks=["a", "b"], depth=1), + TaskChain(tasks=["c"], depth=1) + ] + ) + + all_tasks = level.get_all_tasks() + print(all_tasks) # ["a", "b", "c"] + ``` + """ + return [task for chain in self.parallel_chains for task in chain.tasks] + + def find_chain_containing_task(self, task: str) -> TaskChain: + """ + Locate the task chain containing a specific task. + + Searches through all parallel chains in this level to find which + chain contains the specified task. + + Args: + task (str): Name of the task to locate. + + Returns: + The [`TaskChain`][dag.models.TaskChain] containing the task, or + None if the task is not found in any chain at this level. + + Example: + ```python + level = ExecutionLevel( + depth=2, + parallel_chains=[ + TaskChain(tasks=["prep", "analyze"], depth=2), + TaskChain(tasks=["validate", "report"], depth=2) + ] + ) + + chain = level.find_chain_containing_task("analyze") + print(chain) # TaskChain with ["prep", "analyze"] + + missing = level.find_chain_containing_task("missing") + print(missing) # None + ``` + """ + for chain in self.parallel_chains: + if chain.contains_task(task): + return chain + return None + + def remove_empty_chains(self): + """ + Remove any task chains that contain no tasks. + + This cleanup method filters out chains that have become empty, + which can occur after task removal operations or during workflow + construction. + + Example: + ```python + level = ExecutionLevel( + depth=1, + parallel_chains=[ + TaskChain(tasks=["active_task"], depth=1), + TaskChain(tasks=[], depth=1), # Empty chain + TaskChain(tasks=["another_task"], depth=1) + ] + ) + + level.remove_empty_chains() + print(len(level.parallel_chains)) # 2 (empty chain removed) + ``` + + Note: + This operation modifies the parallel_chains list in-place. + """ + self.parallel_chains = [chain for chain in self.parallel_chains if len(chain.tasks) > 0] + + +class ExecutionPlan: + """ + Container for a complete workflow execution plan with multiple depth levels. + + An `ExecutionPlan` represents the full structure of a workflow, organizing + tasks into levels that execute sequentially, with parallel chains within + each level. This provides a clear, queryable representation of complex + workflow dependencies and execution ordering. + + The plan enforces that: + + - All tasks at depth N complete before any tasks at depth N+1 begin + - Tasks within the same [`ExecutionLevel`][dag.models.ExecutionLevel] can + execute in parallel + - Tasks within the same [`TaskChain`][dag.models.TaskChain] execute sequentially + + Note: + `ExecutionPlan` is designed for planning and analysis, not execution. + Use appropriate [`TaskExecutor`][execution.base.TaskExecutor] implementations + to actually run the tasks. + + Attributes: + levels (List[ExecutionLevel]): Ordered list of execution levels, + typically sorted by depth for sequential execution. + + Example: + ```python + # Create a multi-level execution plan + plan = ExecutionPlan([ + ExecutionLevel(depth=0, parallel_chains=[ + TaskChain(tasks=["initialize"], depth=0) + ]), + ExecutionLevel(depth=1, parallel_chains=[ + TaskChain(tasks=["process_a", "analyze_a"], depth=1), + TaskChain(tasks=["process_b"], depth=1) + ]), + ExecutionLevel(depth=2, parallel_chains=[ + TaskChain(tasks=["finalize"], depth=2) + ]) + ]) + + print(f"Plan has {len(plan.levels)} levels") + print(f"Total tasks: {len(plan.get_all_tasks())}") + print(f"Max execution depth: {plan.get_max_depth()}") + + # Find where a specific task will execute + location = plan.find_task_location("process_a") + print(f"Task 'process_a' at depth {location[0]}, chain {location[1]}") + ``` + """ + + def __init__(self, levels: List[ExecutionLevel] = None): + """ + Constructor for `ExecutionPlan`. + + Args: + levels (Optional[List[ExecutionLevel]]): An optional list of execution levels + to add to the plan. + """ + self.levels = levels or [] + + def __str__(self) -> str: + """ + Return a human-readable representation of the entire execution plan. + + Returns: + Multi-line string showing all levels and their chains. + Each level is displayed on a separate line. + """ + return "\n".join(str(level) for level in self.levels) + + def get_level(self, depth: int) -> ExecutionLevel: + """ + Retrieve the execution level at a specific depth. + + Args: + depth (int): The depth level to retrieve. + + Returns: + The [`ExecutionLevel`][dag.models.ExecutionLevel] at the + specified depth, or None if no level exists at that depth. + + Example: + ```python + plan = ExecutionPlan([ + ExecutionLevel(depth=0, parallel_chains=[...]), + ExecutionLevel(depth=2, parallel_chains=[...]) # Note: no depth 1 + ]) + + level_0 = plan.get_level(0) # Returns ExecutionLevel + level_1 = plan.get_level(1) # Returns None + level_2 = plan.get_level(2) # Returns ExecutionLevel + ``` + """ + for level in self.levels: + if level.depth == depth: + return level + return None + + def get_max_depth(self) -> int: + """ + Determine the maximum depth level in this execution plan. + + Note: + This represents the total number of sequential execution phases + in the workflow minus one (since depths are 0-indexed). + + Returns: + The highest depth value among all levels, or 0 if no levels exist. + + Example: + ```python + plan = ExecutionPlan([ + ExecutionLevel(depth=0, parallel_chains=[...]), + ExecutionLevel(depth=3, parallel_chains=[...]), + ExecutionLevel(depth=1, parallel_chains=[...]) + ]) + + max_depth = plan.get_max_depth() # Returns 3 + ``` + """ + return max(level.depth for level in self.levels) if self.levels else 0 + + def get_all_tasks(self) -> List[str]: + """ + Retrieve all task names from the entire execution plan. + + This flattens the complete plan structure (levels -> chains -> tasks) + into a single list containing every task in the workflow. + + Note: + The order reflects the structure (level order, then chain order within + levels, then task order within chains) but does not imply execution + order due to parallelism. + + Returns: + Complete list of all task names across all levels and chains. + + Example: + ```python + plan = ExecutionPlan([ + ExecutionLevel(depth=0, parallel_chains=[ + TaskChain(tasks=["init"], depth=0) + ]), + ExecutionLevel(depth=1, parallel_chains=[ + TaskChain(tasks=["a", "b"], depth=1), + TaskChain(tasks=["c"], depth=1) + ]) + ]) + + all_tasks = plan.get_all_tasks() + print(all_tasks) # ["init", "a", "b", "c"] + ``` + """ + return [task for level in self.levels for task in level.get_all_tasks()] + + def find_task_location(self, task: str) -> tuple: + """ + Locate a specific task within the execution plan structure. + + Searches through all levels and chains to find the exact location + of a task within the plan hierarchy. + + Args: + task (str): Name of the task to locate. + + Returns: + Optional[Tuple[int, int]]: Tuple of (depth, chain_index) if found, + where: + - depth: The execution level depth containing the task + - chain_index: Index of the chain within that level's parallel_chains + Returns None if the task is not found. + + Example: + ```python + plan = ExecutionPlan([ + ExecutionLevel(depth=0, parallel_chains=[ + TaskChain(tasks=["init"], depth=0) # chain_index=0 + ]), + ExecutionLevel(depth=1, parallel_chains=[ + TaskChain(tasks=["a"], depth=1), # chain_index=0 + TaskChain(tasks=["b"], depth=1) # chain_index=1 + ]) + ]) + + location = plan.find_task_location("b") + print(location) # (1, 1) - depth 1, chain index 1 + + missing = plan.find_task_location("missing") + print(missing) # None + ``` + """ + for level in self.levels: + for chain_idx, chain in enumerate(level.parallel_chains): + if chain.contains_task(task): + return (level.depth, chain_idx) + return None diff --git a/merlin/execution/__init__.py b/merlin/execution/__init__.py new file mode 100644 index 00000000..3232b50b --- /dev/null +++ b/merlin/execution/__init__.py @@ -0,0 +1,5 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## diff --git a/merlin/execution/base.py b/merlin/execution/base.py new file mode 100644 index 00000000..8648cb49 --- /dev/null +++ b/merlin/execution/base.py @@ -0,0 +1,43 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" """ + +from abc import ABC, abstractmethod +from typing import Dict, List + +from merlin.dag.models import ExecutionPlan, TaskChain +from merlin.execution.models import ExecutionContext, TaskResult + + +class TaskExecutor(ABC): + """Abstract base class for different execution strategies.""" + + @abstractmethod + def execute_plan(self, plan: ExecutionPlan, context: ExecutionContext, wait: bool = False, timeout: int = 7200) -> Dict: + """ + Execute the entire plan and return results. + + Args: + plan: Execution plan to execute + context: Execution context + wait: If True, block until execution completes. Default: False + timeout: Timeout in seconds when using wait=True. Default: 7200 + + Returns: + Dictionary containing results and execution information + """ + pass + + @abstractmethod + def execute_chain(self, chain: TaskChain, context: ExecutionContext) -> List[TaskResult]: + """Execute a single chain of tasks.""" + pass + + @abstractmethod + def execute_task(self, task_name: str, context: ExecutionContext) -> TaskResult: + """Execute a single task.""" + pass diff --git a/merlin/execution/celery.py b/merlin/execution/celery.py new file mode 100644 index 00000000..d1f42ab3 --- /dev/null +++ b/merlin/execution/celery.py @@ -0,0 +1,515 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" """ + +import time +import uuid +from typing import Dict, List + +from merlin.dag.models import ExecutionLevel, ExecutionPlan, TaskChain +from merlin.execution.base import TaskExecutor +from merlin.execution.models import ExecutionContext, TaskResult, TaskStatus + + +class CeleryExecutor(TaskExecutor): + """Celery-based task executor.""" + + def __init__(self, default_queue: str = "default"): + from merlin.celery import app + from merlin.execution.sample_expander import SampleExpander + + self.celery_app = app + self.default_queue = default_queue + self.active_tasks = {} # Track running tasks + self.sample_expander = SampleExpander() + + def execute_plan(self, plan: ExecutionPlan, context: ExecutionContext, wait: bool = False, timeout: int = 7200) -> Dict: + """ + Execute the plan using chain(group(...), group(...), ...) pattern. + + Args: + plan: Execution plan to execute + context: Execution context + wait: If True, block until workflow completes. Default: False (non-blocking) + timeout: Timeout in seconds when using wait=True. Default: 7200 (2 hours) + + Returns: + Dictionary with 'results', 'async_result', and 'workflow_id' + """ + from celery import chain, group + + all_results = {} + all_groups = [] + + # Build all groups upfront + for level in plan.levels: + chain_sigs, level_results = self._process_level(level, context) + all_results.update(level_results) + batches = self._create_batches(chain_sigs, batch_size=100) + for batch in batches: + if batch: + all_groups.append(group(*batch)) + + # Submit and handle workflow + if not all_groups: + return {"results": all_results, "async_result": None, "workflow_id": None} + + async_result, workflow_id = self._submit_workflow(all_groups, chain) + workflow_info = self._write_workflow_info(context, plan, all_results, workflow_id) + + if wait: + self._wait_for_completion(async_result, workflow_info, all_results, context, timeout) + + return {"results": all_results, "async_result": async_result, "workflow_id": workflow_id} + + def _process_level(self, level: ExecutionLevel, context: ExecutionContext) -> tuple: + """Process a single level and return chain signatures and results.""" + print(f"Preparing depth {level.depth} with {len(level.parallel_chains)} parallel chains...") + + all_chain_sigs = [] + level_results = {} + + for chain_obj in level.parallel_chains: + if not self._chain_has_real_tasks(chain_obj, context): + for task in chain_obj.tasks: + level_results[task] = TaskResult( + task_name=task, status=TaskStatus.SKIPPED, error="Virtual node, no execution needed" + ) + continue + + chain_sigs, task_results = self._expand_and_build_chain(chain_obj, context) + all_chain_sigs.extend(chain_sigs) + level_results.update(task_results) + + print(f"Level {level.depth}: {len(all_chain_sigs)} chain signatures") + return all_chain_sigs, level_results + + def _chain_has_real_tasks(self, chain_obj: TaskChain, context: ExecutionContext) -> bool: + """Check if a chain has any real (non-virtual) tasks.""" + for task in chain_obj.tasks: + try: + step = context.study.dag.step(task) + if step is not None: + return True + except (AttributeError, KeyError, TypeError): + pass + return False + + def _expand_and_build_chain(self, chain_obj: TaskChain, context: ExecutionContext) -> tuple: + """Expand a chain and build Celery signatures.""" + expanded_positions = self.sample_expander.expand_chain(chain_obj, context) + + total_expanded = sum(len(pos) for pos in expanded_positions) + chain_name = chain_obj.tasks[0] if chain_obj.tasks else "unknown" + print(f" Chain '{chain_name}' expanded to {total_expanded} tasks across {len(expanded_positions)} positions") + + chain_sigs = self._build_chain_with_dependencies(expanded_positions, context) + print(f" Created {len(chain_sigs)} Celery chain signatures") + + task_results = {} + for position_tasks in expanded_positions: + for task_info in position_tasks: + task_name = task_info["step"].name() + task_results[task_name] = TaskResult(task_name=task_name, status=TaskStatus.COMPLETED, result=None) + + return chain_sigs, task_results + + def _submit_workflow(self, all_groups: List, chain_func) -> tuple: + """Submit the workflow chain and return async_result and workflow_id.""" + print(f"\nSubmitting workflow chain with {len(all_groups)} batch groups...") + workflow_chain = chain_func(*all_groups) + async_result = workflow_chain.apply_async() + workflow_id = async_result.id + return async_result, workflow_id + + def _write_workflow_info(self, context: ExecutionContext, plan: ExecutionPlan, results: Dict, workflow_id: str) -> Dict: + """Write workflow info to JSON file and return the info dict.""" + import json + import os + + workspace = context.study.workspace + workflow_info_file = os.path.join(workspace, "WORKFLOW_INFO.json") + workflow_info = { + "workflow_id": workflow_id, + "submitted_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "study_name": context.study.expanded_spec.name, + "num_levels": len(plan.levels), + "num_tasks": len(results), + "status": "SUBMITTED", + "_file_path": workflow_info_file, + } + + try: + with open(workflow_info_file, "w") as f: + json.dump({k: v for k, v in workflow_info.items() if not k.startswith("_")}, f, indent=2) + print(f"\nWorkflow submitted! ID: {workflow_id}") + print(f"Workflow info saved to: {workflow_info_file}") + except Exception as e: + print(f"Warning: Could not save workflow info: {e}") + + return workflow_info + + def _wait_for_completion(self, async_result, workflow_info: Dict, results: Dict, context: ExecutionContext, timeout: int): + """Wait for workflow completion and update status.""" + print(f"\nWaiting for workflow to complete (timeout: {timeout}s)...") + print("Press Ctrl+C to stop waiting (workflow will continue in background)") + + try: + async_result.get(timeout=timeout) + print("Workflow completed successfully") + self._update_workflow_status(workflow_info, "COMPLETED") + except KeyboardInterrupt: + print("\n\nStopped waiting. Workflow continues in background.") + print(f"Check status with: merlin status {context.study.expanded_spec.name}") + print(f"Workflow ID: {workflow_info['workflow_id']}") + except Exception as e: + print(f"Workflow failed with error: {e}") + for task_name in results.keys(): + results[task_name] = TaskResult(task_name=task_name, status=TaskStatus.FAILED, error=str(e)) + self._update_workflow_status(workflow_info, "FAILED", error=str(e)) + + def _update_workflow_status(self, workflow_info: Dict, status: str, error: str = None): + """Update workflow info file with new status.""" + import json + + workflow_info["status"] = status + timestamp_key = "completed_at" if status == "COMPLETED" else "failed_at" + workflow_info[timestamp_key] = time.strftime("%Y-%m-%d %H:%M:%S") + if error: + workflow_info["error"] = error + + try: + file_path = workflow_info.get("_file_path") + if file_path: + with open(file_path, "w") as f: + json.dump({k: v for k, v in workflow_info.items() if not k.startswith("_")}, f, indent=2) + except Exception as e: + print(f"Warning: Could not update workflow info: {e}") + + def _execute_level_parallel(self, level: ExecutionLevel, context: ExecutionContext) -> Dict[str, TaskResult]: + """Execute all chains in a level in parallel, with sample expansion and dependencies.""" + from celery import group + + level_results = {} + all_chain_sigs = [] # Collect signatures for all chains at this level + + # Expand and build each chain + for chain in level.parallel_chains: + # Check if chain has real tasks (not virtual nodes) + has_real_tasks = False + for task in chain.tasks: + try: + step = context.study.dag.step(task) + if step is not None: + has_real_tasks = True + break + except (AttributeError, KeyError, TypeError): + pass + + if not has_real_tasks: + # This is a virtual chain (e.g., _source only), mark as skipped + for task in chain.tasks: + level_results[task] = TaskResult( + task_name=task, status=TaskStatus.SKIPPED, error="Virtual node, no execution needed" + ) + continue + + # Expand chain into 2D structure: [[pos0_tasks], [pos1_tasks], ...] + expanded_positions = self.sample_expander.expand_chain(chain, context) + + # Log expansion details + total_expanded = sum(len(pos) for pos in expanded_positions) + chain_name = chain.tasks[0] if chain.tasks else "unknown" + print( + f" Chain '{chain_name}' expanded to {total_expanded} tasks " + f"across {len(expanded_positions)} positions" + ) + + # Build chain with dependencies + chain_sigs = self._build_chain_with_dependencies(expanded_positions, context) + all_chain_sigs.extend(chain_sigs) + print(f" Created {len(chain_sigs)} Celery signatures for this chain") + + # Mark tasks as queued + for position_tasks in expanded_positions: + for task_info in position_tasks: + task_name = task_info["step"].name() + level_results[task_name] = TaskResult( + task_name=task_name, status=TaskStatus.COMPLETED, result=None # Indicates successfully queued + ) + + # Count tasks for reporting + total_tasks = len(level_results) + print(f"Expanded {len(level.parallel_chains)} chains into {total_tasks} tasks with dependencies") + + # Execute all chains as group (parallel chains, but each chain maintains internal dependencies) + if all_chain_sigs: + print(f"Submitting {len(all_chain_sigs)} chain signatures to Celery...") + task_group = group(all_chain_sigs) + async_result = task_group.apply_async() + + # CRITICAL FIX: Wait for this level to complete before proceeding to next level + # This ensures dependencies between levels are properly enforced + print(f"Waiting for level {level.depth} to complete...") + try: + # Wait for all tasks in this level to complete + # Timeout set to 1 hour per level (can be adjusted) + _ = async_result.get(timeout=3600) + print(f"Level {level.depth} completed successfully") + except Exception as e: + print(f"Level {level.depth} failed with error: {e}") + # Mark tasks as failed + for task_name in level_results.keys(): + level_results[task_name] = TaskResult(task_name=task_name, status=TaskStatus.FAILED, error=str(e)) + + return level_results + + def _submit_chain_to_celery(self, chain: TaskChain, context: ExecutionContext): + """Submit a chain to Celery as a chain of tasks.""" + # This would use Celery's chain primitive + # Simplified example: + celery_chain = self._build_celery_chain(chain, context) + return celery_chain.apply_async(queue=self.default_queue) + + def _build_celery_chain(self, chain: TaskChain, context: ExecutionContext): + """Build a Celery chain from a TaskChain.""" + from celery import chain as celery_chain + + from merlin.common.tasks import merlin_step + + # Get adapter config for tasks + adapter_config = context.study.get_adapter_config(override_type="celery") + + # Build signatures for each task in the chain (skip virtual nodes) + celery_tasks = [] + for task_name in chain.tasks: + try: + # Get Step object from DAG + step = context.study.dag.step(task_name) + + # Skip virtual nodes (like _source) + if step is None: + continue + + # Create signature for merlin_step task + sig = merlin_step.s(step, adapter_config=adapter_config) + sig.set(queue=step.get_task_queue()) + celery_tasks.append(sig) + except (AttributeError, KeyError, TypeError): + # This is a virtual node, skip it + continue + + return celery_chain(*celery_tasks) + + def _mark_dependent_tasks_skipped(self, plan: ExecutionPlan, failed_depth: int, results: Dict[str, TaskResult]): + """Mark tasks that depend on failed tasks as skipped.""" + for level in plan.levels: + if level.depth > failed_depth: + for task in level.get_all_tasks(): + results[task] = TaskResult(task_name=task, status=TaskStatus.SKIPPED, error="Dependency failed") + + def _create_batches(self, task_infos: List[Dict], batch_size: int = 100) -> List[List[Dict]]: + """ + Split task infos into batches. + + Args: + task_infos: List of task info dicts + batch_size: Max tasks per batch (default: 100) + + Returns: + List of batches + """ + if not task_infos: + return [] + + batches = [] + for i in range(0, len(task_infos), batch_size): + batch = task_infos[i : i + batch_size] + batches.append(batch) + + print(f"Created {len(batches)} batches from {len(task_infos)} tasks (batch_size={batch_size})") + return batches + + def _create_task_signature(self, task_info: Dict, adapter_config: Dict): + """ + Create a Celery signature for a task. + + Args: + task_info: Dictionary containing step and metadata + adapter_config: Adapter configuration + + Returns: + Celery signature + """ + from merlin.common.tasks import merlin_step + + step = task_info["step"] + sig = merlin_step.s(step, adapter_config=adapter_config) + sig.set(queue=step.get_task_queue()) + return sig + + def _link_chain_positions(self, all_chains: List[List]) -> List: + """ + Link tasks at different chain positions with dependencies using Celery chains. + + Args: + all_chains: 2D list [[pos0_tasks], [pos1_tasks], ...] + + Returns: + List of Celery chain() primitives, one per parallel sample + """ + from celery import chain + + if len(all_chains) == 0: + return [] + + if len(all_chains) == 1: + # Single position - no linking needed + return all_chains[0] + + # Multi-position chain: use Celery's chain() primitive + # Build one chain per parallel sample/task + chains = [] + num_parallel = len(all_chains[0]) # Number of parallel tasks + + for i in range(num_parallel): + # Collect tasks at position i across all chain positions + task_sequence = [all_chains[j][i] for j in range(len(all_chains))] + # Create a Celery chain + chains.append(chain(*task_sequence)) + + return chains + + def _build_chain_with_dependencies(self, expanded_positions: List[List[Dict]], context: ExecutionContext) -> List: + """ + Build a chain with dependencies from expanded positions. + + Args: + expanded_positions: 2D structure from SampleExpander + context: Execution context + + Returns: + List of signatures with dependencies properly linked + """ + adapter_config = context.study.get_adapter_config(override_type="celery") + + # Convert each position's tasks to signatures + all_sig_chains = [] + for position_tasks in expanded_positions: + position_sigs = [self._create_task_signature(task_info, adapter_config) for task_info in position_tasks] + all_sig_chains.append(position_sigs) + + # Link positions with dependencies + if len(all_sig_chains) > 1: + # Multi-position chain: use linking logic + linked_sigs = self._link_chain_positions(all_sig_chains) + else: + # Single position: just return the signatures + linked_sigs = all_sig_chains[0] if all_sig_chains else [] + + return linked_sigs + + def execute_chain(self, chain: TaskChain, context: ExecutionContext) -> List[TaskResult]: + """Execute a single chain via Celery chain primitive.""" + from celery import chain as celery_chain + + from merlin.common.tasks import merlin_step + + results = [] + adapter_config = context.study.get_adapter_config(override_type="celery") + + # Build Celery chain (skip virtual nodes) + sigs = [] + real_task_names = [] + for task_name in chain.tasks: + try: + step = context.study.dag.step(task_name) + + # Skip virtual nodes + if step is None: + results.append( + TaskResult(task_name=task_name, status=TaskStatus.SKIPPED, error="Virtual node, no execution needed") + ) + continue + + sig = merlin_step.s(step, adapter_config=adapter_config) + sig.set(queue=step.get_task_queue()) + sigs.append(sig) + real_task_names.append(task_name) + except (AttributeError, KeyError, TypeError): + # This is a virtual node, skip it + results.append( + TaskResult(task_name=task_name, status=TaskStatus.SKIPPED, error="Virtual node, no execution needed") + ) + + # Execute chain (only if there are real tasks) + if sigs: + try: + start_time = time.time() + async_result = celery_chain(*sigs).apply_async() + _ = async_result.get(timeout=3600) # 1 hour timeout + end_time = time.time() + + # Create success results for all real tasks + for task_name in real_task_names: + results.append( + TaskResult(task_name=task_name, status=TaskStatus.COMPLETED, start_time=start_time, end_time=end_time) + ) + except Exception as e: + # Mark all real tasks in chain as failed + for task_name in real_task_names: + results.append(TaskResult(task_name=task_name, status=TaskStatus.FAILED, error=str(e))) + + return results + + def execute_task(self, task_name: str, context: ExecutionContext) -> TaskResult: + """Execute a single task via Celery.""" + from merlin.common.tasks import merlin_step + + try: + # Get Step object from DAG + step = context.study.dag.step(task_name) + + # Skip virtual nodes + if step is None: + return TaskResult(task_name=task_name, status=TaskStatus.SKIPPED, error="Virtual node, no execution needed") + + celery_id = str(uuid.uuid4()) + + # Get adapter config + adapter_config = context.study.get_adapter_config(override_type="celery") + + # Create Celery signature + sig = merlin_step.s(step, adapter_config=adapter_config) + sig.set(queue=step.get_task_queue()) + + # Submit to Celery + start_time = time.time() + async_result = sig.apply_async(task_id=celery_id) + + # Wait for result + result = async_result.get(timeout=1800) # 30 min timeout + end_time = time.time() + + return TaskResult( + task_name=task_name, + status=TaskStatus.COMPLETED, + start_time=start_time, + end_time=end_time, + result=result, + celery_id=celery_id, + ) + except (AttributeError, KeyError, TypeError): + # This is a virtual node + return TaskResult(task_name=task_name, status=TaskStatus.SKIPPED, error="Virtual node, no execution needed") + except Exception as e: + return TaskResult( + task_name=task_name, + status=TaskStatus.FAILED, + error=str(e), + celery_id=celery_id if "celery_id" in locals() else None, + ) diff --git a/merlin/execution/local.py b/merlin/execution/local.py new file mode 100644 index 00000000..901989ee --- /dev/null +++ b/merlin/execution/local.py @@ -0,0 +1,318 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +LocalExecutor with sample expansion support. +""" + +import json +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Dict, List + +from merlin.common.enums import ReturnCode +from merlin.dag.models import ExecutionPlan, TaskChain +from merlin.execution.base import TaskExecutor +from merlin.execution.models import ExecutionContext, TaskResult, TaskStatus +from merlin.execution.sample_expander import SampleExpander + + +# Success return codes that indicate task completed successfully +# SOFT_FAIL is included because it allows dependent tasks to continue +SUCCESS_CODES = {ReturnCode.OK, ReturnCode.DRY_OK, ReturnCode.SOFT_FAIL} + + +def write_status(status_file: str, status: str, return_code=None, elapsed_time=None): + """Write status information to a JSON file.""" + status_data = {"status": status, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")} + if return_code is not None: + status_data["return_code"] = return_code + if elapsed_time is not None: + status_data["elapsed_time"] = elapsed_time + + with open(status_file, "w") as f: + json.dump(status_data, f, indent=2) + + +class LocalExecutor(TaskExecutor): + """Local process pool executor with sample expansion support.""" + + def __init__(self, max_workers: int = 4): + """ + Initialize LocalExecutor. + + Args: + max_workers: Maximum number of parallel worker processes (default: 4) + """ + self.sample_expander = SampleExpander() + self.max_workers = max_workers + + def execute_plan( + self, plan: ExecutionPlan, context: ExecutionContext, wait: bool = True, timeout: int = 7200 + ) -> Dict[str, TaskResult]: + """ + Execute plan level-by-level using local process pool. + + Args: + plan: Execution plan to execute + context: Execution context + wait: Ignored for LocalExecutor (always blocks). Included for API compatibility. + timeout: Ignored for LocalExecutor. Included for API compatibility. + + Returns: + Dictionary mapping task names to TaskResults + """ + all_results = {} + + # Create process pool + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + for level in plan.levels: + print(f"Executing depth {level.depth} with {len(level.parallel_chains)} parallel chains...") + + # Execute level and wait for completion + level_results = self._execute_level_parallel(level, context, executor) + all_results.update(level_results) + + # Check for failures + failed = [k for k, v in level_results.items() if v.status == TaskStatus.FAILED] + if failed: + print(f"Level {level.depth} had failures: {failed}") + print("Stopping execution due to failures") + break + + return all_results + + def _execute_level_parallel( + self, level, context: ExecutionContext, executor: ProcessPoolExecutor + ) -> Dict[str, TaskResult]: + """ + Execute all chains in a level using process pool. + + Args: + level: ExecutionLevel to execute + context: Execution context + executor: ProcessPoolExecutor to use + + Returns: + Dictionary mapping task names to TaskResults + """ + level_results = {} + + for chain in level.parallel_chains: + # Skip virtual nodes + has_real_tasks = self._has_real_tasks(chain, context) + if not has_real_tasks: + for task in chain.tasks: + level_results[task] = TaskResult(task_name=task, status=TaskStatus.SKIPPED, error="Virtual node") + continue + + # Expand chain with samples + expanded_positions = self.sample_expander.expand_chain(chain, context) + + # Log expansion details + total_expanded = sum(len(pos) for pos in expanded_positions) + chain_name = chain.tasks[0] if chain.tasks else "unknown" + print( + f" Chain '{chain_name}' expanded to {total_expanded} tasks " + f"across {len(expanded_positions)} positions" + ) + + # Execute chain with dependencies (sequential positions, parallel samples) + position_results = self._execute_chain_with_dependencies(expanded_positions, context, executor) + level_results.update(position_results) + + return level_results + + def _execute_chain_with_dependencies( + self, expanded_positions: List[List[Dict]], context: ExecutionContext, executor: ProcessPoolExecutor + ) -> Dict[str, TaskResult]: + """ + Execute a chain with multiple positions sequentially. + + Each position must complete before the next position starts. + Within each position, tasks execute in parallel. + + Args: + expanded_positions: 2D structure [[pos0_tasks], [pos1_tasks], ...] + context: Execution context + executor: ProcessPoolExecutor to use + + Returns: + Dictionary mapping task names to TaskResults + """ + all_results = {} + adapter_config = context.study.get_adapter_config(override_type="local") + # Add task_server field so Step._update_status_file knows we're running locally + adapter_config["task_server"] = "local" + + # Execute each position sequentially + for position_idx, position_tasks in enumerate(expanded_positions): + print(f" Executing position {position_idx} with {len(position_tasks)} tasks...") + + # Submit all tasks at this position (parallel) + position_futures = {} + for task_info in position_tasks: + future = executor.submit(self._execute_step_wrapper, task_info["step"], adapter_config) + position_futures[future] = task_info + + # Wait for this position to complete before moving to next + for future in as_completed(position_futures): + task_info = position_futures[future] + task_name = task_info["step"].name() + + try: + return_code = future.result(timeout=3600) # 1 hour timeout per task + + # Check for success codes (OK=0, DRY_OK=103, etc.) + if return_code in SUCCESS_CODES or return_code in {rc.value for rc in SUCCESS_CODES}: + all_results[task_name] = TaskResult( + task_name=task_name, status=TaskStatus.COMPLETED, result=return_code + ) + else: + all_results[task_name] = TaskResult( + task_name=task_name, + status=TaskStatus.FAILED, + error=f"Task returned non-zero exit code: {return_code}", + ) + except Exception as e: + all_results[task_name] = TaskResult(task_name=task_name, status=TaskStatus.FAILED, error=str(e)) + + return all_results + + @staticmethod + def _execute_step_wrapper(step, adapter_config): + """ + Wrapper for executing a step in a subprocess. + + This is a static method because it needs to be pickleable + for process pool execution. + + Args: + step: Step object to execute + adapter_config: Adapter configuration + + Returns: + Return code (0 for success, non-zero for failure) + """ + import os + import traceback + + from merlin.common.enums import ReturnCode + + try: + # Get workspace + workspace = step.get_workspace() + step_name = step.name() + + # Check if already completed + finished_file = f"{workspace}/MERLIN_FINISHED" + if os.path.exists(finished_file): + import logging + + LOG = logging.getLogger(__name__) + LOG.info(f"Skipping step '{step_name}' in '{workspace}' (already finished).") + return ReturnCode.OK + + # Get max_retries from step (default to 10 if not specified) + try: + max_retries = step.max_retries + except (AttributeError, KeyError): + max_retries = 10 + + # Execute step with retry logic for RESTART + retry_count = 0 + return_code = None + while retry_count <= max_retries: + # Execute step (handles script generation and execution internally) + return_code = step.execute(adapter_config) + + # Check if we need to restart + if return_code == ReturnCode.RESTART or return_code == ReturnCode.RESTART.value: + retry_count += 1 + import logging + + LOG = logging.getLogger(__name__) + LOG.info(f"Step '{step_name}' requested restart (attempt {retry_count}/{max_retries})") + + # Check if max retries exceeded + if retry_count > max_retries: + LOG.info(f"Step '{step_name}' exceeded max retries ({max_retries}), returning SOFT_FAIL") + return_code = ReturnCode.SOFT_FAIL + break + + # Mark step for restart and continue loop + step.restart = True + continue + + # Not a restart - break out of loop + break + + # Touch MERLIN_FINISHED if successful (OK, DRY_OK, or SOFT_FAIL) + if return_code in ( + ReturnCode.OK, + ReturnCode.OK.value, + ReturnCode.DRY_OK, + ReturnCode.DRY_OK.value, + ReturnCode.SOFT_FAIL, + ReturnCode.SOFT_FAIL.value, + ): + open(finished_file, "w").close() + + return return_code + + except Exception as e: + # Log exception + import logging + + LOG = logging.getLogger(__name__) + LOG.error(f"Error executing step {step.name()}: {e}") + LOG.debug(traceback.format_exc()) + + # Re-raise to let the caller handle it + raise + + def _has_real_tasks(self, chain: TaskChain, context: ExecutionContext) -> bool: + """ + Check if a chain has real tasks (not just virtual nodes). + + Args: + chain: TaskChain to check + context: Execution context + + Returns: + True if chain has real tasks, False if only virtual nodes + """ + for task in chain.tasks: + try: + step = context.study.dag.step(task) + if step is not None: + return True + except (AttributeError, KeyError, TypeError): + pass + return False + + def execute_chain(self, chain: TaskChain, context: ExecutionContext) -> List[TaskResult]: + """Execute chain locally (legacy method for compatibility).""" + results = [] + for task_name in chain.tasks: + result = self.execute_task(task_name, context) + results.append(result) + return results + + def execute_task(self, task_name: str, context: ExecutionContext) -> TaskResult: + """Execute task locally (legacy method for compatibility).""" + try: + start_time = time.time() + step_to_execute = context.study.dag.step(task_name) + adapter_config = context.study.get_adapter_config(override_type="local") + result = step_to_execute.execute(adapter_config) + end_time = time.time() + + return TaskResult( + task_name=task_name, status=TaskStatus.COMPLETED, start_time=start_time, end_time=end_time, result=result + ) + except Exception as e: + return TaskResult(task_name=task_name, status=TaskStatus.FAILED, error=str(e)) diff --git a/merlin/execution/models.py b/merlin/execution/models.py new file mode 100644 index 00000000..171b8b0d --- /dev/null +++ b/merlin/execution/models.py @@ -0,0 +1,165 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" """ + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional + +from merlin.study.study import MerlinStudy + + +class TaskStatus(Enum): + """ + Enumeration of possible states for a task during workflow execution. + + This enum provides a standardized way to track and report the lifecycle + state of individual tasks as they progress through the execution pipeline. + + Values: + PENDING: Task is queued and waiting to be executed. + RUNNING: Task is currently being executed. + COMPLETED: Task has finished successfully. + FAILED: Task encountered an error and could not complete. + SKIPPED: Task was intentionally skipped due to conditions or dependencies. + + Example: + ```python + result = TaskResult( + task_name="analyze_data", + status=TaskStatus.RUNNING, + start_time=time.time() + ) + + # Later, after task completion + result.status = TaskStatus.COMPLETED + result.end_time = time.time() + ``` + """ + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + SKIPPED = "skipped" + + +# TODO in Merlin 2.0 we can probably convert this and ExecutionContext into data models for MerlinDatabase to ingest +@dataclass +class TaskResult: + """ + Represents the execution result and metadata for a single task. + + This class captures comprehensive information about a task's execution, + including its final state, timing information, outputs, and any errors + encountered. It serves as the primary data structure for tracking task + outcomes and enabling workflow monitoring and debugging. + + Attributes: + task_name (str): Unique identifier/name of the task that was executed. + status (TaskStatus): Current execution state of the task. + start_time (Optional[float]): Unix timestamp when task execution began, + None if not started yet. + end_time (Optional[float]): Unix timestamp when task execution completed, + None if still running or not started. + result (Any): The return value or output produced by the task, + None if no result or task failed. + error (Optional[str]): Error message if the task failed, + None if no error occurred. + celery_id (Optional[str]): Celery task ID for distributed execution tracking, + None for local execution or non-Celery backends. + + Example: + ```python + # Successful task result + success_result = TaskResult( + task_name="data_preprocessing", + status=TaskStatus.COMPLETED, + start_time=1693843200.0, + end_time=1693843260.0, + result={"processed_rows": 1000, "output_file": "processed_data.csv"}, + celery_id="abc123-def456-ghi789" + ) + + # Failed task result + failed_result = TaskResult( + task_name="model_training", + status=TaskStatus.FAILED, + start_time=1693843300.0, + end_time=1693843400.0, + error="Insufficient memory: required 8GB, available 4GB" + ) + + # Calculate execution duration + if success_result.start_time and success_result.end_time: + duration = success_result.end_time - success_result.start_time + print(f"Task completed in {duration:.1f} seconds") + ``` + """ + + task_name: str + status: TaskStatus + start_time: Optional[float] = None + end_time: Optional[float] = None + result: Any = None + error: Optional[str] = None + celery_id: Optional[str] = None + + +@dataclass +class ExecutionContext: # TODO entry(ies) for samples? + """ + Execution context and configuration passed to task executors. + + This class encapsulates all the contextual information needed by task + executors to properly run workflows, including study configuration, + parameter information, and execution metadata. It serves as a data + container that travels with the execution plan through the execution pipeline. + + Attributes: + study (study.study.MerlinStudy): The complete Merlin study configuration containing + DAG structure, samples, and workflow specifications. + parameter_info (Dict): Parameter definitions and configurations used + for task parameterization and sample expansion. + execution_id (str): Unique identifier for this specific workflow execution + run, useful for tracking and logging. + metadata (Dict): Additional arbitrary metadata that may be needed + during execution. Initialized as empty dict if not provided. + + Example: + ```python + from merlin.study.study import MerlinStudy + + # Create execution context for a study run + context = ExecutionContext( + study=my_merlin_study, + parameter_info={ + "temperature": {"type": "float", "range": [100, 500]}, + "pressure": {"type": "int", "values": [1, 5, 10]} + }, + execution_id="study_run_20240904_143022", + metadata={ + "user": "researcher", + "cluster": "quartz", + "submit_time": "2024-09-04T14:30:22Z" + } + ) + + # Pass context to executor + executor = CeleryExecutor(app=celery_app) + results = executor.execute_plan(execution_plan, context) + ``` + """ + + study: MerlinStudy + parameter_info: Dict + execution_id: str + metadata: Dict = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} diff --git a/merlin/execution/sample_expander.py b/merlin/execution/sample_expander.py new file mode 100644 index 00000000..de7acb93 --- /dev/null +++ b/merlin/execution/sample_expander.py @@ -0,0 +1,203 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +import logging +from typing import Dict, List + +from merlin.common.sample_index import uniform_directories +from merlin.common.sample_index_factory import create_hierarchy +from merlin.dag.models import TaskChain +from merlin.execution.models import ExecutionContext +from merlin.spec.expansion import parameter_substitutions_for_cmd, parameter_substitutions_for_sample + + +LOG = logging.getLogger(__name__) + + +class SampleExpander: + """ + Handles expansion of task chains based on samples. + + This class takes a TaskChain and expands it into multiple chains, + one for each sample, with parameter substitutions applied. + """ + + def __init__(self, level_max_dirs: int = 25): + self.level_max_dirs = level_max_dirs + + def needs_expansion(self, chain: TaskChain, context: ExecutionContext) -> bool: + """ + Check if a chain needs sample expansion. + + Args: + chain: The TaskChain to check + context: Execution context with study info + + Returns: + True if expansion needed, False otherwise + """ + # Check if any step in chain needs expansion + dag = context.study.dag + labels = context.study.sample_labels + + if not labels or len(labels) == 0: + return False + + for task_name in chain.tasks: + step = dag.step(task_name) + if step is not None and step.check_if_expansion_needed(labels): + return True + + return False + + def expand_chain(self, chain: TaskChain, context: ExecutionContext) -> List[List[Dict]]: + """ + Expand a chain into sample-specific chains, preserving chain structure. + + Args: + chain: The TaskChain to expand + context: Execution context + + Returns: + 2D list structure: [[step0_samples...], [step1_samples...], ...] + Each inner list contains dictionaries with: + - 'step': Step object + - 'sample_id': int (or None if no expansion) + - 'sample_values': Dict (or None if no expansion) + - 'workspace': str + - 'chain_position': int (position in original chain) + - 'original_chain': TaskChain (the original chain) + """ + samples = context.study.samples + labels = context.study.sample_labels + dag = context.study.dag + + # STEP 1: Calculate glob_path and sample_paths for ALL steps (even those not expanded) + # This is needed for steps like 'collect' that reference $(MERLIN_GLOB_PATH) + glob_path = "" + sample_paths = "" + + if samples is not None and len(samples) > 0: + # Create sample hierarchy to get glob_path and sample_paths + directory_sizes = uniform_directories(len(samples), bundle_size=1, level_max_dirs=self.level_max_dirs) + + # Build glob_path (e.g., "*/*/*/*/*") + # CRITICAL FIX: Add one extra level for Maestro execution directories (samples0-1.ext, etc.) + # Files end up in: workspace/sample_dir/execution_dir/file.json + # So we need: */* to match sample_dir/execution_dir + # Note: No trailing slash since the command adds one + glob_path = "/".join(["*"] * (len(directory_sizes) + 1)) + LOG.info( + f"Calculated glob_path: '{glob_path}' from directory_sizes={directory_sizes} " + f"(+1 for execution dir) for {len(samples)} samples" + ) + + # Create sample index to get all sample paths + sample_index = create_hierarchy( + len(samples), bundle_size=1, directory_sizes=directory_sizes, root="", n_digits=len(str(self.level_max_dirs)) + ) + + # Build sample_paths string (e.g., "00/00/00:00/00/01:...") + sample_paths = sample_index.make_directory_string() + + # STEP 2: Apply MERLIN_GLOB_PATH and MERLIN_PATHS_ALL to ALL steps in chain + # This must happen BEFORE checking if expansion is needed + steps_with_glob = [] + for task_name in chain.tasks: + step = dag.step(task_name) + if step is not None: + # Clone with glob substitutions + step_with_glob = step.clone_changing_workspace_and_cmd( + cmd_replacement_pairs=parameter_substitutions_for_cmd(glob_path, sample_paths) + ) + steps_with_glob.append((task_name, step_with_glob)) + + # STEP 3: Check if expansion is needed (after glob substitution) + needs_expansion = False + if labels and len(labels) > 0: + for task_name, step in steps_with_glob: + if step.check_if_expansion_needed(labels): + needs_expansion = True + break + + num_samples = len(samples) if samples is not None else 0 + LOG.info( + f"Sample expansion check: needs_expansion={needs_expansion}, " + f"num_samples={num_samples}, labels={labels}" + ) + + if not needs_expansion: + # No expansion needed - return steps with glob substitutions applied + # Group by chain position (2D structure) + result = [] + for position, (task_name, step) in enumerate(steps_with_glob): + result.append( + [ + { + "step": step, + "sample_id": None, + "sample_values": None, + "workspace": step.get_workspace(), + "chain_position": position, + "original_chain": chain, + } + ] + ) + return result + + # STEP 4: Expand for each sample + # Group by chain position (2D structure: position -> samples) + + # Recreate sample_index for iteration + directory_sizes = uniform_directories(len(samples), bundle_size=1, level_max_dirs=self.level_max_dirs) + + sample_index = create_hierarchy( + len(samples), bundle_size=1, directory_sizes=directory_sizes, root="", n_digits=len(str(self.level_max_dirs)) + ) + + # Build 2D structure: iterate over chain positions first, then samples + result = [] + for position, (task_name, step_with_glob) in enumerate(steps_with_glob): + position_tasks = [] # All samples for this position + + for sample_id, sample in enumerate(samples): + # Get relative path for this sample + relative_path = sample_index.get_path_to_sample(sample_id) + + # Get parameter substitutions for this sample + substitutions = parameter_substitutions_for_sample(sample, labels, sample_id, relative_path) + + # CRITICAL FIX: Append sample path to workspace for sample-specific directories + # Each sample should run in its own subdirectory + sample_workspace = f"{step_with_glob.get_workspace()}/{relative_path}".rstrip("/") + LOG.info(f"Expanded sample {sample_id} workspace: {sample_workspace}") + + # Clone step (already has glob substitutions) with sample substitutions and workspace + expanded_step = step_with_glob.clone_changing_workspace_and_cmd( + cmd_replacement_pairs=substitutions, new_workspace=sample_workspace + ) + + position_tasks.append( + { + "step": expanded_step, + "sample_id": sample_id, + "sample_values": dict(zip(labels, sample)), + "workspace": expanded_step.get_workspace(), + "chain_position": position, + "original_chain": chain, + } + ) + + result.append(position_tasks) + + # Log expansion results + total_tasks = sum(len(pos_tasks) for pos_tasks in result) + LOG.info( + f"Sample expansion complete: created {total_tasks} tasks across {len(result)} positions " + f"for chain with {len(chain.tasks)} original tasks" + ) + + return result diff --git a/merlin/execution/workflow_manager.py b/merlin/execution/workflow_manager.py new file mode 100644 index 00000000..25420847 --- /dev/null +++ b/merlin/execution/workflow_manager.py @@ -0,0 +1,87 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" """ + +import time +import uuid +from typing import Dict + +from merlin.dag.dag import DAG as ExecutionDAG +from merlin.execution.base import TaskExecutor +from merlin.execution.models import ExecutionContext, TaskStatus +from merlin.study.study import MerlinStudy + + +class WorkflowManager: + """High-level workflow manager that ties everything together.""" + + def __init__(self, study: MerlinStudy, executor: TaskExecutor): + self.study = study + # Create a new DAG using the execution framework's DAG class + # This allows the old study.dag to remain unchanged for backwards compatibility + old_dag = study.dag + self.dag = ExecutionDAG( + old_dag.maestro_adjacency_table, + old_dag.maestro_values, + old_dag.column_labels, + old_dag.study_name, + old_dag.parameter_info, + ) + self.executor = executor + + def run_workflow(self, source_node: str = "_source", wait: bool = False, timeout: int = 7200) -> Dict: + """ + Run the complete workflow. + + Args: + source_node: Starting node for workflow execution. Default: "_source" + wait: If True, block until workflow completes. Default: False (non-blocking) + timeout: Timeout in seconds when using wait=True. Default: 7200 (2 hours) + + Returns: + Dictionary containing execution results and workflow information. + For CeleryExecutor, includes: 'results', 'async_result', 'workflow_id' + """ + print("=== WORKFLOW EXECUTION ===") + + # 1. Generate execution plan + print("Generating execution plan...") + execution_plan = self.dag.group_tasks(source_node) + + print(f"Plan generated: {len(execution_plan.levels)} levels") + for level in execution_plan.levels: + print(f" Depth {level.depth}: {len(level.parallel_chains)} chains") + for chain in level.parallel_chains: + print(f" {chain}") + + # 2. Create execution context + context = ExecutionContext( + study=self.study, + parameter_info=self.dag.parameter_info, + execution_id=str(uuid.uuid4()), + metadata={"started_at": time.time()}, # TODO not sure what to do with metadata yet + ) + + # 3. Execute the plan + print("\nExecuting plan...") + result = self.executor.execute_plan(execution_plan, context, wait=wait, timeout=timeout) + + # 4. Report results (handle both old and new return formats) + if isinstance(result, dict) and "results" in result: + # New format from CeleryExecutor + results = result["results"] + else: + # Old format (just results dict) + results = result + + print("\n=== EXECUTION RESULTS ===") + completed = sum(1 for r in results.values() if r.status == TaskStatus.COMPLETED) + failed = sum(1 for r in results.values() if r.status == TaskStatus.FAILED) + + print(f"Completed: {completed}, Failed: {failed}, Total: {len(results)}") + + return result if isinstance(result, dict) and "results" in result else {"results": result} diff --git a/tests/unit/execution/__init__.py b/tests/unit/execution/__init__.py new file mode 100644 index 00000000..3232b50b --- /dev/null +++ b/tests/unit/execution/__init__.py @@ -0,0 +1,5 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## diff --git a/tests/unit/execution/test_celery_executor.py b/tests/unit/execution/test_celery_executor.py new file mode 100644 index 00000000..cb462605 --- /dev/null +++ b/tests/unit/execution/test_celery_executor.py @@ -0,0 +1,880 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the CeleryExecutor class. +""" + +import json +import os +import sys +import tempfile +from unittest.mock import Mock, patch + +import pytest + +from merlin.dag.models import ExecutionLevel, ExecutionPlan, TaskChain +from merlin.execution.models import TaskStatus + + +# Create mock modules for Celery dependencies +@pytest.fixture(autouse=True) +def mock_celery_imports(): + """Mock the Celery-related imports that happen inside CeleryExecutor methods.""" + # Create mock app + mock_app = Mock() + + # Create mock celery module + mock_celery_module = Mock() + mock_celery_module.app = mock_app + + # Create mock merlin.celery module + mock_merlin_celery = Mock() + mock_merlin_celery.app = mock_app + + # Create mock celery primitives module + mock_celery = Mock() + mock_celery.chain = Mock() + mock_celery.group = Mock() + + # Patch sys.modules for imports that happen inside methods + with patch.dict( + sys.modules, + { + "merlin.celery": mock_merlin_celery, + "celery": mock_celery, + }, + ): + yield { + "app": mock_app, + "chain": mock_celery.chain, + "group": mock_celery.group, + } + + +class TestCeleryExecutorInit: + """Tests for CeleryExecutor.__init__()""" + + def test_init_default_queue(self, mock_celery_imports): + """Test that default queue is 'default'""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + assert executor.default_queue == "default" + assert executor.sample_expander is not None + assert executor.active_tasks == {} + + def test_init_custom_queue(self, mock_celery_imports): + """Test that custom queue is set correctly""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor(default_queue="my_queue") + assert executor.default_queue == "my_queue" + + def test_init_creates_sample_expander(self, mock_celery_imports): + """Test that __init__ creates a SampleExpander instance""" + from merlin.execution.celery import CeleryExecutor + from merlin.execution.sample_expander import SampleExpander + + executor = CeleryExecutor() + assert isinstance(executor.sample_expander, SampleExpander) + + +class TestCreateBatches: + """Tests for CeleryExecutor._create_batches()""" + + def test_create_batches_empty_list(self, mock_celery_imports): + """Test _create_batches with empty list returns empty list""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + result = executor._create_batches([], batch_size=100) + assert result == [] + + def test_create_batches_50_tasks_one_batch(self, mock_celery_imports): + """Test _create_batches with 50 tasks creates 1 batch""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + task_infos = [{"task": f"task_{i}"} for i in range(50)] + result = executor._create_batches(task_infos, batch_size=100) + assert len(result) == 1 + assert len(result[0]) == 50 + + def test_create_batches_100_tasks_one_batch(self, mock_celery_imports): + """Test _create_batches with exactly 100 tasks creates 1 batch""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + task_infos = [{"task": f"task_{i}"} for i in range(100)] + result = executor._create_batches(task_infos, batch_size=100) + assert len(result) == 1 + assert len(result[0]) == 100 + + def test_create_batches_250_tasks_three_batches(self, mock_celery_imports): + """Test _create_batches with 250 tasks creates 3 batches""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + task_infos = [{"task": f"task_{i}"} for i in range(250)] + result = executor._create_batches(task_infos, batch_size=100) + assert len(result) == 3 + assert len(result[0]) == 100 + assert len(result[1]) == 100 + assert len(result[2]) == 50 + + def test_create_batches_custom_batch_size(self, mock_celery_imports): + """Test _create_batches with custom batch_size""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + task_infos = [{"task": f"task_{i}"} for i in range(25)] + result = executor._create_batches(task_infos, batch_size=10) + assert len(result) == 3 + assert len(result[0]) == 10 + assert len(result[1]) == 10 + assert len(result[2]) == 5 + + +class TestCreateTaskSignature: + """Tests for CeleryExecutor._create_task_signature()""" + + def test_create_task_signature_basic(self, mock_celery_imports): + """Test _create_task_signature creates a signature""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + + mock_sig = Mock() + + task_info = {"step": mock_step} + adapter_config = {"type": "celery"} + + # Patch merlin_step where it's imported + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_merlin_step.s.return_value = mock_sig + result = executor._create_task_signature(task_info, adapter_config) + + # Should create signature with step and adapter_config + mock_merlin_step.s.assert_called_once_with(mock_step, adapter_config=adapter_config) + # Should set queue + mock_sig.set.assert_called_once_with(queue="test_queue") + assert result == mock_sig + + +class TestLinkChainPositions: + """Tests for CeleryExecutor._link_chain_positions()""" + + def test_link_chain_positions_empty_list(self, mock_celery_imports): + """Test _link_chain_positions with empty list""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + result = executor._link_chain_positions([]) + assert result == [] + + def test_link_chain_positions_single_position(self, mock_celery_imports): + """Test _link_chain_positions with single position returns tasks unchanged""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_sigs = [Mock(), Mock(), Mock()] # 3 parallel tasks + all_chains = [mock_sigs] + + result = executor._link_chain_positions(all_chains) + + # Single position - no chaining needed, return as-is + assert result == mock_sigs + + def test_link_chain_positions_multi_position(self, mock_celery_imports): + """Test _link_chain_positions with multiple positions creates chains""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + # 2 positions with 2 parallel tasks each + pos0_sigs = [Mock(), Mock()] + pos1_sigs = [Mock(), Mock()] + all_chains = [pos0_sigs, pos1_sigs] + + mock_chain_result = Mock() + mock_celery_imports["chain"].return_value = mock_chain_result + + result = executor._link_chain_positions(all_chains) + + # Should create 2 chains (one per parallel sample) + assert len(result) == 2 + # chain() should be called twice + assert mock_celery_imports["chain"].call_count == 2 + + def test_link_chain_positions_three_positions(self, mock_celery_imports): + """Test _link_chain_positions with 3 positions creates proper chains""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + # 3 positions with 1 parallel task each + pos0_sigs = [Mock()] + pos1_sigs = [Mock()] + pos2_sigs = [Mock()] + all_chains = [pos0_sigs, pos1_sigs, pos2_sigs] + + mock_chain_result = Mock() + mock_celery_imports["chain"].return_value = mock_chain_result + + result = executor._link_chain_positions(all_chains) + + # Should create 1 chain + assert len(result) == 1 + # chain() should be called once with all 3 positions + assert mock_celery_imports["chain"].call_count == 1 + + +class TestBuildChainWithDependencies: + """Tests for CeleryExecutor._build_chain_with_dependencies()""" + + def test_build_chain_with_dependencies_single_position(self, mock_celery_imports): + """Test _build_chain_with_dependencies with single position""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_sig = Mock() + + expanded_positions = [[{"step": mock_step, "sample_id": 0}]] + + context = Mock() + context.study.get_adapter_config.return_value = {"type": "celery"} + + # Mock _create_task_signature to avoid internal import issues + with patch.object(executor, "_create_task_signature") as mock_create_sig: + mock_create_sig.return_value = mock_sig + result = executor._build_chain_with_dependencies(expanded_positions, context) + + # Should create signature + mock_create_sig.assert_called_once() + # Should return list of signatures (no linking for single position) + assert result == [mock_sig] + + def test_build_chain_with_dependencies_multi_position(self, mock_celery_imports): + """Test _build_chain_with_dependencies with multiple positions""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step1 = Mock() + mock_step1.get_task_queue.return_value = "test_queue" + mock_step2 = Mock() + mock_step2.get_task_queue.return_value = "test_queue" + mock_sig1 = Mock() + mock_sig2 = Mock() + + mock_chain = Mock() + mock_celery_imports["chain"].return_value = mock_chain + + expanded_positions = [ + [{"step": mock_step1, "sample_id": 0}], + [{"step": mock_step2, "sample_id": 0}], + ] + + context = Mock() + context.study.get_adapter_config.return_value = {"type": "celery"} + + # Mock _create_task_signature to avoid internal import issues + with patch.object(executor, "_create_task_signature") as mock_create_sig: + mock_create_sig.side_effect = [mock_sig1, mock_sig2] + result = executor._build_chain_with_dependencies(expanded_positions, context) + + # Should create 2 signatures + assert mock_create_sig.call_count == 2 + # Result should be from link_chain_positions (which calls chain) + assert result == [mock_chain] + + +class TestExecutePlanVirtualNodes: + """Tests for virtual node handling in execute_plan()""" + + def test_execute_plan_skips_virtual_nodes(self, mock_celery_imports): + """Test execute_plan skips virtual nodes like _source""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + # Mock context with _source returning None (virtual node) + context = Mock() + context.study.dag.step.return_value = None + context.study.workspace = "/workspace" + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["_source"], depth=0)])]) + + result = executor.execute_plan(plan, context, wait=False) + + # _source should be marked as SKIPPED + assert "_source" in result["results"] + assert result["results"]["_source"].status == TaskStatus.SKIPPED + assert "Virtual node" in result["results"]["_source"].error + + def test_execute_plan_mixed_virtual_and_real(self, mock_celery_imports): + """Test execute_plan handles mix of virtual and real tasks""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "real_task" + + # _source returns None, real_task returns mock_step + def step_side_effect(task_name): + if task_name == "_source": + return None + return mock_step + + context = Mock() + context.study.dag.step.side_effect = step_side_effect + context.study.workspace = "/workspace" + context.study.sample_labels = [] + context.study.samples = [] + + # Mock sample_expander + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": mock_step, "sample_id": None}]] + + # Mock async result + mock_async = Mock() + mock_async.id = "test-workflow-id" + mock_chain_workflow = Mock() + mock_chain_workflow.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_workflow + + plan = ExecutionPlan( + [ + ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["_source"], depth=0)]), + ExecutionLevel(depth=1, parallel_chains=[TaskChain(tasks=["real_task"], depth=1)]), + ] + ) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + context.study.get_adapter_config.return_value = {"type": "celery"} + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + result = executor.execute_plan(plan, context, wait=False) + + # _source should be SKIPPED + assert result["results"]["_source"].status == TaskStatus.SKIPPED + # real_task should be tracked + assert "real_task" in result["results"] + + +class TestExecutePlanWorkflow: + """Tests for execute_plan workflow building""" + + def test_execute_plan_creates_workflow_info(self, mock_celery_imports): + """Test execute_plan creates WORKFLOW_INFO.json""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "task1" + + # Mock sample_expander + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": mock_step, "sample_id": None}]] + + # Mock async result + mock_async = Mock() + mock_async.id = "workflow-123" + mock_chain_workflow = Mock() + mock_chain_workflow.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_workflow + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + context.study.sample_labels = [] + context.study.samples = [] + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + executor.execute_plan(plan, context, wait=False) + + # Check WORKFLOW_INFO.json was created + workflow_info_path = os.path.join(tmpdir, "WORKFLOW_INFO.json") + assert os.path.exists(workflow_info_path) + + with open(workflow_info_path) as f: + info = json.load(f) + + assert info["workflow_id"] == "workflow-123" + assert info["study_name"] == "test_study" + assert info["status"] == "SUBMITTED" + + def test_execute_plan_returns_workflow_id(self, mock_celery_imports): + """Test execute_plan returns workflow_id and async_result""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "task1" + + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": mock_step, "sample_id": None}]] + + mock_async = Mock() + mock_async.id = "workflow-456" + mock_chain_workflow = Mock() + mock_chain_workflow.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_workflow + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + context.study.sample_labels = [] + context.study.samples = [] + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + result = executor.execute_plan(plan, context, wait=False) + + assert result["workflow_id"] == "workflow-456" + assert result["async_result"] == mock_async + + def test_execute_plan_no_tasks_returns_none(self, mock_celery_imports): + """Test execute_plan with no real tasks returns None workflow_id""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + context = Mock() + context.study.dag.step.return_value = None # All virtual nodes + context.study.workspace = "/workspace" + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["_source"], depth=0)])]) + + result = executor.execute_plan(plan, context, wait=False) + + assert result["workflow_id"] is None + assert result["async_result"] is None + + +class TestExecutePlanWaitBehavior: + """Tests for execute_plan wait parameter behavior""" + + def test_execute_plan_wait_false_returns_immediately(self, mock_celery_imports): + """Test execute_plan with wait=False returns immediately""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "task1" + + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": mock_step, "sample_id": None}]] + + mock_async = Mock() + mock_async.id = "workflow-789" + mock_chain_workflow = Mock() + mock_chain_workflow.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_workflow + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + context.study.sample_labels = [] + context.study.samples = [] + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + executor.execute_plan(plan, context, wait=False) + + # async_result.get() should NOT be called when wait=False + mock_async.get.assert_not_called() + + def test_execute_plan_wait_true_blocks(self, mock_celery_imports): + """Test execute_plan with wait=True blocks until completion""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "task1" + + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": mock_step, "sample_id": None}]] + + mock_async = Mock() + mock_async.id = "workflow-wait" + mock_async.get.return_value = None # Simulate completion + mock_chain_workflow = Mock() + mock_chain_workflow.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_workflow + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + context.study.sample_labels = [] + context.study.samples = [] + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + executor.execute_plan(plan, context, wait=True, timeout=60) + + # async_result.get() should be called with timeout + mock_async.get.assert_called_once_with(timeout=60) + + def test_execute_plan_wait_true_updates_workflow_info_on_completion(self, mock_celery_imports): + """Test execute_plan with wait=True updates WORKFLOW_INFO.json on completion""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "task1" + + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": mock_step, "sample_id": None}]] + + mock_async = Mock() + mock_async.id = "workflow-complete" + mock_async.get.return_value = None + mock_chain_workflow = Mock() + mock_chain_workflow.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_workflow + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + context.study.sample_labels = [] + context.study.samples = [] + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + executor.execute_plan(plan, context, wait=True) + + # Check WORKFLOW_INFO.json was updated to COMPLETED + workflow_info_path = os.path.join(tmpdir, "WORKFLOW_INFO.json") + with open(workflow_info_path) as f: + info = json.load(f) + + assert info["status"] == "COMPLETED" + assert "completed_at" in info + + def test_execute_plan_wait_true_handles_failure(self, mock_celery_imports): + """Test execute_plan with wait=True handles workflow failure""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "task1" + + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": mock_step, "sample_id": None}]] + + mock_async = Mock() + mock_async.id = "workflow-fail" + mock_async.get.side_effect = Exception("Task execution failed") + mock_chain_workflow = Mock() + mock_chain_workflow.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_workflow + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + context.study.sample_labels = [] + context.study.samples = [] + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + result = executor.execute_plan(plan, context, wait=True) + + # Check tasks marked as failed + assert result["results"]["task1"].status == TaskStatus.FAILED + assert "Task execution failed" in result["results"]["task1"].error + + # Check WORKFLOW_INFO.json was updated to FAILED + workflow_info_path = os.path.join(tmpdir, "WORKFLOW_INFO.json") + with open(workflow_info_path) as f: + info = json.load(f) + + assert info["status"] == "FAILED" + assert "error" in info + + +class TestExecutePlanBatching: + """Tests for execute_plan batching and chain/group pattern""" + + def test_execute_plan_uses_chain_of_groups(self, mock_celery_imports): + """Test execute_plan builds chain(group(...), group(...)) pattern""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + mock_step.name.return_value = "task1" + + # Mock sample_expander to return 3 samples + mock_steps = [Mock() for _ in range(3)] + for i, s in enumerate(mock_steps): + s.get_task_queue.return_value = "test_queue" + s.name.return_value = f"task_sample{i}" + + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": s, "sample_id": i} for i, s in enumerate(mock_steps)]] + + mock_group_result = Mock() + mock_celery_imports["group"].return_value = mock_group_result + + mock_chain_result = Mock() + mock_chain_result.apply_async.return_value = Mock(id="workflow-chain-group") + mock_celery_imports["chain"].return_value = mock_chain_result + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + context.study.sample_labels = [] + context.study.samples = [] + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + with tempfile.TemporaryDirectory() as tmpdir: + context.study.workspace = tmpdir + context.study.expanded_spec.name = "test_study" + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + executor.execute_plan(plan, context, wait=False) + + # group() should be called to create batch groups + assert mock_celery_imports["group"].called + # chain() should be called to chain the groups + assert mock_celery_imports["chain"].called + + +class TestExecuteChainMethod: + """Tests for CeleryExecutor.execute_chain()""" + + def test_execute_chain_skips_virtual_nodes(self, mock_celery_imports): + """Test execute_chain skips virtual nodes""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + context = Mock() + context.study.dag.step.return_value = None # Virtual node + context.study.get_adapter_config.return_value = {"type": "celery"} + + chain_obj = TaskChain(tasks=["_source"], depth=0) + + results = executor.execute_chain(chain_obj, context) + + assert len(results) == 1 + assert results[0].status == TaskStatus.SKIPPED + assert "Virtual node" in results[0].error + + def test_execute_chain_executes_real_tasks(self, mock_celery_imports): + """Test execute_chain executes real tasks""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + + mock_async = Mock() + mock_async.get.return_value = 0 + mock_chain_result = Mock() + mock_chain_result.apply_async.return_value = mock_async + mock_celery_imports["chain"].return_value = mock_chain_result + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + + chain_obj = TaskChain(tasks=["task1"], depth=1) + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_sig = Mock() + mock_merlin_step.s.return_value = mock_sig + results = executor.execute_chain(chain_obj, context) + + assert len(results) == 1 + assert results[0].status == TaskStatus.COMPLETED + + +class TestExecuteTaskMethod: + """Tests for CeleryExecutor.execute_task()""" + + def test_execute_task_skips_virtual_node(self, mock_celery_imports): + """Test execute_task skips virtual nodes""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + context = Mock() + context.study.dag.step.return_value = None + + result = executor.execute_task("_source", context) + + assert result.status == TaskStatus.SKIPPED + assert "Virtual node" in result.error + + def test_execute_task_executes_real_task(self, mock_celery_imports): + """Test execute_task executes real task""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + + mock_sig = Mock() + mock_async = Mock() + mock_async.get.return_value = 0 + mock_sig.apply_async.return_value = mock_async + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + + with patch("merlin.common.tasks.merlin_step") as mock_merlin_step: + mock_merlin_step.s.return_value = mock_sig + result = executor.execute_task("task1", context) + + assert result.status == TaskStatus.COMPLETED + assert result.celery_id is not None + + def test_execute_task_handles_exception(self, mock_celery_imports): + """Test execute_task handles exceptions""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + mock_step = Mock() + mock_step.get_task_queue.return_value = "test_queue" + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "celery"} + + # Create mock merlin_step module and patch it in sys.modules + mock_merlin_step = Mock() + mock_sig = Mock() + mock_sig.apply_async.side_effect = Exception("Connection failed") + mock_merlin_step.s.return_value = mock_sig + + mock_tasks_module = Mock() + mock_tasks_module.merlin_step = mock_merlin_step + + with patch.dict(sys.modules, {"merlin.common.tasks": mock_tasks_module}): + result = executor.execute_task("task1", context) + + assert result.status == TaskStatus.FAILED + assert "Connection failed" in result.error + + +class TestMarkDependentTasksSkipped: + """Tests for CeleryExecutor._mark_dependent_tasks_skipped()""" + + def test_mark_dependent_tasks_skipped(self, mock_celery_imports): + """Test _mark_dependent_tasks_skipped marks downstream tasks""" + from merlin.execution.celery import CeleryExecutor + + executor = CeleryExecutor() + + # Create mock level that returns tasks + mock_level2 = Mock() + mock_level2.depth = 2 + mock_level2.get_all_tasks.return_value = ["task2a", "task2b"] + + mock_level3 = Mock() + mock_level3.depth = 3 + mock_level3.get_all_tasks.return_value = ["task3"] + + plan = Mock() + plan.levels = [ + Mock(depth=0), + Mock(depth=1), + mock_level2, + mock_level3, + ] + + results = {} + + # Mark tasks after depth 1 as skipped + executor._mark_dependent_tasks_skipped(plan, failed_depth=1, results=results) + + # Tasks at depth 2 and 3 should be skipped + assert "task2a" in results + assert "task2b" in results + assert "task3" in results + assert results["task2a"].status == TaskStatus.SKIPPED + assert results["task2b"].status == TaskStatus.SKIPPED + assert results["task3"].status == TaskStatus.SKIPPED + assert "Dependency failed" in results["task2a"].error diff --git a/tests/unit/execution/test_local_executor.py b/tests/unit/execution/test_local_executor.py new file mode 100644 index 00000000..17af7884 --- /dev/null +++ b/tests/unit/execution/test_local_executor.py @@ -0,0 +1,551 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the LocalExecutor class. +""" + +from concurrent.futures import Future +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from merlin.dag.models import ExecutionLevel, ExecutionPlan, TaskChain +from merlin.execution.local import LocalExecutor +from merlin.execution.models import TaskResult, TaskStatus + + +class TestLocalExecutorInit: + """Tests for LocalExecutor.__init__()""" + + def test_init_default_max_workers(self): + """Test that default max_workers is 4""" + executor = LocalExecutor() + assert executor.max_workers == 4 + assert executor.sample_expander is not None + + def test_init_custom_max_workers(self): + """Test that custom max_workers is set correctly""" + executor = LocalExecutor(max_workers=8) + assert executor.max_workers == 8 + assert executor.sample_expander is not None + + def test_init_creates_sample_expander(self): + """Test that __init__ creates a SampleExpander instance""" + executor = LocalExecutor() + from merlin.execution.sample_expander import SampleExpander + + assert isinstance(executor.sample_expander, SampleExpander) + + +class TestHasRealTasks: + """Tests for LocalExecutor._has_real_tasks()""" + + def test_has_real_tasks_returns_true_for_real_tasks(self): + """Test _has_real_tasks returns True for chains with real tasks""" + executor = LocalExecutor() + + # Mock step (real task) + mock_step = Mock() + mock_step.name.return_value = "real_task" + + # Mock context + context = Mock() + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["real_task"], depth=1) + + result = executor._has_real_tasks(chain, context) + assert result is True + + def test_has_real_tasks_returns_false_for_virtual_nodes(self): + """Test _has_real_tasks returns False for virtual nodes""" + executor = LocalExecutor() + + # Mock context with None step (virtual node) + context = Mock() + context.study.dag.step.return_value = None + + chain = TaskChain(tasks=["_source"], depth=0) + + result = executor._has_real_tasks(chain, context) + assert result is False + + def test_has_real_tasks_handles_exceptions(self): + """Test _has_real_tasks handles exceptions gracefully""" + executor = LocalExecutor() + + # Mock context that raises exception + context = Mock() + context.study.dag.step.side_effect = AttributeError("Step not found") + + chain = TaskChain(tasks=["invalid_task"], depth=1) + + result = executor._has_real_tasks(chain, context) + assert result is False + + def test_has_real_tasks_with_mixed_tasks(self): + """Test _has_real_tasks returns True if any task is real""" + executor = LocalExecutor() + + # Mock context that returns None for first task, real step for second + mock_step = Mock() + context = Mock() + context.study.dag.step.side_effect = [None, mock_step] + + chain = TaskChain(tasks=["virtual_task", "real_task"], depth=1) + + result = executor._has_real_tasks(chain, context) + assert result is True + + +class TestExecuteStepWrapper: + """Tests for LocalExecutor._execute_step_wrapper()""" + + def test_execute_step_wrapper_is_static_method(self): + """Test that _execute_step_wrapper is a static method (pickleable)""" + # Static methods can be called without instantiation + import inspect + + # Check that it's decorated as @staticmethod + assert isinstance(inspect.getattr_static(LocalExecutor, "_execute_step_wrapper"), staticmethod) + + @patch("os.path.exists") + def test_execute_step_wrapper_skips_completed_tasks(self, mock_exists): + """Test _execute_step_wrapper skips tasks with MERLIN_FINISHED""" + # Mock that MERLIN_FINISHED exists + mock_exists.return_value = True + + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/task1" + mock_step.name.return_value = "task1" + + adapter_config = {"type": "local"} + + result = LocalExecutor._execute_step_wrapper(mock_step, adapter_config) + + # Should return 0 without calling execute + assert result == 0 + mock_step.execute.assert_not_called() + + @patch("os.path.exists") + @patch("builtins.open", new_callable=MagicMock) + def test_execute_step_wrapper_executes_and_creates_finished_file(self, mock_open, mock_exists): + """Test _execute_step_wrapper executes step and creates MERLIN_FINISHED""" + # Mock that MERLIN_FINISHED doesn't exist + mock_exists.return_value = False + + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/task1" + mock_step.name.return_value = "task1" + mock_step.execute.return_value = 0 # Success + mock_step.max_retries = 10 # Required for retry logic + + adapter_config = {"type": "local"} + + result = LocalExecutor._execute_step_wrapper(mock_step, adapter_config) + + # Should execute and return 0 + assert result == 0 + mock_step.execute.assert_called_once_with(adapter_config) + + # Should create MERLIN_FINISHED + mock_open.assert_called_once_with("/workspace/task1/MERLIN_FINISHED", "w") + + @patch("os.path.exists") + def test_execute_step_wrapper_returns_non_zero_on_failure(self, mock_exists): + """Test _execute_step_wrapper returns non-zero on task failure""" + mock_exists.return_value = False + + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/task1" + mock_step.name.return_value = "task1" + mock_step.execute.return_value = 1 # Failure + mock_step.max_retries = 10 # Required for retry logic + + adapter_config = {"type": "local"} + + result = LocalExecutor._execute_step_wrapper(mock_step, adapter_config) + + assert result == 1 + + @patch("os.path.exists") + def test_execute_step_wrapper_raises_exception_on_error(self, mock_exists): + """Test _execute_step_wrapper re-raises exceptions""" + mock_exists.return_value = False + + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/task1" + mock_step.name.return_value = "task1" + mock_step.execute.side_effect = RuntimeError("Execution failed") + mock_step.max_retries = 10 # Required for retry logic + + adapter_config = {"type": "local"} + + with pytest.raises(RuntimeError, match="Execution failed"): + LocalExecutor._execute_step_wrapper(mock_step, adapter_config) + + +class TestExecuteChainWithDependencies: + """Tests for LocalExecutor._execute_chain_with_dependencies()""" + + @patch("merlin.execution.local.as_completed") + def test_execute_chain_with_dependencies_sequential_positions(self, mock_as_completed): + """Test _execute_chain_with_dependencies executes positions sequentially""" + executor = LocalExecutor() + + # Mock two positions with one task each + mock_step1 = Mock() + mock_step1.name.return_value = "step1" + + mock_step2 = Mock() + mock_step2.name.return_value = "step2" + + expanded_positions = [ + [{"step": mock_step1, "sample_id": None}], + [{"step": mock_step2, "sample_id": None}], + ] + + # Mock context + context = Mock() + context.study.get_adapter_config.return_value = {"type": "local"} + + # Mock executor + mock_executor = Mock() + + # Mock futures + future1 = Mock(spec=Future) + future1.result.return_value = 0 + + future2 = Mock(spec=Future) + future2.result.return_value = 0 + + mock_executor.submit.side_effect = [future1, future2] + + # Mock as_completed to return futures in order + mock_as_completed.side_effect = [[future1], [future2]] + + result = executor._execute_chain_with_dependencies(expanded_positions, context, mock_executor) + + # Should have 2 results + assert len(result) == 2 + assert "step1" in result + assert "step2" in result + assert result["step1"].status == TaskStatus.COMPLETED + assert result["step2"].status == TaskStatus.COMPLETED + + # Submit should be called twice (once per position) + assert mock_executor.submit.call_count == 2 + + @patch("merlin.execution.local.as_completed") + def test_execute_chain_with_dependencies_parallel_samples(self, mock_as_completed): + """Test _execute_chain_with_dependencies executes samples in parallel""" + executor = LocalExecutor() + + # Mock one position with 3 samples + mock_steps = [Mock() for _ in range(3)] + for i, step in enumerate(mock_steps): + step.name.return_value = f"step_sample{i}" + + expanded_positions = [[{"step": step, "sample_id": i} for i, step in enumerate(mock_steps)]] + + context = Mock() + context.study.get_adapter_config.return_value = {"type": "local"} + + mock_executor = Mock() + + # Mock futures for all samples + futures = [Mock(spec=Future) for _ in range(3)] + for future in futures: + future.result.return_value = 0 + + mock_executor.submit.side_effect = futures + mock_as_completed.return_value = futures + + result = executor._execute_chain_with_dependencies(expanded_positions, context, mock_executor) + + # Should have 3 results (all samples executed in parallel) + assert len(result) == 3 + + # Submit should be called 3 times (once per sample) + assert mock_executor.submit.call_count == 3 + + @patch("merlin.execution.local.as_completed") + def test_execute_chain_with_dependencies_handles_failures(self, mock_as_completed): + """Test _execute_chain_with_dependencies handles task failures""" + executor = LocalExecutor() + + mock_step = Mock() + mock_step.name.return_value = "failing_step" + + expanded_positions = [[{"step": mock_step, "sample_id": None}]] + + context = Mock() + context.study.get_adapter_config.return_value = {"type": "local"} + + mock_executor = Mock() + + future = Mock(spec=Future) + future.result.return_value = 1 # Non-zero exit code + + mock_executor.submit.return_value = future + mock_as_completed.return_value = [future] + + result = executor._execute_chain_with_dependencies(expanded_positions, context, mock_executor) + + assert len(result) == 1 + assert result["failing_step"].status == TaskStatus.FAILED + assert "non-zero exit code" in result["failing_step"].error + + @patch("merlin.execution.local.as_completed") + def test_execute_chain_with_dependencies_handles_exceptions(self, mock_as_completed): + """Test _execute_chain_with_dependencies handles exceptions""" + executor = LocalExecutor() + + mock_step = Mock() + mock_step.name.return_value = "error_step" + + expanded_positions = [[{"step": mock_step, "sample_id": None}]] + + context = Mock() + context.study.get_adapter_config.return_value = {"type": "local"} + + mock_executor = Mock() + + future = Mock(spec=Future) + future.result.side_effect = RuntimeError("Execution error") + + mock_executor.submit.return_value = future + mock_as_completed.return_value = [future] + + result = executor._execute_chain_with_dependencies(expanded_positions, context, mock_executor) + + assert len(result) == 1 + assert result["error_step"].status == TaskStatus.FAILED + assert "Execution error" in result["error_step"].error + + +class TestExecuteLevelParallel: + """Tests for LocalExecutor._execute_level_parallel()""" + + @patch.object(LocalExecutor, "_has_real_tasks") + def test_execute_level_parallel_skips_virtual_nodes(self, mock_has_real_tasks): + """Test _execute_level_parallel skips virtual nodes""" + executor = LocalExecutor() + + mock_has_real_tasks.return_value = False # All virtual + + context = Mock() + mock_executor = Mock() + + level = ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["_source"], depth=0)]) + + result = executor._execute_level_parallel(level, context, mock_executor) + + # Should have result for _source marked as SKIPPED + assert len(result) == 1 + assert "_source" in result + assert result["_source"].status == TaskStatus.SKIPPED + assert result["_source"].error == "Virtual node" + + @patch.object(LocalExecutor, "_has_real_tasks") + @patch.object(LocalExecutor, "_execute_chain_with_dependencies") + def test_execute_level_parallel_executes_real_chains(self, mock_execute_chain, mock_has_real_tasks): + """Test _execute_level_parallel executes real task chains""" + executor = LocalExecutor() + + mock_has_real_tasks.return_value = True # Real tasks + + # Mock sample expander + executor.sample_expander = Mock() + executor.sample_expander.expand_chain.return_value = [[{"step": Mock(), "sample_id": None}]] + + # Mock chain execution + mock_step = Mock() + mock_step.name.return_value = "task1" + mock_execute_chain.return_value = {"task1": TaskResult(task_name="task1", status=TaskStatus.COMPLETED)} + + context = Mock() + mock_executor = Mock() + + level = ExecutionLevel(depth=1, parallel_chains=[TaskChain(tasks=["task1"], depth=1)]) + + result = executor._execute_level_parallel(level, context, mock_executor) + + # Should have result for task1 + assert len(result) == 1 + assert "task1" in result + assert result["task1"].status == TaskStatus.COMPLETED + + # Should call expand_chain and execute_chain_with_dependencies + executor.sample_expander.expand_chain.assert_called_once() + mock_execute_chain.assert_called_once() + + +class TestExecutePlan: + """Tests for LocalExecutor.execute_plan()""" + + @patch("merlin.execution.local.ProcessPoolExecutor") + @patch.object(LocalExecutor, "_execute_level_parallel") + def test_execute_plan_with_no_samples(self, mock_execute_level, mock_pool_executor): + """Test execute_plan with workflow without samples""" + executor = LocalExecutor(max_workers=4) + + # Mock execution results + mock_execute_level.side_effect = [ + {"_source": TaskResult(task_name="_source", status=TaskStatus.SKIPPED)}, + { + "step1": TaskResult(task_name="step1", status=TaskStatus.COMPLETED), + "step2": TaskResult(task_name="step2", status=TaskStatus.COMPLETED), + }, + {"step3": TaskResult(task_name="step3", status=TaskStatus.COMPLETED)}, + ] + + # Create simple plan + plan = ExecutionPlan( + [ + ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["_source"], depth=0)]), + ExecutionLevel( + depth=1, + parallel_chains=[ + TaskChain(tasks=["step1"], depth=1), + TaskChain(tasks=["step2"], depth=1), + ], + ), + ExecutionLevel(depth=2, parallel_chains=[TaskChain(tasks=["step3"], depth=2)]), + ] + ) + + context = Mock() + + # Mock ProcessPoolExecutor context manager + mock_pool = Mock() + mock_pool_executor.return_value.__enter__.return_value = mock_pool + + result = executor.execute_plan(plan, context) + + # Should have 4 results (all tasks) + assert len(result) == 4 + assert "_source" in result + assert "step1" in result + assert "step2" in result + assert "step3" in result + + # Should execute all 3 levels + assert mock_execute_level.call_count == 3 + + @patch("merlin.execution.local.ProcessPoolExecutor") + @patch.object(LocalExecutor, "_execute_level_parallel") + def test_execute_plan_stops_on_failure(self, mock_execute_level, mock_pool_executor): + """Test execute_plan stops execution when a level fails""" + executor = LocalExecutor() + + # Mock execution with failure in second level + mock_execute_level.side_effect = [ + {"step1": TaskResult(task_name="step1", status=TaskStatus.COMPLETED)}, + {"step2": TaskResult(task_name="step2", status=TaskStatus.FAILED, error="Task failed")}, + {"step3": TaskResult(task_name="step3", status=TaskStatus.COMPLETED)}, + ] + + plan = ExecutionPlan( + [ + ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["step1"], depth=0)]), + ExecutionLevel(depth=1, parallel_chains=[TaskChain(tasks=["step2"], depth=1)]), + ExecutionLevel(depth=2, parallel_chains=[TaskChain(tasks=["step3"], depth=2)]), + ] + ) + + context = Mock() + + mock_pool = Mock() + mock_pool_executor.return_value.__enter__.return_value = mock_pool + + result = executor.execute_plan(plan, context) + + # Should only have results from first 2 levels + assert len(result) == 2 + assert "step1" in result + assert "step2" in result + assert "step3" not in result # Third level should not execute + + # Should only execute 2 levels (stops after failure) + assert mock_execute_level.call_count == 2 + + @patch("merlin.execution.local.ProcessPoolExecutor") + @patch.object(LocalExecutor, "_execute_level_parallel") + def test_execute_plan_respects_max_workers(self, mock_execute_level, mock_pool_executor): + """Test execute_plan uses correct max_workers""" + executor = LocalExecutor(max_workers=8) + + mock_execute_level.return_value = {"task1": TaskResult(task_name="task1", status=TaskStatus.COMPLETED)} + + plan = ExecutionPlan([ExecutionLevel(depth=0, parallel_chains=[TaskChain(tasks=["task1"], depth=0)])]) + + context = Mock() + + mock_pool = Mock() + mock_pool_executor.return_value.__enter__.return_value = mock_pool + + executor.execute_plan(plan, context) + + # Should create ProcessPoolExecutor with max_workers=8 + mock_pool_executor.assert_called_once_with(max_workers=8) + + +class TestLegacyMethods: + """Tests for legacy execute_chain and execute_task methods""" + + @patch.object(LocalExecutor, "execute_task") + def test_execute_chain_calls_execute_task_for_each_task(self, mock_execute_task): + """Test execute_chain calls execute_task for each task in chain""" + executor = LocalExecutor() + + mock_execute_task.side_effect = [ + TaskResult(task_name="task1", status=TaskStatus.COMPLETED), + TaskResult(task_name="task2", status=TaskStatus.COMPLETED), + ] + + context = Mock() + chain = TaskChain(tasks=["task1", "task2"], depth=1) + + results = executor.execute_chain(chain, context) + + assert len(results) == 2 + assert mock_execute_task.call_count == 2 + + def test_execute_task_executes_step(self): + """Test execute_task executes a single step""" + executor = LocalExecutor() + + mock_step = Mock() + mock_step.execute.return_value = 0 + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "local"} + + result = executor.execute_task("task1", context) + + assert result.task_name == "task1" + assert result.status == TaskStatus.COMPLETED + mock_step.execute.assert_called_once() + + def test_execute_task_handles_exceptions(self): + """Test execute_task handles exceptions gracefully""" + executor = LocalExecutor() + + mock_step = Mock() + mock_step.execute.side_effect = RuntimeError("Task error") + + context = Mock() + context.study.dag.step.return_value = mock_step + context.study.get_adapter_config.return_value = {"type": "local"} + + result = executor.execute_task("task1", context) + + assert result.task_name == "task1" + assert result.status == TaskStatus.FAILED + assert "Task error" in result.error diff --git a/tests/unit/execution/test_sample_expander.py b/tests/unit/execution/test_sample_expander.py new file mode 100644 index 00000000..11a3a2d6 --- /dev/null +++ b/tests/unit/execution/test_sample_expander.py @@ -0,0 +1,560 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the SampleExpander class. +""" + +from unittest.mock import Mock, patch + +import numpy as np + +from merlin.dag.models import TaskChain +from merlin.execution.sample_expander import SampleExpander + + +class TestSampleExpanderInit: + """Tests for SampleExpander.__init__()""" + + def test_init_default_level_max_dirs(self): + """Test that default level_max_dirs is 25""" + expander = SampleExpander() + assert expander.level_max_dirs == 25 + + def test_init_custom_level_max_dirs(self): + """Test that custom level_max_dirs is set correctly""" + expander = SampleExpander(level_max_dirs=50) + assert expander.level_max_dirs == 50 + + +class TestNeedsExpansion: + """Tests for SampleExpander.needs_expansion()""" + + def test_needs_expansion_no_labels(self): + """Test needs_expansion returns False when no labels""" + expander = SampleExpander() + + # Mock context with no labels + context = Mock() + context.study.sample_labels = [] + context.study.dag = Mock() + + chain = TaskChain(tasks=["step1"], depth=1) + + result = expander.needs_expansion(chain, context) + assert result is False + + def test_needs_expansion_none_labels(self): + """Test needs_expansion returns False when labels is None""" + expander = SampleExpander() + + context = Mock() + context.study.sample_labels = None + context.study.dag = Mock() + + chain = TaskChain(tasks=["step1"], depth=1) + + result = expander.needs_expansion(chain, context) + assert result is False + + def test_needs_expansion_with_expandable_step(self): + """Test needs_expansion returns True when step needs expansion""" + expander = SampleExpander() + + # Mock step that needs expansion + mock_step = Mock() + mock_step.check_if_expansion_needed.return_value = True + + # Mock context + context = Mock() + context.study.sample_labels = ["X0", "X1"] + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["step1"], depth=1) + + result = expander.needs_expansion(chain, context) + assert result is True + mock_step.check_if_expansion_needed.assert_called_once_with(["X0", "X1"]) + + def test_needs_expansion_no_expandable_steps(self): + """Test needs_expansion returns False when no steps need expansion""" + expander = SampleExpander() + + # Mock step that doesn't need expansion + mock_step = Mock() + mock_step.check_if_expansion_needed.return_value = False + + context = Mock() + context.study.sample_labels = ["X0", "X1"] + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["step1", "step2"], depth=1) + + result = expander.needs_expansion(chain, context) + assert result is False + + def test_needs_expansion_virtual_node(self): + """Test needs_expansion handles virtual nodes (None steps)""" + expander = SampleExpander() + + context = Mock() + context.study.sample_labels = ["X0", "X1"] + context.study.dag.step.return_value = None # Virtual node + + chain = TaskChain(tasks=["_source"], depth=0) + + result = expander.needs_expansion(chain, context) + assert result is False + + +class TestExpandChainNoSamples: + """Tests for expand_chain() with no samples""" + + def test_expand_chain_no_samples(self): + """Test expand_chain returns single task when no samples""" + expander = SampleExpander() + + # Mock step + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/step1" + mock_step.clone_changing_workspace_and_cmd.return_value = mock_step + mock_step.check_if_expansion_needed.return_value = False + + # Mock context with no samples + context = Mock() + context.study.samples = None + context.study.sample_labels = [] + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["step1"], depth=1) + + result = expander.expand_chain(chain, context) + + # Should return 2D structure with single position, single task + assert len(result) == 1 # One position + assert len(result[0]) == 1 # One task in position + assert result[0][0]["step"] == mock_step + assert result[0][0]["sample_id"] is None + assert result[0][0]["sample_values"] is None + assert result[0][0]["workspace"] == "/workspace/step1" + assert result[0][0]["chain_position"] == 0 + assert result[0][0]["original_chain"] == chain + + def test_expand_chain_empty_samples(self): + """Test expand_chain with empty samples array""" + expander = SampleExpander() + + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/step1" + mock_step.clone_changing_workspace_and_cmd.return_value = mock_step + mock_step.check_if_expansion_needed.return_value = False + + context = Mock() + context.study.samples = np.array([]) # Empty array + context.study.sample_labels = [] + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["step1"], depth=1) + + result = expander.expand_chain(chain, context) + + assert len(result) == 1 + assert len(result[0]) == 1 + assert result[0][0]["sample_id"] is None + + +class TestExpandChainWithSamples: + """Tests for expand_chain() with samples""" + + @patch("merlin.execution.sample_expander.uniform_directories") + @patch("merlin.execution.sample_expander.create_hierarchy") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_cmd") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_sample") + def test_expand_chain_single_step_multiple_samples( + self, mock_param_sub_sample, mock_param_sub_cmd, mock_create_hierarchy, mock_uniform_dirs + ): + """Test expand_chain creates one task per sample""" + expander = SampleExpander() + + # Mock samples (3 samples) + samples = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + labels = ["X0", "X1"] + + # Mock directory structure + mock_uniform_dirs.return_value = [1] # Single directory level + + # Mock sample index + mock_sample_index = Mock() + mock_sample_index.make_directory_string.return_value = "00:01:02" + mock_sample_index.get_path_to_sample.side_effect = lambda i: f"{i:02d}" + mock_create_hierarchy.return_value = mock_sample_index + + # Mock parameter substitutions + mock_param_sub_cmd.return_value = [("$(MERLIN_GLOB_PATH)", "*/")] + mock_param_sub_sample.side_effect = lambda s, labels, i, p: [("$(X0)", str(s[0])), ("$(X1)", str(s[1]))] + + # Mock step + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/hello" + mock_step.check_if_expansion_needed.return_value = True + + # Create cloned steps for each stage + mock_step_with_glob = Mock() + mock_step_with_glob.get_workspace.return_value = "/workspace/hello" + mock_step_with_glob.check_if_expansion_needed.return_value = True + + mock_expanded_steps = [] + for i in range(3): + expanded = Mock() + expanded.get_workspace.return_value = f"/workspace/hello/{i:02d}" + mock_expanded_steps.append(expanded) + + # Setup clone behavior + mock_step.clone_changing_workspace_and_cmd.return_value = mock_step_with_glob + mock_step_with_glob.clone_changing_workspace_and_cmd.side_effect = mock_expanded_steps + + # Mock context + context = Mock() + context.study.samples = samples + context.study.sample_labels = labels + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["hello"], depth=1) + + result = expander.expand_chain(chain, context) + + # Should return 2D structure: [[sample0, sample1, sample2]] + assert len(result) == 1 # One position (single step in chain) + assert len(result[0]) == 3 # Three samples + + # Check each sample + for i in range(3): + task_info = result[0][i] + assert task_info["sample_id"] == i + assert task_info["sample_values"] == {"X0": samples[i][0], "X1": samples[i][1]} + assert task_info["workspace"] == f"/workspace/hello/{i:02d}" + assert task_info["chain_position"] == 0 + assert task_info["original_chain"] == chain + + @patch("merlin.execution.sample_expander.uniform_directories") + @patch("merlin.execution.sample_expander.create_hierarchy") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_cmd") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_sample") + def test_expand_chain_multi_position_chain( + self, mock_param_sub_sample, mock_param_sub_cmd, mock_create_hierarchy, mock_uniform_dirs + ): + """Test expand_chain with multi-step chain (collect -> translate)""" + expander = SampleExpander() + + # Mock samples (2 samples) + samples = np.array([[1.0, 2.0], [3.0, 4.0]]) + labels = ["X0", "X1"] + + # Mock directory structure + mock_uniform_dirs.return_value = [1] + + mock_sample_index = Mock() + mock_sample_index.make_directory_string.return_value = "00:01" + mock_sample_index.get_path_to_sample.side_effect = lambda i: f"{i:02d}" + mock_create_hierarchy.return_value = mock_sample_index + + mock_param_sub_cmd.return_value = [("$(MERLIN_GLOB_PATH)", "*/")] + mock_param_sub_sample.side_effect = lambda s, labels, i, p: [("$(X0)", str(s[0])), ("$(X1)", str(s[1]))] + + # Mock two different steps + mock_collect = Mock() + mock_collect.get_workspace.return_value = "/workspace/collect" + mock_collect.check_if_expansion_needed.return_value = False + + mock_translate = Mock() + mock_translate.get_workspace.return_value = "/workspace/translate" + mock_translate.check_if_expansion_needed.return_value = False + + # Setup cloned steps + mock_collect_glob = Mock() + mock_collect_glob.get_workspace.return_value = "/workspace/collect" + mock_collect_glob.check_if_expansion_needed.return_value = False + + mock_translate_glob = Mock() + mock_translate_glob.get_workspace.return_value = "/workspace/translate" + mock_translate_glob.check_if_expansion_needed.return_value = False + + mock_collect.clone_changing_workspace_and_cmd.return_value = mock_collect_glob + mock_translate.clone_changing_workspace_and_cmd.return_value = mock_translate_glob + + # Mock context with two steps + context = Mock() + context.study.samples = samples + context.study.sample_labels = labels + + def get_step(task_name): + if task_name == "collect": + return mock_collect + elif task_name == "translate": + return mock_translate + return None + + context.study.dag.step.side_effect = get_step + + chain = TaskChain(tasks=["collect", "translate"], depth=2) + + result = expander.expand_chain(chain, context) + + # Should return 2D structure: [[collect], [translate]] + assert len(result) == 2 # Two positions + assert len(result[0]) == 1 # One task at position 0 (no sample expansion) + assert len(result[1]) == 1 # One task at position 1 (no sample expansion) + + # Check positions + assert result[0][0]["chain_position"] == 0 + assert result[1][0]["chain_position"] == 1 + + +class TestGlobPathCalculation: + """Tests for MERLIN_GLOB_PATH calculation""" + + @patch("merlin.execution.sample_expander.uniform_directories") + @patch("merlin.execution.sample_expander.create_hierarchy") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_cmd") + def test_glob_path_calculation_10_samples(self, mock_param_sub_cmd, mock_create_hierarchy, mock_uniform_dirs): + """Test MERLIN_GLOB_PATH is calculated correctly for 10 samples""" + expander = SampleExpander() + + # Mock 10 samples + samples = np.array([[i, i + 1] for i in range(10)]) + labels = ["X0", "X1"] + + # With 10 samples and level_max_dirs=25, should create [1] directory level + # glob_path should be "*/*" (1 level + 1 for execution dir) + mock_uniform_dirs.return_value = [1] + + mock_sample_index = Mock() + mock_sample_index.make_directory_string.return_value = ":".join([f"{i:02d}" for i in range(10)]) + mock_sample_index.get_path_to_sample.side_effect = lambda i: f"{i:02d}" + mock_create_hierarchy.return_value = mock_sample_index + + # Capture the glob_path passed to parameter_substitutions_for_cmd + captured_glob_path = None + + def capture_glob_path(glob_path, sample_paths): + nonlocal captured_glob_path + captured_glob_path = glob_path + return [("$(MERLIN_GLOB_PATH)", glob_path)] + + mock_param_sub_cmd.side_effect = capture_glob_path + + # Mock step + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/hello" + mock_step.check_if_expansion_needed.return_value = False + mock_step.clone_changing_workspace_and_cmd.return_value = mock_step + + context = Mock() + context.study.samples = samples + context.study.sample_labels = labels + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["hello"], depth=1) + + expander.expand_chain(chain, context) + + # glob_path should be "*/*" (1 directory level + 1 execution dir level) + assert captured_glob_path == "*/*" + + @patch("merlin.execution.sample_expander.uniform_directories") + @patch("merlin.execution.sample_expander.create_hierarchy") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_cmd") + def test_glob_path_calculation_100_samples(self, mock_param_sub_cmd, mock_create_hierarchy, mock_uniform_dirs): + """Test MERLIN_GLOB_PATH is calculated correctly for 100 samples""" + expander = SampleExpander() + + # Mock 100 samples + samples = np.array([[i, i + 1] for i in range(100)]) + labels = ["X0", "X1"] + + # With 100 samples and level_max_dirs=25, should create [4, 25] directory levels + # glob_path should be "*/*/*" (2 levels + 1 for execution dir) + mock_uniform_dirs.return_value = [4, 25] + + mock_sample_index = Mock() + mock_sample_index.make_directory_string.return_value = ":".join( + [f"{i:02d}/{j:02d}" for i in range(4) for j in range(25)] + ) + mock_sample_index.get_path_to_sample.side_effect = lambda i: f"{i // 25:02d}/{i % 25:02d}" + mock_create_hierarchy.return_value = mock_sample_index + + captured_glob_path = None + + def capture_glob_path(glob_path, sample_paths): + nonlocal captured_glob_path + captured_glob_path = glob_path + return [("$(MERLIN_GLOB_PATH)", glob_path)] + + mock_param_sub_cmd.side_effect = capture_glob_path + + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/hello" + mock_step.check_if_expansion_needed.return_value = False + mock_step.clone_changing_workspace_and_cmd.return_value = mock_step + + context = Mock() + context.study.samples = samples + context.study.sample_labels = labels + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["hello"], depth=1) + + expander.expand_chain(chain, context) + + # glob_path should be "*/*/*" (2 directory levels + 1 execution dir level) + assert captured_glob_path == "*/*/*" + + +class TestWorkspaceIsolation: + """Tests for sample workspace isolation""" + + @patch("merlin.execution.sample_expander.uniform_directories") + @patch("merlin.execution.sample_expander.create_hierarchy") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_cmd") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_sample") + def test_workspace_isolation_unique_paths( + self, mock_param_sub_sample, mock_param_sub_cmd, mock_create_hierarchy, mock_uniform_dirs + ): + """Test each sample gets a unique workspace path""" + expander = SampleExpander() + + # Mock 5 samples + samples = np.array([[i, i + 1] for i in range(5)]) + labels = ["X0", "X1"] + + mock_uniform_dirs.return_value = [1] + + mock_sample_index = Mock() + mock_sample_index.make_directory_string.return_value = "00:01:02:03:04" + mock_sample_index.get_path_to_sample.side_effect = lambda i: f"{i:02d}" + mock_create_hierarchy.return_value = mock_sample_index + + mock_param_sub_cmd.return_value = [("$(MERLIN_GLOB_PATH)", "*/")] + mock_param_sub_sample.side_effect = lambda s, labels, i, p: [] + + # Mock step + base_workspace = "/workspace/hello" + mock_step = Mock() + mock_step.get_workspace.return_value = base_workspace + mock_step.check_if_expansion_needed.return_value = True + + mock_step_with_glob = Mock() + mock_step_with_glob.get_workspace.return_value = base_workspace + mock_step_with_glob.check_if_expansion_needed.return_value = True + + # Create expanded steps with unique workspaces + mock_expanded_steps = [] + for i in range(5): + expanded = Mock() + expanded.get_workspace.return_value = f"{base_workspace}/{i:02d}" + mock_expanded_steps.append(expanded) + + mock_step.clone_changing_workspace_and_cmd.return_value = mock_step_with_glob + mock_step_with_glob.clone_changing_workspace_and_cmd.side_effect = mock_expanded_steps + + context = Mock() + context.study.samples = samples + context.study.sample_labels = labels + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["hello"], depth=1) + + result = expander.expand_chain(chain, context) + + # Extract all workspaces + workspaces = [task_info["workspace"] for task_info in result[0]] + + # Check all workspaces are unique + assert len(workspaces) == len(set(workspaces)) + + # Check each workspace follows expected pattern + expected_workspaces = [f"{base_workspace}/{i:02d}" for i in range(5)] + assert workspaces == expected_workspaces + + +class TestParameterSubstitutions: + """Tests for parameter substitutions""" + + @patch("merlin.execution.sample_expander.uniform_directories") + @patch("merlin.execution.sample_expander.create_hierarchy") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_cmd") + @patch("merlin.execution.sample_expander.parameter_substitutions_for_sample") + def test_parameter_substitutions_applied( + self, mock_param_sub_sample, mock_param_sub_cmd, mock_create_hierarchy, mock_uniform_dirs + ): + """Test parameter substitutions are applied correctly""" + expander = SampleExpander() + + samples = np.array([[1.5, 2.5], [3.5, 4.5]]) + labels = ["X0", "X1"] + + mock_uniform_dirs.return_value = [1] + + mock_sample_index = Mock() + mock_sample_index.make_directory_string.return_value = "00:01" + mock_sample_index.get_path_to_sample.side_effect = lambda i: f"{i:02d}" + mock_create_hierarchy.return_value = mock_sample_index + + mock_param_sub_cmd.return_value = [("$(MERLIN_GLOB_PATH)", "*/")] + + # Track calls to parameter_substitutions_for_sample + sample_sub_calls = [] + + def track_sample_subs(sample, labels, sample_id, relative_path): + sample_sub_calls.append( + {"sample": sample.tolist(), "labels": labels, "sample_id": sample_id, "relative_path": relative_path} + ) + return [("$(X0)", str(sample[0])), ("$(X1)", str(sample[1]))] + + mock_param_sub_sample.side_effect = track_sample_subs + + mock_step = Mock() + mock_step.get_workspace.return_value = "/workspace/hello" + mock_step.check_if_expansion_needed.return_value = True + + mock_step_with_glob = Mock() + mock_step_with_glob.get_workspace.return_value = "/workspace/hello" + mock_step_with_glob.check_if_expansion_needed.return_value = True + + mock_expanded_steps = [Mock(), Mock()] + for i, m in enumerate(mock_expanded_steps): + m.get_workspace.return_value = f"/workspace/hello/{i:02d}" + + mock_step.clone_changing_workspace_and_cmd.return_value = mock_step_with_glob + mock_step_with_glob.clone_changing_workspace_and_cmd.side_effect = mock_expanded_steps + + context = Mock() + context.study.samples = samples + context.study.sample_labels = labels + context.study.dag.step.return_value = mock_step + + chain = TaskChain(tasks=["hello"], depth=1) + + result = expander.expand_chain(chain, context) + + # Check parameter_substitutions_for_sample was called for each sample + assert len(sample_sub_calls) == 2 + + # Check first sample + assert sample_sub_calls[0]["sample"] == [1.5, 2.5] + assert sample_sub_calls[0]["labels"] == labels + assert sample_sub_calls[0]["sample_id"] == 0 + assert sample_sub_calls[0]["relative_path"] == "00" + + # Check second sample + assert sample_sub_calls[1]["sample"] == [3.5, 4.5] + assert sample_sub_calls[1]["labels"] == labels + assert sample_sub_calls[1]["sample_id"] == 1 + assert sample_sub_calls[1]["relative_path"] == "01" + + # Check sample_values in result + assert result[0][0]["sample_values"] == {"X0": 1.5, "X1": 2.5} + assert result[0][1]["sample_values"] == {"X0": 3.5, "X1": 4.5} diff --git a/tests/unit/execution/test_workflow_manager.py b/tests/unit/execution/test_workflow_manager.py new file mode 100644 index 00000000..738e6d6d --- /dev/null +++ b/tests/unit/execution/test_workflow_manager.py @@ -0,0 +1,436 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the WorkflowManager class. +""" + +from collections import OrderedDict +from unittest.mock import Mock + +from merlin.execution.models import TaskResult, TaskStatus +from merlin.execution.workflow_manager import WorkflowManager + + +def create_mock_study_with_dag(): + """ + Create a properly configured mock study with dag attributes. + + WorkflowManager creates its own ExecutionDAG from study.dag's attributes, + so we need to provide proper values for those attributes. + """ + mock_study = Mock() + mock_dag = Mock() + + # Provide proper values for DAG attributes needed by ExecutionDAG constructor + mock_dag.maestro_adjacency_table = OrderedDict() + mock_dag.maestro_values = OrderedDict() + mock_dag.column_labels = [] + mock_dag.study_name = "test_study" + mock_dag.parameter_info = {} + + mock_study.dag = mock_dag + return mock_study, mock_dag + + +class TestWorkflowManagerInit: + """Tests for WorkflowManager.__init__()""" + + def test_init_sets_study_and_executor(self): + """Test that __init__ sets study and executor, and creates its own DAG""" + mock_study, mock_dag = create_mock_study_with_dag() + mock_executor = Mock() + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + + assert manager.study == mock_study + # DAG is created from study.dag's attributes, not the same object + assert manager.dag is not mock_dag + assert manager.dag.study_name == "test_study" + assert manager.executor == mock_executor + + +class TestRunWorkflowBasic: + """Tests for basic run_workflow functionality""" + + def test_run_workflow_generates_execution_plan(self): + """Test that run_workflow generates execution plan from DAG""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + # Mock the manager's dag.group_tasks after creation + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow() + + # Should call group_tasks with default source_node + manager.dag.group_tasks.assert_called_once_with("_source") + + def test_run_workflow_custom_source_node(self): + """Test that run_workflow uses custom source_node""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow(source_node="custom_start") + + manager.dag.group_tasks.assert_called_once_with("custom_start") + + def test_run_workflow_creates_execution_context(self): + """Test that run_workflow creates ExecutionContext with correct fields""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.dag.parameter_info = {"param1": "value1"} + manager.run_workflow() + + # Check that execute_plan was called with an ExecutionContext + call_args = mock_executor.execute_plan.call_args + context = call_args[0][1] # Second positional argument + + assert context.study == mock_study + assert context.parameter_info == {"param1": "value1"} + assert context.execution_id is not None + assert "started_at" in context.metadata + + +class TestRunWorkflowWaitBehavior: + """Tests for run_workflow wait parameter behavior""" + + def test_run_workflow_wait_false_passes_to_executor(self): + """Test that run_workflow(wait=False) passes wait=False to executor""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow(wait=False) + + # Check wait=False was passed + call_kwargs = mock_executor.execute_plan.call_args[1] + assert call_kwargs["wait"] is False + + def test_run_workflow_wait_true_passes_to_executor(self): + """Test that run_workflow(wait=True) passes wait=True to executor""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow(wait=True) + + # Check wait=True was passed + call_kwargs = mock_executor.execute_plan.call_args[1] + assert call_kwargs["wait"] is True + + def test_run_workflow_default_wait_is_false(self): + """Test that run_workflow defaults to wait=False (non-blocking)""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow() # No wait argument + + # Check wait=False (default) + call_kwargs = mock_executor.execute_plan.call_args[1] + assert call_kwargs["wait"] is False + + def test_run_workflow_timeout_passes_to_executor(self): + """Test that run_workflow passes timeout to executor""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow(wait=True, timeout=60) + + # Check timeout was passed + call_kwargs = mock_executor.execute_plan.call_args[1] + assert call_kwargs["timeout"] == 60 + + def test_run_workflow_default_timeout_is_7200(self): + """Test that run_workflow defaults to timeout=7200 (2 hours)""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow() + + # Check timeout=7200 (default) + call_kwargs = mock_executor.execute_plan.call_args[1] + assert call_kwargs["timeout"] == 7200 + + +class TestRunWorkflowReturnFormats: + """Tests for run_workflow return format handling""" + + def test_run_workflow_handles_new_format_with_results_key(self): + """Test run_workflow handles new format {'results': {...}, 'workflow_id': ...}""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + # New format from CeleryExecutor + mock_executor.execute_plan.return_value = { + "results": {"task1": TaskResult(task_name="task1", status=TaskStatus.COMPLETED)}, + "workflow_id": "workflow-123", + "async_result": Mock(), + } + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + result = manager.run_workflow() + + # Should return the dict as-is + assert "results" in result + assert "workflow_id" in result + assert result["workflow_id"] == "workflow-123" + + def test_run_workflow_handles_old_format_dict_only(self): + """Test run_workflow handles old format (just results dict)""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + # Old format from LocalExecutor (just results dict) + mock_executor.execute_plan.return_value = {"task1": TaskResult(task_name="task1", status=TaskStatus.COMPLETED)} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + result = manager.run_workflow() + + # Should wrap in {"results": ...} + assert "results" in result + assert "task1" in result["results"] + + +class TestRunWorkflowResultCounting: + """Tests for run_workflow result counting""" + + def test_run_workflow_counts_completed_tasks(self): + """Test that run_workflow correctly counts completed tasks""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = { + "results": { + "task1": TaskResult(task_name="task1", status=TaskStatus.COMPLETED), + "task2": TaskResult(task_name="task2", status=TaskStatus.COMPLETED), + "task3": TaskResult(task_name="task3", status=TaskStatus.FAILED), + } + } + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + result = manager.run_workflow() + + # Should have 3 results + assert len(result["results"]) == 3 + + def test_run_workflow_counts_failed_tasks(self): + """Test that run_workflow correctly counts failed tasks""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = { + "results": { + "task1": TaskResult(task_name="task1", status=TaskStatus.COMPLETED), + "task2": TaskResult(task_name="task2", status=TaskStatus.FAILED, error="Error"), + "task3": TaskResult(task_name="task3", status=TaskStatus.FAILED, error="Error"), + } + } + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + result = manager.run_workflow() + + # Count failed tasks + failed_count = sum(1 for r in result["results"].values() if r.status == TaskStatus.FAILED) + assert failed_count == 2 + + +class TestRunWorkflowWithLevels: + """Tests for run_workflow with execution levels""" + + def test_run_workflow_handles_multi_level_plan(self): + """Test run_workflow handles execution plan with multiple levels""" + mock_study, _ = create_mock_study_with_dag() + + # Create mock plan with multiple levels + mock_level0 = Mock() + mock_level0.depth = 0 + mock_level0.parallel_chains = [Mock()] + + mock_level1 = Mock() + mock_level1.depth = 1 + mock_level1.parallel_chains = [Mock(), Mock()] + + mock_level2 = Mock() + mock_level2.depth = 2 + mock_level2.parallel_chains = [Mock()] + + mock_plan = Mock() + mock_plan.levels = [mock_level0, mock_level1, mock_level2] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow() + + # Should have called execute_plan with the plan + call_args = mock_executor.execute_plan.call_args[0] + assert call_args[0] == mock_plan + + def test_run_workflow_handles_empty_plan(self): + """Test run_workflow handles empty execution plan""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + result = manager.run_workflow() + + assert "results" in result + assert len(result["results"]) == 0 + + +class TestRunWorkflowExecutionId: + """Tests for execution ID generation""" + + def test_run_workflow_generates_unique_execution_id(self): + """Test that run_workflow generates unique execution IDs""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + + # Run workflow twice and collect execution IDs + manager.run_workflow() + context1 = mock_executor.execute_plan.call_args_list[0][0][1] + + manager.run_workflow() + context2 = mock_executor.execute_plan.call_args_list[1][0][1] + + # Execution IDs should be different + assert context1.execution_id != context2.execution_id + + def test_run_workflow_execution_id_is_uuid_format(self): + """Test that execution_id is in UUID format""" + import uuid + + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow() + + context = mock_executor.execute_plan.call_args[0][1] + + # Should be valid UUID + try: + uuid.UUID(context.execution_id) + is_valid_uuid = True + except ValueError: + is_valid_uuid = False + + assert is_valid_uuid + + +class TestRunWorkflowMetadata: + """Tests for workflow metadata""" + + def test_run_workflow_includes_started_at_in_metadata(self): + """Test that run_workflow includes started_at timestamp in metadata""" + mock_study, _ = create_mock_study_with_dag() + + mock_plan = Mock() + mock_plan.levels = [] + + mock_executor = Mock() + mock_executor.execute_plan.return_value = {"results": {}} + + manager = WorkflowManager(study=mock_study, executor=mock_executor) + manager.dag.group_tasks = Mock(return_value=mock_plan) + manager.run_workflow() + + context = mock_executor.execute_plan.call_args[0][1] + + assert "started_at" in context.metadata + assert isinstance(context.metadata["started_at"], float) diff --git a/tests/unit/study/test_step_status.py b/tests/unit/study/test_step_status.py new file mode 100644 index 00000000..88acfb1b --- /dev/null +++ b/tests/unit/study/test_step_status.py @@ -0,0 +1,379 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for MerlinStepRecord._update_status_file() auto-detection of task_server. +""" + +import tempfile +from unittest.mock import Mock, patch + +from maestrowf.abstracts.enums import State + +from merlin.study.step import MerlinStepRecord + + +class TestUpdateStatusFileAutoDetection: + """Tests for MerlinStepRecord._update_status_file() task_server auto-detection""" + + def _create_mock_record(self, workspace_dir): + """Helper to create a mock MerlinStepRecord object with required attributes.""" + mock_record = Mock(spec=MerlinStepRecord) + mock_record.name = "test_step" + mock_record.status = State.RUNNING + mock_record.elapsed_time = 1.5 + mock_record.run_time = 1.0 + mock_record.restarts = 0 + mock_record.condensed_workspace = "workspace_0" + + # Mock workspace + mock_workspace = Mock() + mock_workspace.value = workspace_dir + mock_record.workspace = mock_workspace + + # Mock merlin_step params + mock_record.merlin_step = Mock() + mock_record.merlin_step.params = { + "cmd": {"param1": "value1"}, + "restart_cmd": None, + } + + return mock_record + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_auto_detect_local_mode(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test task_server auto-detection when is_local_mode() returns True""" + mock_is_local_mode.return_value = True + mock_exists.return_value = False # Status file doesn't exist + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + # Call the method with task_server=None (auto-detect) + MerlinStepRecord._update_status_file(mock_record, result=None, task_server=None) + + # Should call write_status + mock_write_status.assert_called_once() + + # Check the status_info dict - should NOT have celery-specific info + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + + # Should not have task_queue (celery-specific) + assert "task_queue" not in status_info.get("test_step", {}) + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_auto_detect_celery_mode(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test task_server auto-detection when is_local_mode() returns False""" + mock_is_local_mode.return_value = False + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + # Mock celery app and helpers + mock_app = Mock() + mock_app.conf.task_always_eager = True # Avoid worker lookup + + with patch.dict("sys.modules", {"merlin.celery": Mock(app=mock_app)}): + MerlinStepRecord._update_status_file(mock_record, result=None, task_server=None) + + mock_write_status.assert_called_once() + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_explicit_local_override(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test explicit task_server='local' overrides auto-detection""" + # Even if is_local_mode returns False, explicit override should work + mock_is_local_mode.return_value = False + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + # Explicitly set task_server="local" + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="local") + + mock_write_status.assert_called_once() + + # Check the status_info - should NOT have celery-specific info + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + assert "task_queue" not in status_info.get("test_step", {}) + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_explicit_celery_override(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test explicit task_server='celery' overrides auto-detection""" + # Even if is_local_mode returns True, explicit override should work + mock_is_local_mode.return_value = True + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + # Mock celery app + mock_app = Mock() + mock_app.conf.task_always_eager = True + + with patch.dict("sys.modules", {"merlin.celery": Mock(app=mock_app)}): + # Explicitly set task_server="celery" + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="celery") + + mock_write_status.assert_called_once() + + +class TestUpdateStatusFileCeleryBehavior: + """Tests for Celery-specific behavior in _update_status_file()""" + + def _create_mock_record(self, workspace_dir): + """Helper to create a mock MerlinStepRecord object with required attributes.""" + mock_record = Mock(spec=MerlinStepRecord) + mock_record.name = "test_step" + mock_record.status = State.RUNNING + mock_record.elapsed_time = 1.5 + mock_record.run_time = 1.0 + mock_record.restarts = 0 + mock_record.condensed_workspace = "workspace_0" + + mock_workspace = Mock() + mock_workspace.value = workspace_dir + mock_record.workspace = mock_workspace + + mock_record.merlin_step = Mock() + mock_record.merlin_step.params = { + "cmd": {"param1": "value1"}, + "restart_cmd": None, + } + + return mock_record + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_celery_mode_skips_worker_info_when_eager( + self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status + ): + """Test that Celery mode skips worker info when task_always_eager=True""" + mock_is_local_mode.return_value = False + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + # Mock celery app with task_always_eager=True + mock_app = Mock() + mock_app.conf.task_always_eager = True + + with patch.dict("sys.modules", {"merlin.celery": Mock(app=mock_app)}): + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="celery") + + mock_write_status.assert_called_once() + + # Check status_info doesn't have worker info (task_always_eager=True) + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + assert "workers" not in status_info.get("test_step", {}) + + @patch("merlin.study.step.get_current_worker") + @patch("merlin.study.step.get_current_queue") + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_celery_mode_adds_worker_info_when_not_eager( + self, + mock_is_local_mode, + mock_exists, + mock_read_status, + mock_write_status, + mock_get_queue, + mock_get_worker, + ): + """Test that Celery mode adds worker info when task_always_eager=False""" + mock_is_local_mode.return_value = False + mock_exists.return_value = False + mock_get_queue.return_value = "test_queue" + mock_get_worker.return_value = "worker-1" + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + # Mock celery app with task_always_eager=False (real workers) + mock_app = Mock() + mock_app.conf.task_always_eager = False + + with patch.dict("sys.modules", {"merlin.celery": Mock(app=mock_app)}): + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="celery") + + mock_write_status.assert_called_once() + + # Check status_info has worker info + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + assert status_info["test_step"]["task_queue"] == "test_queue" + assert "workers" in status_info["test_step"] + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_local_mode_never_imports_celery(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test that local mode never tries to import celery""" + mock_is_local_mode.return_value = True + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + # Don't mock celery - if it's imported, test will fail or behave unexpectedly + # The point is that with task_server="local", celery import is skipped + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="local") + + mock_write_status.assert_called_once() + + +class TestUpdateStatusFileStatusInfo: + """Tests for status info dictionary construction""" + + def _create_mock_record(self, workspace_dir, state=State.RUNNING): + """Helper to create a mock MerlinStepRecord object with required attributes.""" + mock_record = Mock(spec=MerlinStepRecord) + mock_record.name = "test_step" + mock_record.status = state + mock_record.elapsed_time = 2.5 + mock_record.run_time = 2.0 + mock_record.restarts = 1 + mock_record.condensed_workspace = "workspace_0" + + mock_workspace = Mock() + mock_workspace.value = workspace_dir + mock_record.workspace = mock_workspace + + mock_record.merlin_step = Mock() + mock_record.merlin_step.params = { + "cmd": {"param1": "value1", "param2": "value2"}, + "restart_cmd": {"restart_param": "restart_value"}, + } + + return mock_record + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_status_info_contains_correct_fields(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test that status_info dict contains all required fields""" + mock_is_local_mode.return_value = True + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + MerlinStepRecord._update_status_file(mock_record, result="SUCCESS", task_server="local") + + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + + # Check structure + assert "test_step" in status_info + assert "parameters" in status_info["test_step"] + assert "workspace_0" in status_info["test_step"] + + # Check workspace-specific info + ws_info = status_info["test_step"]["workspace_0"] + assert ws_info["status"] == "RUNNING" + assert ws_info["return_code"] == "SUCCESS" + assert ws_info["elapsed_time"] == 2.5 + assert ws_info["run_time"] == 2.0 + assert ws_info["restarts"] == 1 + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_status_info_includes_parameters(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test that status_info includes cmd and restart parameters""" + mock_is_local_mode.return_value = True + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="local") + + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + + params = status_info["test_step"]["parameters"] + assert params["cmd"] == {"param1": "value1", "param2": "value2"} + assert params["restart"] == {"restart_param": "restart_value"} + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_status_translates_state_enum(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test that State enum is translated to string""" + mock_is_local_mode.return_value = True + mock_exists.return_value = False + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir, state=State.FINISHED) + + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="local") + + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + + assert status_info["test_step"]["workspace_0"]["status"] == "FINISHED" + + @patch("merlin.study.step.write_status") + @patch("merlin.study.step.read_status") + @patch("merlin.study.step.os.path.exists") + @patch("merlin.config.configfile.is_local_mode") + def test_updates_existing_status_file(self, mock_is_local_mode, mock_exists, mock_read_status, mock_write_status): + """Test that existing status file is read and updated""" + mock_is_local_mode.return_value = True + mock_exists.return_value = True # Status file exists + + existing_status = { + "test_step": { + "parameters": {"cmd": None, "restart": None}, + "old_workspace": {"status": "FINISHED"}, + } + } + mock_read_status.return_value = existing_status + + with tempfile.TemporaryDirectory() as tmpdir: + mock_record = self._create_mock_record(tmpdir) + + MerlinStepRecord._update_status_file(mock_record, result=None, task_server="local") + + # Should read existing status + mock_read_status.assert_called_once() + + # Should write updated status + mock_write_status.assert_called_once() + + call_args = mock_write_status.call_args[0] + status_info = call_args[0] + + # Should preserve old workspace info + assert "old_workspace" in status_info["test_step"] + # Should add new workspace info + assert "workspace_0" in status_info["test_step"]