diff --git a/src/tensorial/datasets/qm9.py b/src/tensorial/datasets/qm9.py index b69d160..5ec235b 100644 --- a/src/tensorial/datasets/qm9.py +++ b/src/tensorial/datasets/qm9.py @@ -61,6 +61,8 @@ def __init__( download: bool = True, limit: int | None = None, as_graphs: dict | None = None, + shuffle: bool = False, + rng_seed: int | None = None, ): # Params self._data_dir: Final[str] = data_dir @@ -68,11 +70,18 @@ def __init__( self._to_graphs: Final[dict] = as_graphs # State + if rng_seed is not None and not shuffle: + _LOGGER.warning( + "rng_seed is provided but shuffle is False. The seed will have no effect." + ) + + self._rng = np.random.default_rng(seed=rng_seed) + if download: self._do_download("/".join([self.URL, self.FILENAME]), self.FILENAME) archive_path = pathlib.Path(self._data_dir) / self.FILENAME - self._data = self._extract_tarball(archive_path, limit) + self._data = self._extract_tarball(archive_path, limit, shuffle) def __getitem__(self, item): entry = self._data[item] @@ -117,18 +126,37 @@ def _do_download(self, url: str, filename: str): _LOGGER.info("downloaded %s to %s", url, self._data_dir) - def _extract_tarball(self, archive_path, limit=None) -> list[MoleculeDict]: + def _extract_tarball(self, archive_path, limit=None, shuffle=False) -> list[MoleculeDict]: molecules = [] with tarfile.open(archive_path) as file: - members = file.getmembers() - if limit: - members = members[:limit] + all_members = file.getmembers() + n_members = len(all_members) + + if limit is not None: + if shuffle: + # Sort indices for efficient sequential tar access. Accessing a + # compressed tarball out of order causes massive performance + # overhead as it must decompress and seek from the start for each file. + indices = self._rng.choice(n_members, size=limit, replace=False) + indices.sort() + else: + # First N files as they appear in the archive + indices = np.arange(limit) + + members = [all_members[i] for i in indices] + else: + members = all_members + for entry in tqdm.tqdm(members): file_handle = file.extractfile(entry.name) out = read_qm9(io.TextIOWrapper(file_handle, encoding="utf-8")) out["filename"] = entry.name molecules.append(out) + # Final shuffle to ensure labels/data aren't ordered by tarball position + if shuffle: + self._rng.shuffle(molecules) + return molecules def to_graph(self, entry: MoleculeDict) -> jraph.GraphsTuple: diff --git a/test/assets/dsgdb9nsd.xyz.tar.bz2 b/test/assets/dsgdb9nsd.xyz.tar.bz2 new file mode 100644 index 0000000..cded29b Binary files /dev/null and b/test/assets/dsgdb9nsd.xyz.tar.bz2 differ diff --git a/test/test_qm9.py b/test/test_qm9.py new file mode 100644 index 0000000..a28596f --- /dev/null +++ b/test/test_qm9.py @@ -0,0 +1,52 @@ +import pathlib + +import pytest + +from tensorial.datasets.qm9 import Qm9 + + +@pytest.fixture(scope="session") +def qm9_data_dir(pytestconfig) -> pathlib.Path: + """Returns the path to the directory containing the test database.""" + return pytestconfig.rootpath / "test" / "assets" + + +def get_qm9_filenames(data_dir, limit=None, shuffle=False, rng_seed=None): + dataset = Qm9( + data_dir=str(data_dir), download=False, limit=limit, shuffle=shuffle, rng_seed=rng_seed + ) + return [entry["filename"] for entry in dataset] + + +@pytest.mark.parametrize( + "shuffle,rng_seed,expected_reproducible,expected_diff_from_base", + [ + (False, None, True, False), # Case 1: No shuffle + (False, 42, True, False), # Case 2: No shuffle with seed (seed ignored, order unchanged) + (True, 42, True, True), # Case 3: Shuffle with seed (reproducible, different from base) + ( + True, + None, + False, + True, + ), # Case 4: Shuffle without seed (unpredictable, different from base) + ], +) +def test_shuffling_combinations( + qm9_data_dir, shuffle, rng_seed, expected_reproducible, expected_diff_from_base +): + """Verify the four main combinations of shuffle and rng_seed.""" + base_order = get_qm9_filenames(qm9_data_dir, shuffle=False) + + order1 = get_qm9_filenames(qm9_data_dir, shuffle=shuffle, rng_seed=rng_seed) + order2 = get_qm9_filenames(qm9_data_dir, shuffle=shuffle, rng_seed=rng_seed) + + if expected_reproducible: + assert order1 == order2 + else: + assert order1 != order2 + + if expected_diff_from_base: + assert order1 != base_order + else: + assert order1 == base_order