diff --git a/hyperbench/tests/train/negative_sampler_test.py b/hyperbench/tests/train/negative_sampler_test.py new file mode 100644 index 0000000..f3962ac --- /dev/null +++ b/hyperbench/tests/train/negative_sampler_test.py @@ -0,0 +1,99 @@ +import pytest +import torch + +from hyperbench.train import NegativeSampler, RandomNegativeSampler +from hyperbench.types import HData + + +@pytest.fixture +def mock_hdata_with_attr(): + return HData( + x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + edge_index=torch.tensor([[0, 1, 2], [0, 1, 2]]), + edge_attr=torch.tensor([[0.5, 0.6], [0.7, 0.8], [0.9, 1.0]]), + num_nodes=3, + num_edges=3, + ) + + +@pytest.fixture +def mock_hdata_no_attr(): + return HData( + x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + edge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]), + edge_attr=None, + num_nodes=3, + num_edges=2, + ) + + +def test_negative_sampler_is_abstract(mock_hdata_no_attr): + sampler = NegativeSampler() + with pytest.raises(NotImplementedError): + sampler.sample(mock_hdata_no_attr) + + +def test_random_negative_sampler_invalid_args(): + with pytest.raises( + ValueError, match="num_negative_samples must be positive, got 0" + ): + RandomNegativeSampler(num_negative_samples=0, num_nodes_per_sample=2) + + with pytest.raises( + ValueError, match="num_nodes_per_sample must be positive, got 0" + ): + RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=0) + + +def test_random_negative_sampler_sample_too_many_nodes(mock_hdata_with_attr): + sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=10) + with pytest.raises( + ValueError, + match="Asked to create samples with 10 nodes, but only 3 nodes are available", + ): + sampler.sample(mock_hdata_with_attr) + + +def test_random_negative_sampler_with_edge_attr(mock_hdata_with_attr): + sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=2) + result = sampler.sample(mock_hdata_with_attr) + + assert result.num_edges == 2 + assert result.x.shape[0] <= mock_hdata_with_attr.x.shape[0] + assert result.edge_index.shape[0] == 2 + assert ( + result.edge_index.shape[1] == 4 + ) # 2 negative hyperedges * 2 nodes per negative hyperedge + assert result.edge_attr is not None + assert result.edge_attr.shape[0] == 2 + + +def test_random_negative_sampler_sample_no_edge_attr(mock_hdata_no_attr): + sampler = RandomNegativeSampler(num_negative_samples=1, num_nodes_per_sample=2) + result = sampler.sample(mock_hdata_no_attr) + + assert result.num_edges == 1 + assert result.x.shape[0] <= mock_hdata_no_attr.x.shape[0] + assert result.edge_index.shape[0] == 2 + assert ( + result.edge_index.shape[1] == 2 + ) # 1 negative hyperedge * 2 nodes per negative hyperedge + assert result.edge_attr is None + + +def test_random_negative_sampler_sample_unique_nodes(mock_hdata_with_attr): + sampler = RandomNegativeSampler(num_negative_samples=3, num_nodes_per_sample=2) + result = sampler.sample(mock_hdata_with_attr) + + node_ids = result.edge_index[0] + edge_ids = result.edge_index[1] + + # All node indices in edge_index should be valid + assert torch.all(node_ids < mock_hdata_with_attr.num_nodes) + + # No duplicate node indices within a single edge + for edge_id in edge_ids.unique(): + edge_mask = torch.isin(edge_ids, edge_id) + unique_edge_nodes = node_ids[edge_mask].unique() + + assert len(unique_edge_nodes) == sampler.num_nodes_per_sample diff --git a/hyperbench/train/__init__.py b/hyperbench/train/__init__.py index bf2d229..e62896b 100644 --- a/hyperbench/train/__init__.py +++ b/hyperbench/train/__init__.py @@ -1,5 +1,8 @@ +from .negative_sampler import NegativeSampler, RandomNegativeSampler from .trainer import MultiModelTrainer __all__ = [ + "NegativeSampler", + "RandomNegativeSampler", "MultiModelTrainer", ] diff --git a/hyperbench/train/negative_sampler.py b/hyperbench/train/negative_sampler.py new file mode 100644 index 0000000..57e54c1 --- /dev/null +++ b/hyperbench/train/negative_sampler.py @@ -0,0 +1,167 @@ +import torch + +from typing import List, Set +from torch import Tensor +from hyperbench.types import HData + + +class NegativeSampler: + def sample(self, data: HData) -> HData: + """ + Abstract method for negative sampling. + + Args: + data: HData + The input data object containing graph or hypergraph information. + + Returns: + HData: The negative samples as a new HData object. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + raise NotImplementedError("Subclasses must implement this method.") + + +class RandomNegativeSampler(NegativeSampler): + """ + A random negative sampler. + + Args: + num_negative_samples (int): Number of negative hyperedges to generate. + num_nodes_per_sample (int): Number of nodes per negative hyperedge. + + Raises: + ValueError: If either argument is not positive. + """ + + def __init__(self, num_negative_samples: int, num_nodes_per_sample: int): + if num_negative_samples <= 0: + raise ValueError( + f"num_negative_samples must be positive, got {num_negative_samples}." + ) + if num_nodes_per_sample <= 0: + raise ValueError( + f"num_nodes_per_sample must be positive, got {num_nodes_per_sample}." + ) + + super().__init__() + self.num_negative_samples = num_negative_samples + self.num_nodes_per_sample = num_nodes_per_sample + + def sample(self, data: HData) -> HData: + """ + Generate negative hyperedges by randomly sampling unique node IDs. + + Args: + data (HData): The input data object containing node and hyperedge information. + + Returns: + HData: A new HData object containing the negative samples. + + Raises: + ValueError: If num_nodes_per_sample is greater than the number of available nodes. + """ + if self.num_nodes_per_sample > data.num_nodes: + raise ValueError( + f"Asked to create samples with {self.num_nodes_per_sample} nodes, but only {data.num_nodes} nodes are available." + ) + + negative_node_ids: Set[int] = set() + sampled_edge_indexes: List[Tensor] = [] + sampled_edge_attrs: List[Tensor] = [] + + new_edge_id_offset = data.num_edges + for new_edge_id in range(self.num_negative_samples): + # Sample with multinomial without replacement to ensure unique node ids + # and assign each node id equal probability of being selected by setting all of them to 1 + # Example: num_nodes_per_sample=3, max_node_id=5 + # -> possible output: [2, 0, 4] + equal_probabilities = torch.ones(data.num_nodes) + sampled_node_ids = torch.multinomial( + equal_probabilities, self.num_nodes_per_sample, replacement=False + ) + + # Example: sampled_node_ids = [2, 0, 4], new_edge_id=0, new_edge_id_offset=3 + # -> edge_index = [[2, 0, 4], + # [3, 3, 3]] + sampled_edge_id_tensor = torch.full( + (self.num_nodes_per_sample,), new_edge_id + new_edge_id_offset + ) + sampled_edge_index = torch.stack( + [sampled_node_ids, sampled_edge_id_tensor], dim=0 + ) + sampled_edge_indexes.append(sampled_edge_index) + + # Example: nodes = [0, 1, 2], + # sampled_node_ids_0 = [0, 1], sampled_node_ids_1 = [1, 2], + # -> negative_node_ids = {0, 1, 2} + negative_node_ids.update(sampled_node_ids.tolist()) + + if data.edge_attr is not None: + random_edge_attr = torch.randn_like(data.edge_attr[0]) + sampled_edge_attrs.append(random_edge_attr) + + negative_node_features = data.x[sorted(negative_node_ids)] + negative_edge_index = self.__new_negative_edge_index(sampled_edge_indexes) + negative_edge_attr = ( + torch.stack(sampled_edge_attrs, dim=0) + if data.edge_attr is not None + else None + ) + + return HData( + x=negative_node_features, + edge_index=negative_edge_index, + edge_attr=negative_edge_attr, + num_nodes=len(negative_node_ids), + num_edges=self.num_negative_samples, + ) + + def __new_negative_edge_index(self, sampled_edge_indexes: List[Tensor]) -> Tensor: + """ + Concatenate and sort the sampled edge indexes for negative samples. + + Args: + sampled_edge_indexes (Tensor): List of edge index tensors for each negative sample. + + Returns: + Tensor: The concatenated and sorted edge index tensor. + """ + negative_edge_index = torch.cat(sampled_edge_indexes, dim=1) + node_ids_order = negative_edge_index[0].argsort() + + # Example: negative_edge_index before sorting: [[2, 0, 4, 0, 1, 3], + # [3, 3, 3, 4, 4, 4]] + # -> negative_edge_index after sorting: [[0, 0, 1, 2, 3, 4], + # [3, 4, 4, 3, 4, 3]] + negative_edge_index = negative_edge_index[:, node_ids_order] + return negative_edge_index + + +if __name__ == "__main__": + edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) + x = torch.randn(3, 2) + edge_attr = torch.randn(3, 3) + print(f"Original node features:\n{x}") + print(f"Original edge_attr:\n{edge_attr}") + + sampler = RandomNegativeSampler(num_negative_samples=4, num_nodes_per_sample=2) + negative_hdata = sampler.sample( + HData(x=x, edge_index=edge_index, edge_attr=edge_attr) + ) + print(f"HData: {negative_hdata}") + + try: + RandomNegativeSampler(num_negative_samples=-1, num_nodes_per_sample=2) + except ValueError as e: + print(f"Caught expected exception: {e}") + try: + RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=-1) + except ValueError as e: + print(f"Caught expected exception: {e}") + try: + s = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=10) + s.sample(HData(x=x, edge_index=edge_index, edge_attr=edge_attr)) + except ValueError as e: + print(f"Caught expected exception: {e}") diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index 076240b..a8df730 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -44,3 +44,14 @@ def __init__( max_edge_id = edge_index[1].max().item() if edge_index.size(1) > 0 else -1 self.num_edges: int = num_edges if num_edges is not None else max_edge_id + 1 + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(\n" + f" num_nodes={self.num_nodes},\n" + f" num_edges={self.num_edges},\n" + f" x_shape={self.x.shape},\n" + f" edge_index_shape={self.edge_index.shape},\n" + f" edge_attr_shape={self.edge_attr.shape if self.edge_attr is not None else None}\n" + f")" + )