From 5fe80a8a33470c4690a8ff6a4d16f7db12af522b Mon Sep 17 00:00:00 2001 From: Arseniy Belkov Date: Sat, 30 Aug 2025 21:16:35 +0300 Subject: [PATCH 1/2] no more np.asarray oin metric monitor --- tests/callbacks/test_metric_monitor.py | 36 ++++++++++++++++++++++++++ thunder/callbacks/metric_monitor.py | 2 +- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_metric_monitor.py b/tests/callbacks/test_metric_monitor.py index ec1eb36..006f2c6 100644 --- a/tests/callbacks/test_metric_monitor.py +++ b/tests/callbacks/test_metric_monitor.py @@ -1,3 +1,4 @@ +from collections import defaultdict from contextlib import nullcontext from functools import partial, wraps from itertools import chain @@ -509,3 +510,38 @@ def preproc2(y, x): assert list(metric_monitor.group_preprocess.keys()) == [preproc1, preproc2], len( metric_monitor.group_preprocess.keys() ) + + +def test_dict_as_group_preprocessing_result(tmpdir): + def restack_dict(batch_of_dicts: tuple[dict, ...]) -> dict: + new = defaultdict(list) + for dct in batch_of_dicts: + for k in dct: + new[k].append(dct[k]) + + return new + + def metric(dct, _): + dct = restack_dict(dct) + return dct["y"] == dct["x"] + + group_metrics = { + lambda y, x: ({"y": y, "x": x}, x): { + "y_eq_x": metric + } + } + + monitor = MetricMonitor(None, group_metrics, None) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=4, + limit_val_batches=4, + enable_checkpointing=False, + enable_progress_bar=False, + callbacks=[monitor], + logger=CSVLogger(tmpdir), + ) + model = MultiLoaderModule(nn.Linear(2, 1), lambda x, y: x + y, -1) + trainer.fit(model) + trainer.test(model) \ No newline at end of file diff --git a/thunder/callbacks/metric_monitor.py b/thunder/callbacks/metric_monitor.py index a69394e..bb5fe8b 100644 --- a/thunder/callbacks/metric_monitor.py +++ b/thunder/callbacks/metric_monitor.py @@ -203,7 +203,7 @@ def evaluate_epoch(self, trainer: Trainer, pl_module: LightningModule, key: str) for dataloader_idx, all_predictions in self._all_predictions.items(): loader_postfix = f"/{dataloader_idx}" if len(self._all_predictions) > 1 else "" for preprocess, metrics_names in self.group_preprocess.items(): - preprocessed = [np.asarray(p) for p in zip(*all_predictions[preprocess], strict=True)] + preprocessed = list(zip(*all_predictions[preprocess], strict=True)) for name in metrics_names: group_metric_values[f"{name}{loader_postfix}"] = self.group_metrics[name](*preprocessed) From eb2ec49328a6caa0adb1e29de162d0644c58ecc1 Mon Sep 17 00:00:00 2001 From: Arseniy Belkov Date: Fri, 21 Nov 2025 20:38:58 +0400 Subject: [PATCH 2/2] new module for metric monitor --- .../{metric_monitor.py => metric_monitor/__init__.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename thunder/callbacks/{metric_monitor.py => metric_monitor/__init__.py} (99%) diff --git a/thunder/callbacks/metric_monitor.py b/thunder/callbacks/metric_monitor/__init__.py similarity index 99% rename from thunder/callbacks/metric_monitor.py rename to thunder/callbacks/metric_monitor/__init__.py index bb5fe8b..dea8ea2 100644 --- a/thunder/callbacks/metric_monitor.py +++ b/thunder/callbacks/metric_monitor/__init__.py @@ -13,8 +13,8 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT from toolz import compose, keymap, valmap -from ..torch.utils import to_np -from ..utils import squeeze_first +from ...torch.utils import to_np +from ...utils import squeeze_first class MetricMonitor(Callback):