Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions src/tensorial/datasets/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,27 @@ def __init__(
download: bool = True,
limit: int | None = None,
as_graphs: dict | None = None,
shuffle: bool = False,
rng_seed: int | None = None,
):
# Params
self._data_dir: Final[str] = data_dir
self._download: Final[bool] = download
self._to_graphs: Final[dict] = as_graphs

# State
if rng_seed is not None and not shuffle:
_LOGGER.warning(
"rng_seed is provided but shuffle is False. The seed will have no effect."
)

self._rng = np.random.default_rng(seed=rng_seed)

if download:
self._do_download("/".join([self.URL, self.FILENAME]), self.FILENAME)

archive_path = pathlib.Path(self._data_dir) / self.FILENAME
self._data = self._extract_tarball(archive_path, limit)
self._data = self._extract_tarball(archive_path, limit, shuffle)

def __getitem__(self, item):
entry = self._data[item]
Expand Down Expand Up @@ -117,18 +126,37 @@ def _do_download(self, url: str, filename: str):

_LOGGER.info("downloaded %s to %s", url, self._data_dir)

def _extract_tarball(self, archive_path, limit=None) -> list[MoleculeDict]:
def _extract_tarball(self, archive_path, limit=None, shuffle=False) -> list[MoleculeDict]:
molecules = []
with tarfile.open(archive_path) as file:
members = file.getmembers()
if limit:
members = members[:limit]
all_members = file.getmembers()
n_members = len(all_members)

if limit is not None:
if shuffle:
# Sort indices for efficient sequential tar access. Accessing a
# compressed tarball out of order causes massive performance
# overhead as it must decompress and seek from the start for each file.
indices = self._rng.choice(n_members, size=limit, replace=False)
indices.sort()
else:
# First N files as they appear in the archive
indices = np.arange(limit)

members = [all_members[i] for i in indices]
else:
members = all_members

for entry in tqdm.tqdm(members):
file_handle = file.extractfile(entry.name)
out = read_qm9(io.TextIOWrapper(file_handle, encoding="utf-8"))
out["filename"] = entry.name
molecules.append(out)

# Final shuffle to ensure labels/data aren't ordered by tarball position
if shuffle:
self._rng.shuffle(molecules)

return molecules

def to_graph(self, entry: MoleculeDict) -> jraph.GraphsTuple:
Expand Down
Binary file added test/assets/dsgdb9nsd.xyz.tar.bz2
Binary file not shown.
52 changes: 52 additions & 0 deletions test/test_qm9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pathlib

import pytest

from tensorial.datasets.qm9 import Qm9


@pytest.fixture(scope="session")
def qm9_data_dir(pytestconfig) -> pathlib.Path:
"""Returns the path to the directory containing the test database."""
return pytestconfig.rootpath / "test" / "assets"


def get_qm9_filenames(data_dir, limit=None, shuffle=False, rng_seed=None):
dataset = Qm9(
data_dir=str(data_dir), download=False, limit=limit, shuffle=shuffle, rng_seed=rng_seed
)
return [entry["filename"] for entry in dataset]


@pytest.mark.parametrize(
"shuffle,rng_seed,expected_reproducible,expected_diff_from_base",
[
(False, None, True, False), # Case 1: No shuffle
(False, 42, True, False), # Case 2: No shuffle with seed (seed ignored, order unchanged)
(True, 42, True, True), # Case 3: Shuffle with seed (reproducible, different from base)
(
True,
None,
False,
True,
), # Case 4: Shuffle without seed (unpredictable, different from base)
],
)
def test_shuffling_combinations(
qm9_data_dir, shuffle, rng_seed, expected_reproducible, expected_diff_from_base
):
"""Verify the four main combinations of shuffle and rng_seed."""
base_order = get_qm9_filenames(qm9_data_dir, shuffle=False)

order1 = get_qm9_filenames(qm9_data_dir, shuffle=shuffle, rng_seed=rng_seed)
order2 = get_qm9_filenames(qm9_data_dir, shuffle=shuffle, rng_seed=rng_seed)

if expected_reproducible:
assert order1 == order2
else:
assert order1 != order2

if expected_diff_from_base:
assert order1 != base_order
else:
assert order1 == base_order
Loading