From 113183b462b4098cd0c9818c29a74634f80c4b26 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 28 Jan 2026 17:04:13 +0100 Subject: [PATCH 01/13] feat: removed gdown - pull from dataset repository - revised test --- README.md | 4 +- hyperbench/data/dataset.py | 105 ++++++++++------ hyperbench/tests/data/dataset_test.py | 168 +++++++++++++++++++++++--- pyproject.toml | 1 - 4 files changed, 220 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 379d231..b53a434 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,12 @@ # HyperBench -[![Contributors][contributors-shield]][contributors-url] [![Forks][forks-shield]][forks-url] [![Stargazers][stars-shield]][stars-url] +[![Contributors][contributors-shield]][contributors-url] + [![Issues][issues-shield]][issues-url] [![project_license][license-shield]][license-url] + [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index c4ba015..f31c122 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -1,9 +1,9 @@ import json import os -import gdown import tempfile import torch import zstandard as zstd +import requests from enum import Enum from typing import Any, Dict, List, Tuple @@ -19,11 +19,24 @@ class DatasetNames(Enum): Enumeration of available datasets. """ - ALGEBRA = "1" - EMAIL_ENRON = "2" - ARXIV = "3" - DBLP = "4" - THREADSMATHSX = "5" + ALGEBRA = "algebra" + AMAZON = "amazon" + CONTACT_HIGH_SCHOOL = "contact-high-school" + CONTACT_PRIMARY_SCHOOL = "contact-primary-school" + DBLP = "dblp" + EMAIL_ENRON = "email-Enron" + EMAIL_W3C = "email-W3C" + GEOMETRY = "geometry" + GOT = "got" + MUSIC_BLUES_REVIEWS = "music-blues-reviews" + NBA = "nba" + NDC_CLASSES = "NDC-classes" + NDC_SUBSTANCES = "NDC-substances" + RESTAURANT_REVIEWS = "restaurant-reviews" + THREADS_ASK_UBUNTU = "threads-ask-ubuntu" + THREADS_MATH_SX = "threads-math-sx" + TWITTER = "twitter" + VEGAS_BARS_REVIEWS = "vegas-bars-reviews" class HIFConverter: @@ -33,38 +46,51 @@ class HIFConverter: """ @staticmethod - def load_from_hif(dataset_name: str | None, file_id: str | None) -> HIFHypergraph: - if dataset_name is None or file_id is None: + def load_from_hif( + dataset_name: str | None, save_on_disk: bool = False + ) -> HIFHypergraph: + if dataset_name is None: raise ValueError( - f"Dataset name (provided: {dataset_name}) and file ID (provided: {file_id}) must be provided." + f"Dataset name (provided: {dataset_name}) must be provided." ) if dataset_name not in DatasetNames.__members__: raise ValueError(f"Dataset '{dataset_name}' not found.") - dataset_name_lower = dataset_name.lower() + dataset_name = DatasetNames[dataset_name].value current_dir = os.path.dirname(os.path.abspath(__file__)) - zst_filename = os.path.join( - current_dir, "datasets", f"{dataset_name_lower}.json.zst" - ) + zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") - if os.path.exists(zst_filename): - dctx = zstd.ZstdDecompressor() - with ( - open(zst_filename, "rb") as input_f, - tempfile.NamedTemporaryFile( - mode="wb", suffix=".json", delete=False - ) as tmp_file, - ): - dctx.copy_stream(input_f, tmp_file) - output = tmp_file.name - else: - url = f"https://drive.google.com/uc?id={file_id}" + if not os.path.exists(zst_filename): + github_dataset_repo = f"https://github.com/hypernetwork-research-group/datasets/blob/main/{dataset_name}.json.zst?raw=true" - with tempfile.NamedTemporaryFile( - mode="w+", suffix=".json", delete=False - ) as tmp_file: - output = tmp_file.name - gdown.download(url=url, output=output, quiet=False, fuzzy=True) + response = requests.get(github_dataset_repo) + if response.status_code != 200: + raise ValueError( + f"Failed to download dataset '{dataset_name}' from GitHub. Status code: {response.status_code}" + ) + + if save_on_disk: + os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) + with open(zst_filename, "wb") as f: + f.write(response.content) + else: + # Create temporary file for downloaded zst content + with tempfile.NamedTemporaryFile( + mode="wb", suffix=".json.zst", delete=False + ) as tmp_zst_file: + tmp_zst_file.write(response.content) + zst_filename = tmp_zst_file.name + + # Decompress the downloaded zst file + dctx = zstd.ZstdDecompressor() + with ( + open(zst_filename, "rb") as input_f, + tempfile.NamedTemporaryFile( + mode="wb", suffix=".json", delete=False + ) as tmp_file, + ): + dctx.copy_stream(input_f, tmp_file) + output = tmp_file.name with open(output, "r") as f: hiftext = json.load(f) @@ -79,7 +105,6 @@ class Dataset(TorchDataset): """ Base Dataset class for hypergraph datasets, extending PyTorch's Dataset. Attributes: - GDRIVE_FILE_ID (str): Google Drive file ID for the dataset. DATASET_NAME (str): Name of the dataset. hypergraph (HIFHypergraph): Loaded hypergraph instance. Methods: @@ -87,7 +112,6 @@ class Dataset(TorchDataset): process(): Processes the hypergraph into HData format. """ - GDRIVE_FILE_ID = None DATASET_NAME = None def __init__(self) -> None: @@ -129,7 +153,7 @@ def download(self) -> HIFHypergraph: """ if hasattr(self, "hypergraph") and self.hypergraph is not None: return self.hypergraph - hypergraph = HIFConverter.load_from_hif(self.DATASET_NAME, self.GDRIVE_FILE_ID) + hypergraph = HIFConverter.load_from_hif(self.DATASET_NAME) return hypergraph def process(self) -> HData: @@ -386,14 +410,23 @@ def __to_0based_ids( class AlgebraDataset(Dataset): DATASET_NAME = "ALGEBRA" - GDRIVE_FILE_ID = "1-H21_mZTcbbae4U_yM3xzXX19VhbCZ9C" class DBLPDataset(Dataset): DATASET_NAME = "DBLP" - GDRIVE_FILE_ID = "1oiXQWdybAAUvhiYbFY1R9Qd0jliMSSQh" class ThreadsMathsxDataset(Dataset): DATASET_NAME = "THREADSMATHSX" - GDRIVE_FILE_ID = "1jS4FDs7ME-mENV6AJwCOb_glXKMT7YLQ" + + +if __name__ == "__main__": + for dataset in DatasetNames: + print(f"Processing dataset: {dataset.value}") + if dataset == DatasetNames.EMAIL_ENRON: + load_hif = HIFConverter.load_from_hif(dataset.name, save_on_disk=True) + continue + load_hif = HIFConverter.load_from_hif(dataset.name) + print( + f"Loaded HIF hypergraph with {len(load_hif.nodes)} nodes and {len(load_hif.edges)} edges." + ) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index c41b040..fad6368 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -1,3 +1,4 @@ +import requests import torch import pytest from unittest.mock import patch, mock_open @@ -154,9 +155,8 @@ def test_fixture(sample_hypergraph): def test_HIFConverter(): """Test loading a known HIF dataset using HIFConverter.""" dataset_name = "ALGEBRA" - file_id = "1-H21_mZTcbbae4U_yM3xzXX19VhbCZ9C" - hypergraph = HIFConverter.load_from_hif(dataset_name, file_id) + hypergraph = HIFConverter.load_from_hif(dataset_name) assert hypergraph is not None assert hasattr(hypergraph, "nodes") @@ -172,35 +172,171 @@ def test_HIFConverter(): def test_HIFConverter_invalid_dataset(): """Test loading an invalid dataset""" dataset_name = "INVALID_DATASET" - file_id = "invalid_file_id" with pytest.raises(ValueError, match="Dataset 'INVALID_DATASET' not found"): - HIFConverter.load_from_hif(dataset_name, file_id) + HIFConverter.load_from_hif(dataset_name) def test_HIFConverter_invalid_hif_format(): """Test loading an invalid HIF format dataset.""" - dataset_name = "EMAIL_ENRON" - file_id = "test_file_id" + dataset_name = "ALGEBRA" invalid_hif_json = '{"network-type": "undirected", "nodes": []}' with ( - patch("hyperbench.data.dataset.gdown.download") as mock_download, - patch("builtins.open", mock_open(read_data=invalid_hif_json)), + patch("hyperbench.data.dataset.requests.get") as mock_get, patch("hyperbench.data.dataset.validate_hif_json", return_value=False), + patch("builtins.open", mock_open(read_data=invalid_hif_json)), + patch("hyperbench.data.dataset.zstd.ZstdDecompressor"), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"mock_zst_content" + + with pytest.raises(ValueError, match="Dataset 'algebra' is not HIF-compliant"): + HIFConverter.load_from_hif(dataset_name) + + +def test_HIFConverter_save_on_disk(): + """Test downloading dataset with save_on_disk=True.""" + dataset_name = "ALGEBRA" + + mock_hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}, {"node": "1"}], + edges=[{"edge": "0"}], + incidences=[{"node": "0", "edge": "0"}], + ) + + mock_hif_json = { + "network-type": "undirected", + "nodes": [{"node": "0"}, {"node": "1"}], + "edges": [{"edge": "0"}], + "incidences": [{"node": "0", "edge": "0"}], + } + + with ( + patch("hyperbench.data.dataset.requests.get") as mock_get, + patch("hyperbench.data.dataset.os.path.exists", return_value=False), + patch("hyperbench.data.dataset.os.makedirs"), + patch("builtins.open", mock_open()) as mock_file, + patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, + patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), + patch("hyperbench.data.dataset.validate_hif_json", return_value=True), + patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), + ): + # Mock successful download + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"mock_zst_content" + + # Mock decompressor + mock_stream = mock_decomp.return_value.stream_reader.return_value + mock_stream.__enter__ = lambda self: mock_stream + mock_stream.__exit__ = lambda self, *args: None + + hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=True) + + assert hypergraph is not None + assert hypergraph.network_type == "undirected" + mock_get.assert_called_once() + # Verify file was written to disk (not temp file) + assert mock_file.call_count >= 2 # Once for write, once for read + + +def test_HIFConverter_temp_file(): + """Test downloading dataset with save_on_disk=False (uses temp file).""" + dataset_name = "ALGEBRA" + + mock_hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}, {"node": "1"}], + edges=[{"edge": "0"}], + incidences=[{"node": "0", "edge": "0"}], + ) + + mock_hif_json = { + "network-type": "undirected", + "nodes": [{"node": "0"}, {"node": "1"}], + "edges": [{"edge": "0"}], + "incidences": [{"node": "0", "edge": "0"}], + } + + with ( + patch("hyperbench.data.dataset.requests.get") as mock_get, + patch("hyperbench.data.dataset.os.path.exists", return_value=False), + patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, + patch("builtins.open", mock_open()), + patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, + patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), + patch("hyperbench.data.dataset.validate_hif_json", return_value=True), + patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), ): + # Mock successful download + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"mock_zst_content" + + # Mock temp file + mock_temp_file = mock_temp.return_value.__enter__.return_value + mock_temp_file.name = "/tmp/fake_temp.json.zst" + + # Mock decompressor + mock_stream = mock_decomp.return_value.stream_reader.return_value + mock_stream.__enter__ = lambda self: mock_stream + mock_stream.__exit__ = lambda self, *args: None + + hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=False) + + assert hypergraph is not None + assert hypergraph.network_type == "undirected" + mock_get.assert_called_once() + # Verify temp file was used + assert mock_temp.call_count >= 1 + + +def test_HIFConverter_download_failure(): + """Test handling of failed download from GitHub.""" + dataset_name = "ALGEBRA" + + with ( + patch("hyperbench.data.dataset.requests.get") as mock_get, + patch("hyperbench.data.dataset.os.path.exists", return_value=False), + ): + # Mock failed download + mock_response = mock_get.return_value + mock_response.status_code = 404 + with pytest.raises( - ValueError, match="Dataset 'EMAIL_ENRON' is not HIF-compliant" + ValueError, + match=r"Failed to download dataset 'algebra' from GitHub\. Status code: 404", ): - HIFConverter.load_from_hif(dataset_name, file_id) + HIFConverter.load_from_hif(dataset_name) + + mock_get.assert_called_once_with( + "https://github.com/hypernetwork-research-group/datasets/blob/main/algebra.json.zst?raw=true" + ) + + +def test_HIFConverter_network_error(): + """Test handling of network errors during download.""" + dataset_name = "ALGEBRA" + + with ( + patch("hyperbench.data.dataset.requests.get") as mock_get, + patch("hyperbench.data.dataset.os.path.exists", return_value=False), + ): + # Mock network error + mock_get.side_effect = requests.RequestException("Network error") + + with pytest.raises(requests.RequestException, match="Network error"): + HIFConverter.load_from_hif(dataset_name) def test_dataset_not_available(): """Test loading an unavailable dataset.""" class FakeMockDataset(Dataset): - GDRIVE_FILE_ID = "fake_id" DATASET_NAME = "FAKE" with pytest.raises(ValueError, match=r"Dataset 'FAKE' not found"): @@ -220,7 +356,6 @@ def test_AlgebraDataset_available(): with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): dataset = AlgebraDataset() - assert dataset.GDRIVE_FILE_ID == "1-H21_mZTcbbae4U_yM3xzXX19VhbCZ9C" assert dataset.DATASET_NAME == "ALGEBRA" assert dataset.hypergraph is not None assert dataset.__len__() == dataset.hypergraph.num_nodes @@ -246,12 +381,11 @@ def test_dataset_name_none(): """Test that ValueError is raised if DATASET_NAME is None.""" class FakeMockDataset(Dataset): - GDRIVE_FILE_ID = "fake_id" DATASET_NAME = None with pytest.raises( ValueError, - match=r"Dataset name \(provided: None\) and file ID \(provided: fake_id\) must be provided\.", + match=r"Dataset name \(provided: None\) must be provided\.", ): FakeMockDataset() @@ -553,7 +687,6 @@ def test_transform_attrs_empty_attrs(): class TestDataset(Dataset): DATASET_NAME = "TEST" - GDRIVE_FILE_ID = "test_id" dataset = TestDataset() @@ -586,7 +719,6 @@ def test_process_with_inconsistent_node_attributes(): class TestDataset(Dataset): DATASET_NAME = "TEST" - GDRIVE_FILE_ID = "test_id" dataset = TestDataset() @@ -621,7 +753,6 @@ def test_process_with_no_node_attributes_fallback(): class TestDataset(Dataset): DATASET_NAME = "TEST" - GDRIVE_FILE_ID = "test_id" dataset = TestDataset() @@ -650,7 +781,6 @@ def test_process_with_single_node_attribute(): class TestDataset(Dataset): DATASET_NAME = "TEST" - GDRIVE_FILE_ID = "test_id" dataset = TestDataset() @@ -683,7 +813,6 @@ def test_getitem_preserves_node_attributes(): class TestDataset(Dataset): DATASET_NAME = "TEST" - GDRIVE_FILE_ID = "test_id" dataset = TestDataset() @@ -713,7 +842,6 @@ def test_transform_attrs_with_attr_keys_padding(): class TestDataset(Dataset): DATASET_NAME = "TEST" - GDRIVE_FILE_ID = "test_id" dataset = TestDataset() diff --git a/pyproject.toml b/pyproject.toml index 39fc8d1..7672397 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "fastjsonschema>=2.21.2", - "gdown>=5.2.1", "numpy>=1.240", "requests>=2.32.5", "torch>=2.9.1", From 300bc10d4502ac43be9f2fd2100c17675b3c031b Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 28 Jan 2026 17:30:16 +0100 Subject: [PATCH 02/13] feat: revised test names --- README.md | 43 +------------- hyperbench/data/dataset.py | 12 ++-- hyperbench/tests/data/dataset_test.py | 81 +++++++++++---------------- 3 files changed, 42 insertions(+), 94 deletions(-) diff --git a/README.md b/README.md index b53a434..7bda100 100644 --- a/README.md +++ b/README.md @@ -9,47 +9,7 @@ [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) - -
- Table of Contents -
    -
  1. - About the project -
  2. -
  3. - Getting started - -
  4. -
  5. Usage
  6. -
  7. - Contributing - -
  8. -
  9. License
  10. -
  11. Contact
  12. -
  13. Acknowledgments
  14. -
-
- -## About The Project - -## Getting Started - -### Prerequisites - -WIP +For documentation, please visit [here][docs]. ### Installation @@ -144,3 +104,4 @@ WIP [issues-url]: https://github.com/hypernetwork-research-group/hyperbench/issues [license-shield]: https://img.shields.io/github/license/hypernetwork-research-group/hyperbench.svg?style=for-the-badge [license-url]: https://github.com/hypernetwork-research-group/hyperbench/blob/master/LICENSE.txt +[docs]: https://hypernetwork-research-group.github.io/hyperbench/ diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index f31c122..3a92d15 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -6,7 +6,7 @@ import requests from enum import Enum -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from torch import Tensor from torch.utils.data import Dataset as TorchDataset from hyperbench.types.hypergraph import HIFHypergraph @@ -47,7 +47,7 @@ class HIFConverter: @staticmethod def load_from_hif( - dataset_name: str | None, save_on_disk: bool = False + dataset_name: Optional[str], save_on_disk: bool = False ) -> HIFHypergraph: if dataset_name is None: raise ValueError( @@ -235,17 +235,17 @@ def process(self) -> HData: return HData(x, edge_index, edge_attr, num_nodes, num_edges) def transform_node_attrs( - self, attrs: Dict[str, Any], attr_keys: List[str] | None = None + self, attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None ) -> Tensor: return self.transform_attrs(attrs, attr_keys) def transform_edge_attrs( - self, attrs: Dict[str, Any], attr_keys: List[str] | None = None + self, attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None ) -> Tensor: return self.transform_attrs(attrs, attr_keys) def transform_attrs( - self, attrs: Dict[str, Any], attr_keys: List[str] | None = None + self, attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None ) -> Tensor: """ Extract and encode numeric node attributes to tensor. @@ -291,7 +291,7 @@ def __collect_attr_keys(self, attr_keys: List[Dict[str, Any]]) -> List[str]: return unique_keys - def __get_node_ids_to_sample(self, id: int | List[int]) -> List[int]: + def __get_node_ids_to_sample(self, id: Union[int, List[int]]) -> List[int]: if isinstance(id, int): return [id] diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index fad6368..0cead09 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -11,7 +11,7 @@ # Reusable fixture for hypergraph instances used in multiple tests @pytest.fixture -def sample_hypergraph(): +def mock_sample_hypergraph(): return HIFHypergraph( network_type="undirected", nodes=[{"node": "0"}, {"node": "1"}], @@ -21,7 +21,7 @@ def sample_hypergraph(): @pytest.fixture -def simple_mock_hypergraph(): +def mock_simple_hypergraph(): """Simple hypergraph with 2 nodes for basic tests.""" return HIFHypergraph( network_type="undirected", @@ -32,22 +32,7 @@ def simple_mock_hypergraph(): @pytest.fixture -def three_node_mock_hypergraph(): - """Hypergraph with 3 nodes for validation tests.""" - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - edges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) - - -@pytest.fixture -def three_node_mock_weighted_hypergraph(): +def mock_three_node_weighted_hypergraph(): return HIFHypergraph( network_type="undirected", nodes=[ @@ -68,7 +53,7 @@ def three_node_mock_weighted_hypergraph(): @pytest.fixture -def four_node_mock_hypergraph(): +def mock_four_node_hypergraph(): """Hypergraph with 4 nodes and 2 edges for sampling tests.""" return HIFHypergraph( network_type="undirected", @@ -89,7 +74,7 @@ def four_node_mock_hypergraph(): @pytest.fixture -def five_node_mock_hypergraph(): +def mock_five_node_hypergraph(): """Hypergraph with 5 nodes for duplicate testing.""" return HIFHypergraph( network_type="undirected", @@ -106,7 +91,7 @@ def five_node_mock_hypergraph(): @pytest.fixture -def no_edge_attr_mock_hypergraph(): +def mock_no_edge_attr_hypergraph(): return HIFHypergraph( network_type="undirected", nodes=[ @@ -122,7 +107,7 @@ def no_edge_attr_mock_hypergraph(): @pytest.fixture -def multiple_edges_attr_mock_hypergraph(): +def mock_multiple_edges_attr_hypergraph(): return HIFHypergraph( network_type="undirected", nodes=[ @@ -145,11 +130,11 @@ def multiple_edges_attr_mock_hypergraph(): ) -def test_fixture(sample_hypergraph): - assert sample_hypergraph.network_type == "undirected" - assert len(sample_hypergraph.nodes) == 2 - assert len(sample_hypergraph.edges) == 1 - assert len(sample_hypergraph.incidences) == 1 +def test_fixture(mock_sample_hypergraph): + assert mock_sample_hypergraph.network_type == "undirected" + assert len(mock_sample_hypergraph.nodes) == 2 + assert len(mock_sample_hypergraph.edges) == 1 + assert len(mock_sample_hypergraph.incidences) == 1 def test_HIFConverter(): @@ -446,11 +431,11 @@ def test_dataset_process_with_edge_attributes(): ) # weight, type -def test_dataset_process_without_edge_attributes(no_edge_attr_mock_hypergraph): +def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): """Test that process handles edges without attributes.""" with patch.object( - HIFConverter, "load_from_hif", return_value=no_edge_attr_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph ): dataset = AlgebraDataset() @@ -460,11 +445,11 @@ def test_dataset_process_without_edge_attributes(no_edge_attr_mock_hypergraph): assert dataset.hdata.edge_attr is None -def test_dataset_process_edge_index_format(four_node_mock_hypergraph): +def test_dataset_process_edge_index_format(mock_four_node_hypergraph): """Test that edge_index has correct format [node_ids, edge_ids].""" with patch.object( - HIFConverter, "load_from_hif", return_value=four_node_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph ): dataset = AlgebraDataset() @@ -499,10 +484,10 @@ def test_dataset_process_random_ids(): assert dataset.hdata.edge_attr.shape == (2, 0) # 2 edges, 0 attributes each -def test_getitem_index_list_empty(simple_mock_hypergraph): +def test_getitem_index_list_empty(mock_simple_hypergraph): """Test __getitem__ with empty index list raises ValueError.""" with patch.object( - HIFConverter, "load_from_hif", return_value=simple_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_simple_hypergraph ): dataset = AlgebraDataset() @@ -510,10 +495,10 @@ def test_getitem_index_list_empty(simple_mock_hypergraph): dataset[[]] -def test_getitem_index_list_too_large(five_node_mock_hypergraph): +def test_getitem_index_list_too_large(mock_five_node_hypergraph): """Test __getitem__ with index list larger than number of nodes raises ValueError.""" with patch.object( - HIFConverter, "load_from_hif", return_value=five_node_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_five_node_hypergraph ): dataset = AlgebraDataset() @@ -524,10 +509,10 @@ def test_getitem_index_list_too_large(five_node_mock_hypergraph): dataset[[0, 1, 2, 3, 4, 5]] -def test_getitem_index_out_of_bounds(four_node_mock_hypergraph): +def test_getitem_index_out_of_bounds(mock_four_node_hypergraph): """Test __getitem__ with out-of-bounds index raises IndexError.""" with patch.object( - HIFConverter, "load_from_hif", return_value=four_node_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph ): dataset = AlgebraDataset() @@ -535,10 +520,12 @@ def test_getitem_index_out_of_bounds(four_node_mock_hypergraph): dataset[4] -def test_getitem_single_index(sample_hypergraph): +def test_getitem_single_index(mock_sample_hypergraph): """Test __getitem__ with a single index.""" - with patch.object(HIFConverter, "load_from_hif", return_value=sample_hypergraph): + with patch.object( + HIFConverter, "load_from_hif", return_value=mock_sample_hypergraph + ): dataset = AlgebraDataset() node_data = dataset[1] @@ -546,11 +533,11 @@ def test_getitem_single_index(sample_hypergraph): assert node_data.edge_index.shape == (2, 0) -def test_getitem_list_index(four_node_mock_hypergraph): +def test_getitem_list_index(mock_four_node_hypergraph): """Test __getitem__ with a list of indices.""" with patch.object( - HIFConverter, "load_from_hif", return_value=four_node_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph ): dataset = AlgebraDataset() @@ -559,11 +546,11 @@ def test_getitem_list_index(four_node_mock_hypergraph): assert node_data_list.edge_index.shape == (2, 3) -def test_getitem_with_edge_attr(three_node_mock_weighted_hypergraph): +def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph): """Test __getitem__ returns correct edge_attr when present.""" with patch.object( - HIFConverter, "load_from_hif", return_value=three_node_mock_weighted_hypergraph + HIFConverter, "load_from_hif", return_value=mock_three_node_weighted_hypergraph ): dataset = AlgebraDataset() @@ -575,11 +562,11 @@ def test_getitem_with_edge_attr(three_node_mock_weighted_hypergraph): assert node_data.edge_attr[0].item() == 1 -def test_getitem_without_edge_attr(no_edge_attr_mock_hypergraph): +def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph): """Test __getitem__ returns None for edge_attr when not present.""" with patch.object( - HIFConverter, "load_from_hif", return_value=no_edge_attr_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph ): dataset = AlgebraDataset() @@ -587,11 +574,11 @@ def test_getitem_without_edge_attr(no_edge_attr_mock_hypergraph): assert node_data.edge_attr is None -def test_getitem_with_multiple_edges_attr(multiple_edges_attr_mock_hypergraph): +def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph): """Test __getitem__ correctly filters edge_attr for sampled edges.""" with patch.object( - HIFConverter, "load_from_hif", return_value=multiple_edges_attr_mock_hypergraph + HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph ): dataset = AlgebraDataset() From 429a5746f8826207cc8d839b6ff762d692181bc5 Mon Sep 17 00:00:00 2001 From: Tiziano Citro <56075735+tizianocitro@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:33:40 +0100 Subject: [PATCH 03/13] chore: add Makefile for common actions (#21) --- Makefile | 44 ++++++++++++++++++++++++++++++++++++++++++++ utest.sh | 12 ------------ 2 files changed, 44 insertions(+), 12 deletions(-) create mode 100644 Makefile delete mode 100755 utest.sh diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8a8d71f --- /dev/null +++ b/Makefile @@ -0,0 +1,44 @@ +.PHONY: all setup check lint typecheck test clean help + +UV=uv +UVX=uvx +PYTEST=pytest +LINTER=ruff +TYPECHECKER=ty + +all: clean setup check test + +setup: + @echo '=== Setup ===' + $(UV) pip uninstall . + $(UV) sync + $(UV) pip install -e . + +check: lint typecheck + +lint: + @echo '=== Linter ===' + $(UV) run $(LINTER) format + +typecheck: + @echo '=== Type checker ===' + $(UVX) $(TYPECHECKER) check + +test: + @echo '=== Tests ===' + $(UV) run $(PYTEST) + +clean: + @echo '=== Cleaning up ===' + rm -rf **/__pycache__ **/*.pyc hyperbench.egg-info .pytest_cache .coverage + +help: + @echo "Usage: make [target]" + @echo "Targets:" + @echo " all - Setup, lint, typecheck, test" + @echo " setup - Install dependencies" + @echo " lint - Run linter" + @echo " typecheck - Run type checker" + @echo " test - Run tests" + @echo " check - Run lint and typecheck" + @echo " clean - Remove build/test artifacts" diff --git a/utest.sh b/utest.sh deleted file mode 100755 index 50aee46..0000000 --- a/utest.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash - -command -v uv >/dev/null 2>&1 || { echo "uv command not found; please install uv" >&2; exit 1; } - -uv pip uninstall . -uv sync -uv pip install -e . - -uv run ruff format -uvx ty check - -uv run pytest --cov=hyperbench --cov-report=term-missing From 55d445bf92b843698e2af6d504c97deb1cbcfae4 Mon Sep 17 00:00:00 2001 From: Tiziano Citro <56075735+tizianocitro@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:46:29 +0100 Subject: [PATCH 04/13] chore: format README (#22) --- README.md | 69 +++++++++++++++++++++++++------------------------------ 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 7bda100..eb9da44 100644 --- a/README.md +++ b/README.md @@ -9,76 +9,73 @@ [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) -For documentation, please visit [here][docs]. +## About the project -### Installation +WIP -#### Sync dependencies +## Getting started -Use [uv](https://docs.astral.sh/uv/reference/cli/) to sync dependencies: +### Prerequisites -```bash -uv sync -``` +WIP + +### Installation + +WIP ## Usage ## Contributing -See [CONTRIBUTING.md](CONTRIBUTING.md) for details. +See [CONTRIBUTING.md](CONTRIBUTING.md) for details on contributing to the project. -### Pre-commit hooks +### Build -Run the following command to install the pre-commit hook: +To build the project, run: ```bash -uv sync - -pre-commit install --config .github/hooks/.pre-commit-config.yaml --hook-type pre-commit --install-hooks --overwrite +make ``` -### Linter +### Linter and type checker Use [Ruff](https://github.com/charliermarsh/ruff) for linting and formatting: ```bash -uvx ruff check - -uvx ruff format +make lint ``` -### Type checker - Use [Ty](https://docs.astral.sh/ty/) for type checking: ```bash -uvx ty check +make typecheck +``` -# In watch mode -uvx ty check --watch +Use the `check` target to run both linter and type checker: + +```bash +make check ``` ### Tests -Run tests with [pytest](https://docs.pytest.org/en/latest/): +Use [pytest](https://docs.pytest.org/en/latest/) to run the test suite: ```bash -uv run pytest --cov=hyperbench --cov-report=term-missing -# html report +make test + +# Run tests with HTML report uv run pytest --cov=hyperbench --cov-report=html ``` -### Utilities +### Pre-commit hooks -Before committing code, run the following command to ensure code quality: +Run the following command to install the pre-commit hook: ```bash -uv pip uninstall . && \ -uv sync && \ -uv pip install -e . && \ -uv run ruff format && \ -uvx ty check && \ -uv run pytest --cov=hyperbench --cov-report=term-missing +make setup + +pre-commit install --config .github/hooks/.pre-commit-config.yaml --hook-type pre-commit --install-hooks --overwrite ``` ## License @@ -89,11 +86,7 @@ WIP WIP -## Acknowledgments - - - - + [contributors-shield]: https://img.shields.io/github/contributors/hypernetwork-research-group/hyperbench.svg?style=for-the-badge [contributors-url]: https://github.com/hypernetwork-research-group/hyperbench/graphs/contributors [forks-shield]: https://img.shields.io/github/forks/hypernetwork-research-group/hyperbench.svg?style=for-the-badge From 1c482d4ed109c62b1659d16659779f3b8aa4f381 Mon Sep 17 00:00:00 2001 From: Tiziano Citro <56075735+tizianocitro@users.noreply.github.com> Date: Thu, 29 Jan 2026 12:52:10 +0100 Subject: [PATCH 05/13] feat: Add DataLoader and model config structure and makefile * feat: add dataloader and model config structure * chore: add Makefile for common actions (#21) * chore: format README (#22) * fix: change assert to if for trainer check * chore: add coverage report to Makefile test target --- Makefile | 2 +- hyperbench/data/loader.py | 10 +- hyperbench/tests/mock/mock.py | 31 +--- hyperbench/tests/train/trainer_test.py | 223 ++++++++++++++++++++++++ hyperbench/tests/types/model_test.py | 47 +++++ hyperbench/train/__init__.py | 5 + hyperbench/train/trainer.py | 232 +++++++++++++++++++++++++ hyperbench/types/__init__.py | 4 + hyperbench/types/hdata.py | 12 +- hyperbench/types/model.py | 34 ++++ hyperbench/utils/__init__.py | 5 +- hyperbench/utils/data_utils.py | 10 +- pyproject.toml | 2 + 13 files changed, 577 insertions(+), 40 deletions(-) create mode 100644 hyperbench/tests/train/trainer_test.py create mode 100644 hyperbench/tests/types/model_test.py create mode 100644 hyperbench/train/__init__.py create mode 100644 hyperbench/train/trainer.py create mode 100644 hyperbench/types/model.py diff --git a/Makefile b/Makefile index 8a8d71f..bb282b7 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ typecheck: test: @echo '=== Tests ===' - $(UV) run $(PYTEST) + $(UV) run $(PYTEST) --cov=hyperbench --cov-report=term-missing clean: @echo '=== Cleaning up ===' diff --git a/hyperbench/data/loader.py b/hyperbench/data/loader.py index f51f8d9..1fd8d58 100644 --- a/hyperbench/data/loader.py +++ b/hyperbench/data/loader.py @@ -1,6 +1,6 @@ import torch -from typing import List, Tuple +from typing import List, Optional, Tuple from torch import Tensor from torch.utils.data import DataLoader as TorchDataLoader from hyperbench.data import Dataset @@ -9,7 +9,11 @@ class DataLoader(TorchDataLoader): def __init__( - self, dataset: Dataset, batch_size: int = 1, shuffle: bool = False, **kwargs + self, + dataset: Dataset, + batch_size: int = 1, + shuffle: Optional[bool] = False, + **kwargs, ) -> None: super().__init__( dataset=dataset, @@ -101,7 +105,7 @@ def __batch_node_features(self, batch: List[HData]) -> Tuple[Tensor, int]: return batched_node_features, total_nodes - def __batch_edges(self, batch: List[HData]) -> Tuple[Tensor, Tensor | None, int]: + def __batch_edges(self, batch: List[HData]) -> Tuple[Tensor, Optional[Tensor], int]: """Batches hyperedge indices and attributes, adjusting indices for concatenated nodes. Hyperedge indices must be offset so they point to the correct nodes in the batched node tensor. diff --git a/hyperbench/tests/mock/mock.py b/hyperbench/tests/mock/mock.py index 3ce8d26..7bb7fd6 100644 --- a/hyperbench/tests/mock/mock.py +++ b/hyperbench/tests/mock/mock.py @@ -1,30 +1,11 @@ -from typing import Any, List -from hyperbench.data import Dataset -from hyperbench import utils +from unittest.mock import MagicMock MOCK_BASE_PATH = "hyperbench/tests/mock" -class MockDataset(Dataset): - def __init__(self, data_list: list[Any]): - super().__init__() - self.data_list = data_list - self.hypergraph = utils.empty_hifhypergraph() # Not used in this mock - self.hdata = utils.empty_hdata() # Not used in this mock - - def __len__(self): - return len(self.data_list) - - def __getitem__(self, index: int | List[int]) -> Any: - if isinstance(index, list): - return [self.data_list[i] for i in index] - return self.data_list[index] - - def download(self): - # Not implemented for mock as we don't need it - pass - - def process(self): - # Not implemented for mock as we don't need it - pass +def new_mock_trainer() -> MagicMock: + trainer = MagicMock() + trainer.fit = MagicMock() + trainer.test = MagicMock(return_value=[{"acc": 0.9}]) + return trainer diff --git a/hyperbench/tests/train/trainer_test.py b/hyperbench/tests/train/trainer_test.py new file mode 100644 index 0000000..64ae672 --- /dev/null +++ b/hyperbench/tests/train/trainer_test.py @@ -0,0 +1,223 @@ +import pytest + +from unittest.mock import MagicMock, patch +from hyperbench.train import MultiModelTrainer +from hyperbench.types import ModelConfig +from hyperbench.tests import new_mock_trainer + + +@pytest.fixture +def mock_model_configs(): + model_configs = [] + + for i in range(2): + model = MagicMock() + model.name = f"model{i}" + model.version = f"{i}" + + model_config = MagicMock(spec=ModelConfig) + model_config.name = f"model{i}" + model_config.version = f"{i}" + model_config.model = model + model_config.trainer = None + model_config.full_model_name = ( + lambda self=model_config: f"{self.name}:{self.version}" + ) + + model_configs.append(model_config) + + return model_configs + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_trainer_initialization(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + assert len(multi_model_trainer.model_configs) == len(mock_model_configs) + for config in multi_model_trainer.model_configs: + assert config.trainer is not None + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_trainer_initialization_with_initialized_trainer( + mock_trainer, mock_model_configs +): + mock_model_configs[0].trainer = mock_trainer + + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + assert len(multi_model_trainer.model_configs) == len(mock_model_configs) + for config in multi_model_trainer.model_configs: + assert config.trainer is not None + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_models_property_returns_models(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + models = multi_model_trainer.models + + assert len(models) == len(mock_model_configs) + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_models_property_returns_empty_when_no_models(_): + multi_model_trainer = MultiModelTrainer([]) + models = multi_model_trainer.models + + assert len(models) == 0 + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_model_returns_model_when_correct_name_and_no_version(_, mock_model_configs): + mock_model_configs[0].version = "default" + mock_model_configs[0].model.version = "default" + + multi_model_trainer = MultiModelTrainer(mock_model_configs) + found = multi_model_trainer.model(name="model0") + + assert found is not None + assert found.name == "model0" + assert found.version == "default" + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_model_returns_None_when_incorrect_name_and_no_version(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + found = multi_model_trainer.model(name="nonexistent") + + assert found is None + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_model_returns_model_when_correct_name_and_version(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + found = multi_model_trainer.model(name="model0", version="0") + + assert found is not None + assert found.name == "model0" + assert found.version == "0" + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_model_returns_None_when_incorrect_name_and_version(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + not_found = multi_model_trainer.model(name="nonexistent", version="100") + + assert not_found is None + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_model_returns_None_when_incorrect_name_and_correct_version( + _, mock_model_configs +): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + not_found = multi_model_trainer.model(name="nonexistent", version="0") + + assert not_found is None + + +@patch( + "hyperbench.train.trainer.L.Trainer", + side_effect=lambda *args, **kwargs: new_mock_trainer(), +) +def test_fit_all_calls_fit(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + multi_model_trainer.fit_all(verbose=False) + for config in mock_model_configs: + config.trainer.fit.assert_called_once() + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_fit_all_with_no_models(_): + multi_model_trainer = MultiModelTrainer([]) + + with pytest.raises(ValueError, match="No models to fit."): + multi_model_trainer.fit_all(verbose=False) + + +@patch("hyperbench.train.trainer.L.Trainer", return_value=None) +def test_fit_all_raises_when_None_trainer(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + with pytest.raises( + ValueError, + match=f"Trainer not defined for model {mock_model_configs[0].full_model_name()}.", + ): + multi_model_trainer.fit_all(verbose=False) + + +@patch( + "hyperbench.train.trainer.L.Trainer", + side_effect=lambda *args, **kwargs: new_mock_trainer(), +) +def test_fit_all_with_verbose_true_prints(_, mock_model_configs, caplog): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + with caplog.at_level("INFO"): + multi_model_trainer.fit_all(verbose=True) + + for config in mock_model_configs: + config.trainer.fit.assert_called_once() + + logs = [ + record.message for record in caplog.records if "Fit model" in record.message + ] + assert len(logs) == len(mock_model_configs) + + +@patch( + "hyperbench.train.trainer.L.Trainer", + side_effect=lambda *args, **kwargs: new_mock_trainer(), +) +def test_test_all_calls_test_and_returns_results(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + results = multi_model_trainer.test_all(verbose=False) + + assert all("acc" in v for v in results.values()) + + for config in mock_model_configs: + config.trainer.test.assert_called_once() + + +@patch("hyperbench.train.trainer.L.Trainer") +def test_test_all_with_no_models(_): + multi_model_trainer = MultiModelTrainer([]) + + with pytest.raises(ValueError, match="No models to test."): + multi_model_trainer.test_all(verbose=False) + + +@patch("hyperbench.train.trainer.L.Trainer", return_value=None) +def test_test_all_raises_when_None_trainer(_, mock_model_configs): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + with pytest.raises( + ValueError, + match=f"Trainer not defined for model {mock_model_configs[0].full_model_name()}.", + ): + multi_model_trainer.test_all(verbose=False) + + +@patch( + "hyperbench.train.trainer.L.Trainer", + side_effect=lambda *args, **kwargs: new_mock_trainer(), +) +def test_test_all_with_verbose_true_prints(_, mock_model_configs, caplog): + multi_model_trainer = MultiModelTrainer(mock_model_configs) + + with caplog.at_level("INFO"): + multi_model_trainer.test_all(verbose=True) + + for config in mock_model_configs: + config.trainer.test.assert_called_once() + + logs = [ + record.message for record in caplog.records if "Test model" in record.message + ] + assert len(logs) == len(mock_model_configs) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/hyperbench/tests/types/model_test.py b/hyperbench/tests/types/model_test.py new file mode 100644 index 0000000..e14e54a --- /dev/null +++ b/hyperbench/tests/types/model_test.py @@ -0,0 +1,47 @@ +import pytest + +from hyperbench.types import ModelConfig +from unittest.mock import MagicMock +from hyperbench.types.model import ModelConfig + + +@pytest.fixture +def mock_model(): + return MagicMock() + + +@pytest.fixture +def mock_trainer(): + return MagicMock() + + +def test_model_config_initialization_with_trainer(mock_model, mock_trainer): + model_config = ModelConfig( + name="model", model=mock_model, version="0", trainer=mock_trainer + ) + + assert model_config.name == "model" + assert model_config.version == "0" + assert model_config.model is mock_model + assert model_config.trainer is mock_trainer + + +def test_model_config_initialization_without_trainer(mock_model): + mock_config = ModelConfig(name="test_model", model=mock_model) + + assert mock_config.name == "test_model" + assert mock_config.version == "default" + assert mock_config.model is mock_model + assert mock_config.trainer is None + + +def test_full_model_name(mock_model): + mock_config = ModelConfig(name="foo", model=mock_model, version="bar") + + assert mock_config.full_model_name() == "foo:bar" + + +def test_full_model_name_default_version(mock_model): + mock_config = ModelConfig(name="foo", model=mock_model) + + assert mock_config.full_model_name() == "foo:default" diff --git a/hyperbench/train/__init__.py b/hyperbench/train/__init__.py new file mode 100644 index 0000000..bf2d229 --- /dev/null +++ b/hyperbench/train/__init__.py @@ -0,0 +1,5 @@ +from .trainer import MultiModelTrainer + +__all__ = [ + "MultiModelTrainer", +] diff --git a/hyperbench/train/trainer.py b/hyperbench/train/trainer.py new file mode 100644 index 0000000..bdba4ef --- /dev/null +++ b/hyperbench/train/trainer.py @@ -0,0 +1,232 @@ +import lightning as L +import logging as log + +from collections.abc import Iterable +from lightning.pytorch.accelerators import Accelerator +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loggers import Logger +from lightning.pytorch.profilers import Profiler +from lightning.pytorch.strategies import Strategy +from pathlib import Path +from typing import Any, Dict, List, Mapping, Optional +from hyperbench.data import DataLoader +from hyperbench.types import CkptStrategy, ModelConfig, TestResult + + +class MultiModelTrainer: + """ + A trainer class to handle training multiple models with individual trainers. + + Args: + accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto") + as well as custom accelerator instances. + + devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices + (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for + automatic selection based on the chosen accelerator. Default: ``"auto"``. + + strategy: Supports different training strategies with aliases as well custom strategies. + Default: ``"auto"``. + + num_nodes: Number of GPU nodes for distributed training. + Default: ``1``. + + precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), + 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). + Can be used on CPU, GPU, TPUs, or HPUs. + Default: ``'32-true'``. + + max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). + If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. + To enable infinite training, set ``max_epochs = -1``. + + min_epochs: Force training for at least these many epochs. Disabled by default (None). + + max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` + and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set + ``max_epochs`` to ``-1``. + + min_steps: Force training for at least these number of steps. Disabled by default (``None``). + + check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``, + validation will be done solely based on the number of training batches, requiring ``val_check_interval`` + to be an integer value. When used together with a time-based ``val_check_interval`` and + ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses + before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) + and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. + For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without + additional alignment. + Default: ``1``. + + logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses + the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. + ``False`` will disable logging. If multiple loggers are provided, local files + (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger. + Default: ``True``. + + default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. + Default: ``os.getcwd()``. + Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + + enable_autolog_hparams: Whether to log hyperparameters at the start of a run. + Default: ``True``. + + log_every_n_steps: How often to log within steps. + Default: ``50``. + + profiler: To profile individual steps during training and assist in identifying bottlenecks. + Default: ``None``. + + fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) + of train, val and test to find any bugs (ie: a sort of unit test). + Default: ``False``. + + enable_checkpointing: If ``True``, enable checkpointing. + It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`. + Default: ``True``. + + enable_progress_bar: Whether to enable the progress bar by default. + Default: ``True``. + + enable_model_summary: Whether to enable model summarization by default. + Default: ``True``. + + callbacks: Add a callback or list of callbacks. + Default: ``None``. + """ + + def __init__( + self, + model_configs: List[ModelConfig], + # args to pass to each Trainer + accelerator: str | Accelerator = "auto", + devices: list[int] | str | int = "auto", + strategy: str | Strategy = "auto", + num_nodes: int = 1, + precision: Optional[ + Any # Any as Lightning accepts multiple types (int, str, Literal, etc.) + ] = None, + max_epochs: Optional[int] = None, + min_epochs: Optional[int] = None, + max_steps: int = -1, + min_steps: Optional[int] = None, + check_val_every_n_epoch: Optional[int] = 1, + logger: Optional[Logger | Iterable[Logger] | bool] = None, + default_root_dir: Optional[str | Path] = None, + enable_autolog_hparams: bool = True, + log_every_n_steps: Optional[int] = None, + profiler: Optional[Profiler | str] = None, + fast_dev_run: int | bool = False, + enable_checkpointing: bool = True, + enable_progress_bar: bool = True, + enable_model_summary: Optional[bool] = None, + callbacks: Optional[List[Callback] | Callback] = None, + **kwargs, + ) -> None: + self.model_configs: List[ModelConfig] = model_configs + + for model_config in model_configs: + if model_config.trainer is None: + model_config.trainer = L.Trainer( + accelerator=accelerator, + devices=devices, + strategy=strategy, + num_nodes=num_nodes, + precision=precision, + max_epochs=max_epochs, + min_epochs=min_epochs, + max_steps=max_steps, + min_steps=min_steps, + check_val_every_n_epoch=check_val_every_n_epoch, + logger=logger, + default_root_dir=default_root_dir, + enable_autolog_hparams=enable_autolog_hparams, + log_every_n_steps=log_every_n_steps, + profiler=profiler, + fast_dev_run=fast_dev_run, + enable_checkpointing=enable_checkpointing, + enable_progress_bar=enable_progress_bar, + enable_model_summary=enable_model_summary, + callbacks=callbacks, + **kwargs, + ) + + @property + def models(self) -> List[L.LightningModule]: + return [config.model for config in self.model_configs] + + def model(self, name: str, version: str = "default") -> Optional[L.LightningModule]: + for config in self.model_configs: + if config.name == name and config.version == version: + return config.model + return None + + def fit_all( + self, + train_dataloader: Optional[DataLoader] = None, + val_dataloader: Optional[DataLoader] = None, + datamodule: Optional[L.LightningDataModule] = None, + ckpt_path: Optional[CkptStrategy] = None, + verbose: bool = True, + ) -> None: + if len(self.model_configs) < 1: + raise ValueError("No models to fit.") + + for i, config in enumerate(self.model_configs): + if config.trainer is None: + raise ValueError( + f"Trainer not defined for model {config.full_model_name()}." + ) + + if verbose: + log.info( + f"Fit model {config.full_model_name()} [{i + 1}/{len(self.model_configs)}]\n" + ) + + config.trainer.fit( + model=config.model, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + datamodule=datamodule, + ckpt_path=ckpt_path, + ) + + def test_all( + self, + dataloader: Optional[DataLoader] = None, + datamodule: Optional[L.LightningDataModule] = None, + ckpt_path: Optional[CkptStrategy] = None, + verbose: bool = True, + verbose_loop: bool = True, + ) -> Mapping[str, TestResult]: + if len(self.model_configs) < 1: + raise ValueError("No models to test.") + + test_results: Dict[str, TestResult] = {} + + for i, config in enumerate(self.model_configs): + if config.trainer is None: + raise ValueError( + f"Trainer not defined for model {config.full_model_name()}." + ) + + if verbose: + log.info( + f"Test model {config.full_model_name()} [{i + 1}/{len(self.model_configs)}]\n" + ) + + trainer_test_results: List[TestResult] = config.trainer.test( + model=config.model, + dataloaders=dataloader, + datamodule=datamodule, + ckpt_path=ckpt_path, + verbose=verbose_loop, + ) + + # In Lightning, test() returns a list of dicts, one per dataloader, but we use a single dataloader + test_results[config.full_model_name()] = ( + trainer_test_results[0] if len(trainer_test_results) > 0 else {} + ) + + return test_results diff --git a/hyperbench/types/__init__.py b/hyperbench/types/__init__.py index ab83609..d183b2b 100644 --- a/hyperbench/types/__init__.py +++ b/hyperbench/types/__init__.py @@ -1,7 +1,11 @@ from .hypergraph import HIFHypergraph from .hdata import HData +from .model import CkptStrategy, ModelConfig, TestResult __all__ = [ + "CkptStrategy", "HIFHypergraph", "HData", + "ModelConfig", + "TestResult", ] diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index a4f9c09..076240b 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -1,8 +1,10 @@ from torch import Tensor +from typing import Optional class HData: - """Container for hypergraph data. + """ + Container for hypergraph data. Attributes: x (Tensor): Node feature matrix of shape [num_nodes, num_features]. @@ -28,15 +30,15 @@ def __init__( self, x: Tensor, edge_index: Tensor, - edge_attr: Tensor | None = None, - num_nodes: int | None = None, - num_edges: int | None = None, + edge_attr: Optional[Tensor] = None, + num_nodes: Optional[int] = None, + num_edges: Optional[int] = None, ): self.x: Tensor = x self.edge_index: Tensor = edge_index - self.edge_attr: Tensor | None = edge_attr + self.edge_attr: Optional[Tensor] = edge_attr self.num_nodes: int = num_nodes if num_nodes is not None else x.size(0) diff --git a/hyperbench/types/model.py b/hyperbench/types/model.py new file mode 100644 index 0000000..af20f81 --- /dev/null +++ b/hyperbench/types/model.py @@ -0,0 +1,34 @@ +import lightning as L + +from typing import Literal, Mapping, Optional, TypeAlias + + +CkptStrategy: TypeAlias = Literal["best", "last"] +TestResult: TypeAlias = Mapping[str, float] + + +class ModelConfig: + """ + A class representing the configuration of a model for the MultiModelTrainer trainer. + + Args: + name: The name of the model. + version: The version of the model. + model: a LightningModule instance. + trainer: a Trainer instance. + """ + + def __init__( + self, + name: str, + model: L.LightningModule, + version: str = "default", + trainer: Optional[L.Trainer] = None, + ) -> None: + self.name = name + self.version = version + self.model = model + self.trainer = trainer + + def full_model_name(self) -> str: + return f"{self.name}:{self.version}" diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 6c42d9f..ad4010e 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -11,6 +11,9 @@ __all__ = [ "empty_edgeattr", "empty_edgeindex", - "to_non_empty_edgeattr", + "empty_hdata", + "empty_hifhypergraph", + "empty_nodefeatures", "validate_hif_json", + "to_non_empty_edgeattr", ] diff --git a/hyperbench/utils/data_utils.py b/hyperbench/utils/data_utils.py index 2c9ca09..4cdd136 100644 --- a/hyperbench/utils/data_utils.py +++ b/hyperbench/utils/data_utils.py @@ -16,11 +16,6 @@ def empty_edgeattr(num_edges: int) -> Tensor: return torch.empty((num_edges, 0)) -def to_non_empty_edgeattr(edge_attr: Tensor | None) -> Tensor: - num_edges = edge_attr.size(0) if edge_attr is not None else 0 - return empty_edgeattr(num_edges) if edge_attr is None else edge_attr - - def empty_hdata() -> HData: return HData( x=empty_nodefeatures(), @@ -35,3 +30,8 @@ def empty_hifhypergraph() -> HIFHypergraph: return HIFHypergraph( network_type="undirected", nodes=[], edges=[], incidences=[], metadata=None ) + + +def to_non_empty_edgeattr(edge_attr: Tensor | None) -> Tensor: + num_edges = edge_attr.size(0) if edge_attr is not None else 0 + return empty_edgeattr(num_edges) if edge_attr is None else edge_attr diff --git a/pyproject.toml b/pyproject.toml index 7672397..d9a8777 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,8 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "fastjsonschema>=2.21.2", + "gdown>=5.2.1", + "lightning>=2.5.5", "numpy>=1.240", "requests>=2.32.5", "torch>=2.9.1", From 7bfbcc9d9bb2e9b1900052e4438c6cd190f9d9f6 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 28 Jan 2026 17:04:13 +0100 Subject: [PATCH 06/13] feat: removed gdown - pull from dataset repository - revised test --- hyperbench/data/dataset.py | 19 +++++++++++++++++++ pyproject.toml | 1 - 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 3a92d15..b362fea 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -4,6 +4,7 @@ import torch import zstandard as zstd import requests +import requests from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union @@ -37,6 +38,24 @@ class DatasetNames(Enum): THREADS_MATH_SX = "threads-math-sx" TWITTER = "twitter" VEGAS_BARS_REVIEWS = "vegas-bars-reviews" + ALGEBRA = "algebra" + AMAZON = "amazon" + CONTACT_HIGH_SCHOOL = "contact-high-school" + CONTACT_PRIMARY_SCHOOL = "contact-primary-school" + DBLP = "dblp" + EMAIL_ENRON = "email-Enron" + EMAIL_W3C = "email-W3C" + GEOMETRY = "geometry" + GOT = "got" + MUSIC_BLUES_REVIEWS = "music-blues-reviews" + NBA = "nba" + NDC_CLASSES = "NDC-classes" + NDC_SUBSTANCES = "NDC-substances" + RESTAURANT_REVIEWS = "restaurant-reviews" + THREADS_ASK_UBUNTU = "threads-ask-ubuntu" + THREADS_MATH_SX = "threads-math-sx" + TWITTER = "twitter" + VEGAS_BARS_REVIEWS = "vegas-bars-reviews" class HIFConverter: diff --git a/pyproject.toml b/pyproject.toml index d9a8777..f472788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "fastjsonschema>=2.21.2", - "gdown>=5.2.1", "lightning>=2.5.5", "numpy>=1.240", "requests>=2.32.5", From d1eaf9ab82e6b64636b1cc4ac1535db1f4cc3e50 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 28 Jan 2026 17:30:16 +0100 Subject: [PATCH 07/13] feat: revised test names --- README.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/README.md b/README.md index eb9da44..11ca5cf 100644 --- a/README.md +++ b/README.md @@ -9,15 +9,7 @@ [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) -## About the project - -WIP - -## Getting started - -### Prerequisites - -WIP +For documentation, please visit [here][docs]. ### Installation From 02f1869f0c47789f39d63fe7343d5eaac386b06e Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 29 Jan 2026 13:12:43 +0100 Subject: [PATCH 08/13] chore: duplicate names --- hyperbench/data/dataset.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index b362fea..633c1f5 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -38,24 +38,6 @@ class DatasetNames(Enum): THREADS_MATH_SX = "threads-math-sx" TWITTER = "twitter" VEGAS_BARS_REVIEWS = "vegas-bars-reviews" - ALGEBRA = "algebra" - AMAZON = "amazon" - CONTACT_HIGH_SCHOOL = "contact-high-school" - CONTACT_PRIMARY_SCHOOL = "contact-primary-school" - DBLP = "dblp" - EMAIL_ENRON = "email-Enron" - EMAIL_W3C = "email-W3C" - GEOMETRY = "geometry" - GOT = "got" - MUSIC_BLUES_REVIEWS = "music-blues-reviews" - NBA = "nba" - NDC_CLASSES = "NDC-classes" - NDC_SUBSTANCES = "NDC-substances" - RESTAURANT_REVIEWS = "restaurant-reviews" - THREADS_ASK_UBUNTU = "threads-ask-ubuntu" - THREADS_MATH_SX = "threads-math-sx" - TWITTER = "twitter" - VEGAS_BARS_REVIEWS = "vegas-bars-reviews" class HIFConverter: From fe0f52a36ed4ca6d527ed40077f5b112e5983752 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 29 Jan 2026 13:27:54 +0100 Subject: [PATCH 09/13] fix resolve merge conflict --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 11ca5cf..eb9da44 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,15 @@ [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) -For documentation, please visit [here][docs]. +## About the project + +WIP + +## Getting started + +### Prerequisites + +WIP ### Installation From 0396338ea43fa5f46ca14e62f686a1c017a3ecef Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 29 Jan 2026 13:30:14 +0100 Subject: [PATCH 10/13] fix: resolve conflicts in dataset --- hyperbench/data/dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 633c1f5..3a92d15 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -4,7 +4,6 @@ import torch import zstandard as zstd import requests -import requests from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union From 0bf9022e6df988706b6607fc16f5a7787192b330 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 28 Jan 2026 17:30:16 +0100 Subject: [PATCH 11/13] feat: revised test names --- README.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/README.md b/README.md index eb9da44..11ca5cf 100644 --- a/README.md +++ b/README.md @@ -9,15 +9,7 @@ [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) -## About the project - -WIP - -## Getting started - -### Prerequisites - -WIP +For documentation, please visit [here][docs]. ### Installation From cc711b3192699aa541a58929e06e3d9cd1bf5d2d Mon Sep 17 00:00:00 2001 From: Tiziano Citro <56075735+tizianocitro@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:46:29 +0100 Subject: [PATCH 12/13] chore: format README (#22) --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 11ca5cf..eb9da44 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,15 @@ [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) -For documentation, please visit [here][docs]. +## About the project + +WIP + +## Getting started + +### Prerequisites + +WIP ### Installation From 1a495096e3136f655d414c03461d90e9d20f664e Mon Sep 17 00:00:00 2001 From: Tiziano Citro <56075735+tizianocitro@users.noreply.github.com> Date: Thu, 29 Jan 2026 13:36:47 +0100 Subject: [PATCH 13/13] refactor: move default case to be int in __get_node_ids_to_sample (#23) --- hyperbench/data/dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 3a92d15..e2a2172 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -291,10 +291,7 @@ def __collect_attr_keys(self, attr_keys: List[Dict[str, Any]]) -> List[str]: return unique_keys - def __get_node_ids_to_sample(self, id: Union[int, List[int]]) -> List[int]: - if isinstance(id, int): - return [id] - + def __get_node_ids_to_sample(self, id: int | List[int]) -> List[int]: if isinstance(id, list): if len(id) < 1: raise ValueError("Index list cannot be empty.") @@ -304,6 +301,8 @@ def __get_node_ids_to_sample(self, id: Union[int, List[int]]) -> List[int]: ) return list(set(id)) + return [id] + def __validate_node_ids(self, node_ids: List[int]) -> None: for id in node_ids: if id < 0 or id >= self.__len__():