-
Notifications
You must be signed in to change notification settings - Fork 0
Cache Matrix0 parameter count for health endpoint #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -173,6 +173,7 @@ async def favicon(): | |||||
|
|
||||||
| # Lazy-loaded engines | ||||||
| _matrix0_model = None | ||||||
| _matrix0_model_params: Optional[int] = None | ||||||
| _matrix0_mcts = None | ||||||
| _device = None | ||||||
| _cfg: Config | None = None | ||||||
|
|
@@ -736,7 +737,7 @@ def orchestrator_workers(): | |||||
|
|
||||||
|
|
||||||
| def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, device_pref: str = "cpu"): | ||||||
| global _matrix0_model, _matrix0_mcts, _device, _cfg | ||||||
| global _matrix0_model, _matrix0_mcts, _device, _cfg, _matrix0_model_params | ||||||
| if _matrix0_model is not None and _matrix0_mcts is not None: | ||||||
| return | ||||||
| _cfg = Config.load(cfg_path) | ||||||
|
|
@@ -753,6 +754,10 @@ def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, dev | |||||
| state = torch.load(ckpt_path, map_location=_device, weights_only=False) | ||||||
| model.load_state_dict(state.get("model_ema", state.get("model", {}))) | ||||||
| _matrix0_model = model.to(_device) | ||||||
| try: | ||||||
| _matrix0_model_params = _matrix0_model.count_parameters() | ||||||
| except Exception: | ||||||
| _matrix0_model_params = None | ||||||
| e = cfg.eval() | ||||||
| mcfg_dict = dict(cfg.mcts()) | ||||||
| mcfg_dict.update( | ||||||
|
|
@@ -769,6 +774,38 @@ def _load_matrix0(cfg_path: str = "config.yaml", ckpt: Optional[str] = None, dev | |||||
| _matrix0_mcts = MCTS(_matrix0_model, MCTSConfig.from_dict(mcfg_dict), _device) | ||||||
|
|
||||||
|
|
||||||
| def _get_matrix0_param_count() -> Optional[int]: | ||||||
| """Return the cached parameter count, computing it lazily if needed.""" | ||||||
|
|
||||||
| global _matrix0_model_params, _cfg | ||||||
|
|
||||||
| if _matrix0_model_params is not None: | ||||||
| return _matrix0_model_params | ||||||
|
|
||||||
| if _matrix0_model is not None: | ||||||
| try: | ||||||
| _matrix0_model_params = _matrix0_model.count_parameters() | ||||||
| return _matrix0_model_params | ||||||
| except Exception as exc: # pragma: no cover - defensive fallback | ||||||
| logger.debug("Failed to count parameters on loaded model: %s", exc) | ||||||
|
Comment on lines
+789
to
+790
|
||||||
| return None | ||||||
|
|
||||||
| try: | ||||||
| cfg = _cfg or Config.load("config.yaml") | ||||||
| if _cfg is None: | ||||||
| _cfg = cfg | ||||||
| model_cfg = cfg.model() if cfg else {} | ||||||
| model = PolicyValueNet.from_config(model_cfg) | ||||||
| try: | ||||||
| _matrix0_model_params = model.count_parameters() | ||||||
| finally: | ||||||
| del model | ||||||
| return _matrix0_model_params | ||||||
| except Exception as exc: # pragma: no cover - defensive fallback | ||||||
|
||||||
| except Exception as exc: # pragma: no cover - defensive fallback | |
| except (FileNotFoundError, OSError, ValueError, AttributeError) as exc: # pragma: no cover - defensive fallback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing
_matrix0_modelin the global declaration. The function accesses_matrix0_modelon line 785 but doesn't declare it as global, which could lead to UnboundLocalError if the global variable doesn't exist.