Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/test_webui_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import chess
import pytest
from fastapi.testclient import TestClient

from webui import server
from webui.server import HTTPException, NewGameRequest
Expand All @@ -9,6 +10,9 @@
def stub_dependencies(monkeypatch):
monkeypatch.setattr(server, "_load_matrix0", lambda *args, **kwargs: None)
monkeypatch.setattr(server, "_jsonl_write", lambda *args, **kwargs: None)
server._matrix0_model = None
server._matrix0_model_params = None
server._cfg = None
server.GAMES.clear()


Expand Down Expand Up @@ -56,3 +60,30 @@ def test_system_metrics_endpoint(monkeypatch):
assert isinstance(metrics["timestamp"], float)

server.GAMES.clear()


def test_health_caches_model_parameter_count(monkeypatch):
call_count = 0

class DummyModel:
def count_parameters(self):
return 123

def fake_from_config(cfg):
nonlocal call_count
call_count += 1
return DummyModel()

monkeypatch.setattr(server.PolicyValueNet, "from_config", staticmethod(fake_from_config))

client = TestClient(server.app)

resp1 = client.get("/health")
assert resp1.status_code == 200
assert resp1.json()["model_params"] == 123

resp2 = client.get("/health")
assert resp2.status_code == 200
assert resp2.json()["model_params"] == 123

assert call_count == 1
47 changes: 39 additions & 8 deletions webui/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ async def favicon():

# Lazy-loaded engines
_matrix0_model = None
_matrix0_model_params: Optional[int] = None
_matrix0_mcts = None
_device = None
_cfg: Config | None = None
Expand Down Expand Up @@ -736,7 +737,7 @@ def orchestrator_workers():


def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, device_pref: str = "cpu"):
global _matrix0_model, _matrix0_mcts, _device, _cfg
global _matrix0_model, _matrix0_mcts, _device, _cfg, _matrix0_model_params
if _matrix0_model is not None and _matrix0_mcts is not None:
return
_cfg = Config.load(cfg_path)
Expand All @@ -753,6 +754,10 @@ def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, dev
state = torch.load(ckpt_path, map_location=_device, weights_only=False)
model.load_state_dict(state.get("model_ema", state.get("model", {})))
_matrix0_model = model.to(_device)
try:
_matrix0_model_params = _matrix0_model.count_parameters()
except Exception:
_matrix0_model_params = None
e = cfg.eval()
mcfg_dict = dict(cfg.mcts())
mcfg_dict.update(
Expand All @@ -769,6 +774,38 @@ def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, dev
_matrix0_mcts = MCTS(_matrix0_model, MCTSConfig.from_dict(mcfg_dict), _device)


def _get_matrix0_param_count() -> Optional[int]:
"""Return the cached parameter count, computing it lazily if needed."""

global _matrix0_model_params, _cfg
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing _matrix0_model in the global declaration. The function accesses _matrix0_model on line 785 but doesn't declare it as global, which could lead to UnboundLocalError if the global variable doesn't exist.

Copilot uses AI. Check for mistakes.

if _matrix0_model_params is not None:
return _matrix0_model_params

if _matrix0_model is not None:
try:
_matrix0_model_params = _matrix0_model.count_parameters()
return _matrix0_model_params
except Exception as exc: # pragma: no cover - defensive fallback
logger.debug("Failed to count parameters on loaded model: %s", exc)
Comment on lines +789 to +790
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using broad Exception catching is discouraged. Consider catching more specific exceptions that count_parameters() might raise, or document why broad exception handling is necessary here.

Copilot uses AI. Check for mistakes.
return None

try:
cfg = _cfg or Config.load("config.yaml")
if _cfg is None:
_cfg = cfg
model_cfg = cfg.model() if cfg else {}
model = PolicyValueNet.from_config(model_cfg)
try:
_matrix0_model_params = model.count_parameters()
finally:
del model
return _matrix0_model_params
except Exception as exc: # pragma: no cover - defensive fallback
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using broad Exception catching is discouraged. Consider catching more specific exceptions related to model configuration loading and parameter counting, or document why broad exception handling is necessary here.

Suggested change
except Exception as exc: # pragma: no cover - defensive fallback
except (FileNotFoundError, OSError, ValueError, AttributeError) as exc: # pragma: no cover - defensive fallback

Copilot uses AI. Check for mistakes.
logger.debug("Unable to compute Matrix0 parameter count: %s", exc)
return None


def _load_stockfish() -> Optional["chess.engine.SimpleEngine"]:
global _stockfish
if not HAVE_ENGINE:
Expand Down Expand Up @@ -1271,13 +1308,7 @@ def get_pgn(name: str):

@app.get("/health")
def health():
# lightweight model info
try:
cfg = Config.load("config.yaml")
model = PolicyValueNet.from_config(cfg.model())
params = model.count_parameters()
except Exception:
params = None
params = _get_matrix0_param_count()
sf_available = _load_stockfish() is not None
return {"stockfish": sf_available, "model_params": params, "device": _device}

Expand Down
Loading