Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/callbacks/test_metric_monitor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from contextlib import nullcontext
from functools import partial, wraps
from itertools import chain
Expand Down Expand Up @@ -509,3 +510,38 @@ def preproc2(y, x):
assert list(metric_monitor.group_preprocess.keys()) == [preproc1, preproc2], len(
metric_monitor.group_preprocess.keys()
)


def test_dict_as_group_preprocessing_result(tmpdir):
def restack_dict(batch_of_dicts: tuple[dict, ...]) -> dict:
new = defaultdict(list)
for dct in batch_of_dicts:
for k in dct:
new[k].append(dct[k])

return new

def metric(dct, _):
dct = restack_dict(dct)
return dct["y"] == dct["x"]

group_metrics = {
lambda y, x: ({"y": y, "x": x}, x): {
"y_eq_x": metric
}
}

monitor = MetricMonitor(None, group_metrics, None)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=4,
limit_val_batches=4,
enable_checkpointing=False,
enable_progress_bar=False,
callbacks=[monitor],
logger=CSVLogger(tmpdir),
)
model = MultiLoaderModule(nn.Linear(2, 1), lambda x, y: x + y, -1)
trainer.fit(model)
trainer.test(model)
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
from toolz import compose, keymap, valmap

from ..torch.utils import to_np
from ..utils import squeeze_first
from ...torch.utils import to_np
from ...utils import squeeze_first


class MetricMonitor(Callback):
Expand Down Expand Up @@ -203,7 +203,7 @@ def evaluate_epoch(self, trainer: Trainer, pl_module: LightningModule, key: str)
for dataloader_idx, all_predictions in self._all_predictions.items():
loader_postfix = f"/{dataloader_idx}" if len(self._all_predictions) > 1 else ""
for preprocess, metrics_names in self.group_preprocess.items():
preprocessed = [np.asarray(p) for p in zip(*all_predictions[preprocess], strict=True)]
preprocessed = list(zip(*all_predictions[preprocess], strict=True))
for name in metrics_names:
group_metric_values[f"{name}{loader_postfix}"] = self.group_metrics[name](*preprocessed)

Expand Down
Loading