Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions hyperbench/tests/train/negative_sampler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch

from hyperbench.train import NegativeSampler, RandomNegativeSampler
from hyperbench.types import HData


@pytest.fixture
def mock_hdata_with_attr():
return HData(
x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
edge_index=torch.tensor([[0, 1, 2], [0, 1, 2]]),
edge_attr=torch.tensor([[0.5, 0.6], [0.7, 0.8], [0.9, 1.0]]),
num_nodes=3,
num_edges=3,
)


@pytest.fixture
def mock_hdata_no_attr():
return HData(
x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
edge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]),
edge_attr=None,
num_nodes=3,
num_edges=2,
)


def test_negative_sampler_is_abstract(mock_hdata_no_attr):
sampler = NegativeSampler()
with pytest.raises(NotImplementedError):
sampler.sample(mock_hdata_no_attr)


def test_random_negative_sampler_invalid_args():
with pytest.raises(
ValueError, match="num_negative_samples must be positive, got 0"
):
RandomNegativeSampler(num_negative_samples=0, num_nodes_per_sample=2)

with pytest.raises(
ValueError, match="num_nodes_per_sample must be positive, got 0"
):
RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=0)


def test_random_negative_sampler_sample_too_many_nodes(mock_hdata_with_attr):
sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=10)
with pytest.raises(
ValueError,
match="Asked to create samples with 10 nodes, but only 3 nodes are available",
):
sampler.sample(mock_hdata_with_attr)


def test_random_negative_sampler_with_edge_attr(mock_hdata_with_attr):
sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=2)
result = sampler.sample(mock_hdata_with_attr)

assert result.num_edges == 2
assert result.x.shape[0] <= mock_hdata_with_attr.x.shape[0]
assert result.edge_index.shape[0] == 2
assert (
result.edge_index.shape[1] == 4
) # 2 negative hyperedges * 2 nodes per negative hyperedge
assert result.edge_attr is not None
assert result.edge_attr.shape[0] == 2


def test_random_negative_sampler_sample_no_edge_attr(mock_hdata_no_attr):
sampler = RandomNegativeSampler(num_negative_samples=1, num_nodes_per_sample=2)
result = sampler.sample(mock_hdata_no_attr)

assert result.num_edges == 1
assert result.x.shape[0] <= mock_hdata_no_attr.x.shape[0]
assert result.edge_index.shape[0] == 2
assert (
result.edge_index.shape[1] == 2
) # 1 negative hyperedge * 2 nodes per negative hyperedge
assert result.edge_attr is None


def test_random_negative_sampler_sample_unique_nodes(mock_hdata_with_attr):
sampler = RandomNegativeSampler(num_negative_samples=3, num_nodes_per_sample=2)
result = sampler.sample(mock_hdata_with_attr)

node_ids = result.edge_index[0]
edge_ids = result.edge_index[1]

# All node indices in edge_index should be valid
assert torch.all(node_ids < mock_hdata_with_attr.num_nodes)

# No duplicate node indices within a single edge
for edge_id in edge_ids.unique():
edge_mask = torch.isin(edge_ids, edge_id)
unique_edge_nodes = node_ids[edge_mask].unique()

assert len(unique_edge_nodes) == sampler.num_nodes_per_sample
3 changes: 3 additions & 0 deletions hyperbench/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .negative_sampler import NegativeSampler, RandomNegativeSampler
from .trainer import MultiModelTrainer

__all__ = [
"NegativeSampler",
"RandomNegativeSampler",
"MultiModelTrainer",
]
167 changes: 167 additions & 0 deletions hyperbench/train/negative_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import torch

from typing import List, Set
from torch import Tensor
from hyperbench.types import HData


class NegativeSampler:
def sample(self, data: HData) -> HData:
"""
Abstract method for negative sampling.

Args:
data: HData
The input data object containing graph or hypergraph information.

Returns:
HData: The negative samples as a new HData object.

Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError("Subclasses must implement this method.")


class RandomNegativeSampler(NegativeSampler):
"""
A random negative sampler.

Args:
num_negative_samples (int): Number of negative hyperedges to generate.
num_nodes_per_sample (int): Number of nodes per negative hyperedge.

Raises:
ValueError: If either argument is not positive.
"""

def __init__(self, num_negative_samples: int, num_nodes_per_sample: int):
if num_negative_samples <= 0:
raise ValueError(
f"num_negative_samples must be positive, got {num_negative_samples}."
)
if num_nodes_per_sample <= 0:
raise ValueError(
f"num_nodes_per_sample must be positive, got {num_nodes_per_sample}."
)

super().__init__()
self.num_negative_samples = num_negative_samples
self.num_nodes_per_sample = num_nodes_per_sample

def sample(self, data: HData) -> HData:
"""
Generate negative hyperedges by randomly sampling unique node IDs.

Args:
data (HData): The input data object containing node and hyperedge information.

Returns:
HData: A new HData object containing the negative samples.

Raises:
ValueError: If num_nodes_per_sample is greater than the number of available nodes.
"""
if self.num_nodes_per_sample > data.num_nodes:
raise ValueError(
f"Asked to create samples with {self.num_nodes_per_sample} nodes, but only {data.num_nodes} nodes are available."
)

negative_node_ids: Set[int] = set()
sampled_edge_indexes: List[Tensor] = []
sampled_edge_attrs: List[Tensor] = []

new_edge_id_offset = data.num_edges
for new_edge_id in range(self.num_negative_samples):
# Sample with multinomial without replacement to ensure unique node ids
# and assign each node id equal probability of being selected by setting all of them to 1
# Example: num_nodes_per_sample=3, max_node_id=5
# -> possible output: [2, 0, 4]
equal_probabilities = torch.ones(data.num_nodes)
sampled_node_ids = torch.multinomial(
equal_probabilities, self.num_nodes_per_sample, replacement=False
)

# Example: sampled_node_ids = [2, 0, 4], new_edge_id=0, new_edge_id_offset=3
# -> edge_index = [[2, 0, 4],
# [3, 3, 3]]
sampled_edge_id_tensor = torch.full(
(self.num_nodes_per_sample,), new_edge_id + new_edge_id_offset
)
sampled_edge_index = torch.stack(
[sampled_node_ids, sampled_edge_id_tensor], dim=0
)
sampled_edge_indexes.append(sampled_edge_index)

# Example: nodes = [0, 1, 2],
# sampled_node_ids_0 = [0, 1], sampled_node_ids_1 = [1, 2],
# -> negative_node_ids = {0, 1, 2}
negative_node_ids.update(sampled_node_ids.tolist())

if data.edge_attr is not None:
random_edge_attr = torch.randn_like(data.edge_attr[0])
sampled_edge_attrs.append(random_edge_attr)

negative_node_features = data.x[sorted(negative_node_ids)]
negative_edge_index = self.__new_negative_edge_index(sampled_edge_indexes)
negative_edge_attr = (
torch.stack(sampled_edge_attrs, dim=0)
if data.edge_attr is not None
else None
)

return HData(
x=negative_node_features,
edge_index=negative_edge_index,
edge_attr=negative_edge_attr,
num_nodes=len(negative_node_ids),
num_edges=self.num_negative_samples,
)

def __new_negative_edge_index(self, sampled_edge_indexes: List[Tensor]) -> Tensor:
"""
Concatenate and sort the sampled edge indexes for negative samples.

Args:
sampled_edge_indexes (Tensor): List of edge index tensors for each negative sample.

Returns:
Tensor: The concatenated and sorted edge index tensor.
"""
negative_edge_index = torch.cat(sampled_edge_indexes, dim=1)
node_ids_order = negative_edge_index[0].argsort()

# Example: negative_edge_index before sorting: [[2, 0, 4, 0, 1, 3],
# [3, 3, 3, 4, 4, 4]]
# -> negative_edge_index after sorting: [[0, 0, 1, 2, 3, 4],
# [3, 4, 4, 3, 4, 3]]
negative_edge_index = negative_edge_index[:, node_ids_order]
return negative_edge_index


if __name__ == "__main__":
edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])
x = torch.randn(3, 2)
edge_attr = torch.randn(3, 3)
print(f"Original node features:\n{x}")
print(f"Original edge_attr:\n{edge_attr}")

sampler = RandomNegativeSampler(num_negative_samples=4, num_nodes_per_sample=2)
negative_hdata = sampler.sample(
HData(x=x, edge_index=edge_index, edge_attr=edge_attr)
)
print(f"HData: {negative_hdata}")

try:
RandomNegativeSampler(num_negative_samples=-1, num_nodes_per_sample=2)
except ValueError as e:
print(f"Caught expected exception: {e}")
try:
RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=-1)
except ValueError as e:
print(f"Caught expected exception: {e}")
try:
s = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=10)
s.sample(HData(x=x, edge_index=edge_index, edge_attr=edge_attr))
except ValueError as e:
print(f"Caught expected exception: {e}")
11 changes: 11 additions & 0 deletions hyperbench/types/hdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,14 @@ def __init__(

max_edge_id = edge_index[1].max().item() if edge_index.size(1) > 0 else -1
self.num_edges: int = num_edges if num_edges is not None else max_edge_id + 1

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(\n"
f" num_nodes={self.num_nodes},\n"
f" num_edges={self.num_edges},\n"
f" x_shape={self.x.shape},\n"
f" edge_index_shape={self.edge_index.shape},\n"
f" edge_attr_shape={self.edge_attr.shape if self.edge_attr is not None else None}\n"
f")"
)