diff --git a/unionfind.py b/unionfind.py index f1be4ee..38e5e8c 100644 --- a/unionfind.py +++ b/unionfind.py @@ -9,6 +9,8 @@ ) # Third-party libraries +from collections import defaultdict + import numpy as np @@ -258,11 +260,11 @@ def components(self): A list of sets. """ - elts = np.array(self._elts) - vfind = np.vectorize(self.find) - roots = vfind(elts) - distinct_roots = set(roots) - return [set(elts[roots == root]) for root in distinct_roots] + dict_components = defaultdict(list) + for elem in self._elts: + root = self.find(elem) + dict_components[root].append(elem) + return dict_components # comps = [] # for root in distinct_roots: # mask = (roots == root)