Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# HyperBench

[![Contributors][contributors-shield]][contributors-url]
[![Forks][forks-shield]][forks-url]
[![Stargazers][stars-shield]][stars-url]
[![Contributors][contributors-shield]][contributors-url]

[![Issues][issues-shield]][issues-url]
[![project_license][license-shield]][license-url]

[![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench)

## About the project
Expand Down Expand Up @@ -95,3 +97,4 @@ WIP
[issues-url]: https://github.com/hypernetwork-research-group/hyperbench/issues
[license-shield]: https://img.shields.io/github/license/hypernetwork-research-group/hyperbench.svg?style=for-the-badge
[license-url]: https://github.com/hypernetwork-research-group/hyperbench/blob/master/LICENSE.txt
[docs]: https://hypernetwork-research-group.github.io/hyperbench/
118 changes: 75 additions & 43 deletions hyperbench/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import os
import gdown
import tempfile
import torch
import zstandard as zstd
import requests

from enum import Enum
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from hyperbench.types.hypergraph import HIFHypergraph
Expand All @@ -19,11 +19,24 @@ class DatasetNames(Enum):
Enumeration of available datasets.
"""

ALGEBRA = "1"
EMAIL_ENRON = "2"
ARXIV = "3"
DBLP = "4"
THREADSMATHSX = "5"
ALGEBRA = "algebra"
AMAZON = "amazon"
CONTACT_HIGH_SCHOOL = "contact-high-school"
CONTACT_PRIMARY_SCHOOL = "contact-primary-school"
DBLP = "dblp"
EMAIL_ENRON = "email-Enron"
EMAIL_W3C = "email-W3C"
GEOMETRY = "geometry"
GOT = "got"
MUSIC_BLUES_REVIEWS = "music-blues-reviews"
NBA = "nba"
NDC_CLASSES = "NDC-classes"
NDC_SUBSTANCES = "NDC-substances"
RESTAURANT_REVIEWS = "restaurant-reviews"
THREADS_ASK_UBUNTU = "threads-ask-ubuntu"
THREADS_MATH_SX = "threads-math-sx"
TWITTER = "twitter"
VEGAS_BARS_REVIEWS = "vegas-bars-reviews"


class HIFConverter:
Expand All @@ -33,38 +46,51 @@ class HIFConverter:
"""

@staticmethod
def load_from_hif(dataset_name: str | None, file_id: str | None) -> HIFHypergraph:
if dataset_name is None or file_id is None:
def load_from_hif(
dataset_name: Optional[str], save_on_disk: bool = False
) -> HIFHypergraph:
if dataset_name is None:
raise ValueError(
f"Dataset name (provided: {dataset_name}) and file ID (provided: {file_id}) must be provided."
f"Dataset name (provided: {dataset_name}) must be provided."
)
if dataset_name not in DatasetNames.__members__:
raise ValueError(f"Dataset '{dataset_name}' not found.")

dataset_name_lower = dataset_name.lower()
dataset_name = DatasetNames[dataset_name].value
current_dir = os.path.dirname(os.path.abspath(__file__))
zst_filename = os.path.join(
current_dir, "datasets", f"{dataset_name_lower}.json.zst"
)
zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst")

if os.path.exists(zst_filename):
dctx = zstd.ZstdDecompressor()
with (
open(zst_filename, "rb") as input_f,
tempfile.NamedTemporaryFile(
mode="wb", suffix=".json", delete=False
) as tmp_file,
):
dctx.copy_stream(input_f, tmp_file)
output = tmp_file.name
else:
url = f"https://drive.google.com/uc?id={file_id}"
if not os.path.exists(zst_filename):
github_dataset_repo = f"https://github.com/hypernetwork-research-group/datasets/blob/main/{dataset_name}.json.zst?raw=true"

response = requests.get(github_dataset_repo)
if response.status_code != 200:
raise ValueError(
f"Failed to download dataset '{dataset_name}' from GitHub. Status code: {response.status_code}"
)

with tempfile.NamedTemporaryFile(
mode="w+", suffix=".json", delete=False
) as tmp_file:
output = tmp_file.name
gdown.download(url=url, output=output, quiet=False, fuzzy=True)
if save_on_disk:
os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True)
with open(zst_filename, "wb") as f:
f.write(response.content)
else:
# Create temporary file for downloaded zst content
with tempfile.NamedTemporaryFile(
mode="wb", suffix=".json.zst", delete=False
) as tmp_zst_file:
tmp_zst_file.write(response.content)
zst_filename = tmp_zst_file.name

# Decompress the downloaded zst file
dctx = zstd.ZstdDecompressor()
with (
open(zst_filename, "rb") as input_f,
tempfile.NamedTemporaryFile(
mode="wb", suffix=".json", delete=False
) as tmp_file,
):
dctx.copy_stream(input_f, tmp_file)
output = tmp_file.name

with open(output, "r") as f:
hiftext = json.load(f)
Expand All @@ -79,15 +105,13 @@ class Dataset(TorchDataset):
"""
Base Dataset class for hypergraph datasets, extending PyTorch's Dataset.
Attributes:
GDRIVE_FILE_ID (str): Google Drive file ID for the dataset.
DATASET_NAME (str): Name of the dataset.
hypergraph (HIFHypergraph): Loaded hypergraph instance.
Methods:
download(): Downloads and loads the hypergraph from HIF.
process(): Processes the hypergraph into HData format.
"""

GDRIVE_FILE_ID = None
DATASET_NAME = None

def __init__(self) -> None:
Expand Down Expand Up @@ -129,7 +153,7 @@ def download(self) -> HIFHypergraph:
"""
if hasattr(self, "hypergraph") and self.hypergraph is not None:
return self.hypergraph
hypergraph = HIFConverter.load_from_hif(self.DATASET_NAME, self.GDRIVE_FILE_ID)
hypergraph = HIFConverter.load_from_hif(self.DATASET_NAME)
return hypergraph

def process(self) -> HData:
Expand Down Expand Up @@ -211,17 +235,17 @@ def process(self) -> HData:
return HData(x, edge_index, edge_attr, num_nodes, num_edges)

def transform_node_attrs(
self, attrs: Dict[str, Any], attr_keys: List[str] | None = None
self, attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None
) -> Tensor:
return self.transform_attrs(attrs, attr_keys)

def transform_edge_attrs(
self, attrs: Dict[str, Any], attr_keys: List[str] | None = None
self, attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None
) -> Tensor:
return self.transform_attrs(attrs, attr_keys)

def transform_attrs(
self, attrs: Dict[str, Any], attr_keys: List[str] | None = None
self, attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None
) -> Tensor:
"""
Extract and encode numeric node attributes to tensor.
Expand Down Expand Up @@ -268,9 +292,6 @@ def __collect_attr_keys(self, attr_keys: List[Dict[str, Any]]) -> List[str]:
return unique_keys

def __get_node_ids_to_sample(self, id: int | List[int]) -> List[int]:
if isinstance(id, int):
return [id]

if isinstance(id, list):
if len(id) < 1:
raise ValueError("Index list cannot be empty.")
Expand All @@ -280,6 +301,8 @@ def __get_node_ids_to_sample(self, id: int | List[int]) -> List[int]:
)
return list(set(id))

return [id]

def __validate_node_ids(self, node_ids: List[int]) -> None:
for id in node_ids:
if id < 0 or id >= self.__len__():
Expand Down Expand Up @@ -386,14 +409,23 @@ def __to_0based_ids(

class AlgebraDataset(Dataset):
DATASET_NAME = "ALGEBRA"
GDRIVE_FILE_ID = "1-H21_mZTcbbae4U_yM3xzXX19VhbCZ9C"


class DBLPDataset(Dataset):
DATASET_NAME = "DBLP"
GDRIVE_FILE_ID = "1oiXQWdybAAUvhiYbFY1R9Qd0jliMSSQh"


class ThreadsMathsxDataset(Dataset):
DATASET_NAME = "THREADSMATHSX"
GDRIVE_FILE_ID = "1jS4FDs7ME-mENV6AJwCOb_glXKMT7YLQ"


if __name__ == "__main__":
for dataset in DatasetNames:
print(f"Processing dataset: {dataset.value}")
if dataset == DatasetNames.EMAIL_ENRON:
load_hif = HIFConverter.load_from_hif(dataset.name, save_on_disk=True)
continue
load_hif = HIFConverter.load_from_hif(dataset.name)
print(
f"Loaded HIF hypergraph with {len(load_hif.nodes)} nodes and {len(load_hif.edges)} edges."
)
Loading