From cceb64a6742de3fc2a1d1327c94929b962b369d9 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 19 Sep 2024 19:31:13 -0500 Subject: [PATCH 1/5] add map_dict_of_named_arrays to DirectPredecessorsGetter --- pytato/analysis/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 880a15b7a..e461edf5e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -337,6 +337,10 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr._data.values()) + def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]: return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) From 4a74924b049e630e9a35453bceb3ae5f5f2238b2 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 24 Sep 2024 14:42:38 -0500 Subject: [PATCH 2/5] support functions as inputs and outputs in DirectPredecessorsGetter --- pytato/analysis/__init__.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e461edf5e..65d7c7756 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): +class DirectPredecessorsGetter( + Mapper[ + FrozenOrderedSet[ArrayOrNames | FunctionDefinition], + FrozenOrderedSet[ArrayOrNames], + []]): """ Mapper to get the `direct predecessors @@ -334,6 +338,10 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): We only consider the predecessors of a nodes in a data-flow sense. """ + def __init__(self, *, include_functions: bool = False) -> None: + super().__init__() + self.include_functions = include_functions + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) @@ -401,8 +409,17 @@ def map_distributed_send_ref_holder(self, ) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet([expr.passthrough_data]) - def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_call( + self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \ + FrozenOrderedSet(expr.bindings.values()) + if self.include_functions: + result = result | FrozenOrderedSet([expr.function]) + return result + + def map_function_definition( + self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr.returns.values()) def map_named_call_result( self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]: From 07d26de6c2e98285d191e083d684bbce1c47d08d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 29 Aug 2024 16:57:13 -0500 Subject: [PATCH 3/5] add collision/duplication checks to CachedMapper/TransformMapper/TransformMapperWithExtraArgs fix doc change terminology from 'no-op duplication' to 'mapper-created duplicate' reword explanation of predecessor check in duplication check --- pytato/analysis/__init__.py | 4 +- pytato/distributed/partition.py | 2 +- pytato/transform/__init__.py | 273 ++++++++++++++++++++++++++++++-- pytato/transform/metadata.py | 4 +- 4 files changed, 261 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 65d7c7756..019fbd846 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -645,7 +645,7 @@ def combine(self, *args: int) -> int: def rec(self, expr: ArrayOrNames) -> int: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -660,7 +660,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(inputs, 0) + self._cache_add(inputs, 0) return result diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index a022f8f8e..73eec2745 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -296,7 +296,7 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index bc9a45da7..512d8711b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -188,6 +188,14 @@ class ForeignObjectError(ValueError): pass +class CacheCollisionError(ValueError): + pass + + +class MapperCreatedDuplicateError(ValueError): + pass + + # {{{ mapper base class ResultT = TypeVar("ResultT") @@ -300,7 +308,7 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT") +CacheExprT = TypeVar("CacheExprT", ArrayOrNames, FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable @@ -351,9 +359,18 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): .. automethod:: retrieve .. automethod:: clear """ - def __init__(self) -> None: - """Initialize the cache.""" + def __init__(self, err_on_collision: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + """ + self.err_on_collision = err_on_collision + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} + if self.err_on_collision: + self._input_key_to_expr: dict[CacheKeyT, CacheExprT] = {} def add( self, @@ -366,16 +383,27 @@ def add( f"Cache entry is already present for key '{key}'." self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + return result def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key - return self._input_key_to_result[key] + + result = self._input_key_to_result[key] + + if self.err_on_collision and inputs.expr is not self._input_key_to_expr[key]: + raise CacheCollisionError + + return result def clear(self) -> None: """Reset the cache.""" self._input_key_to_result = {} + if self.err_on_collision: + self._input_key_to_expr = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): @@ -389,6 +417,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """ def __init__( self, + err_on_collision: bool = False, _cache: CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: @@ -398,12 +427,12 @@ def __init__( self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) self._function_cache: CachedMapperCache[ FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -441,24 +470,53 @@ def _make_function_definition_cache_inputs( expr, self.get_function_definition_cache_key(expr, *args, **kwargs), *args, **kwargs) + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ResultT) -> ResultT: + return self._cache.add(inputs, result) + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionResultT) -> FunctionResultT: + return self._function_cache.add(inputs, result) + + def _cache_retrieve(self, inputs: CacheInputsWithKey[ArrayOrNames, P]) -> ResultT: + try: + return self._cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_retrieve( + self, inputs: CacheInputsWithKey[FunctionDefinition, P]) -> FunctionResultT: + try: + return self._function_cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) + return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve(inputs) + return self._function_cache_retrieve(inputs) except KeyError: - return self._function_cache.add( + return self._function_cache_add( # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 @@ -470,8 +528,10 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ - # Functions are cached globally, but arrays aren't - return type(self)(_function_cache=self._function_cache) + return type(self)( + err_on_collision=self._cache.err_on_collision, + # Functions are cached globally, but arrays aren't + _function_cache=self._function_cache) # }}} @@ -479,7 +539,70 @@ def clone_for_callee( # {{{ TransformMapper class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): - pass + """ + Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_created_duplicate: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_created_duplicate = err_on_created_duplicate + + def add( + self, + inputs: CacheInputsWithKey[CacheExprT, P], + result: CacheExprT) -> CacheExprT: + """ + Cache a mapping result. + + Returns *result*. + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + if self.err_on_created_duplicate: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + hash(result) == hash(inputs.expr) + and result == inputs.expr + and result is not inputs.expr + # Only consider "direct" duplication, not duplication resulting + # from equality-preserving changes to predecessors. Assume that + # such changes are OK, otherwise they would have been detected + # at the point at which they originated. (For example, consider + # a DAG containing pre-existing duplicates. If a subexpression + # of *expr* is a duplicate and is replaced with a previously + # encountered version from the cache, a new instance of *expr* + # must be created. This should not trigger an error.) + and all( + result_pred is pred + for pred, result_pred in zip( + pred_getter(inputs.expr), + pred_getter(result), + strict=True))): + raise MapperCreatedDuplicateError from None + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): @@ -489,13 +612,71 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, []], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_created_duplicate=function_cache.err_on_created_duplicate, + _function_cache=function_cache) # }}} @@ -511,14 +692,72 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, P]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_created_duplicate=function_cache.err_on_created_duplicate, + _function_cache=function_cache) # }}} @@ -1560,12 +1799,12 @@ def clone_for_callee( def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) + return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 200aa25b4..de3625978 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -467,7 +467,7 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -478,7 +478,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(inputs, result) + return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( From cbfaaec8a7325752d04b33a2f5eb6794c2ba1366 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Mar 2025 17:20:48 -0500 Subject: [PATCH 4/5] add a couple of missing clone_for_callee definitions --- pytato/transform/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 512d8711b..a3adffb32 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1371,6 +1371,9 @@ def map_call(self, expr: Call) -> R: def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + # }}} @@ -2384,6 +2387,11 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: self.data_wrapper_cache[cache_key] = expr return expr + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: """For the expression graph given as *array_or_names*, replace all From 6351b3a6abb2c26a157425e34e21cd9677528570 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 18 Mar 2025 15:31:55 -0500 Subject: [PATCH 5/5] Fix type error in extend_bindings_with_shape_inference --- pytato/loopy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytato/loopy.py b/pytato/loopy.py index a6c81dd1d..db7bfd22f 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -39,7 +39,7 @@ import loopy as lp import pymbolic.primitives as prim -from pymbolic.typing import ArithmeticExpression, Expression, Integer, not_none +from loopy.typing import assert_tuple from pytools import memoize_method from pytato.array import ( @@ -61,6 +61,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence + from pymbolic.typing import ArithmeticExpression, Expression, Integer + __doc__ = r""" .. currentmodule:: pytato.loopy @@ -423,7 +425,7 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel, get_size_param_deps = SizeParamGatherer() lp_size_params: frozenset[str] = reduce(frozenset.union, - (lpy_get_deps(not_none(arg.shape)) + (lpy_get_deps(assert_tuple(arg.shape)) for arg in knl.args if isinstance(arg, ArrayBase) and is_expression(arg.shape)