Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions gcm/health_checks/check_utils/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import logging
import socket
import sys
import types
from collections.abc import Collection
from contextlib import ExitStack
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, ContextManager, Literal, NoReturn, Optional, Type

import gni_lib
from gcm.health_checks.check_utils.output_context_manager import OutputContext
from gcm.health_checks.check_utils.telem import TelemetryContext
from gcm.health_checks.types import CHECK_TYPE, ExitCode, LOG_LEVEL
from gcm.monitoring.slurm.derived_cluster import get_derived_cluster
from gcm.monitoring.utils.monitor import init_logger
from gcm.schemas.health_check.health_check_name import HealthCheckName


@dataclass
class HealthCheckRuntime(ContextManager["HealthCheckRuntime"]):
cluster: str
type: CHECK_TYPE
log_level: LOG_LEVEL
log_folder: str
sink: str
sink_opts: Collection[str]
verbose_out: bool
heterogeneous_cluster_v1: bool
health_check_name: HealthCheckName
killswitch_getter: Callable[[], bool]

logger: logging.Logger = field(init=False)
node: str = field(init=False)
gpu_node_id: Optional[str] = field(init=False)
derived_cluster: str = field(init=False)
exit_code: ExitCode = field(init=False, default=ExitCode.UNKNOWN)
msg: str = field(init=False, default="")
_stack: ExitStack = field(init=False)

def __enter__(self) -> "HealthCheckRuntime":
self.node = socket.gethostname()
self.logger, _ = init_logger(
logger_name=self.type,
log_dir=str(Path(self.log_folder) / self.type / "_logs"),
log_name=self.node + ".log",
log_level=getattr(logging, self.log_level),
)
self.logger.info(
"%s: cluster: %s, node: %s, type: %s",
self.health_check_name.value,
self.cluster,
self.node,
self.type,
)
try:
self.gpu_node_id = gni_lib.get_gpu_node_id()
except Exception as e:
self.gpu_node_id = None
self.logger.warning(
f"Could not get gpu_node_id, likely not a GPU host: {e}"
)

self.derived_cluster = get_derived_cluster(
cluster=self.cluster,
heterogeneous_cluster_v1=self.heterogeneous_cluster_v1,
data={"Node": self.node},
)

self._stack = ExitStack()
self._stack.__enter__()
self._stack.enter_context(
TelemetryContext(
sink=self.sink,
sink_opts=self.sink_opts,
logger=self.logger,
cluster=self.cluster,
derived_cluster=self.derived_cluster,
type=self.type,
name=self.health_check_name.value,
node=self.node,
get_exit_code_msg=lambda: (self.exit_code, self.msg),
gpu_node_id=self.gpu_node_id,
)
)
self._stack.enter_context(
OutputContext(
self.type,
self.health_check_name,
lambda: (self.exit_code, self.msg),
self.verbose_out,
)
)

if self.killswitch_getter():
self.exit_code = ExitCode.OK
self.logger.info(
"%s is disabled by killswitch",
self.health_check_name.value,
)
sys.exit(0)

return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[types.TracebackType],
) -> Literal[False]:
self._stack.__exit__(exc_type, exc_val, exc_tb)
return False

def finish(self, exit_code: ExitCode, msg: str) -> NoReturn:
self.exit_code = exit_code
self.msg = msg
sys.exit(exit_code.value)
195 changes: 195 additions & 0 deletions gcm/tests/health_checks_tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
"""Tests for the HealthCheckRuntime context manager."""

import logging
from typing import Callable
from unittest.mock import MagicMock, patch

import pytest
from gcm.health_checks.check_utils.runtime import HealthCheckRuntime
from gcm.health_checks.types import ExitCode
from gcm.schemas.health_check.health_check_name import HealthCheckName


def _make_runtime(
killswitch_getter: Callable[[], bool] = lambda: False,
) -> HealthCheckRuntime:
return HealthCheckRuntime(
cluster="test_cluster",
type="prolog",
log_level="INFO",
log_folder="/tmp",
sink="do_nothing",
sink_opts=(),
verbose_out=False,
heterogeneous_cluster_v1=False,
health_check_name=HealthCheckName.CHECK_SENSORS,
killswitch_getter=killswitch_getter,
)


@patch(
"gcm.health_checks.check_utils.runtime.get_derived_cluster",
return_value="derived_test",
)
@patch("gcm.health_checks.check_utils.runtime.gni_lib")
@patch("gcm.health_checks.check_utils.runtime.init_logger")
@patch("gcm.health_checks.check_utils.runtime.socket")
def test_enter_initializes_fields(
mock_socket: MagicMock,
mock_init_logger: MagicMock,
mock_gni: MagicMock,
mock_derived: MagicMock,
) -> None:
"""Verify __enter__ populates logger, node, gpu_node_id, and derived_cluster."""
mock_socket.gethostname.return_value = "testnode01"
test_logger = logging.getLogger("test")
mock_init_logger.return_value = (test_logger, MagicMock())
mock_gni.get_gpu_node_id.return_value = "gpu-0"

rt = _make_runtime()
with rt as runtime:
assert runtime.node == "testnode01"
assert runtime.logger is test_logger
assert runtime.gpu_node_id == "gpu-0"
assert runtime.derived_cluster == "derived_test"


@patch(
"gcm.health_checks.check_utils.runtime.get_derived_cluster",
return_value="test_cluster",
)
@patch("gcm.health_checks.check_utils.runtime.gni_lib")
@patch("gcm.health_checks.check_utils.runtime.init_logger")
@patch("gcm.health_checks.check_utils.runtime.socket")
def test_killswitch_enabled_exits_ok(
mock_socket: MagicMock,
mock_init_logger: MagicMock,
mock_gni: MagicMock,
mock_derived: MagicMock,
) -> None:
"""When killswitch_getter returns True, sys.exit should be called with 0."""
mock_socket.gethostname.return_value = "testnode01"
mock_init_logger.return_value = (logging.getLogger("test"), MagicMock())
mock_gni.get_gpu_node_id.return_value = "gpu-0"

with pytest.raises(SystemExit) as exc_info:
with _make_runtime(killswitch_getter=lambda: True):
pytest.fail(
"With block body should not be reached when killswitch is enabled"
)

assert exc_info.value.code == ExitCode.OK.value


@patch(
"gcm.health_checks.check_utils.runtime.get_derived_cluster",
return_value="test_cluster",
)
@patch("gcm.health_checks.check_utils.runtime.gni_lib")
@patch("gcm.health_checks.check_utils.runtime.init_logger")
@patch("gcm.health_checks.check_utils.runtime.socket")
def test_killswitch_disabled_continues(
mock_socket: MagicMock,
mock_init_logger: MagicMock,
mock_gni: MagicMock,
mock_derived: MagicMock,
) -> None:
"""When killswitch_getter returns False, the with block body should execute normally."""
mock_socket.gethostname.return_value = "testnode01"
mock_init_logger.return_value = (logging.getLogger("test"), MagicMock())
mock_gni.get_gpu_node_id.return_value = "gpu-0"

body_executed = False
with _make_runtime(killswitch_getter=lambda: False) as rt:
body_executed = True
rt.exit_code = ExitCode.OK
rt.msg = "all good"

assert body_executed


@patch(
"gcm.health_checks.check_utils.runtime.get_derived_cluster",
return_value="test_cluster",
)
@patch("gcm.health_checks.check_utils.runtime.gni_lib")
@patch("gcm.health_checks.check_utils.runtime.init_logger")
@patch("gcm.health_checks.check_utils.runtime.socket")
def test_finish_sets_code_and_exits(
mock_socket: MagicMock,
mock_init_logger: MagicMock,
mock_gni: MagicMock,
mock_derived: MagicMock,
) -> None:
"""finish() should set exit_code and msg, then call sys.exit with the code value."""
mock_socket.gethostname.return_value = "testnode01"
mock_init_logger.return_value = (logging.getLogger("test"), MagicMock())
mock_gni.get_gpu_node_id.return_value = "gpu-0"

with pytest.raises(SystemExit) as exc_info:
with _make_runtime() as rt:
rt.finish(ExitCode.CRITICAL, "something broke")

assert exc_info.value.code == ExitCode.CRITICAL.value
assert rt.exit_code == ExitCode.CRITICAL
assert rt.msg == "something broke"


@patch("gcm.health_checks.check_utils.runtime.OutputContext")
@patch("gcm.health_checks.check_utils.runtime.TelemetryContext")
@patch(
"gcm.health_checks.check_utils.runtime.get_derived_cluster",
return_value="test_cluster",
)
@patch("gcm.health_checks.check_utils.runtime.gni_lib")
@patch("gcm.health_checks.check_utils.runtime.init_logger")
@patch("gcm.health_checks.check_utils.runtime.socket")
def test_telemetry_and_output_contexts_entered(
mock_socket: MagicMock,
mock_init_logger: MagicMock,
mock_gni: MagicMock,
mock_derived: MagicMock,
mock_telemetry_cls: MagicMock,
mock_output_cls: MagicMock,
) -> None:
"""Both TelemetryContext and OutputContext should be entered during __enter__."""
mock_socket.gethostname.return_value = "testnode01"
mock_init_logger.return_value = (logging.getLogger("test"), MagicMock())
mock_gni.get_gpu_node_id.return_value = "gpu-0"

mock_telem_instance = MagicMock()
mock_telemetry_cls.return_value = mock_telem_instance
mock_output_instance = MagicMock()
mock_output_cls.return_value = mock_output_instance

with _make_runtime() as rt:
rt.exit_code = ExitCode.OK

mock_telem_instance.__enter__.assert_called_once()
mock_output_instance.__enter__.assert_called_once()


@patch(
"gcm.health_checks.check_utils.runtime.get_derived_cluster",
return_value="test_cluster",
)
@patch("gcm.health_checks.check_utils.runtime.gni_lib")
@patch("gcm.health_checks.check_utils.runtime.init_logger")
@patch("gcm.health_checks.check_utils.runtime.socket")
def test_gpu_node_id_failure_handled(
mock_socket: MagicMock,
mock_init_logger: MagicMock,
mock_gni: MagicMock,
mock_derived: MagicMock,
) -> None:
"""When gni_lib.get_gpu_node_id raises, gpu_node_id should be None and a warning logged."""
mock_socket.gethostname.return_value = "testnode01"
test_logger = logging.getLogger("test_gpu_failure")
mock_init_logger.return_value = (test_logger, MagicMock())
mock_gni.get_gpu_node_id.side_effect = RuntimeError("not a GPU host")

with _make_runtime() as rt:
assert rt.gpu_node_id is None
rt.exit_code = ExitCode.OK
Loading