From 8a7b778eb426cfed5ae5edd8208f652b7f584582 Mon Sep 17 00:00:00 2001 From: Martin Carrasco Date: Thu, 11 Dec 2025 16:11:20 +0100 Subject: [PATCH 1/2] Fix device conversion --- mantra/representations/simplicial_connectivity.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mantra/representations/simplicial_connectivity.py b/mantra/representations/simplicial_connectivity.py index faa5c40..02d60bc 100644 --- a/mantra/representations/simplicial_connectivity.py +++ b/mantra/representations/simplicial_connectivity.py @@ -31,7 +31,8 @@ def forward(self, data: Data): return data -class AbstractSimplicialComplexConnectivity(BaseTransform, ABCMeta): +class AbstractSimplicialComplexConnectivity(BaseTransform): + __metaclass__ = ABCMeta """Base class for connectivity transforms. Parent class for implementing a transform that adds a @@ -108,7 +109,6 @@ def forward(self, data: Data): self.generate_matrix( data.simplex_trie, rank_idx, max_rank ), - device=data.triangulation.device, ) except ValueError: idx_low_simp = rank_idx - 1 if rank_idx > 0 else rank_idx @@ -119,7 +119,6 @@ def forward(self, data: Data): layout=torch.sparse_coo, ) .coalesce() - .to(data.triangulations.device) ) elif "adjacency" in self.connectivity_name: data[connectivity_name] = ( @@ -128,7 +127,6 @@ def forward(self, data: Data): layout=torch.sparse_coo, ) .coalesce() - .to(data.triangulations.device) ) return data @@ -328,7 +326,7 @@ def _from_sparse(data: scipy.sparse.csc_matrix, device=None) -> torch.Tensor: input data converted to tensor. """ if device is None: - device = data.device("cpu") + device = torch.device('cpu') # cast from csc_matrix to coo format for compatibility coo = data.tocoo() From c9ab74627bc74d64b462185f68d04b6624a26bf5 Mon Sep 17 00:00:00 2001 From: Martin Carrasco Date: Thu, 11 Dec 2025 16:16:56 +0100 Subject: [PATCH 2/2] Ruff'd and Black'd --- .../simplicial_connectivity.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/mantra/representations/simplicial_connectivity.py b/mantra/representations/simplicial_connectivity.py index 02d60bc..0e57ecb 100644 --- a/mantra/representations/simplicial_connectivity.py +++ b/mantra/representations/simplicial_connectivity.py @@ -113,21 +113,15 @@ def forward(self, data: Data): except ValueError: idx_low_simp = rank_idx - 1 if rank_idx > 0 else rank_idx if "incidence" in self.connectivity_name: - data[connectivity_name] = ( - torch.zeros( - [shape[idx_low_simp], shape[rank_idx]], - layout=torch.sparse_coo, - ) - .coalesce() - ) + data[connectivity_name] = torch.zeros( + [shape[idx_low_simp], shape[rank_idx]], + layout=torch.sparse_coo, + ).coalesce() elif "adjacency" in self.connectivity_name: - data[connectivity_name] = ( - torch.zeros( - [shape[rank_idx], shape[rank_idx]], - layout=torch.sparse_coo, - ) - .coalesce() - ) + data[connectivity_name] = torch.zeros( + [shape[rank_idx], shape[rank_idx]], + layout=torch.sparse_coo, + ).coalesce() return data @@ -326,7 +320,7 @@ def _from_sparse(data: scipy.sparse.csc_matrix, device=None) -> torch.Tensor: input data converted to tensor. """ if device is None: - device = torch.device('cpu') + device = torch.device("cpu") # cast from csc_matrix to coo format for compatibility coo = data.tocoo()