From 8bc8cb60e5ea97eb755dc8757da713c855a2b018 Mon Sep 17 00:00:00 2001 From: Gustavo Lima Date: Sun, 1 Mar 2026 12:33:50 +0100 Subject: [PATCH 1/2] Add HealthCheckRuntime context manager for shared health check boilerplate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract the ~30 lines of repeated setup code (logger init, GPU node ID detection, derived cluster resolution, TelemetryContext + OutputContext nesting, killswitch check) into a reusable HealthCheckRuntime dataclass context manager. This reduces per-subcommand boilerplate from ~30 lines to ~5 lines. The helper is purely additive — existing checks continue to work unchanged. New checks can use `with HealthCheckRuntime(...) as rt:` instead of manually wiring up the setup ceremony. Includes comprehensive tests covering field initialization, killswitch behavior, context manager nesting, GPU node ID failure handling, and the finish() convenience method. Refs: #75 --- gcm/health_checks/check_utils/runtime.py | 117 ++++++++++++ gcm/tests/health_checks_tests/test_runtime.py | 174 ++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 gcm/health_checks/check_utils/runtime.py create mode 100644 gcm/tests/health_checks_tests/test_runtime.py diff --git a/gcm/health_checks/check_utils/runtime.py b/gcm/health_checks/check_utils/runtime.py new file mode 100644 index 0000000..abdb312 --- /dev/null +++ b/gcm/health_checks/check_utils/runtime.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import logging +import socket +import sys +import types +from collections.abc import Collection +from contextlib import ExitStack +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, ContextManager, Literal, NoReturn, Optional, Type + +import gni_lib +from gcm.health_checks.check_utils.output_context_manager import OutputContext +from gcm.health_checks.check_utils.telem import TelemetryContext +from gcm.health_checks.types import CHECK_TYPE, ExitCode, LOG_LEVEL +from gcm.monitoring.slurm.derived_cluster import get_derived_cluster +from gcm.monitoring.utils.monitor import init_logger +from gcm.schemas.health_check.health_check_name import HealthCheckName + + +@dataclass +class HealthCheckRuntime(ContextManager["HealthCheckRuntime"]): + cluster: str + type: CHECK_TYPE + log_level: LOG_LEVEL + log_folder: str + sink: str + sink_opts: Collection[str] + verbose_out: bool + heterogeneous_cluster_v1: bool + health_check_name: HealthCheckName + killswitch_getter: Callable[[], bool] + + logger: logging.Logger = field(init=False) + node: str = field(init=False) + gpu_node_id: Optional[str] = field(init=False) + derived_cluster: str = field(init=False) + exit_code: ExitCode = field(init=False, default=ExitCode.UNKNOWN) + msg: str = field(init=False, default="") + _stack: ExitStack = field(init=False) + + def __enter__(self) -> "HealthCheckRuntime": + self.node = socket.gethostname() + self.logger, _ = init_logger( + logger_name=self.type, + log_dir=str(Path(self.log_folder) / self.type / "_logs"), + log_name=self.node + ".log", + log_level=getattr(logging, self.log_level), + ) + self.logger.info( + "%s: cluster: %s, node: %s, type: %s", + self.health_check_name.value, + self.cluster, + self.node, + self.type, + ) + try: + self.gpu_node_id = gni_lib.get_gpu_node_id() + except Exception as e: + self.gpu_node_id = None + self.logger.warning(f"Could not get gpu_node_id, likely not a GPU host: {e}") + + self.derived_cluster = get_derived_cluster( + cluster=self.cluster, + heterogeneous_cluster_v1=self.heterogeneous_cluster_v1, + data={"Node": self.node}, + ) + + self._stack = ExitStack() + self._stack.__enter__() + self._stack.enter_context( + TelemetryContext( + sink=self.sink, + sink_opts=self.sink_opts, + logger=self.logger, + cluster=self.cluster, + derived_cluster=self.derived_cluster, + type=self.type, + name=self.health_check_name.value, + node=self.node, + get_exit_code_msg=lambda: (self.exit_code, self.msg), + gpu_node_id=self.gpu_node_id, + ) + ) + self._stack.enter_context( + OutputContext( + self.type, + self.health_check_name, + lambda: (self.exit_code, self.msg), + self.verbose_out, + ) + ) + + if self.killswitch_getter(): + self.exit_code = ExitCode.OK + self.logger.info( + "%s is disabled by killswitch", + self.health_check_name.value, + ) + sys.exit(0) + + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> Literal[False]: + self._stack.__exit__(exc_type, exc_val, exc_tb) + return False + + def finish(self, exit_code: ExitCode, msg: str) -> NoReturn: + self.exit_code = exit_code + self.msg = msg + sys.exit(exit_code.value) diff --git a/gcm/tests/health_checks_tests/test_runtime.py b/gcm/tests/health_checks_tests/test_runtime.py new file mode 100644 index 0000000..bf5d4af --- /dev/null +++ b/gcm/tests/health_checks_tests/test_runtime.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +"""Tests for the HealthCheckRuntime context manager.""" + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from gcm.health_checks.check_utils.runtime import HealthCheckRuntime +from gcm.health_checks.types import ExitCode +from gcm.schemas.health_check.health_check_name import HealthCheckName + + +def _make_runtime(**kwargs) -> HealthCheckRuntime: + defaults = dict( + cluster="test_cluster", + type="prolog", + log_level="INFO", + log_folder="/tmp", + sink="do_nothing", + sink_opts=(), + verbose_out=False, + heterogeneous_cluster_v1=False, + health_check_name=HealthCheckName.CHECK_SENSORS, + killswitch_getter=lambda: False, + ) + defaults.update(kwargs) + return HealthCheckRuntime(**defaults) + + +@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="derived_test") +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_enter_initializes_fields( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """Verify __enter__ populates logger, node, gpu_node_id, and derived_cluster.""" + mock_socket.gethostname.return_value = "testnode01" + test_logger = logging.getLogger("test") + mock_init_logger.return_value = (test_logger, MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + rt = _make_runtime() + with rt as runtime: + assert runtime.node == "testnode01" + assert runtime.logger is test_logger + assert runtime.gpu_node_id == "gpu-0" + assert runtime.derived_cluster == "derived_test" + + +@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_killswitch_enabled_exits_ok( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """When killswitch_getter returns True, sys.exit should be called with 0.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + with pytest.raises(SystemExit) as exc_info: + with _make_runtime(killswitch_getter=lambda: True): + pytest.fail("With block body should not be reached when killswitch is enabled") + + assert exc_info.value.code == ExitCode.OK.value + + +@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_killswitch_disabled_continues( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """When killswitch_getter returns False, the with block body should execute normally.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + body_executed = False + with _make_runtime(killswitch_getter=lambda: False) as rt: + body_executed = True + rt.exit_code = ExitCode.OK + rt.msg = "all good" + + assert body_executed + + +@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_finish_sets_code_and_exits( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """finish() should set exit_code and msg, then call sys.exit with the code value.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + with pytest.raises(SystemExit) as exc_info: + with _make_runtime() as rt: + rt.finish(ExitCode.CRITICAL, "something broke") + + assert exc_info.value.code == ExitCode.CRITICAL.value + assert rt.exit_code == ExitCode.CRITICAL + assert rt.msg == "something broke" + + +@patch("gcm.health_checks.check_utils.runtime.OutputContext") +@patch("gcm.health_checks.check_utils.runtime.TelemetryContext") +@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_telemetry_and_output_contexts_entered( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, + mock_telemetry_cls: MagicMock, + mock_output_cls: MagicMock, +) -> None: + """Both TelemetryContext and OutputContext should be entered during __enter__.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + mock_telem_instance = MagicMock() + mock_telemetry_cls.return_value = mock_telem_instance + mock_output_instance = MagicMock() + mock_output_cls.return_value = mock_output_instance + + with _make_runtime() as rt: + rt.exit_code = ExitCode.OK + + mock_telem_instance.__enter__.assert_called_once() + mock_output_instance.__enter__.assert_called_once() + + +@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_gpu_node_id_failure_handled( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """When gni_lib.get_gpu_node_id raises, gpu_node_id should be None and a warning logged.""" + mock_socket.gethostname.return_value = "testnode01" + test_logger = logging.getLogger("test_gpu_failure") + mock_init_logger.return_value = (test_logger, MagicMock()) + mock_gni.get_gpu_node_id.side_effect = RuntimeError("not a GPU host") + + with _make_runtime() as rt: + assert rt.gpu_node_id is None + rt.exit_code = ExitCode.OK From a076093fed0c4fb99bf7694fee90e155b2f22911 Mon Sep 17 00:00:00 2001 From: Gustavo Lima Date: Sun, 1 Mar 2026 12:51:58 +0100 Subject: [PATCH 2/2] Fix formatting (ufmt) and type annotations (mypy) Apply ufmt formatting and fix mypy errors in test helper by using explicit typed parameters instead of **kwargs dict unpacking. --- gcm/health_checks/check_utils/runtime.py | 4 +- gcm/tests/health_checks_tests/test_runtime.py | 45 ++++++++++++++----- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/gcm/health_checks/check_utils/runtime.py b/gcm/health_checks/check_utils/runtime.py index abdb312..96cd3fa 100644 --- a/gcm/health_checks/check_utils/runtime.py +++ b/gcm/health_checks/check_utils/runtime.py @@ -59,7 +59,9 @@ def __enter__(self) -> "HealthCheckRuntime": self.gpu_node_id = gni_lib.get_gpu_node_id() except Exception as e: self.gpu_node_id = None - self.logger.warning(f"Could not get gpu_node_id, likely not a GPU host: {e}") + self.logger.warning( + f"Could not get gpu_node_id, likely not a GPU host: {e}" + ) self.derived_cluster = get_derived_cluster( cluster=self.cluster, diff --git a/gcm/tests/health_checks_tests/test_runtime.py b/gcm/tests/health_checks_tests/test_runtime.py index bf5d4af..f665d90 100644 --- a/gcm/tests/health_checks_tests/test_runtime.py +++ b/gcm/tests/health_checks_tests/test_runtime.py @@ -3,6 +3,7 @@ """Tests for the HealthCheckRuntime context manager.""" import logging +from typing import Callable from unittest.mock import MagicMock, patch import pytest @@ -11,8 +12,10 @@ from gcm.schemas.health_check.health_check_name import HealthCheckName -def _make_runtime(**kwargs) -> HealthCheckRuntime: - defaults = dict( +def _make_runtime( + killswitch_getter: Callable[[], bool] = lambda: False, +) -> HealthCheckRuntime: + return HealthCheckRuntime( cluster="test_cluster", type="prolog", log_level="INFO", @@ -22,13 +25,14 @@ def _make_runtime(**kwargs) -> HealthCheckRuntime: verbose_out=False, heterogeneous_cluster_v1=False, health_check_name=HealthCheckName.CHECK_SENSORS, - killswitch_getter=lambda: False, + killswitch_getter=killswitch_getter, ) - defaults.update(kwargs) - return HealthCheckRuntime(**defaults) -@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="derived_test") +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="derived_test", +) @patch("gcm.health_checks.check_utils.runtime.gni_lib") @patch("gcm.health_checks.check_utils.runtime.init_logger") @patch("gcm.health_checks.check_utils.runtime.socket") @@ -52,7 +56,10 @@ def test_enter_initializes_fields( assert runtime.derived_cluster == "derived_test" -@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) @patch("gcm.health_checks.check_utils.runtime.gni_lib") @patch("gcm.health_checks.check_utils.runtime.init_logger") @patch("gcm.health_checks.check_utils.runtime.socket") @@ -69,12 +76,17 @@ def test_killswitch_enabled_exits_ok( with pytest.raises(SystemExit) as exc_info: with _make_runtime(killswitch_getter=lambda: True): - pytest.fail("With block body should not be reached when killswitch is enabled") + pytest.fail( + "With block body should not be reached when killswitch is enabled" + ) assert exc_info.value.code == ExitCode.OK.value -@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) @patch("gcm.health_checks.check_utils.runtime.gni_lib") @patch("gcm.health_checks.check_utils.runtime.init_logger") @patch("gcm.health_checks.check_utils.runtime.socket") @@ -98,7 +110,10 @@ def test_killswitch_disabled_continues( assert body_executed -@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) @patch("gcm.health_checks.check_utils.runtime.gni_lib") @patch("gcm.health_checks.check_utils.runtime.init_logger") @patch("gcm.health_checks.check_utils.runtime.socket") @@ -124,7 +139,10 @@ def test_finish_sets_code_and_exits( @patch("gcm.health_checks.check_utils.runtime.OutputContext") @patch("gcm.health_checks.check_utils.runtime.TelemetryContext") -@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) @patch("gcm.health_checks.check_utils.runtime.gni_lib") @patch("gcm.health_checks.check_utils.runtime.init_logger") @patch("gcm.health_checks.check_utils.runtime.socket") @@ -153,7 +171,10 @@ def test_telemetry_and_output_contexts_entered( mock_output_instance.__enter__.assert_called_once() -@patch("gcm.health_checks.check_utils.runtime.get_derived_cluster", return_value="test_cluster") +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) @patch("gcm.health_checks.check_utils.runtime.gni_lib") @patch("gcm.health_checks.check_utils.runtime.init_logger") @patch("gcm.health_checks.check_utils.runtime.socket")