-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
I ran into the same CUDA OOM issue and found that it was caused by corruption(data) function in CCC_gat_split.py.
In the original implementation, the node feature matrix is converted from sparse to dense, which can easily blow up GPU memory when the graph is large. I tried a small modification that keeps data.x in sparse COO format throughout, and only shuffles rows by remapping sparse indices. This avoids the dense conversion and significantly reduces memory usage on my side.
Hopefully this helps others who are running into the same issue❤.
Here is the modified version I’m using:
def corruption(data):
"""
Shuffle node features while keeping sparse COO format
"""
# data.x: sparse COO tensor, shape [N, F]
x = data.x.coalesce()
# generate row permutation
idx = torch.randperm(x.size(0), device=x.device)
# mapping: old row index -> new row index
idx_inv = torch.empty_like(idx)
idx_inv[idx] = torch.arange(idx.numel(), device=x.device)
# remap sparse indices
indices = x.indices().clone()
indices[0] = idx_inv[indices[0]]
x_shuffled = torch.sparse_coo_tensor(
indices, x.values(), x.size()
).coalesce()
return my_data(x_shuffled, data.edge_index, data.edge_attr)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels