From 87935f6d39f706bb7438bdab60751b15f6e720a6 Mon Sep 17 00:00:00 2001 From: Giovanni Trezza Date: Mon, 16 Mar 2026 18:33:00 +0100 Subject: [PATCH] Shuffling QM9 before limiting --- src/tensorial/datasets/qm9.py | 38 +++++++++++++++++++--- test/assets/dsgdb9nsd.xyz.tar.bz2 | Bin 0 -> 6156 bytes test/test_qm9.py | 52 ++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 test/assets/dsgdb9nsd.xyz.tar.bz2 create mode 100644 test/test_qm9.py diff --git a/src/tensorial/datasets/qm9.py b/src/tensorial/datasets/qm9.py index b69d160..5ec235b 100644 --- a/src/tensorial/datasets/qm9.py +++ b/src/tensorial/datasets/qm9.py @@ -61,6 +61,8 @@ def __init__( download: bool = True, limit: int | None = None, as_graphs: dict | None = None, + shuffle: bool = False, + rng_seed: int | None = None, ): # Params self._data_dir: Final[str] = data_dir @@ -68,11 +70,18 @@ def __init__( self._to_graphs: Final[dict] = as_graphs # State + if rng_seed is not None and not shuffle: + _LOGGER.warning( + "rng_seed is provided but shuffle is False. The seed will have no effect." + ) + + self._rng = np.random.default_rng(seed=rng_seed) + if download: self._do_download("/".join([self.URL, self.FILENAME]), self.FILENAME) archive_path = pathlib.Path(self._data_dir) / self.FILENAME - self._data = self._extract_tarball(archive_path, limit) + self._data = self._extract_tarball(archive_path, limit, shuffle) def __getitem__(self, item): entry = self._data[item] @@ -117,18 +126,37 @@ def _do_download(self, url: str, filename: str): _LOGGER.info("downloaded %s to %s", url, self._data_dir) - def _extract_tarball(self, archive_path, limit=None) -> list[MoleculeDict]: + def _extract_tarball(self, archive_path, limit=None, shuffle=False) -> list[MoleculeDict]: molecules = [] with tarfile.open(archive_path) as file: - members = file.getmembers() - if limit: - members = members[:limit] + all_members = file.getmembers() + n_members = len(all_members) + + if limit is not None: + if shuffle: + # Sort indices for efficient sequential tar access. Accessing a + # compressed tarball out of order causes massive performance + # overhead as it must decompress and seek from the start for each file. + indices = self._rng.choice(n_members, size=limit, replace=False) + indices.sort() + else: + # First N files as they appear in the archive + indices = np.arange(limit) + + members = [all_members[i] for i in indices] + else: + members = all_members + for entry in tqdm.tqdm(members): file_handle = file.extractfile(entry.name) out = read_qm9(io.TextIOWrapper(file_handle, encoding="utf-8")) out["filename"] = entry.name molecules.append(out) + # Final shuffle to ensure labels/data aren't ordered by tarball position + if shuffle: + self._rng.shuffle(molecules) + return molecules def to_graph(self, entry: MoleculeDict) -> jraph.GraphsTuple: diff --git a/test/assets/dsgdb9nsd.xyz.tar.bz2 b/test/assets/dsgdb9nsd.xyz.tar.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..cded29b047dff42c21ed1772ef21302be3adabef GIT binary patch literal 6156 zcmV+n81v^sT4*^jL0KkKS^R%<4gen~|Htew1OmwC|KbQ`hyeWYpYQ-cKnMU}03ZMe zU?2YZ_dLFvyGmN!%=f^!1HQw&0a!VEYg${+tJMGi-uB=)=bZLq$yTuIt?qjEHr6So zN?U1NmNvna)=3hTuBPSNXI1ZnplW&mB$`x4lMrYC0000rPta+ilB#(>MNI}x00uw? zngAl7lOU3*>SO~5000dQ00UFR1d@>qh5(r|G6tFr83qz*B$}C1^*tncr2(KC14pRQ zkPL!EO$eGMDtKxW^rqAUX*6g64^z+(;q4Xvf*nE`fT}tOS|m{;un0@X*m%g1^7!@f z>yH@UF?uAuv23gWi3Bu}QB2FmT~)Jag7?4~V2eNq?*uUekAs}ZzdI0Lkv;4`oA&3@ znjw&ZXp?)@LjX^YPYx4aAb%HgZrir3ac{BIXsBgYltxHs4uu{!TIUbO^J<&8*0w0s zmJ%^Gm0zoO6@aE})y&9KH+Z07XRC4|)k$tuLH+tIl8F|8MmXe?TBW-b&@Wluv=*7l#V3igQhtDSS-0kJI}G~j}%GNc*k z7^T6FCJLLHS4KF7^;R^Qnr%Rwnt?5E3F0_LSX|;Mc(8R~#ag?gvkjQk5=A2&FAfM# zF^a_b)oDsNVvfCV%hebmVz``}(xPGuX_IKosu0~Hc%Bx>z+&!~N}XA}Q#$QcJFYI@6ed}ah%RCCLA9isjA+c92^+|>Em}zxB1K84f+m(L;PywV6*z<+au4hv zAq6zJg?Gfv%4H--XOy2~r_^}@9Qr0(q3nMB&s|$~Dgy(PDP@&1(9q>-GnSEa{`eC0jr;PbpYv!6e!w41_TtYUUc*F@+vq zDLF)q#}yjxU4Wx@45(|8FoYtP0br|!!DRzvT`oFQYXV1& zt_XKiGj!~=nY)u$Lim7<(M4d#v@E(44X8GeA)72B5Qz;#5D~D#l8;rP0(xC)rL>ec zi4t)l#IXt$N0S!eHA56&jMY~}z)|H?+)_gYU8g2X*EW?RMH;Ck@{F`Lflj1MLm;Kf zB&5>>QV2<0L=&2kjI{U0IBX~i;fUzW!1sNb9qhXLTiNt1>@Ew8DqIC#xWNw?ak2&rszx`7FN6B2NBh=v~OpK%;4liT&fPF zCa<@2#&rye9jbH^O(n3ht*7Q~R_{(@f`T2Fa2Dvrw=P#;$1kvUD_IrDI|%e@)ft%a zbFA9RDrQK7%RM`$p9*@>HghS&>Z4BLmYve3V{z?Un!x~gHf=l<6WtPA4BMedQFp}} zx>Q68k|Pn-yKt!R%e8Kn-ObKjUOb12y44BMQZ^xR&Q>&PbomuWCbtZ1K!asH4+O(F z)dcM$!9!NYbDZd`GJ1EqGe=ZhsbrCi&u-5`fS9FXId^duI#`p^J6aNKR5iwD7@^$E z;@bwh%AP@i!kVW$Gh+{mp$xo@kuXFwjgfQaz;RmND!56$YTO}ur#yt$fF*%auy4T};c}8k9F@-93$>=f!;u!faO4D}Q*0rlQX;o>v zymCGQM}l1%a0HTZOLUMl$9N!e;OvOBcXG8^Z*s33rXKzBoYhU}F$;)2om)H-U=iu# zt*hK;a9L9&kiu|2=v3B^EY+iCVWCSEnk+pZw`~H*MZe7oEP2vDE&RN9LG$c;!e@i; zsL{W}$s>U{Z;nl&-Ptu=e>xk7xF8p7AV> zP|niGJG%z&Inh5rDw75NT1B2PzYfHGW(a z@f32!L@>t)Bn&Z7mOYM&Kr7vC1bp#VPd-P&1M!E#C^8>xhjo15@dnEfkxyDLmmmnQjzvd2!Eg+ zmG>wc*;P^NQ@H1)Y&!f_RKya^pHpCcSY8J%*j4xk1k)d=5Xx^evk(jPwiVy%7oPmu~yHskw?Trg@%b;c**p)D{uh z*kWI9yP0a^wDFUK9C8tU+P!0>#`572SXWGCi6ixW9smc)89$@+dw>!PFV4S}_E5!o zWl~7CV9c>UwRG1Xfl3Ct=Jh8XOf)zRoi3X#@Tr5dn? z_PjZ}D3!XwsjgvfNVnK$vw?`;RTBb5p!$LF0M@sla4|Jwqq*DDj&23AA;heidwxNk zE^lxIK@oRjS2tSf?r(%>c+e6fO0WRUG8sT_q`Da3D!0^4m}yZMT-fKE(F=BpbbHLaoeol7KoNx^nRc^5Ct3=&^{XT)((X}Kpp()>xxvTD@9+E z>H$`dDODuZ{+@s=RHA-T)=M~JWI{eL=Z=7rOu<>usD6|$OQ$UUt%iXpiX)pA7n!)L zD5B|d*OV5~jFMP`%PGSi)3!SCgJr!$tUMOVqBuAoJ2*tn%ZW_dyq#5>fHzYbTA@!6 zIyBXc2zqy7lp&tO;@W`nAF9O(UcAilF?u(3=BUuol{78u$c}>%#R%_fO*{b0h_dLd z8EI=wxUbb@=}FNOyl&xySAMvnS-cA}7N`oBqHTm8LU&i{_oKPo;JCp%Gn&p-_e`?* ztzvSm8|RmEXUjnNddG~$#w}Y~YIXKm6`Fgytpsl(GZJuYum?rrhJ=H&Q7_8{g7T{I z*@?|^l=AtTNJ+vP%hc=%Y)uJL*SHfYY10c{U5E{^P;RLow?RK89VrB6Lp1W zi&u#|7bi;*O1Z%bCt5YIwnhp#`sq(ZP|r6Ca#rMiCPfhY+Fir4vfrVsHjuuYs`hHo zr(ye^kIJ5Jv`syURu%U}y3U$W#lF-JwyYd|f!@=3M6x%$#t6 z`E=~bdtT||fOscUSH*4V448Euh-k`rm}Dk}!6=1is2;m)PuI zXuS`FF>oANKV8-lP%k!=wxB+`|=pMLCWeQgCc$V9LYcp$n^+mCOdb zn4d+Z)7NT9%W+xe3$^Ck#`RXhQ|U0xg7G5}*>880a#E>ehG)oDd5Ba5JMJc}C$+v7S@9HW@7;I5gn4GgI#itI z^oNn4N<%%4IBL-?|HA0lShAF||_z+83?`im8jM2a@xhq(=8ZRp&ZxP4Emd zL-KVX!=Mh>QyW?v9)=rswPg>s_h`Vm``6tO0kx1|d~%8_qT1)UXp&dMLQ_(DbQ`Z` zDx*bEu7I`!wZA}Yb`jOjmkxf=+ZaA$pXk(FS096gsHUa|%$cDH0+5Q-;+kn2eby1O zmu?+4a6}7NvcR*Js%6w>b;_$6hhVLrVsjo@X?Lz?FVienvnHT+Ed~+4Y)cuk-P81U zLA!JwXPv>OW{(?9hO24Zwu79Df5d*eUT;%-P@PnWSq;~$G2~K*PDEpMc`~LGEh%w( zxYCn4VzQFB6ESH!#T4yYg{3K4pI}d}Co2=5Xb%Kh^A+7CO8O{kV#^aAUP9FFh2VBr z4h`P~R@rxn-0}?ugXWE(8?@yp#mQT_hY3G4warsL zwYJnf;8t8an-!?sRi?yc(Vp>TES_WxS3PnCGc(5*lF`P{U98fr3-jpVe|T)62SH5W6u#pg1$sVpw+uFb|8qWq3{)CmDZbGV78NZ5Oj zQK)2;$GG%yL|(wuvf+6JagME4qR=4ISsA%{Acu^?7@3&YI`nKPazRw_Xri*E{ zFLI52oywNH*}mVmm#igI=eZB?uUcYD%1kdbjDTf&-5mQ~&aM1d2~VT1HC8IhjMHCE z=N&hF)m0pPFFq>4TBPF|DwwNqC065CJ-NvCR#a5%iP`GQT|65hiaW$8(}|g&MJ|q< ztB|s2OMBNNT1c z$d@$s@Me39FRK%ykyz%Zb5~5mF_Eh^_bL(K9uHk?q+_{bIEF!pNo1WgvZAtIUnG!i<3Am}v{c~^3fUKg8}G?^7aMu736L$IaN866tC*x?$~ z73!((RTCoW;-+Sr$eV*4-%U5P@&vK5omvz}A=8t2omyu;H|@@3R6aVDGQlZ_KsvZz zb+TsAO2pzy!`t**^`25uqahuoOjm_&cW*fu$jh@ML_swYl?@TpQAwoZtHoPxC|sh{ zE4rj&9V0Y@(*x_4oGx+qhB=i<)3s8v3b~hzRYIB*0L?!91i@b)T6DrJJ~&ro=Am_X z{m1;vXyUuRA#8H`5Xt5+*z7~{vJF5Ee z2z9**%Ld}*Qh?Ii7>!okgdMpeaYuzcSaZa&CsAGt$??dAX9q<}?lD+>HK!*K&DY(? z`A@t;I_#R6N@u)t2NyekVvLdfSk;@V!ER#h!AulFhe>9Eug4o`t~E`UO#&SPYS%tsRypP(v^wJ zFmQ-ebX-`uJlJ4^VHi29+CUH|T@02+i9uQwP2~%a-1zNIX$CGjRbXSfa-t|^&4bsf zY)tLM%B|*UBiZc5xYkI}aFfCK0OKHwyS&`%Ob)ncH4TvL1Hm0e~#~c>U^;+*BRzxL48eaORuR+xV+SrV~wzmqP z<7}XNc2_S9PD036rX~g3 pathlib.Path: + """Returns the path to the directory containing the test database.""" + return pytestconfig.rootpath / "test" / "assets" + + +def get_qm9_filenames(data_dir, limit=None, shuffle=False, rng_seed=None): + dataset = Qm9( + data_dir=str(data_dir), download=False, limit=limit, shuffle=shuffle, rng_seed=rng_seed + ) + return [entry["filename"] for entry in dataset] + + +@pytest.mark.parametrize( + "shuffle,rng_seed,expected_reproducible,expected_diff_from_base", + [ + (False, None, True, False), # Case 1: No shuffle + (False, 42, True, False), # Case 2: No shuffle with seed (seed ignored, order unchanged) + (True, 42, True, True), # Case 3: Shuffle with seed (reproducible, different from base) + ( + True, + None, + False, + True, + ), # Case 4: Shuffle without seed (unpredictable, different from base) + ], +) +def test_shuffling_combinations( + qm9_data_dir, shuffle, rng_seed, expected_reproducible, expected_diff_from_base +): + """Verify the four main combinations of shuffle and rng_seed.""" + base_order = get_qm9_filenames(qm9_data_dir, shuffle=False) + + order1 = get_qm9_filenames(qm9_data_dir, shuffle=shuffle, rng_seed=rng_seed) + order2 = get_qm9_filenames(qm9_data_dir, shuffle=shuffle, rng_seed=rng_seed) + + if expected_reproducible: + assert order1 == order2 + else: + assert order1 != order2 + + if expected_diff_from_base: + assert order1 != base_order + else: + assert order1 == base_order