From 1156c7e130a8815f3411e0dfb6464daf74addabf Mon Sep 17 00:00:00 2001 From: Lukifer23 <56565060+lukifer23@users.noreply.github.com> Date: Sun, 12 Oct 2025 16:56:33 -0500 Subject: [PATCH] Cache Matrix0 parameter count for health endpoint --- tests/test_webui_server.py | 31 +++++++++++++++++++++++++ webui/server.py | 47 +++++++++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/tests/test_webui_server.py b/tests/test_webui_server.py index bdd3d12..5f8cbec 100644 --- a/tests/test_webui_server.py +++ b/tests/test_webui_server.py @@ -1,5 +1,6 @@ import chess import pytest +from fastapi.testclient import TestClient from webui import server from webui.server import HTTPException, NewGameRequest @@ -9,6 +10,9 @@ def stub_dependencies(monkeypatch): monkeypatch.setattr(server, "_load_matrix0", lambda *args, **kwargs: None) monkeypatch.setattr(server, "_jsonl_write", lambda *args, **kwargs: None) + server._matrix0_model = None + server._matrix0_model_params = None + server._cfg = None server.GAMES.clear() @@ -56,3 +60,30 @@ def test_system_metrics_endpoint(monkeypatch): assert isinstance(metrics["timestamp"], float) server.GAMES.clear() + + +def test_health_caches_model_parameter_count(monkeypatch): + call_count = 0 + + class DummyModel: + def count_parameters(self): + return 123 + + def fake_from_config(cfg): + nonlocal call_count + call_count += 1 + return DummyModel() + + monkeypatch.setattr(server.PolicyValueNet, "from_config", staticmethod(fake_from_config)) + + client = TestClient(server.app) + + resp1 = client.get("/health") + assert resp1.status_code == 200 + assert resp1.json()["model_params"] == 123 + + resp2 = client.get("/health") + assert resp2.status_code == 200 + assert resp2.json()["model_params"] == 123 + + assert call_count == 1 diff --git a/webui/server.py b/webui/server.py index bfc7a65..96a3a7f 100644 --- a/webui/server.py +++ b/webui/server.py @@ -173,6 +173,7 @@ async def favicon(): # Lazy-loaded engines _matrix0_model = None +_matrix0_model_params: Optional[int] = None _matrix0_mcts = None _device = None _cfg: Config | None = None @@ -736,7 +737,7 @@ def orchestrator_workers(): def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, device_pref: str = "cpu"): - global _matrix0_model, _matrix0_mcts, _device, _cfg + global _matrix0_model, _matrix0_mcts, _device, _cfg, _matrix0_model_params if _matrix0_model is not None and _matrix0_mcts is not None: return _cfg = Config.load(cfg_path) @@ -753,6 +754,10 @@ def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, dev state = torch.load(ckpt_path, map_location=_device, weights_only=False) model.load_state_dict(state.get("model_ema", state.get("model", {}))) _matrix0_model = model.to(_device) + try: + _matrix0_model_params = _matrix0_model.count_parameters() + except Exception: + _matrix0_model_params = None e = cfg.eval() mcfg_dict = dict(cfg.mcts()) mcfg_dict.update( @@ -769,6 +774,38 @@ def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, dev _matrix0_mcts = MCTS(_matrix0_model, MCTSConfig.from_dict(mcfg_dict), _device) +def _get_matrix0_param_count() -> Optional[int]: + """Return the cached parameter count, computing it lazily if needed.""" + + global _matrix0_model_params, _cfg + + if _matrix0_model_params is not None: + return _matrix0_model_params + + if _matrix0_model is not None: + try: + _matrix0_model_params = _matrix0_model.count_parameters() + return _matrix0_model_params + except Exception as exc: # pragma: no cover - defensive fallback + logger.debug("Failed to count parameters on loaded model: %s", exc) + return None + + try: + cfg = _cfg or Config.load("config.yaml") + if _cfg is None: + _cfg = cfg + model_cfg = cfg.model() if cfg else {} + model = PolicyValueNet.from_config(model_cfg) + try: + _matrix0_model_params = model.count_parameters() + finally: + del model + return _matrix0_model_params + except Exception as exc: # pragma: no cover - defensive fallback + logger.debug("Unable to compute Matrix0 parameter count: %s", exc) + return None + + def _load_stockfish() -> Optional["chess.engine.SimpleEngine"]: global _stockfish if not HAVE_ENGINE: @@ -1271,13 +1308,7 @@ def get_pgn(name: str): @app.get("/health") def health(): - # lightweight model info - try: - cfg = Config.load("config.yaml") - model = PolicyValueNet.from_config(cfg.model()) - params = model.count_parameters() - except Exception: - params = None + params = _get_matrix0_param_count() sf_available = _load_stockfish() is not None return {"stockfish": sf_available, "model_params": params, "device": _device}