diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c100e7d31..e1487b710 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -626,7 +626,10 @@ def rec(self, expr: ArrayOrNames) -> int: try: return self._cache.retrieve(expr, key=key) except KeyError: - s = super().rec(expr) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + s = Mapper.rec(self, expr) if ( isinstance(expr, Array) and ( diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 27b1e2cee..741e36548 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -309,6 +309,8 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: if name is not None: return self._get_placeholder_for(name, expr) + # Calling super().rec instead of Mapper.rec is OK here, because we're not + # implementing cache insertion and thus are not double caching return cast("ArrayOrNames", super().rec(expr)) def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5b1ba02c4..dc1045f25 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -349,6 +349,9 @@ def add( else: key = self.get_key(key_inputs) + assert key not in self._expr_key_to_result, \ + f"Cache entry is already present for key '{key}'." + self._expr_key_to_result[key] = result return result @@ -418,7 +421,10 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: except KeyError: return self._cache.add( (expr, args, kwargs), - super().rec(expr, *args, **kwargs), + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + Mapper.rec(self, expr, *args, **kwargs), key=key) def rec_function_definition( @@ -430,7 +436,10 @@ def rec_function_definition( except KeyError: return self._function_cache.add( (expr, args, kwargs), - super().rec_function_definition(expr, *args, **kwargs), + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + Mapper.rec_function_definition(self, expr, *args, **kwargs), key=key) def clone_for_callee( @@ -1532,7 +1541,10 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: return self._cache.retrieve(expr, key=key) except KeyError: return self._cache.add( - expr, super().rec(self.map_fn(expr)), key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + expr, Mapper.rec(self, self.map_fn(expr)), key=key) # }}} @@ -2050,44 +2062,67 @@ def rec_get_user_nodes(expr: ArrayOrNames, # {{{ deduplicate_data_wrappers -def _get_data_dedup_cache_key(ary: DataInterface) -> CacheKeyT: - import sys - if "pyopencl" in sys.modules: - from pyopencl import MemoryObjectHolder - from pyopencl.array import Array as CLArray - try: - from pyopencl import SVMPointer - except ImportError: - SVMPointer = None # noqa: N806 - - if isinstance(ary, CLArray): - base_data = ary.base_data - if isinstance(ary.base_data, MemoryObjectHolder): - ptr = base_data.int_ptr - elif SVMPointer is not None and isinstance(base_data, SVMPointer): - ptr = base_data.svm_ptr - elif base_data is None: - # pyopencl represents 0-long arrays' base_data as None - ptr = None - else: - raise ValueError("base_data of array not understood") - +class DataWrapperDeduplicator(CopyMapper): + """ + Mapper to replace all :class:`pytato.array.DataWrapper` instances containing + identical data with a single instance. + """ + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} + self.data_wrappers_encountered = 0 + + def _get_data_dedup_cache_key(self, ary: DataInterface) -> CacheKeyT: + import sys + if "pyopencl" in sys.modules: + from pyopencl import MemoryObjectHolder + from pyopencl.array import Array as CLArray + try: + from pyopencl import SVMPointer + except ImportError: + SVMPointer = None # noqa: N806 + + if isinstance(ary, CLArray): + base_data = ary.base_data + if isinstance(ary.base_data, MemoryObjectHolder): + ptr = base_data.int_ptr + elif SVMPointer is not None and isinstance(base_data, SVMPointer): + ptr = base_data.svm_ptr + elif base_data is None: + # pyopencl represents 0-long arrays' base_data as None + ptr = None + else: + raise ValueError("base_data of array not understood") + + return ( + ptr, + ary.offset, + ary.shape, + ary.strides, + ary.dtype, + ) + if isinstance(ary, np.ndarray): return ( - ptr, - ary.offset, + ary.__array_interface__["data"], ary.shape, ary.strides, ary.dtype, ) - if isinstance(ary, np.ndarray): - return ( - ary.__array_interface__["data"], - ary.shape, - ary.strides, - ary.dtype, - ) - else: - raise NotImplementedError(str(type(ary))) + else: + raise NotImplementedError(str(type(ary))) + + def map_data_wrapper(self, expr: DataWrapper) -> Array: + self.data_wrappers_encountered += 1 + cache_key = self._get_data_dedup_cache_key(expr.data) + try: + return self.data_wrapper_cache[cache_key] + except KeyError: + self.data_wrapper_cache[cache_key] = expr + return expr def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: @@ -2108,34 +2143,17 @@ def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: this, but it must *also* tolerate this function doing a more thorough job of deduplication. """ + dedup = DataWrapperDeduplicator() + array_or_names = dedup(array_or_names) - data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} - data_wrappers_encountered = 0 - - def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: - nonlocal data_wrappers_encountered - - if isinstance(ary, DataWrapper): - data_wrappers_encountered += 1 - cache_key = _get_data_dedup_cache_key(ary.data) - - try: - return data_wrapper_cache[cache_key] - except KeyError: - result = ary - data_wrapper_cache[cache_key] = result - return result - else: - return ary - - array_or_names = map_and_copy(array_or_names, cached_data_wrapper_if_present) - - if data_wrappers_encountered: + if dedup.data_wrappers_encountered: transform_logger.debug("data wrapper de-duplication: " "%d encountered, %d kept, %d eliminated", - data_wrappers_encountered, - len(data_wrapper_cache), - data_wrappers_encountered - len(data_wrapper_cache)) + dedup.data_wrappers_encountered, + len(dedup.data_wrapper_cache), + ( + dedup.data_wrappers_encountered + - len(dedup.data_wrapper_cache))) return array_or_names diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index e654e8b51..d50da22e0 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -469,6 +469,9 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: try: return self._cache.retrieve(expr, key=key) except KeyError: + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 result = Mapper.rec(self, expr) if not isinstance( expr, AbstractResultWithNamedArrays | DistributedSendRefHolder):