From 70ecfb7575efbf8cabfcce4e7c06c0f4429642ab Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Wed, 10 Sep 2025 09:39:47 -0700 Subject: [PATCH 1/8] add BatchManager class to replace batch function calls --- merlin/study/batch.py | 856 +++++++++++++++----------------- merlin/workers/celery_worker.py | 34 +- 2 files changed, 417 insertions(+), 473 deletions(-) diff --git a/merlin/study/batch.py b/merlin/study/batch.py index 3c8b72d5..c20601ea 100644 --- a/merlin/study/batch.py +++ b/merlin/study/batch.py @@ -5,10 +5,7 @@ ############################################################################## """ -This module parses the batch section of the yaml specification. - -Currently only the batch worker launch for slurm, lsf or flux -are implemented. +This module provides the `BatchManager` class for handling batch job scheduling. """ import logging import os @@ -21,468 +18,393 @@ LOG = logging.getLogger(__name__) -def batch_check_parallel(batch: Dict) -> bool: - """ - Check for a parallel batch section in the provided MerlinSpec object. - - This function examines the 'batch' section of the given specification to determine - whether it is configured for parallel execution. It checks the 'type' attribute - within the batch section, defaulting to 'local' if not specified. If the type - is anything other than 'local', the function will return True, indicating that - parallel processing is enabled. - - Args: - batch: The batch section from either the YAML `batch` block or the worker-specific - batch block. - - Returns: - Returns True if the batch type is set to a value other than 'local', - indicating that parallel processing is enabled; otherwise, returns False. - - Raises: - AttributeError: If the 'batch' section is not present in the specification, - an error is logged and an AttributeError is raised. - """ - parallel = False - - btype = get_yaml_var(batch, "type", "local") - if btype != "local": - parallel = True - - return parallel - - -def check_for_scheduler(scheduler: str, scheduler_legend: Dict[str, str]) -> bool: - """ - Check which scheduler (Flux, Slurm, LSF, or PBS) is the main scheduler for the cluster. - - This function verifies if the specified scheduler is the main scheduler by executing - a command associated with it from the provided scheduler legend. It returns a boolean - indicating whether the specified scheduler is active. - - Args: - scheduler: A string representing the scheduler to check for. Options include 'flux', - 'slurm', 'lsf', or 'pbs'. - scheduler_legend: A dictionary containing information related to each scheduler, - including the command to check its status and the expected output. See - [`construct_scheduler_legend`][study.batch.construct_scheduler_legend] - for more information on all the settings this dict contains. - - Returns: - Returns True if the specified scheduler is the main scheduler for the - cluster, otherwise returns False. - - Raises: - FileNotFoundError: If the command associated with the scheduler cannot be found. - PermissionError: If there are insufficient permissions to execute the command. - """ - # Check for invalid scheduler - if scheduler not in ("flux", "slurm", "lsf", "pbs"): - LOG.warning(f"Invalid scheduler {scheduler} given to check_for_scheduler.") - return False - - # Try to run the check command provided via the scheduler legend - try: - process = subprocess.Popen( # pylint: disable=R1732 - scheduler_legend[scheduler]["check cmd"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - # If the desired output exists, return True. Otherwise, return False - result = process.stdout.readlines() - if result and len(result) > 0 and scheduler_legend[scheduler]["expected check output"] in result[0]: - return True - return False - except (FileNotFoundError, PermissionError): - return False - - -def get_batch_type(scheduler_legend: Dict[str, str], default: str = None) -> str: - """ - Determine which batch scheduler to use. - - This function checks a predefined list of batch schedulers in a specific order - to determine which one is available for use. If none of the schedulers are found, - it checks the system type environment variable to suggest a default scheduler. - If no suitable scheduler is determined, it returns the specified default value. - - Args: - scheduler_legend: A dictionary storing information related to each - scheduler, including commands and expected outputs for checking their - availability. See [`construct_scheduler_legend`][study.batch.construct_scheduler_legend] - for more information on all the settings this dict contains. - default: The default batch scheduler to use if a scheduler cannot be determined. - - Returns: - The name of the available batch scheduler. Possible options include - 'slurm', 'flux', 'lsf', or 'pbs'. If no scheduler is found, returns - the specified default value. - """ - # These schedulers are listed in order of which should be checked for first - # 1. Flux should be checked first due to slurm emulation scripts - # 2. PBS should be checked before slurm for testing - # 3. LSF should be checked before slurm for testing - # 4. Slurm should be checked last - schedulers_to_check = ["flux", "pbs", "lsf", "slurm"] - for scheduler in schedulers_to_check: - LOG.debug(f"check for {scheduler} = {check_for_scheduler(scheduler, scheduler_legend)}") - if check_for_scheduler(scheduler, scheduler_legend): - return scheduler - - SYS_TYPE = os.environ.get("SYS_TYPE", "") # pylint: disable=C0103 - if "toss_3" in SYS_TYPE: - return "slurm" - - if "blueos" in SYS_TYPE: - return "lsf" - - return default - - -def get_node_count(parsed_batch: Dict, default: int = 1) -> int: - """ - Determine a default node count based on the environment. - - This function checks the environment and the Flux version to determine the - appropriate number of nodes to use for batch processing. It first verifies - the Flux version, then attempts to retrieve the node count from the Flux - allocation or environment variables specific to Slurm or LSF. If no valid - node count can be determined, it returns a specified default value. - - Args: - parsed_batch: A dictionary containing parsed batch configurations. - See [`parse_batch_block`][study.batch.parse_batch_block] for more - information on all the settings in this dictionary. - default: The number of nodes to return if a node count from the - environment cannot be determined. - - Returns: - The number of nodes to use for the batch job. This value is determined - based on the environment and scheduler specifics. - - Raises: - ValueError: If the Flux version is too old (below 0.17.0). - """ - - # Flux version check - flux_ver = get_flux_version(parsed_batch["flux exe"], no_errors=True) - major, minor, _ = map(int, flux_ver.split(".")) - if major < 1 and minor < 17: - raise ValueError("Flux version is too old. Supported versions are 0.17.0+.") - - # If flux is the scheduler, we can get the size of the allocation with this - try: - get_size_proc = subprocess.run("flux getattr size", shell=True, capture_output=True, text=True) - return int(get_size_proc.stdout) - except Exception: - pass - - if "SLURM_JOB_NUM_NODES" in os.environ: - return int(os.environ["SLURM_JOB_NUM_NODES"]) - - # LSB systems reserve one node for launching - if "LSB_HOSTS" in os.environ: - nodes = set(os.environ["LSB_HOSTS"].split()) - n_batch_nodes = len(nodes) - 1 - return n_batch_nodes - if "LSB_MCPU_HOSTS" in os.environ: - nodes = os.environ["LSB_MCPU_HOSTS"].split() - n_batch_nodes = len(nodes) // 2 - 1 - return n_batch_nodes - - return default - - -def parse_batch_block(batch: Dict) -> Dict: - """ - Parse the batch block of a YAML configuration file. - - This function extracts relevant information from the provided batch block - dictionary, including paths, execution options, and defaults. It retrieves - the Flux executable path and allocation details, and populates a dictionary - with the parsed values. - - Args: - batch: A dictionary representing the batch block from the YAML - configuration file. - - Returns: - A dictionary containing parsed information from the batch block, - including:\n - - `btype`: The type of batch job (default is 'local'). - - `nodes`: The number of nodes to use (default is None). - - `shell`: The shell to use (default is 'bash'). - - `bank`: The bank to charge for the job (default is an empty string). - - `queue`: The queue to submit the job to (default is an empty string). - - `walltime`: The maximum wall time for the job (default is an empty string). - - `launch pre`: Any commands to run before launching (default is an empty string). - - `launch args`: Arguments for the launch command (default is an empty string). - - `launch command`: Custom command to launch workers. This will override the - default launch command (default is an empty string). - - `flux path`: Optional path to flux bin. - - `flux exe`: The full path to the Flux executable. - - `flux exec`: Optional flux exec command to launch workers on all nodes if - `flux_exec_workers` is True (default is None). - - `flux alloc`: The Flux allocation retrieved from the executable. - - `flux opts`: Optional flux start options (default is an empty string). - - `flux exec workers`: Optional flux argument to launch workers - on all nodes (default is True). +# TODO should the scheduler logic be offloaded to script adapters? It's a bit of a different +# use case than the script adapters are intended for... +class BatchManager: """ - flux_path: str = get_yaml_var(batch, "flux_path", "") - if "/" in flux_path: - flux_path += "/" - - flux_exe: str = os.path.join(flux_path, "flux") - flux_alloc: str - try: - flux_alloc = get_flux_alloc(flux_exe) - except FileNotFoundError as e: # pylint: disable=C0103 - LOG.debug(e) - flux_alloc = "" - - parsed_batch = { - "btype": get_yaml_var(batch, "type", "local"), - "nodes": get_yaml_var(batch, "nodes", None), - "shell": get_yaml_var(batch, "shell", "bash"), - "bank": get_yaml_var(batch, "bank", ""), - "queue": get_yaml_var(batch, "queue", ""), - "walltime": get_yaml_var(batch, "walltime", ""), - "launch pre": get_yaml_var(batch, "launch_pre", ""), - "launch args": get_yaml_var(batch, "launch_args", ""), - "launch command": get_yaml_var(batch, "worker_launch", ""), - "flux path": flux_path, - "flux exe": flux_exe, - "flux exec": get_yaml_var(batch, "flux_exec", None), - "flux alloc": flux_alloc, - "flux opts": get_yaml_var(batch, "flux_start_opts", ""), - "flux exec workers": get_yaml_var(batch, "flux_exec_workers", True), - } - return parsed_batch - - -def get_flux_launch(parsed_batch: Dict) -> str: + Manages batch job scheduling and worker launching across different schedulers. + + This class provides methods for detecting available schedulers, parsing batch + configurations, and constructing appropriate launch commands for different + batch systems including Slurm, LSF, Flux, and PBS. + + Attributes: + batch_config (Dict): The parsed batch configuration dictionary. + scheduler_legend (Dict): Dictionary containing scheduler-specific information. + detected_scheduler (str): The automatically detected scheduler type. """ - Build the Flux launch command based on the batch section of the YAML configuration. - - This function constructs the command to launch a Flux job using the parameters - specified in the parsed batch configuration. It determines the appropriate - execution command for Flux workers and integrates it with the launch command - provided in the batch configuration. - - Args: - parsed_batch: A dictionary containing batch configuration parameters. - See [`parse_batch_block`][study.batch.parse_batch_block] for more information - on all the settings in this dictionary. - - Returns: - The constructed Flux launch command, ready to be executed. - """ - default_flux_exec = "flux exec" if parsed_batch["launch command"] else f"{parsed_batch['flux exe']} exec" - flux_exec: str = "" - if parsed_batch["flux exec workers"]: - flux_exec = parsed_batch["flux exec"] if parsed_batch["flux exec"] else default_flux_exec - - if parsed_batch["launch command"] and "flux" not in parsed_batch["launch command"]: - launch: str = ( - f"{parsed_batch['launch command']} {parsed_batch['flux exe']}" - f" start {parsed_batch['flux opts']} {flux_exec} `which {parsed_batch['shell']}` -c" - ) - else: - launch: str = f"{parsed_batch['launch command']} {flux_exec} `which {parsed_batch['shell']}` -c" - - return launch - - -def batch_worker_launch( - batch: Dict, - com: str, - nodes: Union[str, int] = None, -) -> str: - """ - Create the worker launch command based on the batch configuration in the - workflow specification. - - This function constructs a command to launch a worker process using the - specified batch configuration. It handles different batch types and - integrates any necessary pre-launch commands, launch arguments, and - node specifications. - - Args: - batch: The batch section from either the YAML `batch` block or the worker-specific - batch block. - com: The command to launch with the batch configuration. - nodes: The number of nodes to use in the batch launch. If not specified, - it will default to the value in the batch configuration. - - Returns: - The constructed worker launch command, ready to be executed. - - Raises: - AttributeError: If the batch section is missing in the specification. - TypeError: If the `nodes` parameter is of an invalid type. - """ - parsed_batch = parse_batch_block(batch) - - # A jsrun submission cannot be run under a parent jsrun so - # all non flux lsf submissions need to be local. - if parsed_batch["btype"] == "local" or "lsf" in parsed_batch["btype"]: - return com - - if nodes is None: - # Use the value in the batch section - nodes = parsed_batch["nodes"] - - # Get the number of nodes from the environment if unset - if nodes is None or nodes == "all": - nodes = get_node_count(parsed_batch, default=1) - elif not isinstance(nodes, int): - raise TypeError("Nodes was passed into batch_worker_launch with an invalid type (likely a string other than 'all').") - - if not parsed_batch["launch command"]: - parsed_batch["launch command"] = construct_worker_launch_command(parsed_batch, nodes) - - if parsed_batch["launch args"]: - parsed_batch["launch command"] += f" {parsed_batch['launch args']}" - - # Allow for any pre launch manipulation, e.g. module load - # hwloc/1.11.10-cuda - if parsed_batch["launch pre"]: - parsed_batch["launch command"] = f"{parsed_batch['launch pre']} {parsed_batch['launch command']}" - - LOG.debug(f"launch command: {parsed_batch['launch command']}") - - worker_cmd: str = "" - if parsed_batch["btype"] == "flux": - launch = get_flux_launch(parsed_batch) - worker_cmd = f'{launch} "{com}"' - else: - worker_cmd = f"{parsed_batch['launch command']} {com}" - - return worker_cmd - - -def construct_scheduler_legend(parsed_batch: Dict, nodes: int) -> Dict: - """ - Constructs a legend of relevant information needed for each scheduler. - - This function generates a dictionary containing configuration details for various - job schedulers based on the provided batch configuration. The returned dictionary - includes flags for bank, queue, and walltime, as well as commands to check the - scheduler and the initial launch command. - - Args: - parsed_batch: A dictionary of batch configurations, which must include `bank`, - `queue`, `walltime`, and `flux alloc`. See - [`parse_batch_block`][study.batch.parse_batch_block] for more information on - all the settings in this dictionary. - nodes: The number of nodes to use in the launch command. - - Returns: - A dictionary containing scheduler-related information, structured as - follows:\n - - For each scheduler (e.g., 'flux', 'lsf', 'pbs', 'slurm'):\n - - `bank` (str): The flag to add a bank to the launch command. - - `check cmd` (List[str]): The command to run to check if this is the main - scheduler for the cluster. - - `expected check output` (bytes): The expected output from running - the check command. - - `launch` (str): The initial launch command for the scheduler. - - `queue` (str): The flag to add a queue to the launch command (if - applicable). - - `walltime` (str): The flag to add a walltime to the launch command - (if applicable). - """ - scheduler_legend = { - "flux": { - "bank": f" --setattr=system.bank={parsed_batch['bank']}", - "check cmd": ["flux", "resource", "info"], - "expected check output": b"Nodes", - "launch": f"{parsed_batch['flux alloc']} -o pty -N {nodes} --exclusive --job-name=merlin", - "queue": f" --setattr=system.queue={parsed_batch['queue']}", - "walltime": f" -t {convert_timestring(parsed_batch['walltime'], format_method='FSD')}", - }, - "lsf": { - "check cmd": ["jsrun", "--help"], - "expected check output": b"jsrun", - "launch": f"jsrun -a 1 -c ALL_CPUS -g ALL_GPUS --bind=none -n {nodes}", - }, - # pbs is mainly a placeholder in case a user wants to try it (we don't have it at the lab so it's mostly untested) - "pbs": { - "bank": f" -A {parsed_batch['bank']}", - "check cmd": ["qsub", "--version"], - "expected check output": b"pbs_version", - "launch": f"qsub -l nodes={nodes}", - "queue": f" -q {parsed_batch['queue']}", - "walltime": f" -l walltime={convert_timestring(parsed_batch['walltime'])}", - }, - "slurm": { - "bank": f" -A {parsed_batch['bank']}", - "check cmd": ["sbatch", "--help"], - "expected check output": b"sbatch", - "launch": f"srun -N {nodes} -n {nodes}", - "queue": f" -p {parsed_batch['queue']}", - "walltime": f" -t {convert_timestring(parsed_batch['walltime'])}", - }, - } - return scheduler_legend - - -def construct_worker_launch_command(parsed_batch: Dict, nodes: int) -> str: - """ - Constructs the worker launch command based on the provided batch configuration. - - This function generates a launch command for a worker process when no - 'worker_launch' command is specified in the batch configuration. It - utilizes the scheduler legend to incorporate necessary flags such as - bank, queue, and walltime, depending on the workload manager. - - Args: - parsed_batch: A dictionary of batch configurations, which must include - `btype`, `bank`, `queue`, and `walltime`. See - [`parse_batch_block`][study.batch.parse_batch_block] for more information - on all the settings in this dictionary. - nodes: The number of nodes to use in the batch launch. - - Returns: - The constructed launch command for the worker process. - - Raises: - TypeError: If the PBS scheduler is enabled for a batch type other than 'flux'. - KeyError: If the workload manager is not found in the scheduler legend. - """ - # Initialize launch_command and get the scheduler_legend and workload_manager - launch_command: str = "" - scheduler_legend: Dict = construct_scheduler_legend(parsed_batch, nodes) - workload_manager: str = get_batch_type(scheduler_legend) - - LOG.debug(f"parsed_batch: {parsed_batch}") - - if parsed_batch["btype"] == "pbs" and workload_manager == parsed_batch["btype"]: - raise TypeError("The PBS scheduler is only enabled for 'batch: flux' type") - - if parsed_batch["btype"] == "slurm" and workload_manager not in ("lsf", "flux", "pbs"): - workload_manager = "slurm" - - LOG.debug(f"workload_manager: {workload_manager}") - - try: - launch_command = scheduler_legend[workload_manager]["launch"] - except KeyError as e: # pylint: disable=C0103 - LOG.debug(e) - - # If lsf is the workload manager we stop here (no need to add bank, queue, walltime) - if workload_manager != "lsf" or not launch_command: - # Add bank, queue, and walltime to the launch command as necessary - for key in ("bank", "queue", "walltime"): - if parsed_batch[key]: - try: - launch_command += scheduler_legend[workload_manager][key] - except KeyError as e: # pylint: disable=C0103 - LOG.error(e) - - # To read from stdin we append this to the launch command for pbs - if workload_manager == "pbs": - launch_command += " --" - - return launch_command + + def __init__(self, batch_config: Dict = None): + """ + Initialize the BatchManager with a batch configuration. + + Args: + batch_config: Dictionary containing batch configuration settings. + If None, an empty dictionary will be used. + """ + self.batch_config = batch_config or {} + self.parsed_batch = self._parse_batch_block() + self.scheduler_legend = {} + self.detected_scheduler = None + + def _parse_batch_block(self) -> Dict: + """ + Parse the batch block configuration. + + Returns: + Dictionary containing parsed batch configuration with defaults applied. + """ + flux_path: str = get_yaml_var(self.batch_config, "flux_path", "") + if "/" in flux_path: + flux_path += "/" + + flux_exe: str = os.path.join(flux_path, "flux") + flux_alloc: str + try: + flux_alloc = get_flux_alloc(flux_exe) + except FileNotFoundError as e: + LOG.debug(e) + flux_alloc = "" + + parsed_batch = { + "btype": get_yaml_var(self.batch_config, "type", "local"), + "nodes": get_yaml_var(self.batch_config, "nodes", None), + "shell": get_yaml_var(self.batch_config, "shell", "bash"), + "bank": get_yaml_var(self.batch_config, "bank", ""), + "queue": get_yaml_var(self.batch_config, "queue", ""), + "walltime": get_yaml_var(self.batch_config, "walltime", ""), + "launch pre": get_yaml_var(self.batch_config, "launch_pre", ""), + "launch args": get_yaml_var(self.batch_config, "launch_args", ""), + "launch command": get_yaml_var(self.batch_config, "worker_launch", ""), + "flux path": flux_path, + "flux exe": flux_exe, + "flux exec": get_yaml_var(self.batch_config, "flux_exec", None), + "flux alloc": flux_alloc, + "flux opts": get_yaml_var(self.batch_config, "flux_start_opts", ""), + "flux exec workers": get_yaml_var(self.batch_config, "flux_exec_workers", True), + } + return parsed_batch + + def is_parallel(self) -> bool: + """ + Check if this batch configuration is set up for parallel execution. + + Returns: + True if batch type is not 'local', indicating parallel processing. + """ + return self.parsed_batch["btype"] != "local" + + def _check_scheduler(self, scheduler: str) -> bool: + """ + Check if a specific scheduler is available on the system. + + Args: + scheduler: Name of the scheduler to check ('flux', 'slurm', 'lsf', 'pbs'). + + Returns: + True if the scheduler is available, False otherwise. + """ + if scheduler not in ("flux", "slurm", "lsf", "pbs"): + LOG.warning(f"Invalid scheduler {scheduler} given to _check_scheduler.") + return False + + # Ensure scheduler legend is populated + if not self.scheduler_legend: + self._build_scheduler_legend() + + try: + process = subprocess.Popen( + self.scheduler_legend[scheduler]["check cmd"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + result = process.stdout.readlines() + expected_output = self.scheduler_legend[scheduler]["expected check output"] + if result and len(result) > 0 and expected_output in result[0]: + return True + return False + except (FileNotFoundError, PermissionError): + return False + + def detect_scheduler(self, default: str = None) -> str: + """ + Automatically detect which batch scheduler is available. + + Args: + default: Default scheduler to return if none are detected. + + Returns: + Name of the detected scheduler or the default value. + """ + if self.detected_scheduler is not None: + return self.detected_scheduler + + # Build scheduler legend if not already done + if not self.scheduler_legend: + self._build_scheduler_legend() + + # Check schedulers in priority order + schedulers_to_check = ["flux", "pbs", "lsf", "slurm"] + for scheduler in schedulers_to_check: + LOG.debug(f"check for {scheduler} = {self._check_scheduler(scheduler)}") + if self._check_scheduler(scheduler): + self.detected_scheduler = scheduler + return scheduler + + # Check environment variables for system type + sys_type = os.environ.get("SYS_TYPE", "") + if "toss_3" in sys_type: + self.detected_scheduler = "slurm" + return "slurm" + + if "blueos" in sys_type: + self.detected_scheduler = "lsf" + return "lsf" + + self.detected_scheduler = default + return default + + def _get_node_count(self, default: int = 1) -> int: + """ + Determine node count based on environment and scheduler. + + Args: + default: Default node count if none can be determined. + + Returns: + Number of nodes to use for the batch job. + + Raises: + ValueError: If Flux version is too old. + """ + # Flux version check + flux_ver = get_flux_version(self.parsed_batch["flux exe"], no_errors=True) + if flux_ver: + major, minor, _ = map(int, flux_ver.split(".")) + if major < 1 and minor < 17: + raise ValueError("Flux version is too old. Supported versions are 0.17.0+.") + + # Try to get node count from Flux + try: + get_size_proc = subprocess.run("flux getattr size", shell=True, capture_output=True, text=True) + return int(get_size_proc.stdout) + except Exception: + pass + + # Check Slurm environment + if "SLURM_JOB_NUM_NODES" in os.environ: + return int(os.environ["SLURM_JOB_NUM_NODES"]) + + # Check LSF environment + if "LSB_HOSTS" in os.environ: + nodes = set(os.environ["LSB_HOSTS"].split()) + return len(nodes) - 1 + if "LSB_MCPU_HOSTS" in os.environ: + nodes = os.environ["LSB_MCPU_HOSTS"].split() + return len(nodes) // 2 - 1 + + return default + + def _build_scheduler_legend(self, nodes: int = None) -> None: + """ + Build the scheduler legend with configuration for all supported schedulers. + + Args: + nodes: Number of nodes for the launch command. If None, will attempt + to determine automatically. + """ + if nodes is None: + nodes = self._get_node_count(default=1) + + self.scheduler_legend = { + "flux": { + "bank": f" --setattr=system.bank={self.parsed_batch['bank']}", + "check cmd": ["flux", "resource", "info"], + "expected check output": b"Nodes", + "launch": f"{self.parsed_batch['flux alloc']} -o pty -N {nodes} --exclusive --job-name=merlin", + "queue": f" --setattr=system.queue={self.parsed_batch['queue']}", + "walltime": f" -t {convert_timestring(self.parsed_batch['walltime'], format_method='FSD')}", + }, + "lsf": { + "check cmd": ["jsrun", "--help"], + "expected check output": b"jsrun", + "launch": f"jsrun -a 1 -c ALL_CPUS -g ALL_GPUS --bind=none -n {nodes}", + }, + "pbs": { + "bank": f" -A {self.parsed_batch['bank']}", + "check cmd": ["qsub", "--version"], + "expected check output": b"pbs_version", + "launch": f"qsub -l nodes={nodes}", + "queue": f" -q {self.parsed_batch['queue']}", + "walltime": f" -l walltime={convert_timestring(self.parsed_batch['walltime'])}", + }, + "slurm": { + "bank": f" -A {self.parsed_batch['bank']}", + "check cmd": ["sbatch", "--help"], + "expected check output": b"sbatch", + "launch": f"srun -N {nodes} -n {nodes}", + "queue": f" -p {self.parsed_batch['queue']}", + "walltime": f" -t {convert_timestring(self.parsed_batch['walltime'])}", + }, + } + + def _get_flux_launch_command(self) -> str: + """ + Build the Flux-specific launch command. + + Returns: + Flux launch command string. + """ + default_flux_exec = "flux exec" if self.parsed_batch["launch command"] else f"{self.parsed_batch['flux exe']} exec" + flux_exec = "" + + if self.parsed_batch["flux exec workers"]: + flux_exec = self.parsed_batch["flux exec"] if self.parsed_batch["flux exec"] else default_flux_exec + + if self.parsed_batch["launch command"] and "flux" not in self.parsed_batch["launch command"]: + launch = ( + f"{self.parsed_batch['launch command']} {self.parsed_batch['flux exe']}" + f" start {self.parsed_batch['flux opts']} {flux_exec} `which {self.parsed_batch['shell']}` -c" + ) + else: + launch = f"{self.parsed_batch['launch command']} {flux_exec} `which {self.parsed_batch['shell']}` -c" + + return launch + + def _construct_launch_command(self, nodes: int) -> str: + """ + Construct the base launch command for the detected scheduler. + + Args: + nodes: Number of nodes to use. + + Returns: + The constructed launch command. + + Raises: + TypeError: If PBS scheduler is used with non-flux batch type. + KeyError: If workload manager is not found in scheduler legend. + """ + # Build scheduler legend with the specified nodes + self._build_scheduler_legend(nodes) + + # Detect the workload manager + workload_manager = self.detect_scheduler() + + LOG.debug(f"parsed_batch: {self.parsed_batch}") + + if self.parsed_batch["btype"] == "pbs" and workload_manager == self.parsed_batch["btype"]: + raise TypeError("The PBS scheduler is only enabled for 'batch: flux' type") + + if self.parsed_batch["btype"] == "slurm" and workload_manager not in ("lsf", "flux", "pbs"): + workload_manager = "slurm" + + LOG.debug(f"workload_manager: {workload_manager}") + + try: + launch_command = self.scheduler_legend[workload_manager]["launch"] + except KeyError as e: + LOG.debug(e) + launch_command = "" + + # If LSF is the workload manager we stop here + if workload_manager != "lsf" and launch_command: + # Add bank, queue, and walltime as necessary + for key in ("bank", "queue", "walltime"): + if self.parsed_batch[key]: + try: + launch_command += self.scheduler_legend[workload_manager][key] + except KeyError as e: + LOG.error(e) + + # PBS-specific modification + if workload_manager == "pbs": + launch_command += " --" + + return launch_command + + def create_worker_launch_command(self, command: str, nodes: Union[str, int] = None) -> str: + """ + Create the complete worker launch command. + + Args: + command: The base command to be launched. + nodes: Number of nodes to use. Can be an integer, "all", or None. + If None, will use the batch configuration value. + + Returns: + Complete launch command ready for execution. + + Raises: + TypeError: If nodes parameter is invalid or PBS scheduler is misconfigured. + """ + # Handle local or LSF batch types + if self.parsed_batch["btype"] == "local" or "lsf" in self.parsed_batch["btype"]: + return command + + # Determine node count + if nodes is None: + nodes = self.parsed_batch["nodes"] + + if nodes is None or nodes == "all": + nodes = self._get_node_count(default=1) + elif not isinstance(nodes, int): + if isinstance(nodes, str) and nodes != "all": + raise TypeError("Nodes was passed with an invalid string value (only 'all' is supported).") + elif not isinstance(nodes, str): + raise TypeError("Nodes parameter must be an integer, 'all', or None.") + + # Build launch command if not provided + if not self.parsed_batch["launch command"]: + self.parsed_batch["launch command"] = self._construct_launch_command(nodes) + + # Add launch arguments + if self.parsed_batch["launch args"]: + self.parsed_batch["launch command"] += f" {self.parsed_batch['launch args']}" + + # Add pre-launch commands + if self.parsed_batch["launch pre"]: + self.parsed_batch["launch command"] = f"{self.parsed_batch['launch pre']} {self.parsed_batch['launch command']}" + + LOG.debug(f"launch command: {self.parsed_batch['launch command']}") + + # Construct final worker command + if self.parsed_batch["btype"] == "flux": + launch = self._get_flux_launch_command() + worker_cmd = f'{launch} "{command}"' + else: + worker_cmd = f"{self.parsed_batch['launch command']} {command}" + + return worker_cmd + + def get_batch_info(self) -> Dict: + """ + Get information about the current batch configuration. + + Returns: + Dictionary containing batch configuration details. + """ + return { + "type": self.parsed_batch["btype"], + "nodes": self.parsed_batch["nodes"], + "shell": self.parsed_batch["shell"], + "bank": self.parsed_batch["bank"], + "queue": self.parsed_batch["queue"], + "walltime": self.parsed_batch["walltime"], + "is_parallel": self.is_parallel(), + "detected_scheduler": self.detect_scheduler(), + } + + def update_config(self, new_config: Dict) -> None: + """ + Update the batch configuration and re-parse. + + Args: + new_config: New batch configuration dictionary. + """ + self.batch_config.update(new_config) + self.parsed_batch = self._parse_batch_block() + # Reset cached values + self.scheduler_legend = {} + self.detected_scheduler = None diff --git a/merlin/workers/celery_worker.py b/merlin/workers/celery_worker.py index 640cb461..272c6b34 100644 --- a/merlin/workers/celery_worker.py +++ b/merlin/workers/celery_worker.py @@ -22,7 +22,7 @@ from merlin.db_scripts.merlin_db import MerlinDatabase from merlin.exceptions import MerlinWorkerLaunchError -from merlin.study.batch import batch_check_parallel, batch_worker_launch +from merlin.study.batch import BatchManager from merlin.utils import check_machines from merlin.workers.worker import MerlinWorker @@ -45,6 +45,7 @@ class CeleryWorker(MerlinWorker): args (str): Additional CLI arguments passed to Celery. queues (List[str]): Queues the worker listens to. batch (dict): Optional batch submission settings. + batch_manager (BatchManager): Manager for batch-related operations. machines (List[str]): List of hostnames the worker is allowed to run on. overlap (bool): Whether this worker can overlap queues with others. @@ -85,6 +86,9 @@ def __init__( self.batch = self.config.get("batch", {}) self.machines = self.config.get("machines", []) self.overlap = overlap + + # Initialize BatchManager for this worker + self.batch_manager = BatchManager(self.batch) # Add this worker to the database merlin_db = MerlinDatabase() @@ -100,7 +104,8 @@ def _verify_args(self, disable_logs: bool = False) -> str: Args: disable_logs: If True, logging level will not be appended. """ - if batch_check_parallel(self.batch): + # Use BatchManager to check for parallel configuration + if self.batch_manager.is_parallel(): if "--concurrency" not in self.args: LOG.warning("Missing --concurrency in worker args for parallel tasks.") if "--prefetch-multiplier" not in self.args: @@ -133,10 +138,12 @@ def get_launch_command(self, override_args: str = "", disable_logs: bool = False # Validate args self._verify_args(disable_logs=disable_logs) - # Construct the launch command + # Construct the base celery command celery_cmd = f"celery -A merlin worker {self.args} -Q {','.join(self.queues)}" - nodes = self.batch.get("nodes", None) - launch_cmd = batch_worker_launch(self.batch, celery_cmd, nodes=nodes) + + # Use BatchManager to create the launch command + launch_cmd = self.batch_manager.create_worker_launch_command(celery_cmd) + return os.path.expandvars(launch_cmd) def should_launch(self) -> bool: @@ -202,10 +209,25 @@ def get_metadata(self) -> Dict: Returns: A dictionary containing key details about this worker. """ - return { + metadata = { "name": self.name, "queues": self.queues, "args": self.args, "machines": self.machines, "batch": self.batch, } + + # Add batch manager information + metadata["batch_info"] = self.batch_manager.get_batch_info() + + return metadata + + def update_batch_config(self, new_batch_config: Dict): + """ + Update the batch configuration for this worker. + + Args: + new_batch_config: New batch configuration to apply. + """ + self.batch.update(new_batch_config) + self.batch_manager.update_config(new_batch_config) From abf0447d6a4c5072af04961bef9f69973ffe8751 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 11 Sep 2025 09:27:18 -0700 Subject: [PATCH 2/8] add WorkerConfig and BatchConfig dataclasses --- merlin/abstracts/factory.py | 10 +- merlin/spec/specification.py | 29 ++- merlin/study/batch.py | 160 ++++++------ merlin/study/configurations.py | 294 ++++++++++++++++++++++ merlin/workers/celery_worker.py | 104 ++------ merlin/workers/handlers/celery_handler.py | 5 +- merlin/workers/worker.py | 28 +-- 7 files changed, 433 insertions(+), 197 deletions(-) create mode 100644 merlin/study/configurations.py diff --git a/merlin/abstracts/factory.py b/merlin/abstracts/factory.py index 22719864..d7703430 100644 --- a/merlin/abstracts/factory.py +++ b/merlin/abstracts/factory.py @@ -249,7 +249,15 @@ def create(self, component_type: str, config: Dict = None) -> Any: # Create and return an instance of the component_class try: - instance = component_class() if config is None else component_class(**config) + if config is None: + # No configuration + instance = component_class() + elif isinstance(config, dict): + # Dict-based config + instance = component_class(**config) + else: + # Try to pass directly (config could be dataclass) + instance = component_class(config) LOG.info(f"Created component '{canonical_name}'") return instance except Exception as e: diff --git a/merlin/spec/specification.py b/merlin/spec/specification.py index 66c58da4..8dafde3f 100644 --- a/merlin/spec/specification.py +++ b/merlin/spec/specification.py @@ -24,6 +24,7 @@ from maestrowf.specification import YAMLSpecification from merlin.spec import all_keys, defaults +from merlin.study.configurations import BatchConfig, WorkerConfig from merlin.utils import find_vlaunch_var, get_yaml_var, load_array_file, needs_merlin_expansion, repr_timedelta from merlin.workers.worker import MerlinWorker from merlin.workers.worker_factory import worker_factory @@ -1262,23 +1263,21 @@ def build_worker_list(self, workers_to_start: Set[str]) -> List[MerlinWorker]: for worker_name in workers_to_start: settings = all_workers[worker_name] - config = { - "args": settings.get("args", ""), - "machines": settings.get("machines", []), - "queues": set(self.get_queue_list(settings["steps"])), - "batch": settings["batch"] if settings["batch"] is not None else self.batch.copy(), - } - - if "nodes" in settings and settings["nodes"] is not None: - if config["batch"]: - config["batch"]["nodes"] = settings["nodes"] - else: - config["batch"] = {"nodes": settings["nodes"]} - LOG.debug(f"config for worker '{worker_name}': {config}") + batch_settings = settings["batch"] if settings["batch"] is not None else self.batch.copy() + + worker_config = WorkerConfig( + name=worker_name, + args=settings.get("args", ""), + queues=set(self.get_queue_list(settings["steps"])), + machines=settings.get("machines", []), + env=full_env, + overlap=overlap, + nodes=settings.get("nodes", None), + batch=BatchConfig.from_dict(batch_settings) + ) - worker_params = {"name": worker_name, "config": config, "env": full_env, "overlap": overlap} - worker_instance = worker_factory.create(self.merlin["resources"]["task_server"], worker_params) + worker_instance = worker_factory.create(self.merlin["resources"]["task_server"], worker_config) workers.append(worker_instance) LOG.debug(f"Created CeleryWorker object for worker '{worker_name}'.") diff --git a/merlin/study/batch.py b/merlin/study/batch.py index c20601ea..f5aaa9be 100644 --- a/merlin/study/batch.py +++ b/merlin/study/batch.py @@ -12,7 +12,8 @@ import subprocess from typing import Dict, Union -from merlin.utils import convert_timestring, get_flux_alloc, get_flux_version, get_yaml_var +from merlin.study.configurations import BatchConfig +from merlin.utils import convert_timestring, get_flux_alloc, get_flux_version LOG = logging.getLogger(__name__) @@ -29,61 +30,41 @@ class BatchManager: batch systems including Slurm, LSF, Flux, and PBS. Attributes: - batch_config (Dict): The parsed batch configuration dictionary. + batch_config (BatchConfig): The batch configuration object. scheduler_legend (Dict): Dictionary containing scheduler-specific information. detected_scheduler (str): The automatically detected scheduler type. """ - def __init__(self, batch_config: Dict = None): + def __init__(self, batch_config: BatchConfig = None): """ Initialize the BatchManager with a batch configuration. Args: - batch_config: Dictionary containing batch configuration settings. - If None, an empty dictionary will be used. + batch_config: BatchConfig object containing batch configuration settings. + If None, a default BatchConfig will be created. """ - self.batch_config = batch_config or {} - self.parsed_batch = self._parse_batch_block() + self.batch_config = batch_config or BatchConfig() self.scheduler_legend = {} self.detected_scheduler = None - def _parse_batch_block(self) -> Dict: - """ - Parse the batch block configuration. + # Initialize Flux-specific attributes + self._flux_exe = None + self._flux_alloc = None + self._init_flux_config() - Returns: - Dictionary containing parsed batch configuration with defaults applied. - """ - flux_path: str = get_yaml_var(self.batch_config, "flux_path", "") - if "/" in flux_path: + def _init_flux_config(self): + """Initialize Flux-specific configuration.""" + flux_path = self.batch_config.flux_path + if flux_path and not flux_path.endswith("/"): flux_path += "/" - flux_exe: str = os.path.join(flux_path, "flux") - flux_alloc: str + self._flux_exe = os.path.join(flux_path, "flux") + try: - flux_alloc = get_flux_alloc(flux_exe) + self._flux_alloc = get_flux_alloc(self._flux_exe) except FileNotFoundError as e: LOG.debug(e) - flux_alloc = "" - - parsed_batch = { - "btype": get_yaml_var(self.batch_config, "type", "local"), - "nodes": get_yaml_var(self.batch_config, "nodes", None), - "shell": get_yaml_var(self.batch_config, "shell", "bash"), - "bank": get_yaml_var(self.batch_config, "bank", ""), - "queue": get_yaml_var(self.batch_config, "queue", ""), - "walltime": get_yaml_var(self.batch_config, "walltime", ""), - "launch pre": get_yaml_var(self.batch_config, "launch_pre", ""), - "launch args": get_yaml_var(self.batch_config, "launch_args", ""), - "launch command": get_yaml_var(self.batch_config, "worker_launch", ""), - "flux path": flux_path, - "flux exe": flux_exe, - "flux exec": get_yaml_var(self.batch_config, "flux_exec", None), - "flux alloc": flux_alloc, - "flux opts": get_yaml_var(self.batch_config, "flux_start_opts", ""), - "flux exec workers": get_yaml_var(self.batch_config, "flux_exec_workers", True), - } - return parsed_batch + self._flux_alloc = "" def is_parallel(self) -> bool: """ @@ -92,7 +73,7 @@ def is_parallel(self) -> bool: Returns: True if batch type is not 'local', indicating parallel processing. """ - return self.parsed_batch["btype"] != "local" + return self.batch_config.is_parallel() def _check_scheduler(self, scheduler: str) -> bool: """ @@ -179,7 +160,7 @@ def _get_node_count(self, default: int = 1) -> int: ValueError: If Flux version is too old. """ # Flux version check - flux_ver = get_flux_version(self.parsed_batch["flux exe"], no_errors=True) + flux_ver = get_flux_version(self._flux_exe, no_errors=True) if flux_ver: major, minor, _ = map(int, flux_ver.split(".")) if major < 1 and minor < 17: @@ -206,7 +187,7 @@ def _get_node_count(self, default: int = 1) -> int: return default - def _build_scheduler_legend(self, nodes: int = None) -> None: + def _build_scheduler_legend(self, nodes: int = None): """ Build the scheduler legend with configuration for all supported schedulers. @@ -219,12 +200,12 @@ def _build_scheduler_legend(self, nodes: int = None) -> None: self.scheduler_legend = { "flux": { - "bank": f" --setattr=system.bank={self.parsed_batch['bank']}", + "bank": f" --setattr=system.bank={self.batch_config.bank}", "check cmd": ["flux", "resource", "info"], "expected check output": b"Nodes", - "launch": f"{self.parsed_batch['flux alloc']} -o pty -N {nodes} --exclusive --job-name=merlin", - "queue": f" --setattr=system.queue={self.parsed_batch['queue']}", - "walltime": f" -t {convert_timestring(self.parsed_batch['walltime'], format_method='FSD')}", + "launch": f"{self._flux_alloc} -o pty -N {nodes} --exclusive --job-name=merlin", + "queue": f" --setattr=system.queue={self.batch_config.queue}", + "walltime": f" -t {convert_timestring(self.batch_config.walltime, format_method='FSD')}", }, "lsf": { "check cmd": ["jsrun", "--help"], @@ -232,20 +213,20 @@ def _build_scheduler_legend(self, nodes: int = None) -> None: "launch": f"jsrun -a 1 -c ALL_CPUS -g ALL_GPUS --bind=none -n {nodes}", }, "pbs": { - "bank": f" -A {self.parsed_batch['bank']}", + "bank": f" -A {self.batch_config.bank}", "check cmd": ["qsub", "--version"], "expected check output": b"pbs_version", "launch": f"qsub -l nodes={nodes}", - "queue": f" -q {self.parsed_batch['queue']}", - "walltime": f" -l walltime={convert_timestring(self.parsed_batch['walltime'])}", + "queue": f" -q {self.batch_config.queue}", + "walltime": f" -l walltime={convert_timestring(self.batch_config.walltime)}", }, "slurm": { - "bank": f" -A {self.parsed_batch['bank']}", + "bank": f" -A {self.batch_config.bank}", "check cmd": ["sbatch", "--help"], "expected check output": b"sbatch", "launch": f"srun -N {nodes} -n {nodes}", - "queue": f" -p {self.parsed_batch['queue']}", - "walltime": f" -t {convert_timestring(self.parsed_batch['walltime'])}", + "queue": f" -p {self.batch_config.queue}", + "walltime": f" -t {convert_timestring(self.batch_config.walltime)}", }, } @@ -256,19 +237,19 @@ def _get_flux_launch_command(self) -> str: Returns: Flux launch command string. """ - default_flux_exec = "flux exec" if self.parsed_batch["launch command"] else f"{self.parsed_batch['flux exe']} exec" + default_flux_exec = "flux exec" if self.batch_config.worker_launch else f"{self._flux_exe} exec" flux_exec = "" - if self.parsed_batch["flux exec workers"]: - flux_exec = self.parsed_batch["flux exec"] if self.parsed_batch["flux exec"] else default_flux_exec + if self.batch_config.flux_exec_workers: + flux_exec = self.batch_config.flux_exec if self.batch_config.flux_exec else default_flux_exec - if self.parsed_batch["launch command"] and "flux" not in self.parsed_batch["launch command"]: + if self.batch_config.worker_launch and "flux" not in self.batch_config.worker_launch: launch = ( - f"{self.parsed_batch['launch command']} {self.parsed_batch['flux exe']}" - f" start {self.parsed_batch['flux opts']} {flux_exec} `which {self.parsed_batch['shell']}` -c" + f"{self.batch_config.worker_launch} {self._flux_exe}" + f" start {self.batch_config.flux_start_opts} {flux_exec} `which {self.batch_config.shell}` -c" ) else: - launch = f"{self.parsed_batch['launch command']} {flux_exec} `which {self.parsed_batch['shell']}` -c" + launch = f"{self.batch_config.worker_launch} {flux_exec} `which {self.batch_config.shell}` -c" return launch @@ -292,12 +273,12 @@ def _construct_launch_command(self, nodes: int) -> str: # Detect the workload manager workload_manager = self.detect_scheduler() - LOG.debug(f"parsed_batch: {self.parsed_batch}") + LOG.debug(f"batch_config: {self.batch_config}") - if self.parsed_batch["btype"] == "pbs" and workload_manager == self.parsed_batch["btype"]: + if self.batch_config.type == "pbs" and workload_manager == self.batch_config.type: raise TypeError("The PBS scheduler is only enabled for 'batch: flux' type") - if self.parsed_batch["btype"] == "slurm" and workload_manager not in ("lsf", "flux", "pbs"): + if self.batch_config.type == "slurm" and workload_manager not in ("lsf", "flux", "pbs"): workload_manager = "slurm" LOG.debug(f"workload_manager: {workload_manager}") @@ -312,7 +293,8 @@ def _construct_launch_command(self, nodes: int) -> str: if workload_manager != "lsf" and launch_command: # Add bank, queue, and walltime as necessary for key in ("bank", "queue", "walltime"): - if self.parsed_batch[key]: + config_value = getattr(self.batch_config, key) + if config_value: try: launch_command += self.scheduler_legend[workload_manager][key] except KeyError as e: @@ -340,12 +322,12 @@ def create_worker_launch_command(self, command: str, nodes: Union[str, int] = No TypeError: If nodes parameter is invalid or PBS scheduler is misconfigured. """ # Handle local or LSF batch types - if self.parsed_batch["btype"] == "local" or "lsf" in self.parsed_batch["btype"]: + if self.batch_config.type == "local" or "lsf" in self.batch_config.type: return command # Determine node count if nodes is None: - nodes = self.parsed_batch["nodes"] + nodes = self.batch_config.nodes if nodes is None or nodes == "all": nodes = self._get_node_count(default=1) @@ -356,25 +338,26 @@ def create_worker_launch_command(self, command: str, nodes: Union[str, int] = No raise TypeError("Nodes parameter must be an integer, 'all', or None.") # Build launch command if not provided - if not self.parsed_batch["launch command"]: - self.parsed_batch["launch command"] = self._construct_launch_command(nodes) + launch_command = self.batch_config.worker_launch + if not launch_command: + launch_command = self._construct_launch_command(nodes) # Add launch arguments - if self.parsed_batch["launch args"]: - self.parsed_batch["launch command"] += f" {self.parsed_batch['launch args']}" + if self.batch_config.launch_args: + launch_command += f" {self.batch_config.launch_args}" # Add pre-launch commands - if self.parsed_batch["launch pre"]: - self.parsed_batch["launch command"] = f"{self.parsed_batch['launch pre']} {self.parsed_batch['launch command']}" + if self.batch_config.launch_pre: + launch_command = f"{self.batch_config.launch_pre} {launch_command}" - LOG.debug(f"launch command: {self.parsed_batch['launch command']}") + LOG.debug(f"launch command: {launch_command}") # Construct final worker command - if self.parsed_batch["btype"] == "flux": + if self.batch_config.type == "flux": launch = self._get_flux_launch_command() worker_cmd = f'{launch} "{command}"' else: - worker_cmd = f"{self.parsed_batch['launch command']} {command}" + worker_cmd = f"{launch_command} {command}" return worker_cmd @@ -385,26 +368,27 @@ def get_batch_info(self) -> Dict: Returns: Dictionary containing batch configuration details. """ - return { - "type": self.parsed_batch["btype"], - "nodes": self.parsed_batch["nodes"], - "shell": self.parsed_batch["shell"], - "bank": self.parsed_batch["bank"], - "queue": self.parsed_batch["queue"], - "walltime": self.parsed_batch["walltime"], - "is_parallel": self.is_parallel(), - "detected_scheduler": self.detect_scheduler(), - } + batch_info = self.batch_config.to_dict() + batch_info["is_parallel"] = self.is_parallel() + batch_info["detected_scheduler"] = self.detect_scheduler() + return batch_info - def update_config(self, new_config: Dict) -> None: + def update_config(self, new_config: Union[Dict, BatchConfig]): """ - Update the batch configuration and re-parse. + Update the batch configuration and reset cached values. Args: - new_config: New batch configuration dictionary. + new_config: New batch configuration (Dict or BatchConfig). """ - self.batch_config.update(new_config) - self.parsed_batch = self._parse_batch_block() + if isinstance(new_config, dict): + new_config = BatchConfig.from_dict(new_config) + self.batch_config = self.batch_config.merge(new_config) + elif isinstance(new_config, BatchConfig): + self.batch_config = new_config + else: + raise TypeError("new_config must be a Dict or BatchConfig instance") + # Reset cached values self.scheduler_legend = {} self.detected_scheduler = None + self._init_flux_config() diff --git a/merlin/study/configurations.py b/merlin/study/configurations.py new file mode 100644 index 00000000..49d196a8 --- /dev/null +++ b/merlin/study/configurations.py @@ -0,0 +1,294 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Configuration dataclasses for Merlin workers and batch systems. + +This module provides strongly-typed configuration objects that replace +dictionary-based configurations. +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Union + + +@dataclass +class BatchConfig: + """ + Configuration for batch job submission and execution. + + This dataclass encapsulates all batch-related configuration options + that can be specified in a Merlin workflow specification. + + Attributes: + type: The type of batch system to use ('local', 'slurm', 'lsf', 'flux', 'pbs'). + nodes: Number of nodes to request for the batch job. + shell: Shell to use for command execution. + bank: Account/bank to charge for the job. + queue: Queue/partition to submit the job to. + walltime: Maximum wall time for the job. + launch_pre: Commands to run before launching the main command. + launch_args: Additional arguments for the launch command. + worker_launch: Custom launch command for workers. + flux_path: Path to the Flux executable directory. + flux_exec: Custom Flux exec command. + flux_start_opts: Additional options for flux start command. + flux_exec_workers: Whether to use flux exec to launch workers on all nodes. + """ + type: str = "local" + nodes: Optional[Union[int, str]] = None + shell: str = "bash" + bank: str = "" + queue: str = "" + walltime: str = "" + dry_run: bool = False + launch_pre: str = "" + launch_args: str = "" + worker_launch: str = "" + flux_path: str = "" + flux_exec: Optional[str] = None + flux_start_opts: str = "" + flux_exec_workers: bool = True + + def __post_init__(self): + """Validate configuration after initialization.""" + valid_types = {"local", "slurm", "lsf", "flux", "pbs"} + if self.type not in valid_types: + raise ValueError(f"Invalid batch type '{self.type}'. Must be one of: {valid_types}") + + if self.nodes is not None: + if isinstance(self.nodes, str) and self.nodes != "all": + try: + self.nodes = int(self.nodes) + except ValueError: + raise ValueError(f"Invalid nodes value '{self.nodes}'. Must be an integer, 'all', or None.") + + # Normalize flux_path + if self.flux_path and not self.flux_path.endswith("/"): + self.flux_path += "/" + + @classmethod + def from_dict(cls, config_dict: Dict) -> 'BatchConfig': + """ + Create a BatchConfig from a dictionary. + + Args: + config_dict: Dictionary containing batch configuration. + + Returns: + BatchConfig instance with values from the dictionary. + """ + return cls(**config_dict) + + def to_dict(self) -> Dict: + """ + Convert BatchConfig to dictionary for backward compatibility. + + Returns: + Dictionary representation of the configuration. + """ + return { + "type": self.type, + "nodes": self.nodes, + "shell": self.shell, + "bank": self.bank, + "queue": self.queue, + "walltime": self.walltime, + "dry_run": self.dry_run, + "launch_pre": self.launch_pre, + "launch_args": self.launch_args, + "worker_launch": self.worker_launch, + "flux_path": self.flux_path, + "flux_exec": self.flux_exec, + "flux_start_opts": self.flux_start_opts, + "flux_exec_workers": self.flux_exec_workers, + } + + def is_parallel(self) -> bool: + """ + Check if this configuration enables parallel execution. + + Returns: + True if batch type is not 'local'. + """ + return self.type != "local" + + def merge(self, other: 'BatchConfig') -> 'BatchConfig': + """ + Merge this configuration with another, with other taking precedence. + + Args: + other: BatchConfig to merge with this one. + + Returns: + New BatchConfig with merged values. + """ + merged_dict = self.to_dict() + other_dict = other.to_dict() + + # Only override non-empty/non-None values + for key, value in other_dict.items(): + if value not in (None, "", []): + merged_dict[key] = value + + return BatchConfig.from_dict(merged_dict) + + +@dataclass +class WorkerConfig: + """ + Configuration for Merlin workers. + + This dataclass encapsulates all worker-related configuration options + including queues, machines, batch settings, and launch arguments. + + Attributes: + name: Name of the worker. + args: Command-line arguments for the worker process. + queues: Set of queue names this worker should process. + batch: Batch configuration for this worker. + machines: List of machine names where this worker can run. + nodes: Number of nodes to use (can override batch.nodes). + overlap: Whether this worker can overlap queues with other workers. + env: Environment variables for the worker process. + """ + name: str + args: str = "" + queues: Set[str] = field(default_factory=lambda: {"[merlin]_merlin"}) + batch: BatchConfig = field(default_factory=BatchConfig) + machines: List[str] = field(default_factory=list) + nodes: Optional[Union[int, str]] = None + overlap: bool = False + env: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self): + """Validate configuration after initialization.""" + if not self.name: + raise ValueError("Worker name cannot be empty") + + if not isinstance(self.queues, set): + if isinstance(self.queues, (list, tuple)): + self.queues = set(self.queues) + else: + raise ValueError("queues must be a set, list, or tuple") + + if self.nodes is not None: + if isinstance(self.nodes, str) and self.nodes != "all": + try: + self.nodes = int(self.nodes) + except ValueError: + raise ValueError(f"Invalid nodes value '{self.nodes}'. Must be an integer, 'all', or None.") + + @classmethod + def from_dict(cls, name: str, config_dict: Dict, env: Dict[str, str] = None) -> 'WorkerConfig': + """ + Create a WorkerConfig from a dictionary. + + Args: + name: Name of the worker. + config_dict: Dictionary containing worker configuration. + env: Environment variables dictionary. + + Returns: + WorkerConfig instance with values from the dictionary. + """ + # Extract batch configuration if present + batch_dict = config_dict.get("batch", {}) + batch_config = BatchConfig.from_dict(batch_dict) if batch_dict else BatchConfig() + + # Convert queues to set if needed + queues = config_dict.get("queues", {"[merlin]_merlin"}) + if isinstance(queues, (list, tuple)): + queues = set(queues) + elif not isinstance(queues, set): + queues = {queues} if isinstance(queues, str) else {"[merlin]_merlin"} + + return cls( + name=name, + args=config_dict.get("args", ""), + queues=queues, + batch=batch_config, + machines=config_dict.get("machines", []), + nodes=config_dict.get("nodes"), + overlap=config_dict.get("overlap", False), + env=env or {}, + ) + + def to_dict(self) -> Dict: + """ + Convert WorkerConfig to dictionary for backward compatibility. + + Returns: + Dictionary representation of the configuration. + """ + return { + "name": self.name, + "args": self.args, + "queues": list(self.queues), + "batch": self.batch.to_dict(), + "machines": self.machines, + "nodes": self.nodes, + "overlap": self.overlap, + } + + def get_effective_nodes(self) -> Optional[Union[int, str]]: + """ + Get the effective node count, preferring worker-specific over batch config. + + Returns: + Node count to use, or None if not specified. + """ + return self.nodes if self.nodes is not None else self.batch.nodes + + def get_effective_batch_config(self) -> BatchConfig: + """ + Get the effective batch configuration with worker-specific overrides. + + Returns: + BatchConfig with worker-specific values applied. + """ + if self.nodes is not None: + # Create a copy of batch config with worker's node override + batch_dict = self.batch.to_dict() + batch_dict["nodes"] = self.nodes + return BatchConfig.from_dict(batch_dict) + return self.batch + + def has_machine_restrictions(self) -> bool: + """ + Check if this worker has machine restrictions. + + Returns: + True if machines list is not empty. + """ + return bool(self.machines) + + def add_queue(self, queue_name: str): + """ + Add a queue to this worker's queue set. + + Args: + queue_name: Name of the queue to add. + """ + self.queues.add(queue_name) + + def remove_queue(self, queue_name: str): + """ + Remove a queue from this worker's queue set. + + Args: + queue_name: Name of the queue to remove. + """ + self.queues.discard(queue_name) + + def update_env(self, env_updates: Dict[str, str]): + """ + Update environment variables for this worker. + + Args: + env_updates: Dictionary of environment variable updates. + """ + self.env.update(env_updates) diff --git a/merlin/workers/celery_worker.py b/merlin/workers/celery_worker.py index 272c6b34..f69b25f0 100644 --- a/merlin/workers/celery_worker.py +++ b/merlin/workers/celery_worker.py @@ -23,6 +23,7 @@ from merlin.db_scripts.merlin_db import MerlinDatabase from merlin.exceptions import MerlinWorkerLaunchError from merlin.study.batch import BatchManager +from merlin.study.configurations import WorkerConfig from merlin.utils import check_machines from merlin.workers.worker import MerlinWorker @@ -39,15 +40,8 @@ class CeleryWorker(MerlinWorker): jobs from specific task queues. Attributes: - name (str): The name of the worker. - config (dict): Configuration settings for the worker. - env (dict): Environment variables used by the worker process. - args (str): Additional CLI arguments passed to Celery. - queues (List[str]): Queues the worker listens to. - batch (dict): Optional batch submission settings. - batch_manager (BatchManager): Manager for batch-related operations. - machines (List[str]): List of hostnames the worker is allowed to run on. - overlap (bool): Whether this worker can overlap queues with others. + worker_config (study.configurations.WorkerConfig): The worker configuration object. + batch_manager (study.batch.BatchManager): A manager object for batch settings. Methods: _verify_args: Validate and adjust CLI args based on worker setup. @@ -57,42 +51,22 @@ class CeleryWorker(MerlinWorker): get_metadata: Return identifying metadata about the worker. """ - def __init__( - self, - name: str, - config: Dict, - env: Dict[str, str] = None, - overlap: bool = False, - ): + def __init__(self, worker_config: WorkerConfig): """ Constructor for Celery workers. Sets up attributes used throughout this worker object and saves this worker to the database. Args: - name: The name of the worker. - config: A dictionary containing optional configuration settings for this worker including:\n - - `args`: A string of arguments to pass to the launch command - - `queues`: A set of task queues for this worker to watch - - `batch`: A dictionary of specific batch configuration settings to use for this worker - - `nodes`: The number of nodes to launch this worker on - - `machines`: A list of machines that this worker is allowed to run on - env: A dictionary of environment variables set by the user. - overlap: If True multiple workers can pull tasks from overlapping queues. + worker_config (study.configurations.WorkerConfig): The worker configuration object. """ - super().__init__(name, config, env) - self.args = self.config.get("args", "") - self.queues = self.config.get("queues", {"[merlin]_merlin"}) - self.batch = self.config.get("batch", {}) - self.machines = self.config.get("machines", []) - self.overlap = overlap - - # Initialize BatchManager for this worker - self.batch_manager = BatchManager(self.batch) + super().__init__(worker_config) + # TODO might want to move the below line to the base class? With other task servers + # we need to see what would be important to store and maybe refactor logical worker entries # Add this worker to the database merlin_db = MerlinDatabase() - merlin_db.create("logical_worker", self.name, self.queues) + merlin_db.create("logical_worker", self.worker_config.name, self.worker_config.queues) def _verify_args(self, disable_logs: bool = False) -> str: """ @@ -106,19 +80,19 @@ def _verify_args(self, disable_logs: bool = False) -> str: """ # Use BatchManager to check for parallel configuration if self.batch_manager.is_parallel(): - if "--concurrency" not in self.args: + if "--concurrency" not in self.worker_config.args: LOG.warning("Missing --concurrency in worker args for parallel tasks.") - if "--prefetch-multiplier" not in self.args: + if "--prefetch-multiplier" not in self.worker_config.args: LOG.warning("Missing --prefetch-multiplier in worker args for parallel tasks.") - if "fair" not in self.args: + if "fair" not in self.worker_config.args: LOG.warning("Missing -O fair in worker args for parallel tasks.") - if "-n" not in self.args: - nhash = time.strftime("%Y%m%d-%H%M%S") if self.overlap else "" - self.args += f" -n {self.name}{nhash}.%%h" + if "-n" not in self.worker_config.args: + nhash = time.strftime("%Y%m%d-%H%M%S") if self.worker_config.overlap else "" + self.worker_config.args += f" -n {self.worker_config.name}{nhash}.%%h" - if not disable_logs and "-l" not in self.args: - self.args += f" -l {logging.getLevelName(LOG.getEffectiveLevel())}" + if not disable_logs and "-l" not in self.worker_config.args: + self.worker_config.args += f" -l {logging.getLevelName(LOG.getEffectiveLevel())}" def get_launch_command(self, override_args: str = "", disable_logs: bool = False) -> str: """ @@ -133,13 +107,13 @@ def get_launch_command(self, override_args: str = "", disable_logs: bool = False """ # Override existing arguments if necessary if override_args != "": - self.args = override_args + self.worker_config.args = override_args # Validate args self._verify_args(disable_logs=disable_logs) # Construct the base celery command - celery_cmd = f"celery -A merlin worker {self.args} -Q {','.join(self.queues)}" + celery_cmd = f"celery -A merlin worker {self.worker_config.args} -Q {','.join(self.worker_config.queues)}" # Use BatchManager to create the launch command launch_cmd = self.batch_manager.create_worker_launch_command(celery_cmd) @@ -155,27 +129,24 @@ def should_launch(self) -> bool: Returns: True if the worker should be launched, False otherwise. """ - machines = self.config.get("machines", None) - queues = self.config.get("queues", ["[merlin]_merlin"]) - - if machines: - if not check_machines(machines): + if self.worker_config.machines: + if not check_machines(self.worker_config.machines): LOG.error( - f"The following machines were provided for worker '{self.name}': {machines}. " + f"The following machines were provided for worker '{self.worker_config.name}': {self.worker_config.machines}. " f"However, the current machine '{socket.gethostname()}' is not in this list." ) return False - output_path = self.env.get("OUTPUT_PATH") + output_path = self.worker_config.env.get("OUTPUT_PATH") if output_path and not os.path.exists(output_path): LOG.error(f"{output_path} not accessible on host {socket.gethostname()}") return False - if not self.overlap: + if not self.worker_config.overlap: from merlin.study.celeryadapter import get_running_queues # pylint: disable=import-outside-toplevel running_queues = get_running_queues("merlin") - for queue in queues: + for queue in self.worker_config.queues: if queue in running_queues: LOG.warning(f"Queue {queue} is already being processed by another worker.") return False @@ -196,31 +167,11 @@ def start(self, override_args: str = "", disable_logs: bool = False): if self.should_launch(): launch_cmd = self.get_launch_command(override_args=override_args, disable_logs=disable_logs) try: - subprocess.Popen(launch_cmd, env=self.env, shell=True, universal_newlines=True) # pylint: disable=R1732 - LOG.debug(f"Launched worker '{self.name}' with command: {launch_cmd}.") + subprocess.Popen(launch_cmd, env=self.worker_config.env, shell=True, universal_newlines=True) # pylint: disable=R1732 + LOG.debug(f"Launched worker '{self.worker_config.name}' with command: {launch_cmd}.") except Exception as e: # pylint: disable=C0103 LOG.error(f"Cannot start celery workers, {e}") raise MerlinWorkerLaunchError from e - - def get_metadata(self) -> Dict: - """ - Return metadata about this worker instance. - - Returns: - A dictionary containing key details about this worker. - """ - metadata = { - "name": self.name, - "queues": self.queues, - "args": self.args, - "machines": self.machines, - "batch": self.batch, - } - - # Add batch manager information - metadata["batch_info"] = self.batch_manager.get_batch_info() - - return metadata def update_batch_config(self, new_batch_config: Dict): """ @@ -229,5 +180,4 @@ def update_batch_config(self, new_batch_config: Dict): Args: new_batch_config: New batch configuration to apply. """ - self.batch.update(new_batch_config) self.batch_manager.update_config(new_batch_config) diff --git a/merlin/workers/handlers/celery_handler.py b/merlin/workers/handlers/celery_handler.py index 2fd7b494..fd405f2c 100644 --- a/merlin/workers/handlers/celery_handler.py +++ b/merlin/workers/handlers/celery_handler.py @@ -55,12 +55,13 @@ def start_workers(self, workers: List[CeleryWorker], **kwargs): # Launch the workers or echo out the command that will be used to launch the workers for worker in workers: + worker_name = worker.worker_config.name if echo_only: - LOG.debug(f"Not launching worker '{worker.name}', just echoing command.") + LOG.debug(f"Not launching worker '{worker_name}', just echoing command.") launch_cmd = worker.get_launch_command(override_args=override_args, disable_logs=disable_logs) print(launch_cmd) else: - LOG.debug(f"Launching worker '{worker.name}'.") + LOG.debug(f"Launching worker '{worker_name}'.") worker.start(override_args=override_args, disable_logs=disable_logs) def stop_workers(self): diff --git a/merlin/workers/worker.py b/merlin/workers/worker.py index 8bdabe10..91515745 100644 --- a/merlin/workers/worker.py +++ b/merlin/workers/worker.py @@ -20,6 +20,9 @@ from abc import ABC, abstractmethod from typing import Dict +from merlin.study.batch import BatchManager +from merlin.study.configurations import WorkerConfig + class MerlinWorker(ABC): """ @@ -29,9 +32,8 @@ class MerlinWorker(ABC): an individual worker based on its configuration. Attributes: - name: The name of the worker. - config: The dictionary configuration for the worker. - env: A dictionary representing the full environment for the current context. + worker_config (study.configurations.WorkerConfig): The worker configuration object. + batch_manager (study.batch.BatchManager): A manager object for batch settings. Methods: get_launch_command: Build the shell command to launch the worker. @@ -39,19 +41,17 @@ class MerlinWorker(ABC): get_metadata: Return identifying metadata about the worker. """ - def __init__(self, name: str, config: Dict, env: Dict[str, str] = None): + def __init__(self, worker_config: WorkerConfig): """ Initialize a `MerlinWorker` instance. Args: - name: The name of the worker. - config: A dictionary containing the worker configuration. - env: Optional dictionary of environment variables to use; if not provided, - a copy of the current OS environment is used. + worker_config (study.configurations.WorkerConfig): The worker configuration object. """ - self.name = name - self.config = config - self.env = env or os.environ.copy() + self.worker_config = worker_config + + # Initialize BatchManager for this worker + self.batch_manager: BatchManager = BatchManager(self.worker_config.get_effective_batch_config()) @abstractmethod def get_launch_command(self, override_args: str = "") -> str: @@ -71,11 +71,11 @@ def start(self): Launch this worker. """ - @abstractmethod def get_metadata(self) -> Dict: """ - Return a dictionary of metadata about this worker (for logging/debugging). + Return metadata about this worker instance. Returns: - A metadata dictionary (e.g., name, queues, machines). + A dictionary containing key details about this worker. """ + return self.worker_config.to_dict() From e48ef846ab7129dbb6f459c7f7243ed11b39bff4 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 11 Sep 2025 11:39:31 -0700 Subject: [PATCH 3/8] fix broken unit tests --- merlin/study/configurations.py | 14 +- merlin/workers/celery_worker.py | 1 - .../workers/handlers/test_celery_handler.py | 15 +- .../workers/handlers/test_worker_handler.py | 5 +- tests/unit/workers/test_celery_worker.py | 184 ++++++------------ tests/unit/workers/test_worker.py | 57 +++--- 6 files changed, 114 insertions(+), 162 deletions(-) diff --git a/merlin/study/configurations.py b/merlin/study/configurations.py index 49d196a8..13c5eef7 100644 --- a/merlin/study/configurations.py +++ b/merlin/study/configurations.py @@ -11,6 +11,7 @@ dictionary-based configurations. """ +import os from dataclasses import dataclass, field from typing import Dict, List, Optional, Set, Union @@ -181,16 +182,18 @@ def __post_init__(self): self.nodes = int(self.nodes) except ValueError: raise ValueError(f"Invalid nodes value '{self.nodes}'. Must be an integer, 'all', or None.") + + if not self.env: + self.env = os.environ.copy() @classmethod - def from_dict(cls, name: str, config_dict: Dict, env: Dict[str, str] = None) -> 'WorkerConfig': + # def from_dict(cls, name: str, config_dict: Dict, env: Dict[str, str] = None) -> 'WorkerConfig': + def from_dict(cls, config_dict: Dict) -> 'WorkerConfig': """ Create a WorkerConfig from a dictionary. Args: - name: Name of the worker. config_dict: Dictionary containing worker configuration. - env: Environment variables dictionary. Returns: WorkerConfig instance with values from the dictionary. @@ -207,14 +210,14 @@ def from_dict(cls, name: str, config_dict: Dict, env: Dict[str, str] = None) -> queues = {queues} if isinstance(queues, str) else {"[merlin]_merlin"} return cls( - name=name, + name=config_dict["name"], # Not using `get` since this should fail if 'name' is missing args=config_dict.get("args", ""), queues=queues, batch=batch_config, machines=config_dict.get("machines", []), nodes=config_dict.get("nodes"), overlap=config_dict.get("overlap", False), - env=env or {}, + env=config_dict.get("env", {}), ) def to_dict(self) -> Dict: @@ -232,6 +235,7 @@ def to_dict(self) -> Dict: "machines": self.machines, "nodes": self.nodes, "overlap": self.overlap, + "env": self.env, } def get_effective_nodes(self) -> Optional[Union[int, str]]: diff --git a/merlin/workers/celery_worker.py b/merlin/workers/celery_worker.py index f69b25f0..d631bf95 100644 --- a/merlin/workers/celery_worker.py +++ b/merlin/workers/celery_worker.py @@ -22,7 +22,6 @@ from merlin.db_scripts.merlin_db import MerlinDatabase from merlin.exceptions import MerlinWorkerLaunchError -from merlin.study.batch import BatchManager from merlin.study.configurations import WorkerConfig from merlin.utils import check_machines from merlin.workers.worker import MerlinWorker diff --git a/tests/unit/workers/handlers/test_celery_handler.py b/tests/unit/workers/handlers/test_celery_handler.py index 62340f9d..8dcdf3ba 100644 --- a/tests/unit/workers/handlers/test_celery_handler.py +++ b/tests/unit/workers/handlers/test_celery_handler.py @@ -14,15 +14,16 @@ import pytest from pytest_mock import MockerFixture +from merlin.study.configurations import WorkerConfig from merlin.workers.celery_worker import CeleryWorker from merlin.workers.handlers import CeleryWorkerHandler class DummyCeleryWorker(CeleryWorker): - def __init__(self, name: str, config: Dict = None, env: Dict = None): - super().__init__(name, config or {}, env or {}) + def __init__(self, worker_config: WorkerConfig): + super().__init__(worker_config) self.launched_with = None - self.launch_command = f"celery --worker-name={name}" + self.launch_command = f"celery --worker-name={self.worker_config.name}" def get_launch_command(self, override_args: str = "", disable_logs: bool = False) -> str: parts = [self.launch_command] @@ -34,7 +35,7 @@ def get_launch_command(self, override_args: str = "", disable_logs: bool = False def start(self, override_args: str = "", disable_logs: bool = False): self.launched_with = (override_args, disable_logs) - return f"Launching {self.name} with {override_args} and logs {'off' if disable_logs else 'on'}" + return f"Launching {self.worker_config.name} with {override_args} and logs {'off' if disable_logs else 'on'}" class TestCeleryWorkerHandler: @@ -52,9 +53,11 @@ def mock_db(self, mocker: MockerFixture) -> MagicMock: @pytest.fixture def workers(self, mock_db: MagicMock) -> List[DummyCeleryWorker]: + worker_1_config = WorkerConfig(name="worker1") + worker_2_config = WorkerConfig(name="worker2") return [ - DummyCeleryWorker("worker1"), - DummyCeleryWorker("worker2"), + DummyCeleryWorker(worker_1_config), + DummyCeleryWorker(worker_2_config), ] def test_echo_only_prints_commands( diff --git a/tests/unit/workers/handlers/test_worker_handler.py b/tests/unit/workers/handlers/test_worker_handler.py index b89b2bc6..3e5733b8 100644 --- a/tests/unit/workers/handlers/test_worker_handler.py +++ b/tests/unit/workers/handlers/test_worker_handler.py @@ -12,6 +12,7 @@ import pytest +from merlin.study.configurations import WorkerConfig from merlin.workers.handlers.worker_handler import MerlinWorkerHandler from merlin.workers.worker import MerlinWorker @@ -74,7 +75,7 @@ def test_launch_workers_calls_worker_launch(): Test that `start_workers` calls each worker's `start` method. """ handler = DummyWorkerHandler() - workers = [DummyWorker("w1", {}, {}), DummyWorker("w2", {}, {})] + workers = [DummyWorker(WorkerConfig(name="w1")), DummyWorker(WorkerConfig(name="w2"))] result = handler.start_workers(workers) @@ -98,7 +99,7 @@ def test_query_workers_returns_summary(): Test that `query_workers` returns a valid summary of current worker state. """ handler = DummyWorkerHandler() - workers = [DummyWorker("a", {}, {}), DummyWorker("b", {}, {})] + workers = [DummyWorker(WorkerConfig(name="a")), DummyWorker(WorkerConfig(name="b"))] handler.start_workers(workers) summary = handler.query_workers() diff --git a/tests/unit/workers/test_celery_worker.py b/tests/unit/workers/test_celery_worker.py index a2ba4504..c94d03be 100644 --- a/tests/unit/workers/test_celery_worker.py +++ b/tests/unit/workers/test_celery_worker.py @@ -15,6 +15,7 @@ from pytest_mock import MockerFixture from merlin.exceptions import MerlinWorkerLaunchError +from merlin.study.configurations import WorkerConfig from merlin.workers import CeleryWorker from tests.fixture_types import FixtureCallable, FixtureDict, FixtureStr @@ -35,33 +36,33 @@ def workers_testing_dir(create_testing_dir: FixtureCallable, temp_output_dir: Fi @pytest.fixture -def basic_config() -> FixtureDict[str, Any]: +def worker_config() -> WorkerConfig: """ - Fixture that provides a basic CeleryWorker configuration dictionary. + Fixture that provides a basic WorkerConfig instance. Returns: - A dictionary representing a minimal valid CeleryWorker config. + A WorkerConfig object with default settings. """ - return { - "args": "", - "queues": ["queue1", "queue2"], - "batch": {"nodes": 1}, - "machines": [], - } + return WorkerConfig( + name="test_worker", + queues={"queue1", "queue2"}, + nodes=1, + env={"OUTPUT_PATH": workers_testing_dir}, + ) @pytest.fixture -def dummy_env(workers_testing_dir: FixtureStr) -> FixtureDict[str, str]: +def celery_worker(worker_config: WorkerConfig) -> CeleryWorker: """ - Fixture that provides a mock environment dictionary with OUTPUT_PATH set. + Fixture that provides a CeleryWorker instance. Args: - workers_testing_dir: The path to the temporary testing directory for workers tests. + worker_config: A WorkerConfig object to initialize the CeleryWorker. Returns: - A dictionary simulating environment variables, including OUTPUT_PATH. + A CeleryWorker object for testing. """ - return {"OUTPUT_PATH": workers_testing_dir} + return CeleryWorker(worker_config) @pytest.fixture @@ -81,41 +82,26 @@ def mock_db(mocker: MockerFixture) -> MagicMock: return mocker.patch("merlin.workers.celery_worker.MerlinDatabase") -def test_constructor_sets_fields_and_calls_db_create( - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_constructor_sets_fields_and_calls_db_create(worker_config: WorkerConfig, mock_db: MagicMock): """ Test that CeleryWorker constructor sets all fields correctly and triggers database creation. This test verifies that: - - The worker fields (name, args, queues, batch, machines, overlap) are set from config. + - The worker configuration is stored in the `worker_config` attribute. - The MerlinDatabase.create method is called with the correct arguments. Args: - basic_config: A minimal configuration dictionary for the worker. - dummy_env: A dictionary simulating the environment variables. + worker_config: A WorkerConfig instance for initializing the CeleryWorker. mock_db: A mocked MerlinDatabase to prevent real database interaction. """ - worker = CeleryWorker("worker1", basic_config, dummy_env, overlap=True) + worker = CeleryWorker(worker_config) - assert worker.name == "worker1" - assert worker.args == "" - assert worker.queues == ["queue1", "queue2"] - assert worker.batch == {"nodes": 1} - assert worker.machines == [] - assert worker.overlap is True + assert worker.worker_config == worker_config - mock_db.return_value.create.assert_called_once_with("logical_worker", "worker1", ["queue1", "queue2"]) + mock_db.return_value.create.assert_called_once_with("logical_worker", "test_worker", {"queue1", "queue2"}) -def test_verify_args_adds_name_and_logging_flags( - mocker: MockerFixture, - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_verify_args_adds_name_and_logging_flags(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): """ Test that `_verify_args()` appends required flags to the Celery args string. @@ -130,27 +116,21 @@ def test_verify_args_adds_name_and_logging_flags( Args: mocker: Pytest mocker fixture. - basic_config: Fixture providing a basic CeleryWorker configuration. - dummy_env: Fixture providing a mock environment dictionary. + celery_worker: A CeleryWorker object for testing. mock_db: Mocked MerlinDatabase to avoid real database writes. """ - mocker.patch("merlin.workers.celery_worker.batch_check_parallel", return_value=True) + celery_worker.batch_manager = MagicMock() + celery_worker.batch_manager.is_parallel.return_value = True mock_logger = mocker.patch("merlin.workers.celery_worker.LOG") - worker = CeleryWorker("w1", basic_config, dummy_env) - worker._verify_args() + celery_worker._verify_args() - assert "-n w1" in worker.args - assert "-l" in worker.args + assert "-n test_worker" in celery_worker.worker_config.args + assert "-l" in celery_worker.worker_config.args assert mock_logger.warning.called -def test_get_launch_command_returns_expanded_command( - mocker: MockerFixture, - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_get_launch_command_returns_expanded_command(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): """ Test that `get_launch_command()` constructs a valid Celery command. @@ -165,25 +145,19 @@ def test_get_launch_command_returns_expanded_command( Args: mocker: Pytest mocker fixture. - basic_config: Fixture providing a basic CeleryWorker configuration. - dummy_env: Fixture providing a mock environment dictionary. + celery_worker: A CeleryWorker object for testing. mock_db: Mocked MerlinDatabase to avoid real database writes. """ - mocker.patch("merlin.workers.celery_worker.batch_worker_launch", return_value="celery -A ...") - worker = CeleryWorker("w2", basic_config, dummy_env) + celery_worker.batch_manager = MagicMock() + celery_worker.batch_manager.create_worker_launch_command.return_value = "celery -A ..." - cmd = worker.get_launch_command("--override", disable_logs=True) + cmd = celery_worker.get_launch_command("--override", disable_logs=True) assert isinstance(cmd, str) assert "celery" in cmd -def test_should_launch_rejects_if_machine_check_fails( - mocker: MockerFixture, - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_should_launch_rejects_if_machine_check_fails(mocker: MockerFixture, worker_config: WorkerConfig, mock_db: MagicMock): """ Test that `should_launch` returns False if the machine check fails. @@ -197,25 +171,19 @@ def test_should_launch_rejects_if_machine_check_fails( Args: mocker: Pytest mocker fixture. - basic_config: Configuration dictionary containing the list of valid machines. - dummy_env: Environment variable dictionary (unused in this test). + worker_config: A WorkerConfig instance for initializing the CeleryWorker. mock_db: Mocked MerlinDatabase to avoid real database writes. """ - basic_config["machines"] = ["host1"] + worker_config.machines = ["host1"] mocker.patch("merlin.workers.celery_worker.check_machines", return_value=False) - worker = CeleryWorker("w3", basic_config, dummy_env) + worker = CeleryWorker(worker_config) result = worker.should_launch() assert result is False -def test_should_launch_rejects_if_output_path_missing( - mocker: MockerFixture, - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_should_launch_rejects_if_output_path_missing(mocker: MockerFixture, worker_config: WorkerConfig, mock_db: MagicMock): """ Test that `should_launch` returns False if the output path does not exist. @@ -228,27 +196,21 @@ def test_should_launch_rejects_if_output_path_missing( Args: mocker: Pytest mocker fixture. - basic_config: Configuration dictionary including machine constraints. - dummy_env: Environment variable dictionary containing an invalid output path. + worker_config: A WorkerConfig instance for initializing the CeleryWorker. mock_db: Mocked MerlinDatabase to avoid real database writes. """ - basic_config["machines"] = ["host1"] - dummy_env["OUTPUT_PATH"] = "/nonexistent" + worker_config.machines = ["host1"] + worker_config.env = {"OUTPUT_PATH": "/nonexistent"} mocker.patch("merlin.workers.celery_worker.check_machines", return_value=True) mocker.patch("os.path.exists", return_value=False) - worker = CeleryWorker("w4", basic_config, dummy_env) + worker = CeleryWorker(worker_config) result = worker.should_launch() assert result is False -def test_should_launch_rejects_due_to_running_queues( - mocker: MockerFixture, - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_should_launch_rejects_due_to_running_queues(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): """ Test that `should_launch` returns False when a conflicting queue is already running. @@ -262,24 +224,17 @@ def test_should_launch_rejects_due_to_running_queues( Args: mocker: Pytest mocker fixture. - basic_config: Fixture providing base worker config. - dummy_env: Fixture providing environment variables. + celery_worker: A CeleryWorker object for testing. mock_db: Fixture for the Merlin database mock. """ mocker.patch("merlin.study.celeryadapter.get_running_queues", return_value=["queue1"]) - worker = CeleryWorker("w5", basic_config, dummy_env) - result = worker.should_launch() + result = celery_worker.should_launch() assert result is False -def test_launch_worker_runs_if_should_launch( - mocker: MockerFixture, - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_launch_worker_runs_if_should_launch(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): """ Test that `start` executes the launch command if `should_launch` returns True. @@ -294,28 +249,21 @@ def test_launch_worker_runs_if_should_launch( Args: mocker: Pytest mocker fixture. - basic_config: Fixture providing base worker config. - dummy_env: Fixture providing environment variables. + celery_worker: A CeleryWorker object for testing. mock_db: Fixture for the Merlin database mock. """ - mocker.patch.object(CeleryWorker, "should_launch", return_value=True) - mocker.patch.object(CeleryWorker, "get_launch_command", return_value="echo hello") + mocker.patch.object(celery_worker, "should_launch", return_value=True) + mocker.patch.object(celery_worker, "get_launch_command", return_value="echo hello") mock_popen = mocker.patch("merlin.workers.celery_worker.subprocess.Popen") mock_logger = mocker.patch("merlin.workers.celery_worker.LOG") - worker = CeleryWorker("w6", basic_config, dummy_env) - worker.start() + celery_worker.start() mock_popen.assert_called_once() assert mock_logger.debug.called -def test_launch_worker_raises_if_popen_fails( - mocker: MockerFixture, - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_launch_worker_raises_if_popen_fails(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): """ Test that `start` raises `MerlinWorkerLaunchError` when `subprocess.Popen` fails. @@ -328,26 +276,19 @@ def test_launch_worker_raises_if_popen_fails( Args: mocker: Pytest mocker fixture. - basic_config: Basic configuration dictionary fixture. - dummy_env: Dummy environment dictionary fixture. + celery_worker: A CeleryWorker object for testing. mock_db: Mocked MerlinDatabase object. """ - mocker.patch.object(CeleryWorker, "should_launch", return_value=True) - mocker.patch.object(CeleryWorker, "get_launch_command", return_value="fail") + mocker.patch.object(celery_worker, "should_launch", return_value=True) + mocker.patch.object(celery_worker, "get_launch_command", return_value="fail") mocker.patch("merlin.workers.celery_worker.subprocess.Popen", side_effect=OSError("boom")) mocker.patch("merlin.workers.celery_worker.LOG") - worker = CeleryWorker("w7", basic_config, dummy_env) - with pytest.raises(MerlinWorkerLaunchError): - worker.start() + celery_worker.start() -def test_get_metadata_returns_expected_dict( - basic_config: FixtureDict[str, Any], - dummy_env: FixtureDict[str, str], - mock_db: MagicMock, -): +def test_get_metadata_returns_expected_dict(celery_worker: CeleryWorker, mock_db: MagicMock): """ Test that `get_metadata` returns the expected dictionary with worker configuration. @@ -359,16 +300,9 @@ def test_get_metadata_returns_expected_dict( instantiation. Args: - basic_config: Basic configuration dictionary fixture. - dummy_env: Dummy environment dictionary fixture. + celery_worker: A CeleryWorker object for testing. mock_db: Mocked MerlinDatabase object. """ - worker = CeleryWorker("meta_worker", basic_config, dummy_env) - - metadata = worker.get_metadata() + metadata = celery_worker.get_metadata() - assert metadata["name"] == "meta_worker" - assert metadata["queues"] == ["queue1", "queue2"] - assert metadata["args"] == "" - assert metadata["machines"] == [] - assert metadata["batch"] == {"nodes": 1} + assert WorkerConfig.from_dict(metadata) == celery_worker.worker_config diff --git a/tests/unit/workers/test_worker.py b/tests/unit/workers/test_worker.py index 20c75e50..52c77137 100644 --- a/tests/unit/workers/test_worker.py +++ b/tests/unit/workers/test_worker.py @@ -12,33 +12,33 @@ from pytest_mock import MockerFixture +from merlin.study.configurations import WorkerConfig from merlin.workers.worker import MerlinWorker class DummyMerlinWorker(MerlinWorker): def get_launch_command(self, override_args: str = "") -> str: - return f"run_worker --name {self.name} {override_args}" + return f"run_worker --name {self.worker_config.name} {override_args}" def start(self): - return f"Launching {self.name}" - - def get_metadata(self) -> dict: - return {"name": self.name, "config": self.config} + return f"Launching {self.worker_config.name}" def test_init_sets_attributes(): """ Test that the constructor sets name, config, and env correctly. """ - name = "test_worker" - config = {"foo": "bar"} - env = {"TEST_ENV": "123"} + worker_config = WorkerConfig( + name="test_worker", + queues={"queue1", "queue2"}, + env={"TEST_ENV": "123"} + ) - worker = DummyMerlinWorker(name, config, env) + worker = DummyMerlinWorker(worker_config) - assert worker.name == name - assert worker.config == config - assert worker.env == env + assert worker.worker_config.name == worker_config.name + assert worker.worker_config.queues == worker_config.queues + assert worker.worker_config.env == worker_config.env def test_init_uses_os_environ_when_env_none(mocker: MockerFixture): @@ -51,39 +51,50 @@ def test_init_uses_os_environ_when_env_none(mocker: MockerFixture): mock_environ = {"MY_VAR": "xyz"} mocker.patch.dict("os.environ", mock_environ, clear=True) - worker = DummyMerlinWorker("w", {}, None) + worker_config = WorkerConfig(name="test_worker") + + worker = DummyMerlinWorker(worker_config) - assert "MY_VAR" in worker.env - assert worker.env["MY_VAR"] == "xyz" - assert worker.env is not os.environ # ensure it's a copy + assert "MY_VAR" in worker.worker_config.env + assert worker.worker_config.env["MY_VAR"] == "xyz" + assert worker.worker_config.env is not os.environ # ensure it's a copy def test_get_launch_command_returns_expected_string(): """ Test that get_launch_command builds the correct shell string. """ - worker = DummyMerlinWorker("dummy", {}, {}) + worker_name = "test_worker" + worker_config = WorkerConfig(name=worker_name) + worker = DummyMerlinWorker(worker_config) cmd = worker.get_launch_command("--debug") assert "--debug" in cmd - assert "dummy" in cmd + assert worker_name in cmd def test_launch_worker_returns_expected_string(): """ Test that start returns a string indicating launch. """ - worker = DummyMerlinWorker("dummy", {}, {}) + worker_name = "test_worker" + worker_config = WorkerConfig(name=worker_name) + worker = DummyMerlinWorker(worker_config) result = worker.start() - assert result == "Launching dummy" + assert result == f"Launching {worker_name}" def test_get_metadata_returns_expected_dict(): """ Test that get_metadata returns the correct metadata dictionary. """ - config = {"foo": "bar"} - worker = DummyMerlinWorker("dummy", config, {}) + worker_config = WorkerConfig( + name="test_worker", + args="-l INFO --concurrency 3", + queues={"queue1"}, + nodes=10, + ) + worker = DummyMerlinWorker(worker_config) meta = worker.get_metadata() - assert meta == {"name": "dummy", "config": config} + assert WorkerConfig.from_dict(meta) == worker_config From fd1a566f993f4f2e36bb88a62d702ef5cfe92e33 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 11 Sep 2025 14:01:40 -0700 Subject: [PATCH 4/8] fix broken integration tests --- merlin/study/batch.py | 32 ++++++++++++++++++++++---------- tests/integration/definitions.py | 8 +++++--- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/merlin/study/batch.py b/merlin/study/batch.py index f5aaa9be..583d1a3f 100644 --- a/merlin/study/batch.py +++ b/merlin/study/batch.py @@ -197,6 +197,10 @@ def _build_scheduler_legend(self, nodes: int = None): """ if nodes is None: nodes = self._get_node_count(default=1) + + # print(f"type(walltime): {type(self.batch_config.walltime)}") + # print(f"walltime: {self.batch_config.walltime}") + # print(f"convert timestring: {convert_timestring(self.batch_config.walltime, format_method='FSD')}") self.scheduler_legend = { "flux": { @@ -230,26 +234,29 @@ def _build_scheduler_legend(self, nodes: int = None): }, } - def _get_flux_launch_command(self) -> str: + def _get_flux_launch_command(self, existing_launch_cmd: str) -> str: """ Build the Flux-specific launch command. + + Args: + existing_launch_cmd: The existing launch command or an empty string. Returns: Flux launch command string. """ - default_flux_exec = "flux exec" if self.batch_config.worker_launch else f"{self._flux_exe} exec" + default_flux_exec = "flux exec" if existing_launch_cmd else f"{self._flux_exe} exec" flux_exec = "" if self.batch_config.flux_exec_workers: flux_exec = self.batch_config.flux_exec if self.batch_config.flux_exec else default_flux_exec - if self.batch_config.worker_launch and "flux" not in self.batch_config.worker_launch: + if existing_launch_cmd and "flux" not in existing_launch_cmd: launch = ( - f"{self.batch_config.worker_launch} {self._flux_exe}" + f"{existing_launch_cmd} {self._flux_exe}" f" start {self.batch_config.flux_start_opts} {flux_exec} `which {self.batch_config.shell}` -c" ) else: - launch = f"{self.batch_config.worker_launch} {flux_exec} `which {self.batch_config.shell}` -c" + launch = f"{existing_launch_cmd} {flux_exec} `which {self.batch_config.shell}` -c" return launch @@ -289,13 +296,17 @@ def _construct_launch_command(self, nodes: int) -> str: LOG.debug(e) launch_command = "" + print(f"launch_command before: {launch_command}") # If LSF is the workload manager we stop here if workload_manager != "lsf" and launch_command: # Add bank, queue, and walltime as necessary for key in ("bank", "queue", "walltime"): config_value = getattr(self.batch_config, key) + print(f"workload_manager: {workload_manager}") + print(f"config_value for {key}: {config_value}") if config_value: try: + print(f"scheduler_legend entry: {self.scheduler_legend[workload_manager][key]}") launch_command += self.scheduler_legend[workload_manager][key] except KeyError as e: LOG.error(e) @@ -304,6 +315,7 @@ def _construct_launch_command(self, nodes: int) -> str: if workload_manager == "pbs": launch_command += " --" + print(f"launch_command after: {launch_command}") return launch_command def create_worker_launch_command(self, command: str, nodes: Union[str, int] = None) -> str: @@ -352,12 +364,12 @@ def create_worker_launch_command(self, command: str, nodes: Union[str, int] = No LOG.debug(f"launch command: {launch_command}") - # Construct final worker command + # Add Flux-specific launch settings if self.batch_config.type == "flux": - launch = self._get_flux_launch_command() - worker_cmd = f'{launch} "{command}"' - else: - worker_cmd = f"{launch_command} {command}" + launch_command = self._get_flux_launch_command(launch_command) + + # Construct final worker command + worker_cmd = f"{launch_command} {command}" return worker_cmd diff --git a/tests/integration/definitions.py b/tests/integration/definitions.py index 7d326484..6ebeb92c 100644 --- a/tests/integration/definitions.py +++ b/tests/integration/definitions.py @@ -32,7 +32,8 @@ StepFileHasRegex, ) -from merlin.study.batch import check_for_scheduler +# from merlin.study.batch import check_for_scheduler +from merlin.study.batch import BatchManager from merlin.utils import get_flux_alloc, get_flux_cmd @@ -52,13 +53,14 @@ def get_worker_by_cmd(cmd: str, default: str) -> str: workers_cmd = default fake_cmds_path = f"tests/integration/fake_commands/{cmd}_fake_command" bogus_cmd = f"""PATH="{fake_cmds_path}:$PATH";{default}""" - scheduler_legend = {"flux": {"check cmd": ["flux", "resource", "list"], "expected check output": b"Nodes"}} + # scheduler_legend = {"flux": {"check cmd": ["flux", "resource", "list"], "expected check output": b"Nodes"}} + batch_manager = BatchManager() # Use bogus flux/qsub to test if no flux/qsub is present if not shutil.which(cmd): workers_cmd = bogus_cmd # Use bogus flux if flux is present but slurm is the main scheduler - elif cmd == "flux" and not check_for_scheduler(cmd, scheduler_legend): + elif cmd == "flux" and not batch_manager._check_scheduler(cmd): workers_cmd = bogus_cmd return workers_cmd From cf5b773b1002fb85eab2f37a018d949325697fcd Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 11 Sep 2025 14:20:49 -0700 Subject: [PATCH 5/8] update CHANGELOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 421e3a5d..5399e07f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `MerlinWorkerHandler`: base class for managing launching, stopping, and querying multiple workers - `CeleryWorkerHandler`: implementation of `MerlinWorkerHandler` specifically for manager Celery workers - `WorkerHandlerFactory`: to help determine which task server handler to use +- New configuration dataclasses: + - `BatchConfig`: To load in batch configuration settings + - `WorkerConfig`: To define worker settings ### Changed - Maestro version requirement is now at minimum 1.1.10 for status renderer changes - The `BackendFactory`, `MonitorFactory`, and `StatusRendererFactory` classes all now inherit from `MerlinBaseFactory` - Launching workers is now handled through worker classes rather than functions in the `celeryadapter.py` file +- `batch.py` is now a class rather than standalone functions ## [1.13.0b2] ### Added From a28c4a94a30d4a26d0235d2d9e64a13f61bfdde0 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 15 Sep 2025 08:54:47 -0700 Subject: [PATCH 6/8] run fix-style --- merlin/spec/specification.py | 2 +- merlin/study/batch.py | 120 ++++++++---------- merlin/study/configurations.py | 103 ++++++++------- merlin/workers/celery_worker.py | 21 +-- merlin/workers/worker.py | 1 - tests/integration/definitions.py | 2 +- .../workers/handlers/test_celery_handler.py | 2 +- tests/unit/workers/test_celery_worker.py | 3 +- tests/unit/workers/test_worker.py | 6 +- 9 files changed, 126 insertions(+), 134 deletions(-) diff --git a/merlin/spec/specification.py b/merlin/spec/specification.py index 8dafde3f..cc40c842 100644 --- a/merlin/spec/specification.py +++ b/merlin/spec/specification.py @@ -1274,7 +1274,7 @@ def build_worker_list(self, workers_to_start: Set[str]) -> List[MerlinWorker]: env=full_env, overlap=overlap, nodes=settings.get("nodes", None), - batch=BatchConfig.from_dict(batch_settings) + batch=BatchConfig.from_dict(batch_settings), ) worker_instance = worker_factory.create(self.merlin["resources"]["task_server"], worker_config) diff --git a/merlin/study/batch.py b/merlin/study/batch.py index 583d1a3f..32718859 100644 --- a/merlin/study/batch.py +++ b/merlin/study/batch.py @@ -24,21 +24,21 @@ class BatchManager: """ Manages batch job scheduling and worker launching across different schedulers. - + This class provides methods for detecting available schedulers, parsing batch configurations, and constructing appropriate launch commands for different batch systems including Slurm, LSF, Flux, and PBS. - + Attributes: batch_config (BatchConfig): The batch configuration object. scheduler_legend (Dict): Dictionary containing scheduler-specific information. detected_scheduler (str): The automatically detected scheduler type. """ - + def __init__(self, batch_config: BatchConfig = None): """ Initialize the BatchManager with a batch configuration. - + Args: batch_config: BatchConfig object containing batch configuration settings. If None, a default BatchConfig will be created. @@ -46,12 +46,12 @@ def __init__(self, batch_config: BatchConfig = None): self.batch_config = batch_config or BatchConfig() self.scheduler_legend = {} self.detected_scheduler = None - + # Initialize Flux-specific attributes self._flux_exe = None self._flux_alloc = None self._init_flux_config() - + def _init_flux_config(self): """Initialize Flux-specific configuration.""" flux_path = self.batch_config.flux_path @@ -59,34 +59,34 @@ def _init_flux_config(self): flux_path += "/" self._flux_exe = os.path.join(flux_path, "flux") - + try: self._flux_alloc = get_flux_alloc(self._flux_exe) except FileNotFoundError as e: LOG.debug(e) self._flux_alloc = "" - + def is_parallel(self) -> bool: """ Check if this batch configuration is set up for parallel execution. - + Returns: True if batch type is not 'local', indicating parallel processing. """ return self.batch_config.is_parallel() - - def _check_scheduler(self, scheduler: str) -> bool: + + def check_scheduler(self, scheduler: str) -> bool: """ Check if a specific scheduler is available on the system. - + Args: scheduler: Name of the scheduler to check ('flux', 'slurm', 'lsf', 'pbs'). - + Returns: True if the scheduler is available, False otherwise. """ if scheduler not in ("flux", "slurm", "lsf", "pbs"): - LOG.warning(f"Invalid scheduler {scheduler} given to _check_scheduler.") + LOG.warning(f"Invalid scheduler {scheduler} given to check_scheduler.") return False # Ensure scheduler legend is populated @@ -94,42 +94,39 @@ def _check_scheduler(self, scheduler: str) -> bool: self._build_scheduler_legend() try: - process = subprocess.Popen( - self.scheduler_legend[scheduler]["check cmd"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) + process = subprocess.run(self.scheduler_legend[scheduler]["check cmd"], capture_output=True) - result = process.stdout.readlines() - expected_output = self.scheduler_legend[scheduler]["expected check output"] - if result and len(result) > 0 and expected_output in result[0]: - return True + if process.stdout: + expected_output = self.scheduler_legend[scheduler]["expected check output"] + lines = process.stdout.splitlines() + if lines and expected_output in lines[0]: + return True return False except (FileNotFoundError, PermissionError): return False - + def detect_scheduler(self, default: str = None) -> str: """ Automatically detect which batch scheduler is available. - + Args: default: Default scheduler to return if none are detected. - + Returns: Name of the detected scheduler or the default value. """ if self.detected_scheduler is not None: return self.detected_scheduler - + # Build scheduler legend if not already done if not self.scheduler_legend: self._build_scheduler_legend() - + # Check schedulers in priority order schedulers_to_check = ["flux", "pbs", "lsf", "slurm"] for scheduler in schedulers_to_check: - LOG.debug(f"check for {scheduler} = {self._check_scheduler(scheduler)}") - if self._check_scheduler(scheduler): + LOG.debug(f"check for {scheduler} = {self.check_scheduler(scheduler)}") + if self.check_scheduler(scheduler): self.detected_scheduler = scheduler return scheduler @@ -145,17 +142,17 @@ def detect_scheduler(self, default: str = None) -> str: self.detected_scheduler = default return default - + def _get_node_count(self, default: int = 1) -> int: """ Determine node count based on environment and scheduler. - + Args: default: Default node count if none can be determined. - + Returns: Number of nodes to use for the batch job. - + Raises: ValueError: If Flux version is too old. """ @@ -170,7 +167,7 @@ def _get_node_count(self, default: int = 1) -> int: try: get_size_proc = subprocess.run("flux getattr size", shell=True, capture_output=True, text=True) return int(get_size_proc.stdout) - except Exception: + except (FileNotFoundError, PermissionError, ValueError): pass # Check Slurm environment @@ -186,11 +183,11 @@ def _get_node_count(self, default: int = 1) -> int: return len(nodes) // 2 - 1 return default - + def _build_scheduler_legend(self, nodes: int = None): """ Build the scheduler legend with configuration for all supported schedulers. - + Args: nodes: Number of nodes for the launch command. If None, will attempt to determine automatically. @@ -198,10 +195,6 @@ def _build_scheduler_legend(self, nodes: int = None): if nodes is None: nodes = self._get_node_count(default=1) - # print(f"type(walltime): {type(self.batch_config.walltime)}") - # print(f"walltime: {self.batch_config.walltime}") - # print(f"convert timestring: {convert_timestring(self.batch_config.walltime, format_method='FSD')}") - self.scheduler_legend = { "flux": { "bank": f" --setattr=system.bank={self.batch_config.bank}", @@ -233,20 +226,20 @@ def _build_scheduler_legend(self, nodes: int = None): "walltime": f" -t {convert_timestring(self.batch_config.walltime)}", }, } - + def _get_flux_launch_command(self, existing_launch_cmd: str) -> str: """ Build the Flux-specific launch command. Args: existing_launch_cmd: The existing launch command or an empty string. - + Returns: Flux launch command string. """ default_flux_exec = "flux exec" if existing_launch_cmd else f"{self._flux_exe} exec" flux_exec = "" - + if self.batch_config.flux_exec_workers: flux_exec = self.batch_config.flux_exec if self.batch_config.flux_exec else default_flux_exec @@ -259,27 +252,27 @@ def _get_flux_launch_command(self, existing_launch_cmd: str) -> str: launch = f"{existing_launch_cmd} {flux_exec} `which {self.batch_config.shell}` -c" return launch - + def _construct_launch_command(self, nodes: int) -> str: """ Construct the base launch command for the detected scheduler. - + Args: nodes: Number of nodes to use. - + Returns: The constructed launch command. - + Raises: TypeError: If PBS scheduler is used with non-flux batch type. KeyError: If workload manager is not found in scheduler legend. """ # Build scheduler legend with the specified nodes self._build_scheduler_legend(nodes) - + # Detect the workload manager workload_manager = self.detect_scheduler() - + LOG.debug(f"batch_config: {self.batch_config}") if self.batch_config.type == "pbs" and workload_manager == self.batch_config.type: @@ -296,17 +289,13 @@ def _construct_launch_command(self, nodes: int) -> str: LOG.debug(e) launch_command = "" - print(f"launch_command before: {launch_command}") # If LSF is the workload manager we stop here if workload_manager != "lsf" and launch_command: # Add bank, queue, and walltime as necessary for key in ("bank", "queue", "walltime"): config_value = getattr(self.batch_config, key) - print(f"workload_manager: {workload_manager}") - print(f"config_value for {key}: {config_value}") if config_value: try: - print(f"scheduler_legend entry: {self.scheduler_legend[workload_manager][key]}") launch_command += self.scheduler_legend[workload_manager][key] except KeyError as e: LOG.error(e) @@ -315,21 +304,20 @@ def _construct_launch_command(self, nodes: int) -> str: if workload_manager == "pbs": launch_command += " --" - print(f"launch_command after: {launch_command}") return launch_command - + def create_worker_launch_command(self, command: str, nodes: Union[str, int] = None) -> str: """ Create the complete worker launch command. - + Args: command: The base command to be launched. nodes: Number of nodes to use. Can be an integer, "all", or None. If None, will use the batch configuration value. - + Returns: Complete launch command ready for execution. - + Raises: TypeError: If nodes parameter is invalid or PBS scheduler is misconfigured. """ @@ -346,7 +334,7 @@ def create_worker_launch_command(self, command: str, nodes: Union[str, int] = No elif not isinstance(nodes, int): if isinstance(nodes, str) and nodes != "all": raise TypeError("Nodes was passed with an invalid string value (only 'all' is supported).") - elif not isinstance(nodes, str): + if not isinstance(nodes, str): raise TypeError("Nodes parameter must be an integer, 'all', or None.") # Build launch command if not provided @@ -367,16 +355,16 @@ def create_worker_launch_command(self, command: str, nodes: Union[str, int] = No # Add Flux-specific launch settings if self.batch_config.type == "flux": launch_command = self._get_flux_launch_command(launch_command) - + # Construct final worker command worker_cmd = f"{launch_command} {command}" return worker_cmd - + def get_batch_info(self) -> Dict: """ Get information about the current batch configuration. - + Returns: Dictionary containing batch configuration details. """ @@ -384,11 +372,11 @@ def get_batch_info(self) -> Dict: batch_info["is_parallel"] = self.is_parallel() batch_info["detected_scheduler"] = self.detect_scheduler() return batch_info - + def update_config(self, new_config: Union[Dict, BatchConfig]): """ Update the batch configuration and reset cached values. - + Args: new_config: New batch configuration (Dict or BatchConfig). """ @@ -399,7 +387,7 @@ def update_config(self, new_config: Union[Dict, BatchConfig]): self.batch_config = new_config else: raise TypeError("new_config must be a Dict or BatchConfig instance") - + # Reset cached values self.scheduler_legend = {} self.detected_scheduler = None diff --git a/merlin/study/configurations.py b/merlin/study/configurations.py index 13c5eef7..698e4f71 100644 --- a/merlin/study/configurations.py +++ b/merlin/study/configurations.py @@ -16,14 +16,17 @@ from typing import Dict, List, Optional, Set, Union +# pylint: disable=too-many-instance-attributes + + @dataclass class BatchConfig: """ Configuration for batch job submission and execution. - + This dataclass encapsulates all batch-related configuration options that can be specified in a Merlin workflow specification. - + Attributes: type: The type of batch system to use ('local', 'slurm', 'lsf', 'flux', 'pbs'). nodes: Number of nodes to request for the batch job. @@ -39,6 +42,7 @@ class BatchConfig: flux_start_opts: Additional options for flux start command. flux_exec_workers: Whether to use flux exec to launch workers on all nodes. """ + type: str = "local" nodes: Optional[Union[int, str]] = None shell: str = "bash" @@ -53,41 +57,41 @@ class BatchConfig: flux_exec: Optional[str] = None flux_start_opts: str = "" flux_exec_workers: bool = True - + def __post_init__(self): """Validate configuration after initialization.""" valid_types = {"local", "slurm", "lsf", "flux", "pbs"} if self.type not in valid_types: raise ValueError(f"Invalid batch type '{self.type}'. Must be one of: {valid_types}") - + if self.nodes is not None: if isinstance(self.nodes, str) and self.nodes != "all": try: self.nodes = int(self.nodes) - except ValueError: - raise ValueError(f"Invalid nodes value '{self.nodes}'. Must be an integer, 'all', or None.") - + except ValueError as exc: + raise ValueError(f"Invalid nodes value '{self.nodes}'. Must be an integer, 'all', or None.") from exc + # Normalize flux_path if self.flux_path and not self.flux_path.endswith("/"): self.flux_path += "/" - + @classmethod - def from_dict(cls, config_dict: Dict) -> 'BatchConfig': + def from_dict(cls, config_dict: Dict) -> "BatchConfig": """ Create a BatchConfig from a dictionary. - + Args: config_dict: Dictionary containing batch configuration. - + Returns: BatchConfig instance with values from the dictionary. """ return cls(**config_dict) - + def to_dict(self) -> Dict: """ Convert BatchConfig to dictionary for backward compatibility. - + Returns: Dictionary representation of the configuration. """ @@ -107,34 +111,34 @@ def to_dict(self) -> Dict: "flux_start_opts": self.flux_start_opts, "flux_exec_workers": self.flux_exec_workers, } - + def is_parallel(self) -> bool: """ Check if this configuration enables parallel execution. - + Returns: True if batch type is not 'local'. """ return self.type != "local" - - def merge(self, other: 'BatchConfig') -> 'BatchConfig': + + def merge(self, other: "BatchConfig") -> "BatchConfig": """ Merge this configuration with another, with other taking precedence. - + Args: other: BatchConfig to merge with this one. - + Returns: New BatchConfig with merged values. """ merged_dict = self.to_dict() other_dict = other.to_dict() - + # Only override non-empty/non-None values for key, value in other_dict.items(): if value not in (None, "", []): merged_dict[key] = value - + return BatchConfig.from_dict(merged_dict) @@ -142,10 +146,10 @@ def merge(self, other: 'BatchConfig') -> 'BatchConfig': class WorkerConfig: """ Configuration for Merlin workers. - + This dataclass encapsulates all worker-related configuration options including queues, machines, batch settings, and launch arguments. - + Attributes: name: Name of the worker. args: Command-line arguments for the worker process. @@ -156,6 +160,7 @@ class WorkerConfig: overlap: Whether this worker can overlap queues with other workers. env: Environment variables for the worker process. """ + name: str args: str = "" queues: Set[str] = field(default_factory=lambda: {"[merlin]_merlin"}) @@ -164,51 +169,51 @@ class WorkerConfig: nodes: Optional[Union[int, str]] = None overlap: bool = False env: Dict[str, str] = field(default_factory=dict) - + def __post_init__(self): """Validate configuration after initialization.""" if not self.name: raise ValueError("Worker name cannot be empty") - + if not isinstance(self.queues, set): if isinstance(self.queues, (list, tuple)): self.queues = set(self.queues) else: raise ValueError("queues must be a set, list, or tuple") - + if self.nodes is not None: if isinstance(self.nodes, str) and self.nodes != "all": try: self.nodes = int(self.nodes) - except ValueError: - raise ValueError(f"Invalid nodes value '{self.nodes}'. Must be an integer, 'all', or None.") - + except ValueError as exc: + raise ValueError(f"Invalid nodes value '{self.nodes}'. Must be an integer, 'all', or None.") from exc + if not self.env: self.env = os.environ.copy() - + @classmethod # def from_dict(cls, name: str, config_dict: Dict, env: Dict[str, str] = None) -> 'WorkerConfig': - def from_dict(cls, config_dict: Dict) -> 'WorkerConfig': + def from_dict(cls, config_dict: Dict) -> "WorkerConfig": """ Create a WorkerConfig from a dictionary. - + Args: config_dict: Dictionary containing worker configuration. - + Returns: WorkerConfig instance with values from the dictionary. """ # Extract batch configuration if present batch_dict = config_dict.get("batch", {}) batch_config = BatchConfig.from_dict(batch_dict) if batch_dict else BatchConfig() - + # Convert queues to set if needed queues = config_dict.get("queues", {"[merlin]_merlin"}) if isinstance(queues, (list, tuple)): queues = set(queues) elif not isinstance(queues, set): queues = {queues} if isinstance(queues, str) else {"[merlin]_merlin"} - + return cls( name=config_dict["name"], # Not using `get` since this should fail if 'name' is missing args=config_dict.get("args", ""), @@ -219,11 +224,11 @@ def from_dict(cls, config_dict: Dict) -> 'WorkerConfig': overlap=config_dict.get("overlap", False), env=config_dict.get("env", {}), ) - + def to_dict(self) -> Dict: """ Convert WorkerConfig to dictionary for backward compatibility. - + Returns: Dictionary representation of the configuration. """ @@ -237,20 +242,20 @@ def to_dict(self) -> Dict: "overlap": self.overlap, "env": self.env, } - + def get_effective_nodes(self) -> Optional[Union[int, str]]: """ Get the effective node count, preferring worker-specific over batch config. - + Returns: Node count to use, or None if not specified. """ return self.nodes if self.nodes is not None else self.batch.nodes - + def get_effective_batch_config(self) -> BatchConfig: """ Get the effective batch configuration with worker-specific overrides. - + Returns: BatchConfig with worker-specific values applied. """ @@ -260,38 +265,38 @@ def get_effective_batch_config(self) -> BatchConfig: batch_dict["nodes"] = self.nodes return BatchConfig.from_dict(batch_dict) return self.batch - + def has_machine_restrictions(self) -> bool: """ Check if this worker has machine restrictions. - + Returns: True if machines list is not empty. """ return bool(self.machines) - + def add_queue(self, queue_name: str): """ Add a queue to this worker's queue set. - + Args: queue_name: Name of the queue to add. """ self.queues.add(queue_name) - + def remove_queue(self, queue_name: str): """ Remove a queue from this worker's queue set. - + Args: queue_name: Name of the queue to remove. """ self.queues.discard(queue_name) - + def update_env(self, env_updates: Dict[str, str]): """ Update environment variables for this worker. - + Args: env_updates: Dictionary of environment variable updates. """ diff --git a/merlin/workers/celery_worker.py b/merlin/workers/celery_worker.py index d631bf95..51a1d582 100644 --- a/merlin/workers/celery_worker.py +++ b/merlin/workers/celery_worker.py @@ -113,10 +113,10 @@ def get_launch_command(self, override_args: str = "", disable_logs: bool = False # Construct the base celery command celery_cmd = f"celery -A merlin worker {self.worker_config.args} -Q {','.join(self.worker_config.queues)}" - + # Use BatchManager to create the launch command launch_cmd = self.batch_manager.create_worker_launch_command(celery_cmd) - + return os.path.expandvars(launch_cmd) def should_launch(self) -> bool: @@ -131,8 +131,9 @@ def should_launch(self) -> bool: if self.worker_config.machines: if not check_machines(self.worker_config.machines): LOG.error( - f"The following machines were provided for worker '{self.worker_config.name}': {self.worker_config.machines}. " - f"However, the current machine '{socket.gethostname()}' is not in this list." + f"The following machines were provided for worker '{self.worker_config.name}': " + f"{self.worker_config.machines}. However, the current machine '{socket.gethostname()}' " + "is not in this list." ) return False @@ -166,16 +167,20 @@ def start(self, override_args: str = "", disable_logs: bool = False): if self.should_launch(): launch_cmd = self.get_launch_command(override_args=override_args, disable_logs=disable_logs) try: - subprocess.Popen(launch_cmd, env=self.worker_config.env, shell=True, universal_newlines=True) # pylint: disable=R1732 + # We intentionally do not use "with" here because we want the worker + # to run as a detached background process and not block execution. + subprocess.Popen( # pylint: disable=consider-using-with + launch_cmd, env=self.worker_config.env, shell=True, universal_newlines=True + ) LOG.debug(f"Launched worker '{self.worker_config.name}' with command: {launch_cmd}.") - except Exception as e: # pylint: disable=C0103 + except Exception as e: # pylint: disable=broad-exception-caught LOG.error(f"Cannot start celery workers, {e}") raise MerlinWorkerLaunchError from e - + def update_batch_config(self, new_batch_config: Dict): """ Update the batch configuration for this worker. - + Args: new_batch_config: New batch configuration to apply. """ diff --git a/merlin/workers/worker.py b/merlin/workers/worker.py index 91515745..ba1c44d7 100644 --- a/merlin/workers/worker.py +++ b/merlin/workers/worker.py @@ -16,7 +16,6 @@ a consistent interface for launching and managing worker processes. """ -import os from abc import ABC, abstractmethod from typing import Dict diff --git a/tests/integration/definitions.py b/tests/integration/definitions.py index 6ebeb92c..142d83b0 100644 --- a/tests/integration/definitions.py +++ b/tests/integration/definitions.py @@ -60,7 +60,7 @@ def get_worker_by_cmd(cmd: str, default: str) -> str: if not shutil.which(cmd): workers_cmd = bogus_cmd # Use bogus flux if flux is present but slurm is the main scheduler - elif cmd == "flux" and not batch_manager._check_scheduler(cmd): + elif cmd == "flux" and not batch_manager.check_scheduler(cmd): workers_cmd = bogus_cmd return workers_cmd diff --git a/tests/unit/workers/handlers/test_celery_handler.py b/tests/unit/workers/handlers/test_celery_handler.py index 8dcdf3ba..2aab7645 100644 --- a/tests/unit/workers/handlers/test_celery_handler.py +++ b/tests/unit/workers/handlers/test_celery_handler.py @@ -8,7 +8,7 @@ Tests for the `merlin/workers/handlers/celery_handler.py` module. """ -from typing import Dict, List +from typing import List from unittest.mock import MagicMock import pytest diff --git a/tests/unit/workers/test_celery_worker.py b/tests/unit/workers/test_celery_worker.py index c94d03be..5e3fdb12 100644 --- a/tests/unit/workers/test_celery_worker.py +++ b/tests/unit/workers/test_celery_worker.py @@ -8,7 +8,6 @@ Tests for the `merlin/workers/celery_worker.py` module. """ -from typing import Any from unittest.mock import MagicMock import pytest @@ -17,7 +16,7 @@ from merlin.exceptions import MerlinWorkerLaunchError from merlin.study.configurations import WorkerConfig from merlin.workers import CeleryWorker -from tests.fixture_types import FixtureCallable, FixtureDict, FixtureStr +from tests.fixture_types import FixtureCallable, FixtureStr @pytest.fixture diff --git a/tests/unit/workers/test_worker.py b/tests/unit/workers/test_worker.py index 52c77137..e139481f 100644 --- a/tests/unit/workers/test_worker.py +++ b/tests/unit/workers/test_worker.py @@ -28,11 +28,7 @@ def test_init_sets_attributes(): """ Test that the constructor sets name, config, and env correctly. """ - worker_config = WorkerConfig( - name="test_worker", - queues={"queue1", "queue2"}, - env={"TEST_ENV": "123"} - ) + worker_config = WorkerConfig(name="test_worker", queues={"queue1", "queue2"}, env={"TEST_ENV": "123"}) worker = DummyMerlinWorker(worker_config) From fdb306b5037c6f10048c39a9e0c7b95852a1b1e2 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 15 Sep 2025 15:35:22 -0700 Subject: [PATCH 7/8] fixed broken tests --- merlin/config/broker.py | 1 - merlin/config/configfile.py | 1 + tests/unit/config/test_configfile.py | 1 + tests/unit/workers/test_celery_worker.py | 38 ++++++++++++------------ 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/merlin/config/broker.py b/merlin/config/broker.py index 864dbef8..eeddd72f 100644 --- a/merlin/config/broker.py +++ b/merlin/config/broker.py @@ -19,7 +19,6 @@ import logging import ssl from typing import Dict, List, Optional, Union -from urllib.parse import quote from merlin.config.configfile import CONFIG, get_ssl_entries from merlin.config.utils import resolve_password diff --git a/merlin/config/configfile.py b/merlin/config/configfile.py index 7ff6a239..ece97bbd 100644 --- a/merlin/config/configfile.py +++ b/merlin/config/configfile.py @@ -41,6 +41,7 @@ def set_local_mode(enable: bool = True): if enable: LOG.info("Running Merlin in local mode (no configuration file required)") + def is_local_mode() -> bool: """ Checks if Merlin is running in local mode. diff --git a/tests/unit/config/test_configfile.py b/tests/unit/config/test_configfile.py index 10bfecb8..41af703d 100644 --- a/tests/unit/config/test_configfile.py +++ b/tests/unit/config/test_configfile.py @@ -261,6 +261,7 @@ def test_find_config_file_merlin_home_app_yaml_exists(mocker: MockerFixture, con result = find_config_file() assert result == merlin_home_app_yaml + def test_find_config_file_no_app_yaml_found(mocker: MockerFixture): """ Test that `find_config_file` returns `None` when no `app.yaml` file is found in any location. diff --git a/tests/unit/workers/test_celery_worker.py b/tests/unit/workers/test_celery_worker.py index 5e3fdb12..171bde0a 100644 --- a/tests/unit/workers/test_celery_worker.py +++ b/tests/unit/workers/test_celery_worker.py @@ -51,34 +51,34 @@ def worker_config() -> WorkerConfig: @pytest.fixture -def celery_worker(worker_config: WorkerConfig) -> CeleryWorker: +def mock_db(mocker: MockerFixture) -> MagicMock: """ - Fixture that provides a CeleryWorker instance. + Fixture that patches the MerlinDatabase constructor. + + This prevents CeleryWorker from writing to the real Merlin database during + unit tests. Returns a mock instance of MerlinDatabase. Args: - worker_config: A WorkerConfig object to initialize the CeleryWorker. + mocker: Pytest mocker fixture. Returns: - A CeleryWorker object for testing. + A mocked MerlinDatabase instance. """ - return CeleryWorker(worker_config) + return mocker.patch("merlin.workers.celery_worker.MerlinDatabase", autospec=True) @pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: +def celery_worker(worker_config: WorkerConfig, mock_db: MagicMock) -> CeleryWorker: """ - Fixture that patches the MerlinDatabase constructor. - - This prevents CeleryWorker from writing to the real Merlin database during - unit tests. Returns a mock instance of MerlinDatabase. + Fixture that provides a CeleryWorker instance. Args: - mocker: Pytest mocker fixture. + worker_config: A WorkerConfig object to initialize the CeleryWorker. Returns: - A mocked MerlinDatabase instance. + A CeleryWorker object for testing. """ - return mocker.patch("merlin.workers.celery_worker.MerlinDatabase") + return CeleryWorker(worker_config) def test_constructor_sets_fields_and_calls_db_create(worker_config: WorkerConfig, mock_db: MagicMock): @@ -100,7 +100,7 @@ def test_constructor_sets_fields_and_calls_db_create(worker_config: WorkerConfig mock_db.return_value.create.assert_called_once_with("logical_worker", "test_worker", {"queue1", "queue2"}) -def test_verify_args_adds_name_and_logging_flags(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): +def test_verify_args_adds_name_and_logging_flags(mocker: MockerFixture, celery_worker: CeleryWorker): """ Test that `_verify_args()` appends required flags to the Celery args string. @@ -129,7 +129,7 @@ def test_verify_args_adds_name_and_logging_flags(mocker: MockerFixture, celery_w assert mock_logger.warning.called -def test_get_launch_command_returns_expanded_command(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): +def test_get_launch_command_returns_expanded_command(mocker: MockerFixture, celery_worker: CeleryWorker): """ Test that `get_launch_command()` constructs a valid Celery command. @@ -209,7 +209,7 @@ def test_should_launch_rejects_if_output_path_missing(mocker: MockerFixture, wor assert result is False -def test_should_launch_rejects_due_to_running_queues(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): +def test_should_launch_rejects_due_to_running_queues(mocker: MockerFixture, celery_worker: CeleryWorker): """ Test that `should_launch` returns False when a conflicting queue is already running. @@ -233,7 +233,7 @@ def test_should_launch_rejects_due_to_running_queues(mocker: MockerFixture, cele assert result is False -def test_launch_worker_runs_if_should_launch(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): +def test_launch_worker_runs_if_should_launch(mocker: MockerFixture, celery_worker: CeleryWorker): """ Test that `start` executes the launch command if `should_launch` returns True. @@ -262,7 +262,7 @@ def test_launch_worker_runs_if_should_launch(mocker: MockerFixture, celery_worke assert mock_logger.debug.called -def test_launch_worker_raises_if_popen_fails(mocker: MockerFixture, celery_worker: CeleryWorker, mock_db: MagicMock): +def test_launch_worker_raises_if_popen_fails(mocker: MockerFixture, celery_worker: CeleryWorker): """ Test that `start` raises `MerlinWorkerLaunchError` when `subprocess.Popen` fails. @@ -287,7 +287,7 @@ def test_launch_worker_raises_if_popen_fails(mocker: MockerFixture, celery_worke celery_worker.start() -def test_get_metadata_returns_expected_dict(celery_worker: CeleryWorker, mock_db: MagicMock): +def test_get_metadata_returns_expected_dict(celery_worker: CeleryWorker): """ Test that `get_metadata` returns the expected dictionary with worker configuration. From b4c7317b1e0ae420108be97116b876d3f7a783fc Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 15 Sep 2025 16:09:06 -0700 Subject: [PATCH 8/8] remove missed commented code --- tests/integration/definitions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/definitions.py b/tests/integration/definitions.py index 5d44a095..29482664 100644 --- a/tests/integration/definitions.py +++ b/tests/integration/definitions.py @@ -32,7 +32,6 @@ StepFileHasRegex, ) -# from merlin.study.batch import check_for_scheduler from merlin.study.batch import BatchManager from merlin.utils import get_flux_alloc, get_flux_cmd @@ -53,7 +52,6 @@ def get_worker_by_cmd(cmd: str, default: str) -> str: workers_cmd = default fake_cmds_path = f"tests/integration/fake_commands/{cmd}_fake_command" bogus_cmd = f"""PATH="{fake_cmds_path}:$PATH";{default}""" - # scheduler_legend = {"flux": {"check cmd": ["flux", "resource", "list"], "expected check output": b"Nodes"}} batch_manager = BatchManager() # Use bogus flux/qsub to test if no flux/qsub is present