diff --git a/CHANGELOG.md b/CHANGELOG.md index 7670113b..ca4bc670 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Removed old monitor code from v1.0 +- Moved `stop-workers` functionality to the `CeleryWorkerHandlers` class ## [2.0.0b4] diff --git a/merlin/cli/commands/stop_workers.py b/merlin/cli/commands/stop_workers.py index 3d5ef6d8..07083057 100644 --- a/merlin/cli/commands/stop_workers.py +++ b/merlin/cli/commands/stop_workers.py @@ -19,9 +19,9 @@ from merlin.ascii_art import banner_small from merlin.cli.commands.command_entry_point import CommandEntryPoint -from merlin.router import stop_workers from merlin.spec.specification import MerlinSpec from merlin.utils import verify_filepath +from merlin.workers.handlers.handler_factory import worker_handler_factory LOG = logging.getLogger("merlin") @@ -67,6 +67,12 @@ def add_parser(self, subparsers: ArgumentParser): default=None, help="regex match for specific workers to stop", ) + stop.add_argument( + "-d", + "--dry-run", + action="store_true", + help="Display which workers would be stopped without actually stopping them", + ) def process_command(self, args: Namespace): """ @@ -87,6 +93,7 @@ def process_command(self, args: Namespace): worker_names = [] # Load in the spec if one was provided via the CLI + spec = None if args.spec: spec_path = verify_filepath(args.spec) spec = MerlinSpec.load_specification(spec_path) @@ -94,6 +101,19 @@ def process_command(self, args: Namespace): for worker_name in worker_names: if "$" in worker_name: LOG.warning(f"Worker '{worker_name}' is unexpanded. Target provenance spec instead?") + LOG.debug(f"Searching for the following workers to stop based on the spec {args.spec}: {worker_names}") + + # If we have workers from --workers flag, add them to the list + if args.workers: + worker_names.extend(args.workers) - # Send stop command to router - stop_workers(args.task_server, worker_names, args.queues, args.workers) + # Get the task server from spec or CLI argument + task_server = spec.merlin["resources"]["task_server"] if spec else args.task_server + + # Create the handler and send stop command + worker_handler = worker_handler_factory.create(task_server) + worker_handler.stop_workers( + queues=args.queues, + workers=worker_names if worker_names else None, + dry_run=args.dry_run, + ) diff --git a/merlin/common/tasks.py b/merlin/common/tasks.py index 6db4dbd7..841c6182 100644 --- a/merlin/common/tasks.py +++ b/merlin/common/tasks.py @@ -41,13 +41,13 @@ RestartException, RetryException, ) -from merlin.router import stop_workers from merlin.spec.expansion import parameter_substitutions_for_cmd, parameter_substitutions_for_sample from merlin.study.dag import DAG from merlin.study.status import read_status, status_conflict_handler from merlin.study.step import Step from merlin.study.study import MerlinStudy from merlin.utils import dict_deep_merge +from merlin.workers.handlers.celery_handler import CeleryWorkerHandler retry_exceptions = ( @@ -894,12 +894,13 @@ def expand_tasks_with_samples( # pylint: disable=R0913,R0914 name="merlin:shutdown_workers", priority=get_priority(Priority.HIGH), ) -def shutdown_workers(self: Task, shutdown_queues: List[str]): # pylint: disable=W0613 +def shutdown_workers(self: Task, shutdown_queues: List[str] = None): # pylint: disable=W0613 """ Initiates the shutdown of Celery workers. - This task wraps the [`stop_celery_workers`][study.celeryadapter.stop_celery_workers] - function, allowing for the graceful shutdown of specified Celery worker queues. It is + This task wraps the [`stop_workers`][workers.handlers.celery_handler.CeleryWorkerHandler.stop_workers] + method of the [`CeleryWorkerHandler`][workers.handlers.celery_handler.CeleryWorkerHandler] + class, allowing for the graceful shutdown of specified Celery worker queues. It is acknowledged immediately upon execution, ensuring that it will not be requeued, even if executed by a worker. @@ -908,11 +909,8 @@ def shutdown_workers(self: Task, shutdown_queues: List[str]): # pylint: disable shutdown_queues: A list of specific queues to shut down. If None, all queues will be shut down. """ - if shutdown_queues is not None: - LOG.warning(f"Shutting down workers in queues {shutdown_queues}!") - else: - LOG.warning("Shutting down workers in all queues!") - return stop_workers("celery", None, shutdown_queues, None) + worker_handler = CeleryWorkerHandler() + worker_handler.stop_workers(queues=shutdown_queues) # Pylint complains that these args are unused but celery passes args diff --git a/merlin/router.py b/merlin/router.py index 2af92435..fbf21fb6 100644 --- a/merlin/router.py +++ b/merlin/router.py @@ -21,7 +21,6 @@ purge_celery_tasks, query_celery_queues, run_celery, - stop_celery_workers, ) from merlin.study.study import MerlinStudy @@ -150,24 +149,3 @@ def query_queues( else: LOG.error("Celery is not specified as the task server!") return {} - - -def stop_workers(task_server: str, spec_worker_names: List[str], queues: List[str], workers_regex: str): - """ - This function sends a command to stop workers that match the specified - criteria from the designated task server. - - Args: - task_server: The task server from which to stop workers. - spec_worker_names: A list of worker names to stop, as defined - in a specification. - queues: A list of queues from which to stop associated workers. - workers_regex: A regex pattern used to filter the workers to stop. - """ - LOG.info("Stopping workers...") - - if task_server == "celery": # pylint: disable=R1705 - # Stop workers - stop_celery_workers(queues, spec_worker_names, workers_regex) - else: - LOG.error("Celery is not specified as the task server!") diff --git a/merlin/study/celeryadapter.py b/merlin/study/celeryadapter.py index 789891e0..fb0cca63 100644 --- a/merlin/study/celeryadapter.py +++ b/merlin/study/celeryadapter.py @@ -20,7 +20,7 @@ from merlin.config import Config from merlin.spec.specification import MerlinSpec from merlin.study.study import MerlinStudy -from merlin.utils import apply_list_of_regex, get_procs, is_running +from merlin.utils import get_procs, is_running LOG = logging.getLogger(__name__) @@ -458,82 +458,3 @@ def purge_celery_tasks(queues: str, force: bool) -> int: purge_command = " ".join(["celery -A merlin purge", force_com, "-Q", queues]) LOG.debug(purge_command) return subprocess.run(purge_command, shell=True).returncode - - -def stop_celery_workers( - queues: List[str] = None, spec_worker_names: List[str] = None, worker_regex: List[str] = None -): # pylint: disable=R0912 - """ - Send a stop command to Celery workers. - - This function sends a shutdown command to Celery workers associated with - specified queues. By default, it stops all connected workers, but it can - be configured to target specific workers based on queue names or regular - expression patterns. - - Args: - queues: A list of queue names to which the stop command will be sent. - If None, all connected workers across all queues will be stopped. - spec_worker_names: A list of specific worker names to stop, in addition - to those matching the `worker_regex`. - worker_regex: A regular expression string used to match worker names. - If None, no regex filtering will be applied. - - Side Effects: - - Broadcasts a shutdown signal to Celery workers - - Example: - ```python - stop_celery_workers(queues=['hello'], worker_regex='celery@*my_machine*') - stop_celery_workers() - ``` - """ - from merlin.celery import app # pylint: disable=C0415 - - LOG.debug(f"Sending stop to queues: {queues}, worker_regex: {worker_regex}, spec_worker_names: {spec_worker_names}") - active_queues, _ = get_active_celery_queues(app) - - # If not specified, get all the queues - if queues is None: - queues = [*active_queues] - # Celery adds the queue tag in front of each queue so we add that here - else: - celerize_queues(queues) - - # Find the set of all workers attached to all of those queues - all_workers = set() - for queue in queues: - try: - all_workers.update(active_queues[queue]) - LOG.debug(f"Workers attached to queue {queue}: {active_queues[queue]}") - except KeyError: - LOG.warning(f"No workers are connected to queue {queue}") - - all_workers = list(all_workers) - - LOG.debug(f"Pre-filter worker stop list: {all_workers}") - - # Stop workers with no flags - if (spec_worker_names is None or len(spec_worker_names) == 0) and worker_regex is None: - workers_to_stop = list(all_workers) - # Flag handling - else: - workers_to_stop = [] - # --spec flag - if (spec_worker_names is not None) and len(spec_worker_names) > 0: - apply_list_of_regex(spec_worker_names, all_workers, workers_to_stop) - # --workers flag - if worker_regex is not None: - LOG.debug(f"Searching for workers to stop based on the following regex's: {worker_regex}") - apply_list_of_regex(worker_regex, all_workers, workers_to_stop) - - # Remove duplicates - workers_to_stop = list(set(workers_to_stop)) - LOG.debug(f"Post-filter worker stop list: {workers_to_stop}") - - if workers_to_stop: - LOG.info(f"Sending stop to these workers: {workers_to_stop}") - # Send the shutdown signal - app.control.broadcast("shutdown", destination=workers_to_stop) - else: - LOG.warning("No workers found to stop") diff --git a/merlin/study/manager.py b/merlin/study/manager.py index 3157b278..434d3881 100644 --- a/merlin/study/manager.py +++ b/merlin/study/manager.py @@ -21,7 +21,8 @@ from merlin.db_scripts.merlin_db import MerlinDatabase from merlin.exceptions import RunNotFoundError, StudyNotFoundError from merlin.spec.specification import MerlinSpec -from merlin.study.celeryadapter import purge_celery_tasks, stop_celery_workers +from merlin.study.celeryadapter import purge_celery_tasks +from merlin.workers.handlers.celery_handler import CeleryWorkerHandler LOG = logging.getLogger(__name__) @@ -97,12 +98,13 @@ def cancel( # Step 1: Stop the workers if stop_workers: - # TODO when we refactor `stop-workers`, update this worker_names = spec.get_worker_names() for worker_name in worker_names: if "$" in worker_name: LOG.warning(f"Worker '{worker_name}' is unexpanded. Target provenance spec instead?") - stop_celery_workers(spec_worker_names=worker_names) + + worker_handler = CeleryWorkerHandler() + worker_handler.stop_workers(workers=worker_names) # TODO when we refactor `stop-workers`, may want to do some extra validation here to ensure # all of these workers have actually been stopped diff --git a/merlin/workers/celery_worker.py b/merlin/workers/celery_worker.py index 7bea76bc..0b3fcafd 100644 --- a/merlin/workers/celery_worker.py +++ b/merlin/workers/celery_worker.py @@ -85,6 +85,7 @@ def __init__( self.batch = self.config.get("batch", {}) self.machines = self.config.get("machines", []) self.overlap = overlap + self.pid = None # Set when the worker is launched # Add this worker to the database merlin_db = MerlinDatabase() @@ -189,12 +190,32 @@ def start(self, override_args: str = "", disable_logs: bool = False): if self.should_launch(): launch_cmd = self.get_launch_command(override_args=override_args, disable_logs=disable_logs) try: - subprocess.Popen(launch_cmd, env=self.env, shell=True, universal_newlines=True) # pylint: disable=R1732 + worker_proc = subprocess.Popen( + launch_cmd, env=self.env, shell=True, universal_newlines=True + ) # pylint: disable=R1732 + self.pid = worker_proc.pid LOG.debug(f"Launched worker '{self.name}' with command: {launch_cmd}.") except Exception as e: # pylint: disable=C0103 LOG.error(f"Cannot start celery workers, {e}") raise MerlinWorkerLaunchError from e + def stop(self): + """ + Stop the worker process. + + If the worker has a known PID, sends a SIGTERM to terminate it. + Otherwise, logs a warning that the worker cannot be stopped. + """ + if self.pid: + try: + os.kill(self.pid, 15) # Send SIGTERM + LOG.debug(f"Stopped worker '{self.name}' with PID {self.pid}.") + self.pid = None # Reset PID after stopping + except Exception as e: # pylint: disable=C0103 + LOG.error(f"Cannot stop celery worker '{self.name}', {e}") + else: + LOG.warning(f"Worker '{self.name}' is not running or PID is unknown; cannot stop.") + def get_metadata(self) -> Dict: """ Return metadata about this worker instance. diff --git a/merlin/workers/handlers/celery_handler.py b/merlin/workers/handlers/celery_handler.py index aaba14cb..b5b3a902 100644 --- a/merlin/workers/handlers/celery_handler.py +++ b/merlin/workers/handlers/celery_handler.py @@ -21,6 +21,7 @@ from merlin.common.enums import WorkerStatus from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity from merlin.db_scripts.merlin_db import MerlinDatabase +from merlin.utils import apply_list_of_regex from merlin.workers import CeleryWorker from merlin.workers.formatters.formatter_factory import worker_formatter_factory from merlin.workers.handlers.worker_handler import MerlinWorkerHandler @@ -78,11 +79,6 @@ def start_workers(self, workers: List[CeleryWorker], **kwargs): LOG.debug(f"Launching worker '{worker.name}'.") worker.start(override_args=override_args, disable_logs=disable_logs) - def stop_workers(self): - """ - Attempt to stop Celery workers. - """ - def get_workers_from_app(self) -> List[str]: """ Retrieve a list of all workers connected to the Celery application. @@ -108,10 +104,8 @@ def get_active_workers(self) -> Dict[str, List[str]]: """ Retrieve a mapping of active workers to their associated queues for a Celery application. - This function serves as the inverse of - [`get_active_celery_queues()`][study.celeryadapter.get_active_celery_queues]. It constructs - a dictionary where each key is a worker's name and the corresponding value is a - list of queues that the worker is connected to. This allows for easy identification + This method constructs a dictionary where each key is a worker's name and the corresponding + value is a list of queues that the worker is connected to. This allows for easy identification of which queues are being handled by each worker. Returns: @@ -206,3 +200,115 @@ def query_workers( # Use formatter to display the results formatter = worker_formatter_factory.create(formatter) formatter.format_and_display(logical_workers, filters, self.merlin_db) + + def normalize_queue_names(self, queues: List[str]) -> List[str]: + """ + Normalize queue names to conform to Celery's naming conventions. + + Args: + queues (List[str]): List of queue names to normalize. + + Returns: + List[str]: Normalized queue names. + """ + from merlin.config.configfile import CONFIG # Importing configuration for queue tag + + return [f"{CONFIG.celery.queue_tag}{queue}" for queue in queues] + + def get_workers_from_queues(self, queues: List[str]) -> List[str]: + """ + Given a list of queue names, retrieve the Celery workers associated with those queues. + + Args: + queues (List[str]): The list of queue names to filter workers by. + + Returns: + List[str]: A list of Celery worker names associated with the specified queues. + """ + live_workers = self.get_active_workers() + return [worker for worker, live_queues in live_workers.items() if set(queues) & set(live_queues)] + + def filter_workers(self, all_workers: List[str], filters: List[str]) -> List[str]: + """ + Filter workers based on regex patterns or specific names. + + Args: + all_workers (List[str]): List of all available workers. + filters (List[str]): List of regex patterns or specific names to filter workers. + + Returns: + List[str]: Filtered list of workers. + """ + filtered_workers = [] + apply_list_of_regex(filters, all_workers, filtered_workers) + return list(set(filtered_workers)) + + def send_shutdown_signal(self, workers_to_stop: List[str]): + """ + Send a shutdown signal to the specified workers. + + Args: + workers_to_stop (List[str]): List of worker names to send the shutdown signal to. + """ + if workers_to_stop: + LOG.info(f"Sending shutdown signal to workers: {workers_to_stop}") + self.app.control.broadcast("shutdown", destination=workers_to_stop) + else: + LOG.warning("No workers found to stop.") + + def stop_workers(self, queues: List[str] = None, workers: List[str] = None, dry_run: bool = False): + """ + Stop worker processes, optionally filtered by queue or worker name. + + This method terminates active worker processes based on the provided filters. + The behavior varies by implementation: + + - If both `queues` and `workers` are None, all active workers are stopped. + - If `queues` is provided, only workers attached to those queues are stopped. + - If `workers` is provided, only workers matching those names/patterns are stopped. + - If both are provided, workers must match both criteria (intersection). + + Args: + queues: Optional list of queue names to filter workers by. + workers: Optional list of worker names or patterns to match. For Celery, + these can be logical worker names from the spec or regex patterns + matching physical worker names (e.g., "celery@worker1.*"). + dry_run: If True, just print out the names of the workers that will be stopped. + + Example: + ```python + handler = CeleryWorkerHandler() + + # Stop all workers + handler.stop_workers() + + # Stop workers on specific queues + handler.stop_workers(queues=['hello_queue', 'world_queue']) + + # Stop specific workers by name + handler.stop_workers(workers=['worker1', 'worker2']) + + # Stop workers matching both criteria + handler.stop_workers(queues=['hello_queue'], workers=['worker1.*']) + ``` + """ + LOG.debug(f"Stopping workers with queues: {queues}, workers: {workers}") + + # Step 1: Normalize queue names + if queues: + queues = self.normalize_queue_names(queues) + + # Step 2: Get workers from queues + all_workers = self.get_workers_from_queues(queues) if queues else self.get_workers_from_app() + + # Step 3: Filter workers + workers_to_stop = self.filter_workers(all_workers, workers) if workers else all_workers + + # Step 4: Send shutdown signal + if len(workers_to_stop) == 0: + LOG.warning("No workers found to stop.") + else: + if dry_run: + print(f"Would send shutdown signal to workers: {workers_to_stop}.") + else: + self.send_shutdown_signal(workers_to_stop) diff --git a/merlin/workers/handlers/worker_handler.py b/merlin/workers/handlers/worker_handler.py index 03df36ad..91799be1 100644 --- a/merlin/workers/handlers/worker_handler.py +++ b/merlin/workers/handlers/worker_handler.py @@ -56,11 +56,44 @@ def start_workers(self, workers: List[MerlinWorker], **kwargs): raise NotImplementedError("Subclasses of `MerlinWorkerHandler` must implement a `start_workers` method.") @abstractmethod - def stop_workers(self): + def stop_workers(self, queues: List[str] = None, workers: List[str] = None): """ - Stop worker processes. + Stop worker processes, optionally filtered by queue or worker name. - This method should terminate any active worker sessions that were previously launched. + This method terminates active worker processes based on the provided filters. + The behavior varies by implementation: + + - If both `queues` and `workers` are None, all active workers are stopped. + - If `queues` is provided, only workers attached to those queues are stopped. + - If `workers` is provided, only workers matching those names/patterns are stopped. + - If both are provided, workers must match both criteria (intersection). + + Args: + queues: Optional list of queue names to filter workers by. Queue names + will be normalized with the appropriate task server prefix if needed. + workers: Optional list of worker names or patterns to match. For Celery, + these can be logical worker names from the spec or regex patterns + matching physical worker names (e.g., "celery@worker1.*"). + + Example: + ```python + handler = CeleryWorkerHandler() + + # Stop all workers + handler.stop_workers() + + # Stop workers on specific queues + handler.stop_workers(queues=['hello_queue', 'world_queue']) + + # Stop specific workers by name + handler.stop_workers(workers=['worker1', 'worker2']) + + # Stop workers matching both criteria + handler.stop_workers(queues=['hello_queue'], workers=['worker1.*']) + ``` + + Raises: + May raise task-server-specific exceptions if connection fails. """ raise NotImplementedError("Subclasses of `MerlinWorkerHandler` must implement a `stop_workers` method.") diff --git a/tests/unit/cli/commands/test_stop_workers.py b/tests/unit/cli/commands/test_stop_workers.py index 970a0219..132d438f 100644 --- a/tests/unit/cli/commands/test_stop_workers.py +++ b/tests/unit/cli/commands/test_stop_workers.py @@ -33,6 +33,7 @@ def test_add_parser_sets_up_stop_workers_command(create_parser: FixtureCallable) assert args.queues == ["queue1", "queue2"] assert args.workers is None assert args.spec is None + assert args.dry_run is False def test_process_command_calls_stop_workers_no_spec(mocker: MockerFixture): @@ -43,12 +44,14 @@ def test_process_command_calls_stop_workers_no_spec(mocker: MockerFixture): mocker: PyTest mocker fixture. """ mocker.patch("merlin.cli.commands.stop_workers.banner_small", "BANNER") - mock_stop = mocker.patch("merlin.cli.commands.stop_workers.stop_workers") + mock_handler_factory = mocker.patch("merlin.cli.commands.stop_workers.worker_handler_factory.create") + mock_handler = mock_handler_factory.return_value + mock_handler.stop_workers = mocker.MagicMock() - args = Namespace(spec=None, task_server="celery", queues=["q1"], workers=["worker1"]) + args = Namespace(spec=None, task_server="celery", queues=["q1"], workers=["worker1"], dry_run=False) StopWorkersCommand().process_command(args) - mock_stop.assert_called_once_with("celery", [], ["q1"], ["worker1"]) + mock_handler.stop_workers.assert_called_once_with(queues=["q1"], workers=["worker1"], dry_run=False) def test_process_command_with_spec_and_worker_names(mocker: MockerFixture): @@ -59,7 +62,9 @@ def test_process_command_with_spec_and_worker_names(mocker: MockerFixture): mocker: PyTest mocker fixture. """ mocker.patch("merlin.cli.commands.stop_workers.banner_small", "BANNER") - mock_stop = mocker.patch("merlin.cli.commands.stop_workers.stop_workers") + mock_handler_factory = mocker.patch("merlin.cli.commands.stop_workers.worker_handler_factory.create") + mock_handler = mock_handler_factory.return_value + mock_handler.stop_workers = mocker.MagicMock() mock_verify = mocker.patch("merlin.cli.commands.stop_workers.verify_filepath", return_value="study.yaml") mock_spec = mocker.patch("merlin.cli.commands.stop_workers.MerlinSpec") @@ -70,12 +75,13 @@ def test_process_command_with_spec_and_worker_names(mocker: MockerFixture): task_server="celery", queues=None, workers=None, + dry_run=False, ) StopWorkersCommand().process_command(args) mock_verify.assert_called_once_with("study.yaml") mock_spec.load_specification.assert_called_once_with("study.yaml") - mock_stop.assert_called_once_with("celery", ["worker.alpha", "worker.beta"], None, None) + mock_handler.stop_workers.assert_called_once_with(queues=None, workers=["worker.alpha", "worker.beta"], dry_run=False) def test_process_command_logs_warning_on_unexpanded_worker(mocker: MockerFixture, caplog: CaptureFixture): @@ -89,14 +95,16 @@ def test_process_command_logs_warning_on_unexpanded_worker(mocker: MockerFixture caplog.set_level("WARNING", logger="merlin") mocker.patch("merlin.cli.commands.stop_workers.banner_small", "BANNER") - mock_stop = mocker.patch("merlin.cli.commands.stop_workers.stop_workers") + mock_handler_factory = mocker.patch("merlin.cli.commands.stop_workers.worker_handler_factory.create") + mock_handler = mock_handler_factory.return_value + mock_handler.stop_workers = mocker.MagicMock() mocker.patch("merlin.cli.commands.stop_workers.verify_filepath", return_value="spec.yaml") mock_spec = mocker.patch("merlin.cli.commands.stop_workers.MerlinSpec") mock_spec.load_specification.return_value.get_worker_names.return_value = ["worker.1", "worker.$step"] - args = Namespace(spec="spec.yaml", task_server="celery", queues=None, workers=None) + args = Namespace(spec="spec.yaml", task_server="celery", queues=None, workers=None, dry_run=False) StopWorkersCommand().process_command(args) assert any("is unexpanded" in record.message for record in caplog.records) - mock_stop.assert_called_once_with("celery", ["worker.1", "worker.$step"], None, None) + mock_handler.stop_workers.assert_called_once_with(queues=None, workers=["worker.1", "worker.$step"], dry_run=False) diff --git a/tests/unit/study/test_study_manager.py b/tests/unit/study/test_study_manager.py index d893231e..623a1415 100644 --- a/tests/unit/study/test_study_manager.py +++ b/tests/unit/study/test_study_manager.py @@ -55,15 +55,18 @@ def mock_spec(mocker: MockerFixture) -> MagicMock: @pytest.fixture def mock_stop_workers(mocker: MockerFixture) -> MagicMock: """ - Fixture that mocks the stop_celery_workers function. + Fixture that mocks the CeleryWorkerHandler class and its stop_workers method. Args: mocker: PyTest mocker fixture. Returns: - The mocked stop_celery_workers function. + A mocked CeleryWorkerHandler instance with the stop_workers method mocked. """ - return mocker.patch("merlin.study.manager.stop_celery_workers") + mock_handler = mocker.MagicMock() + mock_handler.stop_workers = mocker.MagicMock() + mocker.patch("merlin.study.manager.CeleryWorkerHandler", return_value=mock_handler) + return mock_handler.stop_workers @pytest.fixture @@ -129,7 +132,7 @@ def test_cancel_full_cancellation_with_defaults( result = manager.cancel(mock_spec) # Verify workers were stopped - mock_stop_workers.assert_called_once_with(spec_worker_names=["worker1", "worker2"]) + mock_stop_workers.assert_called_once_with(workers=["worker1", "worker2"]) # Verify queues were purged mock_purge_tasks.assert_called_once_with("queue1,queue2,queue3", True) @@ -451,7 +454,7 @@ def test_cancel_with_unexpanded_worker_names( assert "Target provenance spec instead?" in caplog.text # Verify workers were still stopped (including unexpanded one) - mock_stop_workers.assert_called_once_with(spec_worker_names=["worker1", "$(UNEXPANDED_WORKER)", "worker2"]) + mock_stop_workers.assert_called_once_with(workers=["worker1", "$(UNEXPANDED_WORKER)", "worker2"]) def test_cancel_queue_formatting( self, @@ -519,7 +522,7 @@ def test_cancel_with_empty_worker_list( result = manager.cancel(mock_spec) # Verify stop_workers was still called with empty list - mock_stop_workers.assert_called_once_with(spec_worker_names=[]) + mock_stop_workers.assert_called_once_with(workers=[]) # Verify result reflects empty worker list assert result["workers_stopped"] == [] diff --git a/tests/unit/workers/handlers/test_celery_handler.py b/tests/unit/workers/handlers/test_celery_handler.py index b73398c6..2daf2640 100644 --- a/tests/unit/workers/handlers/test_celery_handler.py +++ b/tests/unit/workers/handlers/test_celery_handler.py @@ -634,3 +634,267 @@ def test_get_workers_from_app_preserves_worker_names(self, handler: CeleryWorker assert "celery@worker2" in result assert "worker3@localhost" in result assert len(result) == 3 + + def test_normalize_queue_names_with_valid_queues(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `normalize_queue_names` correctly normalizes valid queue names. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_config = mocker.patch("merlin.config.configfile.CONFIG") + mock_config.celery.queue_tag = "[merlin]_" + queues = ["queue1", "queue2"] + + result = handler.normalize_queue_names(queues) + + assert result == ["[merlin]_queue1", "[merlin]_queue2"] + + def test_normalize_queue_names_with_empty_list(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `normalize_queue_names` handles an empty list of queues. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_config = mocker.patch("merlin.config.configfile.CONFIG") + mock_config.celery.queue_tag = "[merlin]_" + queues = [] + + result = handler.normalize_queue_names(queues) + + assert result == [] + + def test_normalize_queue_names_with_special_characters(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `normalize_queue_names` handles queue names with special characters. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_config = mocker.patch("merlin.config.configfile.CONFIG") + mock_config.celery.queue_tag = "[merlin]_" + queues = ["queue@1", "queue#2"] + + result = handler.normalize_queue_names(queues) + + assert result == ["[merlin]_queue@1", "[merlin]_queue#2"] + + def test_get_workers_from_queues_with_matching_queues(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `get_workers_from_queues` retrieves workers associated with specified queues. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mocker.patch.object( + handler, + "get_active_workers", + return_value={ + "worker1": ["queue1", "queue2"], + "worker2": ["queue2", "queue3"], + }, + ) + queues = ["queue1", "queue3"] + + result = handler.get_workers_from_queues(queues) + + assert result == ["worker1", "worker2"] + + def test_get_workers_from_queues_with_no_matching_queues(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `get_workers_from_queues` returns an empty list when no queues match. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mocker.patch.object( + handler, + "get_active_workers", + return_value={ + "worker1": ["queue1", "queue2"], + "worker2": ["queue2", "queue3"], + }, + ) + queues = ["queue4"] + + result = handler.get_workers_from_queues(queues) + + assert result == [] + + def test_get_workers_from_queues_with_empty_queues(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `get_workers_from_queues` returns an empty list when the queues list is empty. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mocker.patch.object( + handler, + "get_active_workers", + return_value={ + "worker1": ["queue1", "queue2"], + "worker2": ["queue2", "queue3"], + }, + ) + queues = [] + + result = handler.get_workers_from_queues(queues) + + assert result == [] + + def test_filter_workers_with_matching_filters(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `filter_workers` filters workers based on matching filters. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_apply_list_of_regex = mocker.patch("merlin.workers.handlers.celery_handler.apply_list_of_regex") + all_workers = ["worker1", "worker2", "worker3"] + filters = ["worker1", "worker3"] + + handler.filter_workers(all_workers, filters) + + mock_apply_list_of_regex.assert_called_once_with(filters, all_workers, []) + + def test_filter_workers_with_no_matching_filters(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `filter_workers` returns an empty list when no filters match. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_apply_list_of_regex = mocker.patch("merlin.workers.handlers.celery_handler.apply_list_of_regex") + all_workers = ["worker1", "worker2", "worker3"] + filters = ["worker4"] + + handler.filter_workers(all_workers, filters) + + mock_apply_list_of_regex.assert_called_once_with(filters, all_workers, []) + + def test_filter_workers_with_empty_filters(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `filter_workers` returns all workers when filters are empty. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_apply_list_of_regex = mocker.patch("merlin.workers.handlers.celery_handler.apply_list_of_regex") + all_workers = ["worker1", "worker2", "worker3"] + filters = [] + + handler.filter_workers(all_workers, filters) + + mock_apply_list_of_regex.assert_called_once_with(filters, all_workers, []) + + def test_send_shutdown_signal_with_workers(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `send_shutdown_signal` sends a shutdown signal to specified workers. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_broadcast = mocker.patch.object(handler.app.control, "broadcast") + workers_to_stop = ["worker1", "worker2"] + + handler.send_shutdown_signal(workers_to_stop) + + mock_broadcast.assert_called_once_with("shutdown", destination=workers_to_stop) + + def test_send_shutdown_signal_with_no_workers(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `send_shutdown_signal` logs a warning when no workers are provided. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_broadcast = mocker.patch.object(handler.app.control, "broadcast") + mock_logger = mocker.patch("merlin.workers.handlers.celery_handler.LOG") + workers_to_stop = [] + + handler.send_shutdown_signal(workers_to_stop) + + mock_broadcast.assert_not_called() + mock_logger.warning.assert_called_once_with("No workers found to stop.") + + def test_stop_workers_with_matching_queues_and_workers(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `stop_workers` stops workers matching both queues and worker names. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_normalize_queue_names = mocker.patch.object(handler, "normalize_queue_names", return_value=["[merlin]_queue1"]) + mock_get_workers_from_queues = mocker.patch.object( + handler, "get_workers_from_queues", return_value=["worker1", "worker2"] + ) + mock_filter_workers = mocker.patch.object(handler, "filter_workers", return_value=["worker1"]) + mock_send_shutdown_signal = mocker.patch.object(handler, "send_shutdown_signal") + + handler.stop_workers(queues=["queue1"], workers=["worker1"], dry_run=False) + + mock_normalize_queue_names.assert_called_once_with(["queue1"]) + mock_get_workers_from_queues.assert_called_once_with(["[merlin]_queue1"]) + mock_filter_workers.assert_called_once_with(["worker1", "worker2"], ["worker1"]) + mock_send_shutdown_signal.assert_called_once_with(["worker1"]) + + def test_stop_workers_with_dry_run( + self, handler: CeleryWorkerHandler, mocker: MockerFixture, capsys: pytest.CaptureFixture + ): + """ + Test that `stop_workers` performs a dry run and prints the workers to be stopped. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + capsys: Pytest system output capture fixture. + """ + mock_normalize_queue_names = mocker.patch.object(handler, "normalize_queue_names", return_value=["[merlin]_queue1"]) + mock_get_workers_from_queues = mocker.patch.object( + handler, "get_workers_from_queues", return_value=["worker1", "worker2"] + ) + mock_filter_workers = mocker.patch.object(handler, "filter_workers", return_value=["worker1"]) + mock_send_shutdown_signal = mocker.patch.object(handler, "send_shutdown_signal") + + handler.stop_workers(queues=["queue1"], workers=["worker1"], dry_run=True) + + mock_normalize_queue_names.assert_called_once_with(["queue1"]) + mock_get_workers_from_queues.assert_called_once_with(["[merlin]_queue1"]) + mock_filter_workers.assert_called_once_with(["worker1", "worker2"], ["worker1"]) + mock_send_shutdown_signal.assert_not_called() + + captured = capsys.readouterr() + assert "Would send shutdown signal to workers: ['worker1']." in captured.out + + def test_stop_workers_with_no_workers_found(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `stop_workers` logs a warning when no workers are found to stop. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + mock_normalize_queue_names = mocker.patch.object(handler, "normalize_queue_names", return_value=["[merlin]_queue1"]) + mock_get_workers_from_queues = mocker.patch.object(handler, "get_workers_from_queues", return_value=[]) + mock_filter_workers = mocker.patch.object(handler, "filter_workers", return_value=[]) + mock_logger = mocker.patch("merlin.workers.handlers.celery_handler.LOG") + + handler.stop_workers(queues=["queue1"], workers=["worker1"], dry_run=False) + + mock_normalize_queue_names.assert_called_once_with(["queue1"]) + mock_get_workers_from_queues.assert_called_once_with(["[merlin]_queue1"]) + mock_filter_workers.assert_called_once_with([], ["worker1"]) + mock_logger.warning.assert_called_once_with("No workers found to stop.") diff --git a/tests/unit/workers/test_celery_worker.py b/tests/unit/workers/test_celery_worker.py index 00918139..3f49c2ce 100644 --- a/tests/unit/workers/test_celery_worker.py +++ b/tests/unit/workers/test_celery_worker.py @@ -81,6 +81,80 @@ def mock_db(mocker: MockerFixture) -> MagicMock: return mocker.patch("merlin.workers.celery_worker.MerlinDatabase") +def test_stop_worker_with_valid_pid( + mocker: MockerFixture, + basic_config: FixtureDict[str, Any], + dummy_env: FixtureDict[str, str], + mock_db: MagicMock, +): + """ + Test that `stop` successfully terminates a worker with a valid PID. + + Args: + mocker: Pytest mocker fixture. + basic_config: Basic configuration dictionary fixture. + dummy_env: Dummy environment dictionary fixture. + mock_db: Mocked MerlinDatabase object. + """ + mock_kill = mocker.patch("os.kill") + worker = CeleryWorker("worker1", basic_config, dummy_env) + worker.pid = 12345 + + worker.stop() + + mock_kill.assert_called_once_with(12345, 15) + assert worker.pid is None + + +def test_stop_worker_handles_exception( + mocker: MockerFixture, + basic_config: FixtureDict[str, Any], + dummy_env: FixtureDict[str, str], + mock_db: MagicMock, +): + """ + Test that `stop` logs an error if `os.kill` raises an exception. + + Args: + mocker: Pytest mocker fixture. + basic_config: Basic configuration dictionary fixture. + dummy_env: Dummy environment dictionary fixture. + mock_db: Mocked MerlinDatabase object. + """ + mock_kill = mocker.patch("os.kill", side_effect=OSError("Failed to stop process")) + mock_logger = mocker.patch("merlin.workers.celery_worker.LOG") + worker = CeleryWorker("worker2", basic_config, dummy_env) + worker.pid = 12345 + + worker.stop() + + mock_kill.assert_called_once_with(12345, 15) + mock_logger.error.assert_called_once_with("Cannot stop celery worker 'worker2', Failed to stop process") + + +def test_stop_worker_without_pid( + mocker: MockerFixture, + basic_config: FixtureDict[str, Any], + dummy_env: FixtureDict[str, str], + mock_db: MagicMock, +): + """ + Test that `stop` logs a warning if the worker has no PID. + + Args: + mocker: Pytest mocker fixture. + basic_config: Basic configuration dictionary fixture. + dummy_env: Dummy environment dictionary fixture. + mock_db: Mocked MerlinDatabase object. + """ + mock_logger = mocker.patch("merlin.workers.celery_worker.LOG") + worker = CeleryWorker("worker3", basic_config, dummy_env) + + worker.stop() + + mock_logger.warning.assert_called_once_with("Worker 'worker3' is not running or PID is unknown; cannot stop.") + + def test_constructor_sets_fields_and_calls_db_create( basic_config: FixtureDict[str, Any], dummy_env: FixtureDict[str, str],