From 6fdedbc3f536e2855b4ec6c7b63c618d7c7762c9 Mon Sep 17 00:00:00 2001 From: Giovanni Trezza Date: Wed, 18 Mar 2026 18:27:26 +0100 Subject: [PATCH] Add per-epoch shuffling --- src/tensorial/gcnn/data/_datamodule.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tensorial/gcnn/data/_datamodule.py b/src/tensorial/gcnn/data/_datamodule.py index 3516a80..038c5a2 100644 --- a/src/tensorial/gcnn/data/_datamodule.py +++ b/src/tensorial/gcnn/data/_datamodule.py @@ -24,6 +24,7 @@ def __init__( train_val_test_split: Sequence[int | float] = (0.85, 0.05, 0.1), batch_size: int = 32, batch_mode: "gcnn.data.BatchMode | str" = _common.BatchMode.IMPLICIT, + shuffle_every_epoch: bool = False, ): """Initialize the module @@ -45,6 +46,7 @@ def __init__( self.data_test: Dataset | None = None self._max_padding: "gcnn.data.GraphPadding | None" = None self._batch_mode = batch_mode + self._shuffle_every_epoch = shuffle_every_epoch @override def setup(self, stage: "reax.Stage", /) -> None: @@ -113,6 +115,7 @@ def train_dataloader(self) -> reax.DataLoader: return _dataloader.GraphLoader( self.data_train, batch_size=self._batch_size, + shuffle=self._shuffle_every_epoch, padding=self._max_padding, pad=True, batch_mode=self._batch_mode,