From 2014c174bebe0eb80b33ee72d16b55153a3df2cf Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 16:16:10 -0600 Subject: [PATCH 1/5] disable default implementation of get_cache_key and get_function_definition_cache_key for extra args case ambiguous due to the fact that any arg can be specified with/without keyword --- pytato/transform/__init__.py | 15 +++++++++++---- pytato/transform/einsum_distributive_law.py | 9 +++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index dc1045f25..c8f195d86 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -406,13 +406,20 @@ def __init__( def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + raise NotImplementedError( + "Derived classes must override get_cache_key if using extra inputs.") + return expr def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + raise NotImplementedError( + "Derived classes must override get_function_definition_cache_key if " + "using extra inputs.") + return expr def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: key = self._cache.get_key(expr, *args, **kwargs) diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 8cd635f61..694901b03 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -57,6 +57,8 @@ Stack, ) from pytato.transform import ( + ArrayOrNames, + CacheKeyT, MappedT, TransformMapperWithExtraArgs, _verify_is_array, @@ -160,6 +162,13 @@ def __init__(self, super().__init__() self.how_to_distribute = how_to_distribute + def get_cache_key( + self, + expr: ArrayOrNames, + ctx: _EinsumDistributiveLawMapperContext | None + ) -> CacheKeyT: + return (expr, ctx) + def _map_input_base(self, expr: InputArgumentBase, ctx: _EinsumDistributiveLawMapperContext | None, From 39866e400aaf4b6e12cb09370e58619ef417ce82 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 16:16:55 -0600 Subject: [PATCH 2/5] add CacheInputs to simplify cache key handling logic --- pytato/analysis/__init__.py | 6 +- pytato/codegen.py | 4 +- pytato/distributed/partition.py | 10 +- pytato/transform/__init__.py | 164 +++++++++++++++++--------------- pytato/transform/metadata.py | 10 +- 5 files changed, 101 insertions(+), 93 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e1487b710..880a15b7a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -622,9 +622,9 @@ def combine(self, *args: int) -> int: return sum(args) def rec(self, expr: ArrayOrNames) -> int: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -639,7 +639,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(expr, 0, key=key) + self._cache.add(inputs, 0) return result diff --git a/pytato/codegen.py b/pytato/codegen.py index 86a328929..cb957f076 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -138,8 +138,8 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 741e36548..a022f8f8e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -240,9 +240,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None, + TransformMapperCache[FunctionDefinition, []] | None = None, ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -261,7 +261,7 @@ def clone_for_callee( return type(self)( {}, {}, {}, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index c8f195d86..af73906bb 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,6 +46,7 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -93,6 +94,7 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CacheInputs .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -304,12 +306,45 @@ def __call__( CacheKeyT: TypeAlias = Hashable -class CachedMapperCache(Generic[CacheExprT, CacheResultT]): +class CacheInputs(Generic[CacheExprT, P]): + """ + Data structure for inputs to :class:`CachedMapperCache`. + + .. attribute:: expr + + The input expression being mapped. + + .. attribute:: key + + The cache key corresponding to *expr* and any additional inputs that were + passed. + + """ + def __init__( + self, + expr: CacheExprT, + key_func: Callable[..., CacheKeyT], + *args: P.args, + **kwargs: P.kwargs): + self.expr: CacheExprT = expr + self._args: tuple[Any, ...] = args + self._kwargs: dict[str, Any] = kwargs + self._key_func = key_func + + @memoize_method + def _get_key(self) -> CacheKeyT: + return self._key_func(self.expr, *self._args, **self._kwargs) + + @property + def key(self) -> CacheKeyT: + return self._get_key() + + +class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ Cache for mappers. .. automethod:: __init__ - .. method:: get_key Compute the key for an input expression. @@ -317,37 +352,16 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT]): .. automethod:: retrieve .. automethod:: clear """ - def __init__( - self, - key_func: Callable[..., CacheKeyT]) -> None: - """ - Initialize the cache. - - :arg key_func: Function to compute a hashable cache key from an input - expression and any extra arguments. - """ - self.get_key = key_func - + def __init__(self) -> None: + """Initialize the cache.""" self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} def add( self, - key_inputs: - CacheExprT - # Currently, Python's type system doesn't have a way to annotate - # containers of args/kwargs (ParamSpec won't work here). So we have - # to fall back to using Any. More details here: - # https://github.com/python/typing/issues/1252 - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - result: CacheResultT, - key: CacheKeyT | None = None) -> CacheResultT: + inputs: CacheInputs[CacheExprT, P], + result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key assert key not in self._expr_key_to_result, \ f"Cache entry is already present for key '{key}'." @@ -356,20 +370,9 @@ def add( return result - def retrieve( - self, - key_inputs: - CacheExprT - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - key: CacheKeyT | None = None) -> CacheResultT: + def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) - + key = inputs.key return self._expr_key_to_result[key] def clear(self) -> None: @@ -389,20 +392,20 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): def __init__( self, _cache: - CachedMapperCache[ArrayOrNames, ResultT] | None = None, + CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: - CachedMapperCache[FunctionDefinition, FunctionResultT] | None = None + CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: CachedMapperCache[ArrayOrNames, ResultT] = ( + self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache(self.get_cache_key)) + else CachedMapperCache()) self._function_cache: CachedMapperCache[ - FunctionDefinition, FunctionResultT] = ( + FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache(self.get_function_definition_cache_key)) + else CachedMapperCache()) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -421,33 +424,39 @@ def get_function_definition_cache_key( "using extra inputs.") return expr + def _make_cache_inputs( + self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputs[ArrayOrNames, P]: + return CacheInputs(expr, self.get_cache_key, *args, **kwargs) + + def _make_function_definition_cache_inputs( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputs[FunctionDefinition, P]: + return CacheInputs( + expr, self.get_function_definition_cache_key, *args, **kwargs) + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self._cache.get_key(expr, *args, **kwargs) + inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve((expr, args, kwargs), key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - (expr, args, kwargs), - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - Mapper.rec(self, expr, *args, **kwargs), - key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self._function_cache.get_key(expr, *args, **kwargs) + inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve((expr, args, kwargs), key=key) + return self._function_cache.retrieve(inputs) except KeyError: return self._function_cache.add( - (expr, args, kwargs), # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - Mapper.rec_function_definition(self, expr, *args, **kwargs), - key=key) + inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -463,7 +472,7 @@ def clone_for_callee( # {{{ TransformMapper -class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]): +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): pass @@ -477,8 +486,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -499,9 +508,9 @@ class TransformMapperWithExtraArgs( """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None + TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -1529,8 +1538,8 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn @@ -1540,18 +1549,17 @@ def clone_for_callee( return type(self)( self.map_fn, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - expr, Mapper.rec(self, self.map_fn(expr)), key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} @@ -2076,8 +2084,8 @@ class DataWrapperDeduplicator(CopyMapper): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index d50da22e0..200aa25b4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -416,9 +416,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None): + TransformMapperCache[FunctionDefinition, []] | None = None): super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]] = axis_to_tags @@ -465,9 +465,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -478,7 +478,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(expr, result, key=key) + return self._cache.add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( From 6f6ccbea12493a779cdf38dd1fa838b05c421fea Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 09:25:55 -0600 Subject: [PATCH 3/5] rename expr_key* to input_key* --- pytato/transform/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index af73906bb..b5dd9f094 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -354,7 +354,7 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ def __init__(self) -> None: """Initialize the cache.""" - self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} def add( self, @@ -363,21 +363,20 @@ def add( """Cache a mapping result.""" key = inputs.key - assert key not in self._expr_key_to_result, \ + assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - self._expr_key_to_result[key] = result - + self._input_key_to_result[key] = result return result def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key - return self._expr_key_to_result[key] + return self._input_key_to_result[key] def clear(self) -> None: """Reset the cache.""" - self._expr_key_to_result = {} + self._input_key_to_result = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): From f7d5c7e4e1a12ed430940356a48ec063cd66a309 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Feb 2025 14:22:24 -0600 Subject: [PATCH 4/5] refactor to avoid performance drop --- pytato/transform/__init__.py | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b5dd9f094..4d2bbd2bd 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,7 +46,6 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -94,7 +93,7 @@ __doc__ = """ .. autoclass:: Mapper -.. autoclass:: CacheInputs +.. autoclass:: CacheInputsWithKey .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -306,7 +305,7 @@ def __call__( CacheKeyT: TypeAlias = Hashable -class CacheInputs(Generic[CacheExprT, P]): +class CacheInputsWithKey(Generic[CacheExprT, P]): """ Data structure for inputs to :class:`CachedMapperCache`. @@ -314,6 +313,14 @@ class CacheInputs(Generic[CacheExprT, P]): The input expression being mapped. + .. attribute:: args + + A :class:`tuple` of extra positional arguments. + + .. attribute:: kwargs + + A :class:`dict` of extra keyword arguments. + .. attribute:: key The cache key corresponding to *expr* and any additional inputs that were @@ -323,21 +330,13 @@ class CacheInputs(Generic[CacheExprT, P]): def __init__( self, expr: CacheExprT, - key_func: Callable[..., CacheKeyT], + key: CacheKeyT, *args: P.args, **kwargs: P.kwargs): self.expr: CacheExprT = expr - self._args: tuple[Any, ...] = args - self._kwargs: dict[str, Any] = kwargs - self._key_func = key_func - - @memoize_method - def _get_key(self) -> CacheKeyT: - return self._key_func(self.expr, *self._args, **self._kwargs) - - @property - def key(self) -> CacheKeyT: - return self._get_key() + self.args: tuple[Any, ...] = args + self.kwargs: dict[str, Any] = kwargs + self.key: CacheKeyT = key class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): @@ -358,7 +357,7 @@ def __init__(self) -> None: def add( self, - inputs: CacheInputs[CacheExprT, P], + inputs: CacheInputsWithKey[CacheExprT, P], result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" key = inputs.key @@ -369,7 +368,7 @@ def add( self._input_key_to_result[key] = result return result - def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: + def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key return self._input_key_to_result[key] @@ -425,14 +424,16 @@ def get_function_definition_cache_key( def _make_cache_inputs( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> CacheInputs[ArrayOrNames, P]: - return CacheInputs(expr, self.get_cache_key, *args, **kwargs) + ) -> CacheInputsWithKey[ArrayOrNames, P]: + return CacheInputsWithKey( + expr, self.get_cache_key(expr, *args, **kwargs), *args, **kwargs) def _make_function_definition_cache_inputs( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> CacheInputs[FunctionDefinition, P]: - return CacheInputs( - expr, self.get_function_definition_cache_key, *args, **kwargs) + ) -> CacheInputsWithKey[FunctionDefinition, P]: + return CacheInputsWithKey( + expr, self.get_function_definition_cache_key(expr, *args, **kwargs), + *args, **kwargs) def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: inputs = self._make_cache_inputs(expr, *args, **kwargs) From 81d300a89fb66b1909301806a4fa316cc51f01c5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Mar 2025 16:55:49 -0500 Subject: [PATCH 5/5] add comment explaining why CachedMapper.get_cache_key and get_function_definition_cache_key are not defined for general extra args/kwargs --- pytato/transform/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 4d2bbd2bd..bc9a45da7 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -409,6 +409,9 @@ def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs ) -> CacheKeyT: if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case raise NotImplementedError( "Derived classes must override get_cache_key if using extra inputs.") return expr @@ -417,6 +420,9 @@ def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> CacheKeyT: if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case raise NotImplementedError( "Derived classes must override get_function_definition_cache_key if " "using extra inputs.")