diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e1487b710..880a15b7a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -622,9 +622,9 @@ def combine(self, *args: int) -> int: return sum(args) def rec(self, expr: ArrayOrNames) -> int: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -639,7 +639,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(expr, 0, key=key) + self._cache.add(inputs, 0) return result diff --git a/pytato/codegen.py b/pytato/codegen.py index 86a328929..cb957f076 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -138,8 +138,8 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 741e36548..a022f8f8e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -240,9 +240,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None, + TransformMapperCache[FunctionDefinition, []] | None = None, ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -261,7 +261,7 @@ def clone_for_callee( return type(self)( {}, {}, {}, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index dc1045f25..bc9a45da7 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -93,6 +93,7 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CacheInputsWithKey .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -304,12 +305,45 @@ def __call__( CacheKeyT: TypeAlias = Hashable -class CachedMapperCache(Generic[CacheExprT, CacheResultT]): +class CacheInputsWithKey(Generic[CacheExprT, P]): + """ + Data structure for inputs to :class:`CachedMapperCache`. + + .. attribute:: expr + + The input expression being mapped. + + .. attribute:: args + + A :class:`tuple` of extra positional arguments. + + .. attribute:: kwargs + + A :class:`dict` of extra keyword arguments. + + .. attribute:: key + + The cache key corresponding to *expr* and any additional inputs that were + passed. + + """ + def __init__( + self, + expr: CacheExprT, + key: CacheKeyT, + *args: P.args, + **kwargs: P.kwargs): + self.expr: CacheExprT = expr + self.args: tuple[Any, ...] = args + self.kwargs: dict[str, Any] = kwargs + self.key: CacheKeyT = key + + +class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ Cache for mappers. .. automethod:: __init__ - .. method:: get_key Compute the key for an input expression. @@ -317,64 +351,31 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT]): .. automethod:: retrieve .. automethod:: clear """ - def __init__( - self, - key_func: Callable[..., CacheKeyT]) -> None: - """ - Initialize the cache. - - :arg key_func: Function to compute a hashable cache key from an input - expression and any extra arguments. - """ - self.get_key = key_func - - self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + def __init__(self) -> None: + """Initialize the cache.""" + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} def add( self, - key_inputs: - CacheExprT - # Currently, Python's type system doesn't have a way to annotate - # containers of args/kwargs (ParamSpec won't work here). So we have - # to fall back to using Any. More details here: - # https://github.com/python/typing/issues/1252 - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - result: CacheResultT, - key: CacheKeyT | None = None) -> CacheResultT: + inputs: CacheInputsWithKey[CacheExprT, P], + result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key - assert key not in self._expr_key_to_result, \ + assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - self._expr_key_to_result[key] = result - + self._input_key_to_result[key] = result return result - def retrieve( - self, - key_inputs: - CacheExprT - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - key: CacheKeyT | None = None) -> CacheResultT: + def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) - - return self._expr_key_to_result[key] + key = inputs.key + return self._input_key_to_result[key] def clear(self) -> None: """Reset the cache.""" - self._expr_key_to_result = {} + self._input_key_to_result = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): @@ -389,58 +390,79 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): def __init__( self, _cache: - CachedMapperCache[ArrayOrNames, ResultT] | None = None, + CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: - CachedMapperCache[FunctionDefinition, FunctionResultT] | None = None + CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: CachedMapperCache[ArrayOrNames, ResultT] = ( + self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache(self.get_cache_key)) + else CachedMapperCache()) self._function_cache: CachedMapperCache[ - FunctionDefinition, FunctionResultT] = ( + FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache(self.get_function_definition_cache_key)) + else CachedMapperCache()) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case + raise NotImplementedError( + "Derived classes must override get_cache_key if using extra inputs.") + return expr def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case + raise NotImplementedError( + "Derived classes must override get_function_definition_cache_key if " + "using extra inputs.") + return expr + + def _make_cache_inputs( + self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputsWithKey[ArrayOrNames, P]: + return CacheInputsWithKey( + expr, self.get_cache_key(expr, *args, **kwargs), *args, **kwargs) + + def _make_function_definition_cache_inputs( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputsWithKey[FunctionDefinition, P]: + return CacheInputsWithKey( + expr, self.get_function_definition_cache_key(expr, *args, **kwargs), + *args, **kwargs) def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self._cache.get_key(expr, *args, **kwargs) + inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve((expr, args, kwargs), key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - (expr, args, kwargs), - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - Mapper.rec(self, expr, *args, **kwargs), - key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self._function_cache.get_key(expr, *args, **kwargs) + inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve((expr, args, kwargs), key=key) + return self._function_cache.retrieve(inputs) except KeyError: return self._function_cache.add( - (expr, args, kwargs), # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - Mapper.rec_function_definition(self, expr, *args, **kwargs), - key=key) + inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -456,7 +478,7 @@ def clone_for_callee( # {{{ TransformMapper -class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]): +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): pass @@ -470,8 +492,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -492,9 +514,9 @@ class TransformMapperWithExtraArgs( """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None + TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -1522,8 +1544,8 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn @@ -1533,18 +1555,17 @@ def clone_for_callee( return type(self)( self.map_fn, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - expr, Mapper.rec(self, self.map_fn(expr)), key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} @@ -2069,8 +2090,8 @@ class DataWrapperDeduplicator(CopyMapper): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 8cd635f61..694901b03 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -57,6 +57,8 @@ Stack, ) from pytato.transform import ( + ArrayOrNames, + CacheKeyT, MappedT, TransformMapperWithExtraArgs, _verify_is_array, @@ -160,6 +162,13 @@ def __init__(self, super().__init__() self.how_to_distribute = how_to_distribute + def get_cache_key( + self, + expr: ArrayOrNames, + ctx: _EinsumDistributiveLawMapperContext | None + ) -> CacheKeyT: + return (expr, ctx) + def _map_input_base(self, expr: InputArgumentBase, ctx: _EinsumDistributiveLawMapperContext | None, diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index d50da22e0..200aa25b4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -416,9 +416,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None): + TransformMapperCache[FunctionDefinition, []] | None = None): super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]] = axis_to_tags @@ -465,9 +465,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -478,7 +478,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(expr, result, key=key) + return self._cache.add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError(