diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 38ed276fe..072ea8e66 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -321,26 +321,26 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]: + def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]: return frozenset({dim for dim in shape if isinstance(dim, Array)}) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]: + def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]: return (frozenset(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[Array]: + def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]: return (frozenset(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[Array]: + def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]: return (frozenset(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[Array]: + def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]: return (frozenset(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: + def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) @@ -349,7 +349,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: + def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]: return (frozenset([expr.array]) | frozenset(idx for idx in expr.indices if isinstance(idx, Array)) @@ -360,32 +360,34 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[Array]: + ) -> frozenset[ArrayOrNames]: return frozenset([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]: + def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]: + def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[Array]: + ) -> frozenset[ArrayOrNames]: return frozenset([expr.passthrough_data]) - def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]: - raise NotImplementedError( - "DirectPredecessorsGetter does not yet support expressions containing " - "functions.") + def map_call(self, expr: Call) -> frozenset[ArrayOrNames]: + return frozenset(expr.bindings.values()) + + def map_named_call_result( + self, expr: NamedCallResult) -> frozenset[ArrayOrNames]: + return frozenset([expr._container]) # }}} @@ -448,17 +450,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @memoize_method - def map_function_definition(self, /, expr: FunctionDefinition, - *args: Any, **kwargs: Any) -> None: + def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return new_mapper = self.clone_for_callee(expr) for subexpr in expr.returns.values(): - new_mapper(subexpr, *args, **kwargs) + new_mapper(subexpr) self.count += new_mapper.count - self.post_visit(expr, *args, **kwargs) + self.post_visit(expr) def post_visit(self, expr: Any) -> None: if isinstance(expr, Call): diff --git a/pytato/equality.py b/pytato/equality.py index 5750d2b93..79d038d72 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -298,6 +298,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any ) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.parameters == expr2.parameters + and expr1.return_type == expr2.return_type and (set(expr1.returns.keys()) == set(expr2.returns.keys())) and all(self.rec(expr1.returns[k], expr2.returns[k]) for k in expr1.returns) @@ -311,6 +312,7 @@ def map_call(self, expr1: Call, expr2: Any) -> bool: and all(self.rec(bnd, expr2.bindings[name]) for name, bnd in expr1.bindings.items()) + and expr1.tags == expr2.tags ) def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: diff --git a/pytato/function.py b/pytato/function.py index 5a4202011..c79dfcfe3 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -17,6 +17,17 @@ A type variable corresponding to the return type of the function :func:`pytato.trace_call`. + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: Tag + + See :class:`pytools.tag.Tag`. + +.. class:: AxesT + + A :class:`tuple` of :class:`pytato.array.Axis` objects. """ __copyright__ = """ @@ -50,7 +61,6 @@ from typing import ( Callable, ClassVar, - Dict, Hashable, Iterable, Iterator, @@ -63,6 +73,7 @@ import attrs from immutabledict import immutabledict +from pytools import memoize_method from pytools.tag import Tag, Taggable from pytato.array import ( @@ -75,7 +86,7 @@ ) -ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array]) +ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Mapping[str, Array]) # {{{ Call/NamedCallResult @@ -92,14 +103,14 @@ class ReturnType(enum.Enum): # eq=False to avoid equality comparison without EqualityMaper -@attrs.define(frozen=True, eq=False, hash=True) +@attrs.define(frozen=True, eq=False, hash=True, cache_hash=True) class FunctionDefinition(Taggable): r""" A function definition that represents its outputs as instances of :class:`~pytato.Array` with the inputs being :class:`~pytato.array.Placeholder`\ s. The outputs of the function can be a single :class:`pytato.Array`, a tuple of :class:`pytato.Array`\ s or an - instance of ``Dict[str, Array]``. + instance of ``Mapping[str, Array]``. .. attribute:: parameters @@ -184,7 +195,7 @@ def _with_new_tags( return attrs.evolve(self, tags=tags) def __call__(self, **kwargs: Array - ) -> Array | tuple[Array, ...] | dict[str, Array]: + ) -> Array | tuple[Array, ...] | Mapping[str, Array]: from pytato.array import _get_default_tags from pytato.utils import are_shapes_equal @@ -221,11 +232,12 @@ def __call__(self, **kwargs: Array return tuple(call_site[f"_{iarg}"] for iarg in range(len(self.returns))) elif self.return_type == ReturnType.DICT_OF_ARRAYS: - return {kw: call_site[kw] for kw in self.returns} + return immutabledict({kw: call_site[kw] for kw in self.returns}) else: raise NotImplementedError(self.return_type) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class NamedCallResult(NamedArray): """ One of the arrays that are returned from a call to :class:`FunctionDefinition`. @@ -239,19 +251,8 @@ class NamedCallResult(NamedArray): The name by which the returned array is referred to in :attr:`FunctionDefinition.returns`. """ - call: Call - name: str _mapper_method: ClassVar[str] = "map_named_call_result" - def __init__(self, - call: Call, - name: str) -> None: - super().__init__(call, name, - axes=call.function.returns[name].axes, - tags=call.function.returns[name].tags, - non_equality_tags=( - call.function.returns[name].non_equality_tags)) - def with_tagged_axis(self, iaxis: int, tags: Sequence[Tag] | Tag) -> Array: raise ValueError("Tagging a NamedCallResult's axis is illegal, use" @@ -269,6 +270,11 @@ def without_tags(self, raise ValueError("Untagging a NamedCallResult is illegal, use" " Call.without_tags instead") + @property + def call(self) -> Call: + assert isinstance(self._container, Call) + return self._container + @property def shape(self) -> ShapeType: assert isinstance(self._container, Call) @@ -317,8 +323,13 @@ def __contains__(self, name: object) -> bool: def __iter__(self) -> Iterator[str]: return iter(self.function.returns) + @memoize_method def __getitem__(self, name: str) -> NamedCallResult: - return NamedCallResult(self, name) + return NamedCallResult( + self, name, + axes=self.function.returns[name].axes, + tags=self.function.returns[name].tags, + non_equality_tags=self.function.returns[name].non_equality_tags) def __len__(self) -> int: return len(self.function.returns) diff --git a/pytato/tags.py b/pytato/tags.py index b97125f63..fbe477d06 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -205,7 +205,7 @@ class ExpandedDimsReshape(UniqueTag): class FunctionIdentifier(UniqueTag): """ A tag that can be attached to a :class:`~pytato.function.FunctionDefinition` - node to to describe the function's identifier. One can use this to refer + node to describe the function's identifier. One can use this to refer all instances of :class:`~pytato.function.FunctionDefinition`, for example in transformations.transform.calls.concatenate_calls`. diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b78c24301..4d389c245 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -465,7 +465,7 @@ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: def map_named_call_result(self, expr: NamedCallResult) -> Array: call = self.rec(expr._container) assert isinstance(call, Call) - return NamedCallResult(call, expr.name) + return call[expr.name] class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): @@ -703,7 +703,7 @@ def map_named_call_result(self, expr: NamedCallResult, *args: Any, **kwargs: Any) -> Array: call = self.rec(expr._container, *args, **kwargs) assert isinstance(call, Call) - return NamedCallResult(call, expr.name) + return call[expr.name] # }}} @@ -823,9 +823,9 @@ def map_function_definition(self, expr: FunctionDefinition) -> CombineT: " must override map_function_definition.") def map_call(self, expr: Call) -> CombineT: - return self.combine(self.map_function_definition(expr.function), - *[self.rec(bnd) - for name, bnd in sorted(expr.bindings.items())]) + raise NotImplementedError( + "Mapping calls is context-dependent. Derived classes must override " + "map_call.") def map_named_call_result(self, expr: NamedCallResult) -> CombineT: return self.rec(expr._container) @@ -975,6 +975,12 @@ def map_function_definition(self, expr: FunctionDefinition return frozenset(result) + def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: + return self.combine(self.map_function_definition(expr.function), + *[ + self.rec(bnd) + for name, bnd in sorted(expr.bindings.items())]) + # }}} @@ -999,6 +1005,12 @@ def map_function_definition(self, expr: FunctionDefinition return self.combine(*[self.rec(ret) for ret in expr.returns.values()]) + def map_call(self, expr: Call) -> frozenset[SizeParam]: + return self.combine(self.map_function_definition(expr.function), + *[ + self.rec(bnd) + for name, bnd in sorted(expr.bindings.items())]) + # }}} @@ -1173,7 +1185,7 @@ def map_loopy_call(self, expr: LoopyCall, *args: Any, **kwargs: Any) -> None: def map_function_definition(self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> None: - if not self.visit(expr): + if not self.visit(expr, *args, **kwargs): return new_mapper = self.clone_for_callee(expr) @@ -1183,14 +1195,14 @@ def map_function_definition(self, expr: FunctionDefinition, self.post_visit(expr, *args, **kwargs) def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None: - if not self.visit(expr): + if not self.visit(expr, *args, **kwargs): return - self.map_function_definition(expr.function) + self.map_function_definition(expr.function, *args, **kwargs) for bnd in expr.bindings.values(): - self.rec(bnd) + self.rec(bnd, *args, **kwargs) - self.post_visit(expr) + self.post_visit(expr, *args, **kwargs) def map_named_call_result(self, expr: NamedCallResult, *args: Any, **kwargs: Any) -> None: @@ -1229,6 +1241,9 @@ def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any super().rec(expr, *args, **kwargs) self._visited_nodes.add(cache_key) + def clone_for_callee( + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: + return type(self)() # }}} @@ -1749,7 +1764,7 @@ def map_call(self, expr: Call, *args: Any) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - def map_named_call(self, expr: NamedCallResult, *args: Any) -> None: + def map_named_call_result(self, expr: NamedCallResult, *args: Any) -> None: assert isinstance(expr._container, Call) for bnd in expr._container.bindings.values(): self.node_to_users.setdefault(bnd, set()).add(expr)