diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index e37d8a83a..9bc5d270b 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -272,8 +272,8 @@ def add_needed_pid(pid: _DistributedPartId, from pytato.partition import PartitionInducedCycleError try: compute_topological_order(pid_to_needed_pids) - except CycleError: - raise PartitionInducedCycleError + except CycleError as err: + raise PartitionInducedCycleError(err.node) logger.info("verify_distributed_partition completed successfully.") diff --git a/pytato/partition.py b/pytato/partition.py index 8b1596844..73e348ea2 100644 --- a/pytato/partition.py +++ b/pytato/partition.py @@ -32,6 +32,7 @@ logger = logging.getLogger(__name__) from pytools import memoize_method +from pytools.graph import CycleError from pytato.transform import EdgeCachedMapper, CachedWalkMapper from pytato.array import ( Array, AbstractResultWithNamedArrays, Placeholder, @@ -209,13 +210,13 @@ def make_partition(self, outputs: DictOfNamedArrays) -> GraphPartition: pid_to_output_names[pid_dependency].add(var_name) pid_to_input_names[pid_target].add(var_name) - from pytools.graph import compute_topological_order, CycleError + from pytools.graph import compute_topological_order try: toposorted_part_ids = compute_topological_order( pid_to_needing_pids, lambda x: sorted(pid_to_output_names[x])) - except CycleError: - raise PartitionInducedCycleError + except CycleError as err: + raise PartitionInducedCycleError(err.node) return GraphPartition( parts={ @@ -311,7 +312,7 @@ class GraphPartition: # }}} -class PartitionInducedCycleError(Exception): +class PartitionInducedCycleError(CycleError): """Raised by :func:`find_partition` if the partitioning induced a cycle in the graph of partitions. """