diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 880a15b7a..019fbd846 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): +class DirectPredecessorsGetter( + Mapper[ + FrozenOrderedSet[ArrayOrNames | FunctionDefinition], + FrozenOrderedSet[ArrayOrNames], + []]): """ Mapper to get the `direct predecessors @@ -334,9 +338,17 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): We only consider the predecessors of a nodes in a data-flow sense. """ + def __init__(self, *, include_functions: bool = False) -> None: + super().__init__() + self.include_functions = include_functions + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr._data.values()) + def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]: return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) @@ -397,8 +409,17 @@ def map_distributed_send_ref_holder(self, ) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet([expr.passthrough_data]) - def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_call( + self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \ + FrozenOrderedSet(expr.bindings.values()) + if self.include_functions: + result = result | FrozenOrderedSet([expr.function]) + return result + + def map_function_definition( + self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr.returns.values()) def map_named_call_result( self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]: @@ -624,7 +645,7 @@ def combine(self, *args: int) -> int: def rec(self, expr: ArrayOrNames) -> int: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -639,7 +660,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(inputs, 0) + self._cache_add(inputs, 0) return result diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index a022f8f8e..73eec2745 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -296,7 +296,7 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: pass diff --git a/pytato/loopy.py b/pytato/loopy.py index a6c81dd1d..db7bfd22f 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -39,7 +39,7 @@ import loopy as lp import pymbolic.primitives as prim -from pymbolic.typing import ArithmeticExpression, Expression, Integer, not_none +from loopy.typing import assert_tuple from pytools import memoize_method from pytato.array import ( @@ -61,6 +61,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence + from pymbolic.typing import ArithmeticExpression, Expression, Integer + __doc__ = r""" .. currentmodule:: pytato.loopy @@ -423,7 +425,7 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel, get_size_param_deps = SizeParamGatherer() lp_size_params: frozenset[str] = reduce(frozenset.union, - (lpy_get_deps(not_none(arg.shape)) + (lpy_get_deps(assert_tuple(arg.shape)) for arg in knl.args if isinstance(arg, ArrayBase) and is_expression(arg.shape) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index bc9a45da7..a3adffb32 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -188,6 +188,14 @@ class ForeignObjectError(ValueError): pass +class CacheCollisionError(ValueError): + pass + + +class MapperCreatedDuplicateError(ValueError): + pass + + # {{{ mapper base class ResultT = TypeVar("ResultT") @@ -300,7 +308,7 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT") +CacheExprT = TypeVar("CacheExprT", ArrayOrNames, FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable @@ -351,9 +359,18 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): .. automethod:: retrieve .. automethod:: clear """ - def __init__(self) -> None: - """Initialize the cache.""" + def __init__(self, err_on_collision: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + """ + self.err_on_collision = err_on_collision + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} + if self.err_on_collision: + self._input_key_to_expr: dict[CacheKeyT, CacheExprT] = {} def add( self, @@ -366,16 +383,27 @@ def add( f"Cache entry is already present for key '{key}'." self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + return result def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key - return self._input_key_to_result[key] + + result = self._input_key_to_result[key] + + if self.err_on_collision and inputs.expr is not self._input_key_to_expr[key]: + raise CacheCollisionError + + return result def clear(self) -> None: """Reset the cache.""" self._input_key_to_result = {} + if self.err_on_collision: + self._input_key_to_expr = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): @@ -389,6 +417,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """ def __init__( self, + err_on_collision: bool = False, _cache: CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: @@ -398,12 +427,12 @@ def __init__( self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) self._function_cache: CachedMapperCache[ FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -441,24 +470,53 @@ def _make_function_definition_cache_inputs( expr, self.get_function_definition_cache_key(expr, *args, **kwargs), *args, **kwargs) + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ResultT) -> ResultT: + return self._cache.add(inputs, result) + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionResultT) -> FunctionResultT: + return self._function_cache.add(inputs, result) + + def _cache_retrieve(self, inputs: CacheInputsWithKey[ArrayOrNames, P]) -> ResultT: + try: + return self._cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_retrieve( + self, inputs: CacheInputsWithKey[FunctionDefinition, P]) -> FunctionResultT: + try: + return self._function_cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) + return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve(inputs) + return self._function_cache_retrieve(inputs) except KeyError: - return self._function_cache.add( + return self._function_cache_add( # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 @@ -470,8 +528,10 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ - # Functions are cached globally, but arrays aren't - return type(self)(_function_cache=self._function_cache) + return type(self)( + err_on_collision=self._cache.err_on_collision, + # Functions are cached globally, but arrays aren't + _function_cache=self._function_cache) # }}} @@ -479,7 +539,70 @@ def clone_for_callee( # {{{ TransformMapper class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): - pass + """ + Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_created_duplicate: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_created_duplicate = err_on_created_duplicate + + def add( + self, + inputs: CacheInputsWithKey[CacheExprT, P], + result: CacheExprT) -> CacheExprT: + """ + Cache a mapping result. + + Returns *result*. + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + if self.err_on_created_duplicate: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + hash(result) == hash(inputs.expr) + and result == inputs.expr + and result is not inputs.expr + # Only consider "direct" duplication, not duplication resulting + # from equality-preserving changes to predecessors. Assume that + # such changes are OK, otherwise they would have been detected + # at the point at which they originated. (For example, consider + # a DAG containing pre-existing duplicates. If a subexpression + # of *expr* is a duplicate and is replaced with a previously + # encountered version from the cache, a new instance of *expr* + # must be created. This should not trigger an error.) + and all( + result_pred is pred + for pred, result_pred in zip( + pred_getter(inputs.expr), + pred_getter(result), + strict=True))): + raise MapperCreatedDuplicateError from None + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): @@ -489,13 +612,71 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, []], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_created_duplicate=function_cache.err_on_created_duplicate, + _function_cache=function_cache) # }}} @@ -511,14 +692,72 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, P]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_created_duplicate=function_cache.err_on_created_duplicate, + _function_cache=function_cache) # }}} @@ -1132,6 +1371,9 @@ def map_call(self, expr: Call) -> R: def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + # }}} @@ -1560,12 +1802,12 @@ def clone_for_callee( def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) + return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} @@ -2145,6 +2387,11 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: self.data_wrapper_cache[cache_key] = expr return expr + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: """For the expression graph given as *array_or_names*, replace all diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 200aa25b4..de3625978 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -467,7 +467,7 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -478,7 +478,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(inputs, result) + return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError(