diff --git a/mantra/representations/simplicial_connectivity.py b/mantra/representations/simplicial_connectivity.py index faa5c40..0e57ecb 100644 --- a/mantra/representations/simplicial_connectivity.py +++ b/mantra/representations/simplicial_connectivity.py @@ -31,7 +31,8 @@ def forward(self, data: Data): return data -class AbstractSimplicialComplexConnectivity(BaseTransform, ABCMeta): +class AbstractSimplicialComplexConnectivity(BaseTransform): + __metaclass__ = ABCMeta """Base class for connectivity transforms. Parent class for implementing a transform that adds a @@ -108,28 +109,19 @@ def forward(self, data: Data): self.generate_matrix( data.simplex_trie, rank_idx, max_rank ), - device=data.triangulation.device, ) except ValueError: idx_low_simp = rank_idx - 1 if rank_idx > 0 else rank_idx if "incidence" in self.connectivity_name: - data[connectivity_name] = ( - torch.zeros( - [shape[idx_low_simp], shape[rank_idx]], - layout=torch.sparse_coo, - ) - .coalesce() - .to(data.triangulations.device) - ) + data[connectivity_name] = torch.zeros( + [shape[idx_low_simp], shape[rank_idx]], + layout=torch.sparse_coo, + ).coalesce() elif "adjacency" in self.connectivity_name: - data[connectivity_name] = ( - torch.zeros( - [shape[rank_idx], shape[rank_idx]], - layout=torch.sparse_coo, - ) - .coalesce() - .to(data.triangulations.device) - ) + data[connectivity_name] = torch.zeros( + [shape[rank_idx], shape[rank_idx]], + layout=torch.sparse_coo, + ).coalesce() return data @@ -328,7 +320,7 @@ def _from_sparse(data: scipy.sparse.csc_matrix, device=None) -> torch.Tensor: input data converted to tensor. """ if device is None: - device = data.device("cpu") + device = torch.device("cpu") # cast from csc_matrix to coo format for compatibility coo = data.tocoo()