diff --git a/gcm/health_checks/check_utils/runtime.py b/gcm/health_checks/check_utils/runtime.py new file mode 100644 index 0000000..96cd3fa --- /dev/null +++ b/gcm/health_checks/check_utils/runtime.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import logging +import socket +import sys +import types +from collections.abc import Collection +from contextlib import ExitStack +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, ContextManager, Literal, NoReturn, Optional, Type + +import gni_lib +from gcm.health_checks.check_utils.output_context_manager import OutputContext +from gcm.health_checks.check_utils.telem import TelemetryContext +from gcm.health_checks.types import CHECK_TYPE, ExitCode, LOG_LEVEL +from gcm.monitoring.slurm.derived_cluster import get_derived_cluster +from gcm.monitoring.utils.monitor import init_logger +from gcm.schemas.health_check.health_check_name import HealthCheckName + + +@dataclass +class HealthCheckRuntime(ContextManager["HealthCheckRuntime"]): + cluster: str + type: CHECK_TYPE + log_level: LOG_LEVEL + log_folder: str + sink: str + sink_opts: Collection[str] + verbose_out: bool + heterogeneous_cluster_v1: bool + health_check_name: HealthCheckName + killswitch_getter: Callable[[], bool] + + logger: logging.Logger = field(init=False) + node: str = field(init=False) + gpu_node_id: Optional[str] = field(init=False) + derived_cluster: str = field(init=False) + exit_code: ExitCode = field(init=False, default=ExitCode.UNKNOWN) + msg: str = field(init=False, default="") + _stack: ExitStack = field(init=False) + + def __enter__(self) -> "HealthCheckRuntime": + self.node = socket.gethostname() + self.logger, _ = init_logger( + logger_name=self.type, + log_dir=str(Path(self.log_folder) / self.type / "_logs"), + log_name=self.node + ".log", + log_level=getattr(logging, self.log_level), + ) + self.logger.info( + "%s: cluster: %s, node: %s, type: %s", + self.health_check_name.value, + self.cluster, + self.node, + self.type, + ) + try: + self.gpu_node_id = gni_lib.get_gpu_node_id() + except Exception as e: + self.gpu_node_id = None + self.logger.warning( + f"Could not get gpu_node_id, likely not a GPU host: {e}" + ) + + self.derived_cluster = get_derived_cluster( + cluster=self.cluster, + heterogeneous_cluster_v1=self.heterogeneous_cluster_v1, + data={"Node": self.node}, + ) + + self._stack = ExitStack() + self._stack.__enter__() + self._stack.enter_context( + TelemetryContext( + sink=self.sink, + sink_opts=self.sink_opts, + logger=self.logger, + cluster=self.cluster, + derived_cluster=self.derived_cluster, + type=self.type, + name=self.health_check_name.value, + node=self.node, + get_exit_code_msg=lambda: (self.exit_code, self.msg), + gpu_node_id=self.gpu_node_id, + ) + ) + self._stack.enter_context( + OutputContext( + self.type, + self.health_check_name, + lambda: (self.exit_code, self.msg), + self.verbose_out, + ) + ) + + if self.killswitch_getter(): + self.exit_code = ExitCode.OK + self.logger.info( + "%s is disabled by killswitch", + self.health_check_name.value, + ) + sys.exit(0) + + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> Literal[False]: + self._stack.__exit__(exc_type, exc_val, exc_tb) + return False + + def finish(self, exit_code: ExitCode, msg: str) -> NoReturn: + self.exit_code = exit_code + self.msg = msg + sys.exit(exit_code.value) diff --git a/gcm/tests/health_checks_tests/test_runtime.py b/gcm/tests/health_checks_tests/test_runtime.py new file mode 100644 index 0000000..f665d90 --- /dev/null +++ b/gcm/tests/health_checks_tests/test_runtime.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +"""Tests for the HealthCheckRuntime context manager.""" + +import logging +from typing import Callable +from unittest.mock import MagicMock, patch + +import pytest +from gcm.health_checks.check_utils.runtime import HealthCheckRuntime +from gcm.health_checks.types import ExitCode +from gcm.schemas.health_check.health_check_name import HealthCheckName + + +def _make_runtime( + killswitch_getter: Callable[[], bool] = lambda: False, +) -> HealthCheckRuntime: + return HealthCheckRuntime( + cluster="test_cluster", + type="prolog", + log_level="INFO", + log_folder="/tmp", + sink="do_nothing", + sink_opts=(), + verbose_out=False, + heterogeneous_cluster_v1=False, + health_check_name=HealthCheckName.CHECK_SENSORS, + killswitch_getter=killswitch_getter, + ) + + +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="derived_test", +) +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_enter_initializes_fields( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """Verify __enter__ populates logger, node, gpu_node_id, and derived_cluster.""" + mock_socket.gethostname.return_value = "testnode01" + test_logger = logging.getLogger("test") + mock_init_logger.return_value = (test_logger, MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + rt = _make_runtime() + with rt as runtime: + assert runtime.node == "testnode01" + assert runtime.logger is test_logger + assert runtime.gpu_node_id == "gpu-0" + assert runtime.derived_cluster == "derived_test" + + +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_killswitch_enabled_exits_ok( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """When killswitch_getter returns True, sys.exit should be called with 0.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + with pytest.raises(SystemExit) as exc_info: + with _make_runtime(killswitch_getter=lambda: True): + pytest.fail( + "With block body should not be reached when killswitch is enabled" + ) + + assert exc_info.value.code == ExitCode.OK.value + + +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_killswitch_disabled_continues( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """When killswitch_getter returns False, the with block body should execute normally.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + body_executed = False + with _make_runtime(killswitch_getter=lambda: False) as rt: + body_executed = True + rt.exit_code = ExitCode.OK + rt.msg = "all good" + + assert body_executed + + +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_finish_sets_code_and_exits( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """finish() should set exit_code and msg, then call sys.exit with the code value.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + with pytest.raises(SystemExit) as exc_info: + with _make_runtime() as rt: + rt.finish(ExitCode.CRITICAL, "something broke") + + assert exc_info.value.code == ExitCode.CRITICAL.value + assert rt.exit_code == ExitCode.CRITICAL + assert rt.msg == "something broke" + + +@patch("gcm.health_checks.check_utils.runtime.OutputContext") +@patch("gcm.health_checks.check_utils.runtime.TelemetryContext") +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_telemetry_and_output_contexts_entered( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, + mock_telemetry_cls: MagicMock, + mock_output_cls: MagicMock, +) -> None: + """Both TelemetryContext and OutputContext should be entered during __enter__.""" + mock_socket.gethostname.return_value = "testnode01" + mock_init_logger.return_value = (logging.getLogger("test"), MagicMock()) + mock_gni.get_gpu_node_id.return_value = "gpu-0" + + mock_telem_instance = MagicMock() + mock_telemetry_cls.return_value = mock_telem_instance + mock_output_instance = MagicMock() + mock_output_cls.return_value = mock_output_instance + + with _make_runtime() as rt: + rt.exit_code = ExitCode.OK + + mock_telem_instance.__enter__.assert_called_once() + mock_output_instance.__enter__.assert_called_once() + + +@patch( + "gcm.health_checks.check_utils.runtime.get_derived_cluster", + return_value="test_cluster", +) +@patch("gcm.health_checks.check_utils.runtime.gni_lib") +@patch("gcm.health_checks.check_utils.runtime.init_logger") +@patch("gcm.health_checks.check_utils.runtime.socket") +def test_gpu_node_id_failure_handled( + mock_socket: MagicMock, + mock_init_logger: MagicMock, + mock_gni: MagicMock, + mock_derived: MagicMock, +) -> None: + """When gni_lib.get_gpu_node_id raises, gpu_node_id should be None and a warning logged.""" + mock_socket.gethostname.return_value = "testnode01" + test_logger = logging.getLogger("test_gpu_failure") + mock_init_logger.return_value = (test_logger, MagicMock()) + mock_gni.get_gpu_node_id.side_effect = RuntimeError("not a GPU host") + + with _make_runtime() as rt: + assert rt.gpu_node_id is None + rt.exit_code = ExitCode.OK