From 8738a9075845f0d3162f6a27d50b6dafa794e624 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 5 Feb 2026 15:38:54 +0100 Subject: [PATCH] feat: complete test --- hyperbench/tests/data/dataset_test.py | 41 +++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 0cead09..08c97ac 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -1,3 +1,5 @@ +import zstandard as zstd +import json import requests import torch import pytest @@ -852,3 +854,42 @@ class TestDataset(Dataset): assert torch.allclose( result, torch.tensor([1.5, 0.8]) ) # weight, score (insertion order) + + +def test_load_from_hif_file_exists(): + """Test loading dataset when file already exists locally (skip download).""" + dataset_name = "ALGEBRA" + + sample_hif = { + "network-type": "undirected", + "nodes": [{"node": "0"}, {"node": "1"}], + "edges": [{"edge": "0"}], + "incidences": [{"node": "0", "edge": "0"}], + } + + mock_hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}, {"node": "1"}], + edges=[{"edge": "0"}], + incidences=[{"node": "0", "edge": "0"}], + ) + + with ( + patch("hyperbench.data.dataset.requests.get") as mock_get, + patch("hyperbench.data.dataset.os.path.exists", return_value=True), + patch("builtins.open", mock_open()) as mock_file, + patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, + patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, + patch("hyperbench.data.dataset.json.load", return_value=sample_hif), + patch("hyperbench.data.dataset.validate_hif_json", return_value=True), + patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), + ): + mock_dctx = mock_decomp.return_value + mock_dctx.copy_stream = lambda input_f, tmp_file: None + + mock_temp_instance = mock_temp.return_value.__enter__.return_value + mock_temp_instance.name = "/tmp/decompressed.json" + + result = HIFConverter.load_from_hif(dataset_name, save_on_disk=True) + mock_get.assert_not_called() + assert result == mock_hypergraph