From 2be17bb7e4508e37b23656d4d25086b8289cd009 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 9 Jan 2026 11:14:16 -0600 Subject: [PATCH 1/6] add CSR sparse matrix multiplication array --- .basedpyright/baseline.json | 200 +++------ pytato/__init__.py | 12 + pytato/analysis/__init__.py | 23 + pytato/array.py | 196 +++++++++ pytato/codegen.py | 1 + pytato/equality.py | 10 + pytato/stringifier.py | 1 + pytato/target/loopy/codegen.py | 399 ++++++++++++++---- pytato/target/python/numpy_like.py | 4 + pytato/transform/__init__.py | 114 ++++- pytato/transform/einsum_distributive_law.py | 29 ++ pytato/transform/lower_to_index_lambda.py | 37 ++ pytato/transform/materialize.py | 28 ++ pytato/transform/metadata.py | 10 + pytato/visualization/dot.py | 20 + .../fancy_placeholder_data_flow.py | 24 ++ test/test_codegen.py | 65 +++ test/test_pytato.py | 28 ++ 18 files changed, 970 insertions(+), 231 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index dd72c4312..e0accddc9 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2627,6 +2627,22 @@ "lineCount": 1 } }, + { + "code": "reportIncompatibleVariableOverride", + "range": { + "startColumn": 6, + "endColumn": 18, + "lineCount": 1 + } + }, + { + "code": "reportIncompatibleVariableOverride", + "range": { + "startColumn": 6, + "endColumn": 18, + "lineCount": 1 + } + }, { "code": "reportConstantRedefinition", "range": { @@ -2707,6 +2723,14 @@ "lineCount": 1 } }, + { + "code": "reportCallInDefaultInitializer", + "range": { + "startColumn": 43, + "endColumn": 54, + "lineCount": 1 + } + }, { "code": "reportAny", "range": { @@ -5841,6 +5865,14 @@ "lineCount": 1 } }, + { + "code": "reportUnannotatedClassAttribute", + "range": { + "startColumn": 4, + "endColumn": 18, + "lineCount": 1 + } + }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -7265,134 +7297,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 29, - "lineCount": 1 - } - }, { "code": "reportPrivateUsage", "range": { @@ -7401,14 +7305,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -11053,6 +10949,38 @@ "lineCount": 1 } }, + { + "code": "reportUnusedExpression", + "range": { + "startColumn": 4, + "endColumn": 9, + "lineCount": 1 + } + }, + { + "code": "reportUnusedExpression", + "range": { + "startColumn": 8, + "endColumn": 13, + "lineCount": 1 + } + }, + { + "code": "reportUnusedExpression", + "range": { + "startColumn": 8, + "endColumn": 13, + "lineCount": 1 + } + }, + { + "code": "reportUnusedExpression", + "range": { + "startColumn": 4, + "endColumn": 9, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { diff --git a/pytato/__init__.py b/pytato/__init__.py index d99432942..0e1906f67 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -57,6 +57,8 @@ def set_debug_enabled(flag: bool) -> None: AxisPermutation, BasicIndex, Concatenate, + CSRMatmul, + CSRMatrix, DataWrapper, DictOfNamedArrays, Einsum, @@ -70,6 +72,8 @@ def set_debug_enabled(flag: bool) -> None: Reshape, Roll, SizeParam, + SparseMatmul, + SparseMatrix, Stack, arange, broadcast_to, @@ -87,6 +91,7 @@ def set_debug_enabled(flag: bool) -> None: logical_and, logical_not, logical_or, + make_csr_matrix, make_data_wrapper, make_dict_of_named_arrays, make_placeholder, @@ -99,6 +104,7 @@ def set_debug_enabled(flag: bool) -> None: reshape, roll, set_traceback_tag_enabled, + sparse_matmul, squeeze, stack, transpose, @@ -179,6 +185,8 @@ def set_debug_enabled(flag: bool) -> None: "Axis", "AxisPermutation", "BasicIndex", + "CSRMatmul", + "CSRMatrix", "Concatenate", "DataWrapper", "DictOfNamedArrays", @@ -200,6 +208,8 @@ def set_debug_enabled(flag: bool) -> None: "Reshape", "Roll", "SizeParam", + "SparseMatmul", + "SparseMatrix", "Stack", "Target", "abs", @@ -247,6 +257,7 @@ def set_debug_enabled(flag: bool) -> None: "logical_and", "logical_not", "logical_or", + "make_csr_matrix", "make_data_wrapper", "make_dict_of_named_arrays", "make_distributed_recv", @@ -273,6 +284,7 @@ def set_debug_enabled(flag: bool) -> None: "show_fancy_placeholder_data_flow", "sin", "sinh", + "sparse_matmul", "sqrt", "squeeze", "stack", diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 00e36ce32..2626856ba 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -38,6 +38,7 @@ from pytato.array import ( Array, Concatenate, + CSRMatmul, DictOfNamedArrays, Einsum, IndexBase, @@ -155,6 +156,20 @@ def map_einsum(self, expr: Einsum) -> None: self.array_to_users[dim].append(expr) self.rec(dim) + def map_csr_matmul(self, expr: CSRMatmul) -> None: + for ary in ( + expr.matrix.elem_values, + expr.matrix.elem_col_indices, + expr.matrix.row_starts, + expr.array): + self.array_to_users[ary].append(expr) + self.rec(ary) + + for dim in expr.shape: + if isinstance(dim, Array): + self.array_to_users[dim].append(expr) + self.rec(dim) + def map_named_array(self, expr: NamedArray) -> None: self.rec(expr._container) @@ -378,6 +393,14 @@ def map_concatenate(self, expr: Concatenate) -> list[ArrayOrNames]: def map_einsum(self, expr: Einsum) -> list[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) + list(expr.args) + def map_csr_matmul(self, expr: CSRMatmul) -> list[ArrayOrNames]: + return [ + *self._get_preds_from_shape(expr.shape), + expr.matrix.elem_values, + expr.matrix.elem_col_indices, + expr.matrix.row_starts, + expr.array] + def map_loopy_call(self, expr: LoopyCall) -> list[ArrayOrNames]: return [ary for ary in expr.bindings.values() if isinstance(ary, Array)] diff --git a/pytato/array.py b/pytato/array.py index 73d9ebc6c..239d88c24 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -61,6 +61,7 @@ :mod:`numpy`, but not all NumPy features may be supported. .. autofunction:: matmul +.. autofunction:: sparse_matmul .. autofunction:: roll .. autofunction:: transpose .. autofunction:: stack @@ -129,6 +130,14 @@ .. autoclass:: Placeholder .. autoclass:: SizeParam +Sparse Matrices +^^^^^^^^^^^^^^^ + +.. autoclass:: SparseMatmul +.. autoclass:: SparseMatrix +.. autoclass:: CSRMatmul +.. autoclass:: CSRMatrix + .. currentmodule:: pytato User-Facing Node Creation @@ -148,6 +157,7 @@ .. autofunction:: make_placeholder .. autofunction:: make_size_param .. autofunction:: make_data_wrapper +.. autofunction:: make_csr_matrix Internal API ------------ @@ -659,6 +669,8 @@ class Array(Taggable): .. method:: __mul__ .. method:: __rmul__ + .. method:: __matmul__ + .. method:: __rmatmul__ .. method:: __add__ .. method:: __radd__ .. method:: __sub__ @@ -2213,6 +2225,127 @@ def dtype(self) -> np.dtype[Any]: # }}} +# {{{ sparse matrix multiply + +@opt_frozen_dataclass(eq=False, repr=False) +class SparseMatrix(_SuppliedAxesAndTagsMixin, _SuppliedShapeAndDtypeMixin, ABC): + """ + Abstract base class for sparse matrices. + + .. automethod:: __matmul__ + """ + if __debug__: + def __post_init__(self) -> None: + pass + + def __matmul__(self, other: Array) -> SparseMatmul: + return sparse_matmul(self, other) + + +@array_dataclass() +class SparseMatmul(_SuppliedAxesAndTagsMixin, Array, ABC): + """ + Abstract base class for sparse matrix multiplies. + + .. attribute:: shape + .. attribute:: dtype + """ + @property + @abstractmethod + def _matrix(self) -> SparseMatrix: + pass + + @property + @abstractmethod + def _array(self) -> Array: + pass + + @memoize_method + def _get_shape(self) -> ShapeType: + return (self._matrix.shape[0], *self._array.shape[1:]) + + @property + @override + def shape(self) -> ShapeType: + return self._get_shape() + + @memoize_method + def _get_dtype(self) -> np.dtype[Any]: + return np.result_type(self._matrix.dtype, self._array.dtype) + + @property + @override + def dtype(self) -> np.dtype[Any]: + return self._get_dtype() + + +@opt_frozen_dataclass(eq=False, repr=False) +class CSRMatrix(SparseMatrix): + """ + A sparse matrix in compressed sparse row (CSR) format. + + .. attribute:: elem_values + + A one-dimensional array containing the values of all of the nonzero entries + of the matrix, grouped by row. + + .. attribute:: elem_col_indices + + A one-dimensional array containing the column index values corresponding to + each entry in *elem_values*. + + .. attribute:: row_starts + + A one-dimensional array of length `nrows+1`, where each entry gives the + starting index in *elem_values* and *elem_col_indices* for the given row, + with the last entry being equal to `nrows`. + """ + elem_values: Array + elem_col_indices: Array + row_starts: Array + + if __debug__: + @override + def __post_init__(self) -> None: + if self.elem_values.ndim != 1: + raise ValueError("elem_values must be a 1D array.") + if self.elem_col_indices.ndim != 1: + raise ValueError("elem_col_indices must be a 1D array.") + if self.row_starts.ndim != 1: + raise ValueError("row_starts must be a 1D array.") + super().__post_init__() + + +@array_dataclass() +class CSRMatmul(SparseMatmul): + """ + A multiplication of a sparse matrix in compressed sparse row (CSR) format with + an array. + + .. attribute:: matrix + + The :class:`CSRMatrix` representing the sparse matrix to be applied. + + .. attribute:: array + + The :class:`Array` to which the sparse matrix is being applied. + """ + matrix: CSRMatrix + array: Array + + @property + @override + def _matrix(self) -> SparseMatrix: + return self.matrix + + @property + @override + def _array(self) -> Array: + return self.array + +# }}} + + # {{{ end-user facing def _get_default_axes(ndim: int) -> AxesT: @@ -2292,6 +2425,31 @@ def matmul(x1: Array, x2: Array) -> Array: return pt.einsum(f"{x1_indices}, {x2_indices} -> {result_indices}", x1, x2) +def sparse_matmul(x1: SparseMatrix, x2: Array) -> SparseMatmul: + """Sparse matrix multiplication. + + :param x1: first argument + :param x2: second argument + """ + if ( + isinstance(x2, SCALAR_CLASSES) + or x2.shape == ()): + raise ValueError("scalars not allowed as arguments to sparse_matmul") + + if x2.shape[0] != x1.shape[1]: + raise ValueError("argument shapes are incompatible") + + if isinstance(x1, CSRMatrix): + return CSRMatmul( + matrix=x1, + array=x2, + axes=_get_default_axes(x2.ndim), + tags=_get_default_tags(), + non_equality_tags=_get_created_at_tag()) + else: + raise ValueError(f"unknown sparse matrix type '{type(x1).__name__}'.") + + def roll(a: Array, shift: int, axis: int | None = None) -> Array: """Roll array elements along a given axis. @@ -2585,6 +2743,44 @@ def make_data_wrapper(data: DataInterface, return DataWrapper(data, shape, axes=axes, tags=(tags | _get_default_tags()), non_equality_tags=_get_created_at_tag(),) + +def make_csr_matrix(shape: ConvertibleToShape, + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + tags: frozenset[Tag] = frozenset(), + axes: AxesT | None = None) -> CSRMatrix: + """Make a :class:`CSRMatrix` object. + + :param shape: the shape of the matrix + :param elem_values: a one-dimensional array containing the values of all of the + nonzero entries of the matrix, grouped by row. + :param elem_col_indices: a one-dimensional array containing the column index + values corresponding to each entry in *elem_values*. + :param row_starts: a one-dimensional array of length `nrows+1`, where each entry + gives the starting index in *elem_values* and *elem_col_indices* for the + given row, with the last entry being equal to `nrows`. + """ + shape = normalize_shape(shape) + dtype = elem_values.dtype + + if axes is None: + axes = _get_default_axes(len(shape)) + + if len(axes) != len(shape): + raise ValueError("'axes' dimensionality mismatch:" + f" expected {len(shape)}, got {len(axes)}.") + + return CSRMatrix( + shape=shape, + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts, + dtype=dtype, + axes=axes, + tags=(tags | _get_default_tags()), + non_equality_tags=_get_created_at_tag(),) + # }}} diff --git a/pytato/codegen.py b/pytato/codegen.py index b0805bc76..3e079953a 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -129,6 +129,7 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] :class:`~pytato.array.Concatenate` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Einsum` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Stack` :class:`~pytato.array.IndexLambda` + :class:`~pytato.array.CSRMatmul` :class:`~pytato.array.IndexLambda` ====================================== ===================================== """ diff --git a/pytato/equality.py b/pytato/equality.py index 837530f02..920eb13fc 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -35,6 +35,7 @@ AxisPermutation, BasicIndex, Concatenate, + CSRMatmul, DataWrapper, DictOfNamedArrays, Einsum, @@ -241,6 +242,15 @@ def map_einsum(self, expr1: Einsum, expr2: Einsum) -> bool: and expr1.redn_axis_to_redn_descr == expr2.redn_axis_to_redn_descr ) + def map_csr_matmul(self, expr1: CSRMatmul, expr2: CSRMatmul) -> bool: + return (self.rec(expr1.matrix.elem_values, expr2.matrix.elem_values) + and self.rec( + expr1.matrix.elem_col_indices, expr2.matrix.elem_col_indices) + and self.rec(expr1.matrix.row_starts, expr2.matrix.row_starts) + and self.rec(expr1.array, expr2.array) + and expr1.tags == expr2.tags + and expr1.axes == expr2.axes) + def map_named_array(self, expr1: NamedArray, expr2: NamedArray) -> bool: return (self.rec(expr1._container, expr2._container) and expr1.tags == expr2.tags diff --git a/pytato/stringifier.py b/pytato/stringifier.py index e13989b91..de08ade1b 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -164,6 +164,7 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: map_non_contiguous_advanced_index = _map_generic_array map_reshape = _map_generic_array map_einsum = _map_generic_array + map_csr_matmul = _map_generic_array map_distributed_recv = _map_generic_array map_distributed_send_ref_holder = _map_generic_array diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 3e155441e..164b87930 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -27,11 +27,14 @@ import re import sys from abc import ABC, abstractmethod -from collections.abc import Mapping -from typing import TYPE_CHECKING +from collections.abc import Iterable, Mapping +from functools import reduce +from typing import TYPE_CHECKING, Any, cast import islpy as isl -from typing_extensions import Never +import numpy as np +from constantdict import constantdict +from typing_extensions import Never, override import loopy as lp import loopy.symbolic as lp_symbolic @@ -43,6 +46,7 @@ from pytato.array import ( AbstractResultWithNamedArrays, Array, + AxesT, DataWrapper, DictOfNamedArrays, IndexLambda, @@ -212,6 +216,7 @@ class LocalExpressionContext: local_namespace: Mapping[str, ImplementedResult] reduction_bounds: ReductionBounds var_to_reduction_descr: Mapping[str, ReductionDescriptor] + var_to_reduction_unique_name: Mapping[str, str] def lookup(self, name: str) -> ImplementedResult: return self.local_namespace[name] @@ -221,6 +226,7 @@ def copy(self, *, num_indices: int | None = None, local_namespace: Mapping[str, ImplementedResult] | None = None, var_to_reduction_descr: Mapping[str, ReductionDescriptor] | None = None, + var_to_reduction_unique_name: Mapping[str, str] | None = None, ) -> LocalExpressionContext: if reduction_bounds is None: reduction_bounds = self.reduction_bounds @@ -230,10 +236,14 @@ def copy(self, *, local_namespace = self.local_namespace if var_to_reduction_descr is None: var_to_reduction_descr = self.var_to_reduction_descr - return LocalExpressionContext(reduction_bounds=reduction_bounds, - num_indices=num_indices, - local_namespace=local_namespace, - var_to_reduction_descr=var_to_reduction_descr) + if var_to_reduction_unique_name is None: + var_to_reduction_unique_name = self.var_to_reduction_unique_name + return LocalExpressionContext( + reduction_bounds=reduction_bounds, + num_indices=num_indices, + local_namespace=local_namespace, + var_to_reduction_descr=var_to_reduction_descr, + var_to_reduction_unique_name=var_to_reduction_unique_name) # }}} @@ -295,7 +305,7 @@ class InlinedResult(ImplementedResult): See also: :class:`pytato.tags.ImplInlined`. """ - def __init__(self, expr: ScalarExpression, + def __init__(self, expr: Expression, num_indices: int, depends_on: frozenset[str]): self.expr = expr @@ -412,7 +422,7 @@ def map_size_param(self, expr: SizeParam, arg = lp.ValueArg(expr.name, dtype=expr.dtype, - tags=_filter_tags_not_of_type(expr, + tags=_filter_tags_not_of_type(expr.tags, self .array_tag_t_to_not_propagate )) @@ -436,7 +446,7 @@ def map_placeholder(self, expr: Placeholder, arg: lp.ArrayArg | lp.ValueArg = lp.ValueArg(expr.name, dtype=expr.dtype, - tags=_filter_tags_not_of_type(expr, + tags=_filter_tags_not_of_type(expr.tags, self .array_tag_t_to_not_propagate)) else: @@ -448,7 +458,7 @@ def map_placeholder(self, expr: Placeholder, offset=lp.auto, is_input=True, is_output=False, - tags=_filter_tags_not_of_type(expr, + tags=_filter_tags_not_of_type(expr.tags, self .array_tag_t_to_not_propagate)) @@ -464,6 +474,25 @@ def map_index_lambda(self, expr: IndexLambda, if expr in state.results: return state.results[expr] + idx_expr = expr.expr + + reductions = ReductionCollector()(idx_expr) + + var_to_reduction = { + var_name: redn + for redn in reductions + for var_name in redn.bounds} + + var_to_reduction_unique_name: Mapping[str, str] = {} + for var_name in expr.var_to_reduction_descr: + redn = var_to_reduction[var_name] + try: + loopy_redn_op = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[type(redn.op)] + except KeyError as err: + raise NotImplementedError(redn.op) from err + var_to_reduction_unique_name[var_name] = \ + state.var_name_gen(f"_pt_{loopy_redn_op}" + var_name) + prstnt_ctx = PersistentExpressionContext(state) local_ctx = LocalExpressionContext( local_namespace={ @@ -471,24 +500,106 @@ def map_index_lambda(self, expr: IndexLambda, for name in sorted(expr.bindings)}, num_indices=expr.ndim, reduction_bounds={}, - var_to_reduction_descr=expr.var_to_reduction_descr) - loopy_expr = self.exprgen_mapper(expr.expr, prstnt_ctx, local_ctx) + var_to_reduction_descr=expr.var_to_reduction_descr, + var_to_reduction_unique_name=var_to_reduction_unique_name) + + loopy_shape = shape_to_scalar_expression(expr.shape, self, state) + + # If the scalar expression contains any reductions with bounds expressions + # that index into a binding, need to store the results of those expressions + # as scalar temporaries + subscript_detector = SubscriptDetector() + + redn_bounds = { + var_name: redn.bounds[var_name] + for var_name, redn in var_to_reduction.items()} + + # FIXME: Forcing storage of expressions containing processed reductions for + # now; attempting to generalize to unmaterialized expressions would require + # handling at least two complications: + # 1) final inames aren't assigned until the expression is stored, so any + # temporary variables defined below would need to be finalized at that + # point, not here + # 2) lp.make_reduction_inames_unique does not rename the temporaries + # created below, so something would need to be done to make them unique + # across all index lambda evaluations. + store_result = expr.tags_of_type(ImplStored) or any( + subscript_detector(bound) + for bounds in redn_bounds.values() + for bound in bounds) + + name: str | None = None + inames: tuple[str, ...] | None = None + if store_result: + name = _generate_name_for_temp(expr, state.var_name_gen) + + inames = tuple( + state.var_name_gen(f"{name}_dim{d}") + for d in range(expr.ndim)) + + from pytato.utils import are_shape_components_equal + result_is_empty = any( + are_shape_components_equal(s_i, 0) for s_i in expr.shape) + if not result_is_empty: + domain = domain_for_shape(inames, loopy_shape, {}) + state.update_kernel( + state.kernel.copy(domains=[*state.kernel.domains, domain])) + + redn_bound_temps: dict[str, ImplementedResult] = {} + new_redn_bounds: dict[ + str, tuple[ArithmeticExpression, ArithmeticExpression]] = {} + bound_prefixes = ("l", "u") + for var_name, bounds in redn_bounds.items(): + new_bounds_list: list[ArithmeticExpression] = [] + for bound_prefix, bound in zip(bound_prefixes, bounds, strict=True): + if subscript_detector(bound): + unique_name = var_to_reduction_unique_name[var_name] + bound_name = f"{unique_name}_{bound_prefix}bound" + loopy_bound = self.exprgen_mapper( + bound, prstnt_ctx, local_ctx) + bound_result: ImplementedResult = InlinedResult( + loopy_bound, expr.ndim, prstnt_ctx.depends_on) + bound_result = StoredResult( + bound_name, 0, frozenset([ + add_store( + bound_name, (), np.dtype(np.int64), bound_result, + state, self, output_to_temporary=True, + store_inames=(), + result_inames=inames, add_domain=False)])) + redn_bound_temps[bound_name] = bound_result + new_bound = prim.Variable(bound_name) + else: + new_bound = bound + new_bounds_list.append(new_bound) + new_redn_bounds[var_name] = cast( + "tuple[ArithmeticExpression, ArithmeticExpression]", + tuple(new_bounds_list)) + + new_namespace = dict(local_ctx.local_namespace) + new_namespace.update(redn_bound_temps) + local_ctx = local_ctx.copy(local_namespace=new_namespace) + + idx_expr = ReductionBoundsReplacer(new_redn_bounds)(idx_expr) + + loopy_expr = self.exprgen_mapper(idx_expr, prstnt_ctx, local_ctx) assert not isinstance(loopy_expr, tuple) result: ImplementedResult = InlinedResult(loopy_expr, expr.ndim, prstnt_ctx.depends_on) - shape_to_scalar_expression(expr.shape, self, state) # walk over size params - # {{{ implementation tag - if expr.tags_of_type(ImplStored): - name = _generate_name_for_temp(expr, state.var_name_gen) - result = StoredResult(name, expr.ndim, - frozenset([add_store(name, expr, - result, state, - self, True)])) + if store_result: + assert name is not None + assert inames is not None + result = StoredResult( + name, expr.ndim, frozenset([ + add_store( + name, expr.shape, expr.dtype, result, state, self, + tags=expr.tags, axes=expr.axes, output_to_temporary=True, + store_inames=inames, result_inames=inames, + add_domain=False)])) elif expr.tags_of_type(ImplInlined): # inlined results are automatically handled pass @@ -496,7 +607,7 @@ def map_index_lambda(self, expr: IndexLambda, subst_name = _generate_name_for_temp(expr, state.var_name_gen, default_prefix="_pt_subst") - add_substitution(subst_name, expr, result, state, self) + add_substitution(subst_name, expr.ndim, result, state, self) result = SubstitutionRuleResult(subst_name, expr.ndim, prstnt_ctx.depends_on) elif expr.tags_of_type(ImplementationStrategy): @@ -517,8 +628,9 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, for key in sorted(expr.keys()): subexpr = expr[key].expr name = _generate_name_for_temp(subexpr, state.var_name_gen) - insn_id = add_store(name, subexpr, self.rec(subexpr, state), state, - output_to_temporary=True, cgen_mapper=self) + insn_id = add_store( + name, subexpr.shape, subexpr.dtype, self.rec(subexpr, state), state, + self, tags=subexpr.tags, axes=subexpr.axes, output_to_temporary=True) state.results[subexpr] = state.results[expr[key]] = ( StoredResult(name, subexpr.ndim, frozenset([insn_id]))) @@ -581,9 +693,9 @@ def _get_sub_array_ref(array: Array, name: str) -> lp_symbolic.SubArrayRef: # record the result for the corresponding loopy array state.results[named_array] = result - new_tvs[assignee_name] = get_loopy_temporary(assignee_name, - named_array, - self, state) + new_tvs[assignee_name] = get_loopy_temporary( + assignee_name, named_array.shape, named_array.dtype, + self, state, tags=named_array.tags) else: assert arg.is_input pt_arg = expr.bindings[arg.name] @@ -600,17 +712,18 @@ def _get_sub_array_ref(array: Array, name: str) -> lp_symbolic.SubArrayRef: # did not find a stored result for the sub-expression, store # it and then pass it to the call name = _generate_name_for_temp(pt_arg, state.var_name_gen) - store_insn_id = add_store(name, pt_arg, - pt_arg_rec, - state, output_to_temporary=True, - cgen_mapper=self) + store_insn_id = add_store( + name, pt_arg.shape, pt_arg.dtype, pt_arg_rec, state, self, + tags=pt_arg.tags, axes=pt_arg.axes, + output_to_temporary=True) depends_on.add(store_insn_id) # replace "arg" with the created stored variable state.results[pt_arg] = StoredResult(name, pt_arg.ndim, frozenset([store_insn_id])) params.append(_get_sub_array_ref(pt_arg, name)) - new_tvs[name] = get_loopy_temporary(name, pt_arg, - self, state) + new_tvs[name] = get_loopy_temporary( + name, pt_arg.shape, pt_arg.dtype, self, state, + tags=pt_arg.tags) else: assert isinstance(arg, lp.ValueArg) and arg.is_input pt_arg = expr.bindings[arg.name] @@ -625,7 +738,8 @@ def _get_sub_array_ref(array: Array, name: str) -> lp_symbolic.SubArrayRef: local_ctx = LocalExpressionContext(reduction_bounds={}, num_indices=0, local_namespace={}, - var_to_reduction_descr={}) + var_to_reduction_descr={}, + var_to_reduction_unique_name={}) params.append(self.exprgen_mapper(pt_arg, prstnt_ctx, local_ctx)) @@ -679,6 +793,79 @@ def map_call(self, expr: Call, state: CodeGenState) -> None: } +class SubscriptDetector(scalar_expr.CombineMapper[bool, []]): + """Returns *True* if a scalar expression contains any subscripts.""" + @override + def combine(self, values: Iterable[bool]) -> bool: + return any(values) + + @override + def map_algebraic_leaf(self, expr: prim.AlgebraicLeaf) -> bool: + return False + + @override + def map_subscript(self, expr: prim.Subscript) -> bool: + return True + + @override + def map_constant(self, expr: object) -> bool: + return False + + +class ReductionCollector(scalar_expr.CombineMapper[frozenset[scalar_expr.Reduce], []]): + """ + Constructs a :class:`frozenset` containing all instances of + :class:`pytato.scalar_expr.Reduce` found in a scalar expression. + """ + @override + def combine( + self, values: Iterable[frozenset[scalar_expr.Reduce]] + ) -> frozenset[scalar_expr.Reduce]: + return reduce( + lambda x, y: x.union(y), + values, + cast("frozenset[scalar_expr.Reduce]", frozenset())) + + @override + def map_algebraic_leaf( + self, expr: prim.AlgebraicLeaf) -> frozenset[scalar_expr.Reduce]: + return frozenset() + + @override + def map_constant(self, expr: object) -> frozenset[scalar_expr.Reduce]: + return frozenset() + + @override + def map_reduce(self, expr: scalar_expr.Reduce) -> frozenset[scalar_expr.Reduce]: + return self.combine([ + frozenset([expr]), + *( + self.rec(bnd) + for _, bnd in sorted(expr.bounds.items())), + self.rec(expr.inner_expr)]) + + +class ReductionBoundsReplacer(scalar_expr.IdentityMapper[[]]): + """ + Replaces the expressions for the bounds of :class:`pytato.scalar_expr.Reduce` + instances in a scalar expression with those provided in *new_reduction_bounds*. + """ + def __init__(self, new_redn_bounds: ReductionBounds): + super().__init__() + self.new_reduction_bounds: ReductionBounds = new_redn_bounds + + @override + def map_reduce(self, expr: scalar_expr.Reduce) -> scalar_expr.Reduce: + new_bounds: ReductionBounds = constantdict({ + name: ( + self.rec_arith( + cast("ArithmeticExpression", self.new_reduction_bounds[name][0])), + self.rec_arith( + cast("ArithmeticExpression", self.new_reduction_bounds[name][1]))) + for name in expr.bounds}) + return dataclasses.replace(expr, bounds=new_bounds) + + class InlinedExpressionGenMapper( scalar_expr.IdentityMapper[ [PersistentExpressionContext, LocalExpressionContext]]): @@ -754,48 +941,49 @@ def map_reduce(self, expr: scalar_expr.Reduce, state = prstnt_ctx.state try: - loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[type(expr.op)] + loopy_redn_op = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[type(expr.op)] except KeyError as err: raise NotImplementedError(expr.op) from err - unique_names_mapping = { - old_name: state.var_name_gen(f"_pt_{loopy_redn}" + old_name) - for old_name in expr.bounds} + inner_expr = loopy_substitute( + expr.inner_expr, + { + var_name: prim.Variable(new_var_name) + for var_name, new_var_name in + local_ctx.var_to_reduction_unique_name.items()}) - inner_expr = loopy_substitute(expr.inner_expr, - {k: prim.Variable(v) - for k, v in unique_names_mapping.items()}) - new_bounds = {unique_names_mapping[name]: bound_exprs - for name, bound_exprs in expr.bounds.items()} + renamed_bounds = { + local_ctx.var_to_reduction_unique_name[var_name]: bound_exprs + for var_name, bound_exprs in expr.bounds.items()} inner_expr = self.rec(inner_expr, prstnt_ctx, - local_ctx.copy(reduction_bounds=new_bounds)) + local_ctx.copy(reduction_bounds=renamed_bounds)) - inner_expr = LoopyReduction(loopy_redn, - tuple(unique_names_mapping.values()), + loopy_expr = LoopyReduction(loopy_redn_op, + tuple(renamed_bounds.keys()), inner_expr) domain = domain_for_shape((), shape=(), reductions={ - redn_iname: ( + var_name: ( self.rec_arith(lbound, prstnt_ctx, local_ctx), self.rec_arith(ubound, prstnt_ctx, local_ctx), ) - for redn_iname, (lbound, ubound) in new_bounds.items()}) + for var_name, (lbound, ubound) in renamed_bounds.items()}) kernel = state.kernel state.update_kernel(kernel.copy(domains=[*kernel.domains, domain])) # {{{ pytato tags -> loopy tags - for name_in_expr, name_in_kernel in sorted(unique_names_mapping.items()): - for tag in local_ctx.var_to_reduction_descr[name_in_expr].tags: + for old_var_name, var_name in sorted( + local_ctx.var_to_reduction_unique_name.items()): + for tag in local_ctx.var_to_reduction_descr[old_var_name].tags: if all(not isinstance(tag, tag_t) for tag_t in self.axis_tag_t_to_not_propagate): - state.update_kernel(lp.tag_inames(state.kernel, - {name_in_kernel: tag})) + state.update_kernel(lp.tag_inames(state.kernel, {var_name: tag})) # }}} - return inner_expr + return loopy_expr def map_type_cast( self, expr: TypeCast, @@ -895,65 +1083,93 @@ def domain_for_shape(dim_names: tuple[str, ...], return dom -def _filter_tags_not_of_type(expr: Array, +def _filter_tags_not_of_type(tags: frozenset[Tag], ignore_tag_t: frozenset[type[Tag]] ) -> frozenset[Tag]: return frozenset(tag - for tag in expr.tags + for tag in tags if not isinstance(tag, tuple(ignore_tag_t))) -def add_store(name: str, expr: Array, result: ImplementedResult, - state: CodeGenState, cgen_mapper: CodeGenMapper, - output_to_temporary: bool = False) -> str: +def add_store( + name: str, shape: ShapeType, dtype: np.dtype[Any], result: ImplementedResult, + state: CodeGenState, cgen_mapper: CodeGenMapper, *, + tags: frozenset[Tag] | None = None, axes: AxesT | None = None, + output_to_temporary: bool = False, result_inames: tuple[str, ...] | None = None, + store_inames: tuple[str, ...] | None = None, add_domain: bool = True) -> str: """Add an instruction that stores to a variable in the kernel. :param name: name of the output array, which is created - :param expr: the :class:`~pytato.Array` to store + :param shape: the shape of the output array + :param dtype: the data type of the output array :param result: the corresponding :class:`ImplementedResult` :param state: code generation state + :param tags: the tags of the output array + :param axes: the axes of the output array :param output_to_temporary: whether to generate an output argument (default) or a temporary variable + :param store_inames: the index inames of the left hand side of the assignment; + must be a subset of *result_inames* + :param result_inames: the index inames of the right hand side of the assignment + :param add_domain: add a new domain to the kernel for these inames/shape. :returns: the id of the generated instruction """ - # Get expression. - inames = tuple( + if tags is None: + tags = frozenset() + + if store_inames is None and result_inames is None: + result_inames = tuple( state.var_name_gen(f"{name}_dim{d}") - for d in range(expr.ndim)) - indices = tuple(prim.Variable(iname) for iname in inames) + for d in range(len(shape))) + store_inames = result_inames + else: + if store_inames is None or result_inames is None: + raise ValueError( + "must specify either both store_inames and result_inames or neither.") + + if not (frozenset(store_inames) <= frozenset(result_inames)): + raise ValueError("store_inames must be a subset of result_inames") + + store_indices = tuple(prim.Variable(iname) for iname in store_inames) + result_indices = tuple(prim.Variable(iname) for iname in result_inames) + + # Get expression. loopy_expr_context = PersistentExpressionContext(state) - loopy_expr = result.to_loopy_expression(indices, loopy_expr_context) + loopy_expr = result.to_loopy_expression(result_indices, loopy_expr_context) # Make the instruction from loopy.kernel.instruction import make_assignment - assignee = prim.Variable(name)[indices] if indices else prim.Variable(name) + assignee = ( + prim.Variable(name)[store_indices] if store_indices else prim.Variable(name)) insn_id = state.insn_id_gen(f"{name}_store") insn = make_assignment((assignee,), loopy_expr, id=insn_id, - within_inames=frozenset(inames), + within_inames=frozenset(result_inames), depends_on=loopy_expr_context.depends_on) - shape = shape_to_scalar_expression(expr.shape, cgen_mapper, state) - - # Get the domain. - domain = domain_for_shape(inames, shape, {}) + loopy_shape = shape_to_scalar_expression(shape, cgen_mapper, state) from pytato.utils import are_shape_components_equal - result_is_empty = any(are_shape_components_equal(s_i, 0) for s_i in expr.shape) + result_is_empty = any(are_shape_components_equal(s_i, 0) for s_i in shape) if result_is_empty: # empty array, no need to do computation additional_domains = [] additional_insns = [] else: - additional_domains = [domain] + if add_domain: + # Get the domain. + domain = domain_for_shape(result_inames, loopy_shape, {}) + additional_domains = [domain] + else: + additional_domains = [] additional_insns = [insn] # Update the kernel. kernel = state.kernel if output_to_temporary: - tvar = get_loopy_temporary(name, expr, cgen_mapper, state) + tvar = get_loopy_temporary(name, shape, dtype, cgen_mapper, state, tags=tags) temporary_variables = dict(kernel.temporary_variables) temporary_variables[name] = tvar kernel = kernel.copy(temporary_variables=temporary_variables, @@ -961,12 +1177,12 @@ def add_store(name: str, expr: Array, result: ImplementedResult, instructions=[*kernel.instructions, *additional_insns]) else: arg = lp.GlobalArg(name, - shape=shape, - dtype=expr.dtype, + shape=loopy_shape, + dtype=dtype, order="C", is_input=False, is_output=True, - tags=_filter_tags_not_of_type(expr, + tags=_filter_tags_not_of_type(tags, cgen_mapper .array_tag_t_to_not_propagate)) kernel = kernel.copy(args=[*kernel.args, arg], @@ -975,8 +1191,8 @@ def add_store(name: str, expr: Array, result: ImplementedResult, # {{{ axes tags -> iname tags - if not result_is_empty: - for axis, iname in zip(expr.axes, inames, strict=True): + if not result_is_empty and axes is not None: + for axis, iname in zip(axes, result_inames, strict=True): for tag in axis.tags: if all(not isinstance(tag, tag_t) for tag_t in cgen_mapper.axis_tag_t_to_not_propagate): @@ -988,20 +1204,21 @@ def add_store(name: str, expr: Array, result: ImplementedResult, return insn_id -def add_substitution(subst_name: str, expr: Array, result: ImplementedResult, +def add_substitution(subst_name: str, ndim: int, result: ImplementedResult, state: CodeGenState, cgen_mapper: CodeGenMapper) -> None: """Add a :class:`~loopy.kernel.data.SubstitutionRule` to the kernel being built in *state*. The substitution rule that will be introduced with take the indices - of array expression *expr*'s as arguments and return the value for the index. + of an array expression of shape *ndim* as arguments and return the value for the + index. """ # Get expression. - indices = tuple(prim.Variable(f"_{idim}") for idim in range(expr.ndim)) + indices = tuple(prim.Variable(f"_{idim}") for idim in range(ndim)) loopy_expr_context = PersistentExpressionContext(state) loopy_expr = result.to_loopy_expression(indices, loopy_expr_context) # Make the substitution rule subst_rule = lp.SubstitutionRule(subst_name, - tuple(f"_{idim}" for idim in range(expr.ndim)), + tuple(f"_{idim}" for idim in range(ndim)), loopy_expr) # Update the kernel. @@ -1013,15 +1230,19 @@ def add_substitution(subst_name: str, expr: Array, result: ImplementedResult, state.update_kernel(kernel) -def get_loopy_temporary(name: str, expr: Array, cgen_mapper: CodeGenMapper, - state: CodeGenState) -> lp.TemporaryVariable: +def get_loopy_temporary( + name: str, shape: ShapeType, dtype: np.dtype[Any], + cgen_mapper: CodeGenMapper, state: CodeGenState, *, + tags: frozenset[Tag] | None = None) -> lp.TemporaryVariable: + if tags is None: + tags = frozenset() # always allocating to global address space to avoid stack overflow address_space = lp.AddressSpace.GLOBAL return lp.TemporaryVariable(name, - shape=shape_to_scalar_expression(expr.shape, cgen_mapper, state), - dtype=expr.dtype, + shape=shape_to_scalar_expression(shape, cgen_mapper, state), + dtype=dtype, address_space=address_space, - tags=_filter_tags_not_of_type(expr, + tags=_filter_tags_not_of_type(tags, cgen_mapper .array_tag_t_to_not_propagate)) @@ -1149,7 +1370,9 @@ def generate_loopy(result: Array | AbstractResultWithNamedArrays | dict[str, Arr # Generate code for outputs. for name in compute_order: expr = outputs[name].expr - insn_id = add_store(name, expr, cg_mapper(expr, state), state, cg_mapper) + insn_id = add_store( + name, expr.shape, expr.dtype, cg_mapper(expr, state), + state, cg_mapper, tags=expr.tags, axes=expr.axes) # replace "expr" with the created stored variable state.results[expr] = StoredResult(name, expr.ndim, frozenset([insn_id])) diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 9aa610fda..dcc687eba 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -47,6 +47,7 @@ ArrayOrScalar, AxisPermutation, Concatenate, + CSRMatmul, DataInterface, DataWrapper, DictOfNamedArrays, @@ -528,6 +529,9 @@ def map_einsum(self, expr: Einsum) -> str: return self._record_line_and_return_lhs(lhs, rhs) + def map_csr_matmul(self, expr: CSRMatmul) -> str: + raise NotImplementedError("CSRMatmul not yet supported in numpy-like targets.") + def map_reshape(self, expr: Reshape) -> str: lhs = self.vng("_pt_tmp") if not all(isinstance(d, int) for d in expr.shape): diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index a5d6b62e9..6a745f7b4 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -26,6 +26,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import dataclasses import logging from collections.abc import Hashable, Mapping from typing import ( @@ -54,6 +55,7 @@ AxisPermutation, BasicIndex, Concatenate, + CSRMatmul, DataInterface, DataWrapper, DictOfNamedArrays, @@ -900,6 +902,29 @@ def map_einsum(self, expr: Einsum) -> Array: new_args = tuple(_verify_is_array(self.rec(arg)) for arg in expr.args) return expr.replace_if_different(args=new_args) + def map_csr_matmul(self, expr: CSRMatmul) -> Array: + new_matrix_elem_values = _verify_is_array( + self.rec(expr.matrix.elem_values)) + new_matrix_elem_col_indices = _verify_is_array( + self.rec(expr.matrix.elem_col_indices)) + new_matrix_row_starts = _verify_is_array( + self.rec(expr.matrix.row_starts)) + if ( + new_matrix_elem_values is not expr.matrix.elem_values + or new_matrix_elem_col_indices is not expr.matrix.elem_col_indices + or new_matrix_row_starts is not expr.matrix.row_starts): + new_matrix = dataclasses.replace( + expr.matrix, + elem_values=new_matrix_elem_values, + elem_col_indices=new_matrix_elem_col_indices, + row_starts=new_matrix_row_starts) + else: + new_matrix = expr.matrix + new_array = _verify_is_array(self.rec(expr.array)) + return expr.replace_if_different( + matrix=new_matrix, + array=new_array) + def map_named_array(self, expr: NamedArray) -> Array: new_container = self.rec(expr._container) assert isinstance(new_container, AbstractResultWithNamedArrays) @@ -1070,6 +1095,30 @@ def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> Array: _verify_is_array(self.rec(arg, *args, **kwargs)) for arg in expr.args) return expr.replace_if_different(args=new_args) + def map_csr_matmul( + self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> Array: + new_matrix_elem_values = _verify_is_array( + self.rec(expr.matrix.elem_values, *args, **kwargs)) + new_matrix_elem_col_indices = _verify_is_array( + self.rec(expr.matrix.elem_col_indices, *args, **kwargs)) + new_matrix_row_starts = _verify_is_array( + self.rec(expr.matrix.row_starts, *args, **kwargs)) + if ( + new_matrix_elem_values is not expr.matrix.elem_values + or new_matrix_elem_col_indices is not expr.matrix.elem_col_indices + or new_matrix_row_starts is not expr.matrix.row_starts): + new_matrix = dataclasses.replace( + expr.matrix, + elem_values=new_matrix_elem_values, + elem_col_indices=new_matrix_elem_col_indices, + row_starts=new_matrix_row_starts) + else: + new_matrix = expr.matrix + new_array = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + return expr.replace_if_different( + matrix=new_matrix, + array=new_array) + def map_named_array(self, expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: new_container = self.rec(expr._container, *args, **kwargs) @@ -1259,6 +1308,13 @@ def map_einsum(self, expr: Einsum) -> ResultT: return self.combine(*(self.rec(ary) for ary in expr.args)) + def map_csr_matmul(self, expr: CSRMatmul) -> ResultT: + return self.combine( + self.rec(expr.matrix.elem_values), + self.rec(expr.matrix.elem_col_indices), + self.rec(expr.matrix.row_starts), + self.rec(expr.array)) + def map_named_array(self, expr: NamedArray) -> ResultT: return self.combine(self.rec(expr._container)) @@ -1317,61 +1373,82 @@ def combine(self, *args: R) -> R: from functools import reduce return reduce(lambda a, b: a | b, args, frozenset()) + @override def map_index_lambda(self, expr: IndexLambda) -> R: return self.combine(frozenset([expr]), super().map_index_lambda(expr)) + @override def map_placeholder(self, expr: Placeholder) -> R: return self.combine(frozenset([expr]), super().map_placeholder(expr)) + @override def map_data_wrapper(self, expr: DataWrapper) -> R: return self.combine(frozenset([expr]), super().map_data_wrapper(expr)) def map_size_param(self, expr: SizeParam) -> R: return frozenset([expr]) + @override def map_stack(self, expr: Stack) -> R: return self.combine(frozenset([expr]), super().map_stack(expr)) + @override def map_roll(self, expr: Roll) -> R: return self.combine(frozenset([expr]), super().map_roll(expr)) + @override def map_axis_permutation(self, expr: AxisPermutation) -> R: return self.combine(frozenset([expr]), super().map_axis_permutation(expr)) + @override def _map_index_base(self, expr: IndexBase) -> R: return self.combine(frozenset([expr]), super()._map_index_base(expr)) + @override def map_reshape(self, expr: Reshape) -> R: return self.combine(frozenset([expr]), super().map_reshape(expr)) + @override def map_concatenate(self, expr: Concatenate) -> R: return self.combine(frozenset([expr]), super().map_concatenate(expr)) + @override def map_einsum(self, expr: Einsum) -> R: return self.combine(frozenset([expr]), super().map_einsum(expr)) + @override + def map_csr_matmul(self, expr: CSRMatmul) -> R: + return self.combine(frozenset([expr]), super().map_csr_matmul(expr)) + + @override def map_named_array(self, expr: NamedArray) -> R: return self.combine(frozenset([expr]), super().map_named_array(expr)) + @override def map_loopy_call_result(self, expr: LoopyCallResult) -> R: return self.combine(frozenset([expr]), super().map_loopy_call_result(expr)) + @override def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> R: return self.combine( frozenset([expr]), super().map_distributed_send_ref_holder(expr)) + @override def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) + @override def map_call(self, expr: Call) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. return self.combine(*[self.rec(bnd) for bnd in expr.bindings.values()]) + @override def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) + @override def clone_for_callee(self, function: FunctionDefinition) -> Self: raise AssertionError("Control shouldn't reach this point.") @@ -1608,6 +1685,18 @@ def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> None: self.post_visit(expr, *args, **kwargs) + def map_csr_matmul( + self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.matrix.elem_values, *args, **kwargs) + self.rec(expr.matrix.elem_col_indices, *args, **kwargs) + self.rec(expr.matrix.row_starts, *args, **kwargs) + self.rec(expr.array, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): @@ -1917,13 +2006,6 @@ def map_named_array(self, expr: NamedArray) -> None: self.node_to_users.setdefault(expr._container, set()).add(expr) self.rec(expr._container) - def map_einsum(self, expr: Einsum) -> None: - for arg in expr.args: - self.node_to_users.setdefault(arg, set()).add(expr) - self.rec(arg) - - self.rec_idx_or_size_tuple(expr, expr.shape) - def map_reshape(self, expr: Reshape) -> None: self.rec_idx_or_size_tuple(expr, expr.shape) @@ -1986,6 +2068,24 @@ def map_non_contiguous_advanced_index(self, ) -> None: self._map_index_base(expr) + def map_einsum(self, expr: Einsum) -> None: + for arg in expr.args: + self.node_to_users.setdefault(arg, set()).add(expr) + self.rec(arg) + + self.rec_idx_or_size_tuple(expr, expr.shape) + + def map_csr_matmul(self, expr: CSRMatmul) -> None: + for child in ( + expr.matrix.elem_values, + expr.matrix.elem_col_indices, + expr.matrix.row_starts, + expr.array): + self.node_to_users.setdefault(child, set()).add(expr) + self.rec(child) + + self.rec_idx_or_size_tuple(expr, expr.shape) + def map_loopy_call(self, expr: LoopyCall) -> None: for _, child in sorted(expr.bindings.items()): if isinstance(child, Array): diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 28406bef6..1050d3180 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -45,6 +45,7 @@ AxesT, AxisPermutation, Concatenate, + CSRMatmul, Einsum, EinsumAxisDescriptor, EinsumReductionAxis, @@ -149,6 +150,8 @@ def _can_hlo_be_distributed(hlo: HighLevelOp) -> bool: hlo.x2.shape)))) +# FIXME: This mapper still needs to be updated to avoid duplicating arrays (see +# https://github.com/inducer/pytato/pull/515). class EinsumDistributiveLawMapper( TransformMapperWithExtraArgs[ [_EinsumDistributiveLawMapperContext | None]]): @@ -272,6 +275,32 @@ def map_einsum(self, return _wrap_einsum_from_ctx(rec_expr, ctx) + def map_csr_matmul(self, + expr: CSRMatmul, + ctx: _EinsumDistributiveLawMapperContext | None) -> Array: + rec_matrix_elem_values = _verify_is_array( + self.rec(expr.matrix.elem_values, None)) + rec_matrix_elem_col_indices = _verify_is_array( + self.rec(expr.matrix.elem_col_indices, None)) + rec_matrix_row_starts = _verify_is_array( + self.rec(expr.matrix.row_starts, None)) + if ( + rec_matrix_elem_values is not expr.matrix.elem_values + or rec_matrix_elem_col_indices is not expr.matrix.elem_col_indices + or rec_matrix_row_starts is not expr.matrix.row_starts): + rec_matrix = dataclasses.replace( + expr.matrix, + elem_values=rec_matrix_elem_values, + elem_col_indices=rec_matrix_elem_col_indices, + row_starts=rec_matrix_row_starts) + else: + rec_matrix = expr.matrix + rec_array = _verify_is_array(self.rec(expr.array, None)) + rec_expr = expr.replace_if_different( + matrix=rec_matrix, + array=rec_array) + return _wrap_einsum_from_ctx(rec_expr, ctx) + def map_stack(self, expr: Stack, ctx: _EinsumDistributiveLawMapperContext | None) -> Array: diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 43468f309..3537b0486 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -46,10 +46,12 @@ AxisPermutation, BasicIndex, Concatenate, + CSRMatmul, Einsum, IndexExpr, IndexLambda, NormalizedSlice, + ReductionDescriptor, Reshape, Roll, ShapeComponent, @@ -711,6 +713,41 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_csr_matmul(self, expr: CSRMatmul) -> IndexLambda: + rec_matrix_elem_values = self.rec(expr.matrix.elem_values) + rec_matrix_elem_col_indices = self.rec(expr.matrix.elem_col_indices) + rec_matrix_row_starts = self.rec(expr.matrix.row_starts) + rec_array = self.rec(expr.array) + + from pytato.reductions import SumReductionOperation + from pytato.scalar_expr import Reduce + index_expr = Reduce( + prim.Variable("_in0")[prim.Variable("_r0"),] + * prim.Variable("_in3")[( + prim.Variable("_in1")[prim.Variable("_r0"),], + *( + prim.Variable(f"_{idim}") + for idim in range(1, rec_array.ndim)))], + SumReductionOperation(), + constantdict({ + "_r0": ( + prim.Variable("_in2")[prim.Variable("_0"),], + prim.Variable("_in2")[prim.Variable("_0") + 1,])})) + + return IndexLambda(expr=index_expr, + shape=self.rec_size_tuple(expr.shape), + dtype=expr.dtype, + bindings=constantdict({ + "_in0": rec_matrix_elem_values, + "_in1": rec_matrix_elem_col_indices, + "_in2": rec_matrix_row_starts, + "_in3": rec_array}), + axes=expr.axes, + var_to_reduction_descr=constantdict({ + "_r0": ReductionDescriptor(tags=frozenset())}), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + class ToIndexLambdaMapper(Mapper[Array, Never, []], ToIndexLambdaMixin): diff --git a/pytato/transform/materialize.py b/pytato/transform/materialize.py index 182079c10..b82ce9656 100644 --- a/pytato/transform/materialize.py +++ b/pytato/transform/materialize.py @@ -43,6 +43,7 @@ AxisPermutation, BasicIndex, Concatenate, + CSRMatmul, DataWrapper, DictOfNamedArrays, Einsum, @@ -368,6 +369,33 @@ def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: self.successors[expr], rec_args) + def map_csr_matmul(self, expr: CSRMatmul) -> MPMSMaterializerAccumulator: + rec_matrix_elem_values = self.rec(expr.matrix.elem_values) + rec_matrix_elem_col_indices = self.rec(expr.matrix.elem_col_indices) + rec_matrix_row_starts = self.rec(expr.matrix.row_starts) + if ( + rec_matrix_elem_values.expr is not expr.matrix.elem_values + or rec_matrix_elem_col_indices.expr is not expr.matrix.elem_col_indices + or rec_matrix_row_starts.expr is not expr.matrix.row_starts): + new_matrix = dataclasses.replace( + expr.matrix, + elem_values=rec_matrix_elem_values.expr, + elem_col_indices=rec_matrix_elem_col_indices.expr, + row_starts=rec_matrix_row_starts.expr) + else: + new_matrix = expr.matrix + rec_array = self.rec(expr.array) + return _materialize_if_mpms( + expr.replace_if_different( + matrix=new_matrix, + array=rec_array.expr), + self.successors[expr], + ( + rec_matrix_elem_values, + rec_matrix_elem_col_indices, + rec_matrix_row_starts, + rec_array)) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays ) -> MPMSMaterializerAccumulator: raise NotImplementedError diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index ad2c50b01..4310704ee 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -68,6 +68,7 @@ AxisPermutation, BasicIndex, Concatenate, + CSRMatmul, DictOfNamedArrays, Einsum, EinsumReductionAxis, @@ -391,6 +392,15 @@ def map_einsum(self, expr: Einsum) -> None: self.rec(arg) self.add_equations_using_index_lambda_version_of_expr(expr) + def map_csr_matmul(self, expr: CSRMatmul) -> None: + for ary in ( + expr.matrix.elem_values, + expr.matrix.elem_col_indices, + expr.matrix.row_starts, + expr.array): + self.rec(ary) + self.add_equations_using_index_lambda_version_of_expr(expr) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: for _, subexpr in sorted(expr._data.items()): self.rec(subexpr) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 5b39f83bf..72bc213c0 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -41,6 +41,7 @@ from pytato.array import ( AbstractResultWithNamedArrays, Array, + CSRMatmul, DataWrapper, DictOfNamedArrays, Einsum, @@ -296,6 +297,25 @@ def map_einsum(self, expr: Einsum) -> None: self.node_to_dot[expr] = info + def map_csr_matmul(self, expr: CSRMatmul) -> None: + info = self.get_common_dot_info(expr) + + self.rec(expr.matrix.elem_values) + info.edges["matrix.elem_values"] = expr.matrix.elem_values + + self.rec(expr.matrix.elem_col_indices) + info.edges["matrix.elem_col_indices"] = expr.matrix.elem_col_indices + + self.rec(expr.matrix.row_starts) + info.edges["matrix.row_starts"] = expr.matrix.row_starts + + self.rec(expr.array) + info.edges["array"] = expr.array + + info.fields["matrix_shape"] = stringify_shape(expr.matrix.shape) + + self.node_to_dot[expr] = info + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges: dict[str, ArrayOrNames | FunctionDefinition] = {} for name, val in expr._data.items(): diff --git a/pytato/visualization/fancy_placeholder_data_flow.py b/pytato/visualization/fancy_placeholder_data_flow.py index a4563407c..f0338b1ba 100644 --- a/pytato/visualization/fancy_placeholder_data_flow.py +++ b/pytato/visualization/fancy_placeholder_data_flow.py @@ -17,6 +17,7 @@ AdvancedIndexInNoncontiguousAxes, Array, Concatenate, + CSRMatmul, DataWrapper, DictOfNamedArrays, Einsum, @@ -40,6 +41,7 @@ EINSUM_COLOR = "crimson" STACK_CONCAT_COLOR = "deepskyblue" INDIRECTION_COLOR = "darkblue" +SPARSE_MATMUL_COLOR = "gold" # }}} @@ -51,6 +53,7 @@ EINSUM_SHAPE = "box3d" STACK_CONCAT_SHAPE = "folder" INDIRECTION_SHAPE = "hexagon" +SPARSE_MATMUL_SHAPE = "star" # }}} @@ -176,6 +179,27 @@ def map_einsum(self, expr: Einsum) -> _FancyDotWriterNode: return ret_node + def map_csr_matmul(self, expr: CSRMatmul) -> _FancyDotWriterNode: + node_id = self.vng("_pt_sparse_matmul") + node_decl = (f'{node_id} [label="",' + f" color={SPARSE_MATMUL_COLOR}," + f" shape={SPARSE_MATMUL_SHAPE}]") + + ret_node, new_edges = _get_dot_node_from_predecessors( + node_id, + [ + self.rec(expr.matrix.elem_values), + self.rec(expr.matrix.elem_col_indices), + self.rec(expr.matrix.row_starts), + self.rec(expr.array)] + ) + + if new_edges: + self.node_decls.append(node_decl) + self.edges.update(new_edges) + + return ret_node + def _map_stack_concat(self, expr: Stack | Concatenate) -> _FancyDotWriterNode: node_id = self.vng("_pt_stack_concat") diff --git a/test/test_codegen.py b/test/test_codegen.py index 4872de7bf..073b80eca 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -966,6 +966,71 @@ def _get_x_shape(_m, n_): np.testing.assert_allclose(np_out, pt_out) +@pytest.mark.parametrize("case", ["single", "multiple", "stacked"]) +def test_csr_matmul(ctx_factory: cl.CtxFactory, case, visualize=False): + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + n = 100 + h = 2/n + + np_x = np.linspace(-1, 1, n) + + # FD Laplacian operator for interior points + diags = [np.ones(n-2)/h**2, -2*np.ones(n-2)/h**2, np.ones(n-2)/h**2] + col_indices = [np.arange(n-2), np.arange(1, n-1), np.arange(2, n)] + np_A = np.zeros((n-2, n)) # noqa: N806 + np_A[np.arange(n-2), col_indices[0]] = diags[0] + np_A[np.arange(n-2), col_indices[1]] = diags[1] + np_A[np.arange(n-2), col_indices[2]] = diags[2] + + pt_A = pt.make_csr_matrix( # noqa: N806 + shape=np_A.shape, + elem_values=pt.make_data_wrapper(np.stack(diags).T.flatten()), + elem_col_indices=pt.make_data_wrapper(np.stack(col_indices).T.flatten()), + row_starts=pt.make_data_wrapper( + np.concatenate((3*np.arange(n-2), np.array([3*(n-2)]))))) + + np_u = np.sin(np.pi*np_x) + pt_u = pt.make_data_wrapper(np_u) + + np_v = -np.sin(np.pi*np_x) + pt_v = pt.make_data_wrapper(np_v) + + if case == "single": + exact_out = -np.pi**2 * np_u + np_out = np_A @ np_u + _, (pt_out,) = pt.generate_loopy(pt_A @ pt_u)(cq) + elif case == "multiple": + exact_out = 0*np_x + np_out = np_A @ np_u + np_A @ np_v + _, (pt_out,) = pt.generate_loopy(pt_A @ pt_u + pt_A @ pt_v)(cq) + elif case == "stacked": + np_w = (np.stack([np_u, np_v]).T).copy() + exact_out = np.stack([-np.pi**2 * np_u, -np.pi**2 * np_v]).T + pt_w = pt.make_data_wrapper(np_w) + np_out = np_A @ np_w + _, (pt_out,) = pt.generate_loopy(pt_A @ pt_w)(cq) + else: + raise ValueError("invalid case.") + + if visualize: + import matplotlib.pyplot as plt + ax = plt.axes() + ax.set_xlim(-1, 1) + ax.set_ylim(-12, 12) + ax.plot(np_x, np_u) + ax.plot(np_x, np_v) + ax.plot(np_x[1:n-1], pt_out) + plt.show() + + assert np_out.shape[0] == exact_out.shape[0]-2 + assert np_out.shape[1:] == exact_out.shape[1:] + assert pt_out.shape == np_out.shape + np.testing.assert_allclose(np_out, exact_out[1:n-1], rtol=1e-1) + np.testing.assert_allclose(pt_out, np_out) + + def test_arguments_passing_to_loopy_kernel_for_non_dependent_vars( ctx_factory: cl.CtxFactory): from numpy.random import default_rng diff --git a/test/test_pytato.py b/test/test_pytato.py index 8746810ae..ba4904854 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -63,6 +63,34 @@ def test_matmul_input_validation(): d @ d +def test_csr_matmul_input_validation(): + a = pt.make_csr_matrix( + shape=(8, 10), + elem_values=pt.make_placeholder(name="a_elem_values", shape=(8,)), + elem_col_indices=pt.make_placeholder(name="a_elem_col_indicesi", shape=(8,)), + row_starts=pt.make_placeholder(name="a_row_starts", shape=(9,))) + + b = pt.make_placeholder(name="b", shape=(10,)) + a @ b + + c = pt.make_placeholder(name="c", shape=(20,)) + with pytest.raises(ValueError): + a @ c + + d = pt.make_placeholder(name="d", shape=()) + with pytest.raises(ValueError): + a @ d + + n = pt.make_size_param("n") + e = pt.make_csr_matrix( + shape=(n, n), + elem_values=pt.make_placeholder(name="e_elem_values", shape=(n,)), + elem_col_indices=pt.make_placeholder(name="e_elem_col_indicesi", shape=(n,)), + row_starts=pt.make_placeholder(name="e_row_starts", shape=(n+1,))) + f = pt.make_placeholder(name="f", shape=(n,)) + e @ f + + def test_roll_input_validation(): a = pt.make_placeholder(name="a", shape=(10, 10)) pt.roll(a, 1, axis=0) From eca03d89078773f551fe40d6a75fe15b0aed9fdb Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 22 Jan 2026 15:00:40 -0600 Subject: [PATCH 2/6] add matplotlib to test conda env --- .test-conda-env-py3.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 58b05ade1..29db0c233 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -15,3 +15,4 @@ dependencies: - jax - openmpi # Force using Open MPI since our pytest infrastructure needs it - graphviz # for visualization tests +- matplotlib-base From de4a431a372d873bc049156402201e295195f687 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 23 Jan 2026 12:31:39 -0600 Subject: [PATCH 3/6] add some checks for shapes of things in make_csr_matrix, and move some other checks over from constructor --- pytato/array.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 239d88c24..586841cae 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -2234,10 +2234,6 @@ class SparseMatrix(_SuppliedAxesAndTagsMixin, _SuppliedShapeAndDtypeMixin, ABC): .. automethod:: __matmul__ """ - if __debug__: - def __post_init__(self) -> None: - pass - def __matmul__(self, other: Array) -> SparseMatmul: return sparse_matmul(self, other) @@ -2304,17 +2300,6 @@ class CSRMatrix(SparseMatrix): elem_col_indices: Array row_starts: Array - if __debug__: - @override - def __post_init__(self) -> None: - if self.elem_values.ndim != 1: - raise ValueError("elem_values must be a 1D array.") - if self.elem_col_indices.ndim != 1: - raise ValueError("elem_col_indices must be a 1D array.") - if self.row_starts.ndim != 1: - raise ValueError("row_starts must be a 1D array.") - super().__post_init__() - @array_dataclass() class CSRMatmul(SparseMatmul): @@ -2752,7 +2737,7 @@ def make_csr_matrix(shape: ConvertibleToShape, axes: AxesT | None = None) -> CSRMatrix: """Make a :class:`CSRMatrix` object. - :param shape: the shape of the matrix + :param shape: the (two-dimensional) shape of the matrix :param elem_values: a one-dimensional array containing the values of all of the nonzero entries of the matrix, grouped by row. :param elem_col_indices: a one-dimensional array containing the column index @@ -2764,6 +2749,9 @@ def make_csr_matrix(shape: ConvertibleToShape, shape = normalize_shape(shape) dtype = elem_values.dtype + if len(shape) != 2: + raise ValueError("matrix must be 2D.") + if axes is None: axes = _get_default_axes(len(shape)) @@ -2771,6 +2759,18 @@ def make_csr_matrix(shape: ConvertibleToShape, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") + if elem_values.ndim != 1: + raise ValueError("'elem_values' must be 1D.") + if elem_col_indices.ndim != 1: + raise ValueError("'elem_col_indices' must be 1D.") + if row_starts.ndim != 1: + raise ValueError("'row_starts' must be 1D.") + + from pytato.utils import are_shapes_equal + if not are_shapes_equal(row_starts.shape, (shape[0] + 1,)): + raise ValueError( + "'row_starts' must have length equal to the number of rows plus one.") + return CSRMatrix( shape=shape, elem_values=elem_values, From e7177fa9d1054f2660e01aef02744dc813e44738 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 26 Jan 2026 13:06:56 -0600 Subject: [PATCH 4/6] minor wording change in comment --- pytato/target/loopy/codegen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 164b87930..9cd0eab03 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -518,8 +518,8 @@ def map_index_lambda(self, expr: IndexLambda, # now; attempting to generalize to unmaterialized expressions would require # handling at least two complications: # 1) final inames aren't assigned until the expression is stored, so any - # temporary variables defined below would need to be finalized at that - # point, not here + # temporary variables defined below would need to be finalized at the + # point of storage, not here # 2) lp.make_reduction_inames_unique does not rename the temporaries # created below, so something would need to be done to make them unique # across all index lambda evaluations. From 922e733db662e149a7c59d4999b4a8b5feb14f89 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 9 Feb 2026 16:10:43 -0600 Subject: [PATCH 5/6] accept tuple[ToTagSetConvertible, ...] for axes in make_csr_matrix --- pytato/array.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 586841cae..0c7262ddb 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -236,7 +236,7 @@ from pymbolic import ArithmeticExpression, var from pymbolic.typing import Integer, Scalar, not_none from pytools import memoize_method, opt_frozen_dataclass -from pytools.tag import Tag, Taggable, ToTagSetConvertible +from pytools.tag import Tag, Taggable, ToTagSetConvertible, normalize_tags from pytato.scalar_expr import ( INT_CLASSES, @@ -2648,6 +2648,7 @@ def make_placeholder(name: str, shape: ConvertibleToShape, dtype: Any = np.float64, tags: frozenset[Tag] = frozenset(), + # FIXME: Accept tuple[ToTagSetConvertible, ...]? axes: AxesT | None = None) -> Placeholder: """Make a :class:`Placeholder` object. @@ -2695,6 +2696,7 @@ def make_data_wrapper(data: DataInterface, name: str | None = None, shape: ConvertibleToShape | None = None, tags: frozenset[Tag] = frozenset(), + # FIXME: Accept tuple[ToTagSetConvertible, ...]? axes: AxesT | None = None) -> DataWrapper: """Make a :class:`DataWrapper`. @@ -2734,7 +2736,8 @@ def make_csr_matrix(shape: ConvertibleToShape, elem_col_indices: Array, row_starts: Array, tags: frozenset[Tag] = frozenset(), - axes: AxesT | None = None) -> CSRMatrix: + axes: AxesT | tuple[ToTagSetConvertible, ...] | None = None + ) -> CSRMatrix: """Make a :class:`CSRMatrix` object. :param shape: the (two-dimensional) shape of the matrix @@ -2752,7 +2755,11 @@ def make_csr_matrix(shape: ConvertibleToShape, if len(shape) != 2: raise ValueError("matrix must be 2D.") - if axes is None: + if axes is not None: + axes = tuple( + Axis(normalize_tags(axis)) if not isinstance(axis, Axis) else axis + for axis in axes) + else: axes = _get_default_axes(len(shape)) if len(axes) != len(shape): From 87f5820fce64d0ebefdc91330b168f1405268fc5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 13 Feb 2026 16:07:10 -0600 Subject: [PATCH 6/6] check for non-affineness instead of subscripts --- pytato/scalar_expr.py | 17 +++++++++ pytato/target/loopy/codegen.py | 70 +++++++++++++++------------------- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index e59cbdce8..f99c2280c 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -42,6 +42,7 @@ THE SOFTWARE. """ + import re from collections.abc import Iterable, Mapping, Set as AbstractSet from typing import ( @@ -55,6 +56,8 @@ from typing_extensions import Never, TypeIs, override import pymbolic.primitives as prim +from loopy.diagnostic import ExpressionToAffineConversionError +from loopy.symbolic import guarded_pwaff_from_expr from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass from pymbolic.mapper import ( CombineMapper as CombineMapperBase, @@ -368,4 +371,18 @@ def get_reduction_induction_variables(expr: Expression) -> AbstractSet[str]: """ return InductionVariableCollector()(expr) + +def is_quasi_affine(expr: Expression) -> bool: + import islpy as isl + space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + set=list(get_dependencies(expr)), + ) + try: + guarded_pwaff_from_expr(space, expr) + except ExpressionToAffineConversionError: + return False + return True + + # vim: foldmethod=marker diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 9cd0eab03..5eea5fb52 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -68,6 +68,7 @@ INT_CLASSES, ScalarExpression, TypeCast, + is_quasi_affine, ) from pytato.tags import ( ForceValueArgTag, @@ -505,28 +506,36 @@ def map_index_lambda(self, expr: IndexLambda, loopy_shape = shape_to_scalar_expression(expr.shape, self, state) - # If the scalar expression contains any reductions with bounds expressions - # that index into a binding, need to store the results of those expressions - # as scalar temporaries - subscript_detector = SubscriptDetector() - redn_bounds = { var_name: redn.bounds[var_name] for var_name, redn in var_to_reduction.items()} - # FIXME: Forcing storage of expressions containing processed reductions for - # now; attempting to generalize to unmaterialized expressions would require - # handling at least two complications: - # 1) final inames aren't assigned until the expression is stored, so any - # temporary variables defined below would need to be finalized at the - # point of storage, not here + loopy_redn_bounds: Mapping[str, tuple[Expression, Expression]] = { + var_name: cast( + "tuple[Expression, Expression]", + tuple( + self.exprgen_mapper(bound, prstnt_ctx, local_ctx) + for bound in bounds)) + for var_name, bounds in redn_bounds.items()} + + # If the scalar expression contains any reductions with bounds + # expressions that are non-affine, we need to store the results of those + # expressions as scalar temporaries. + # FIXME: For now, forcing storage of expressions containing such + # processed reductions; attempting to generalize to unmaterialized + # expressions would require handling at least two complications: + # 1) final inames aren't assigned until the expression is stored, so + # any temporary variables defined below would need to be finalized + # at the point of storage, not here # 2) lp.make_reduction_inames_unique does not rename the temporaries - # created below, so something would need to be done to make them unique - # across all index lambda evaluations. - store_result = expr.tags_of_type(ImplStored) or any( - subscript_detector(bound) - for bounds in redn_bounds.values() - for bound in bounds) + # created below, so something would need to be done to make them + # unique across all index lambda evaluations. + store_result = ( + expr.tags_of_type(ImplStored) + or any( + not is_quasi_affine(bound) + for bounds in loopy_redn_bounds.values() + for bound in bounds)) name: str | None = None inames: tuple[str, ...] | None = None @@ -550,13 +559,13 @@ def map_index_lambda(self, expr: IndexLambda, str, tuple[ArithmeticExpression, ArithmeticExpression]] = {} bound_prefixes = ("l", "u") for var_name, bounds in redn_bounds.items(): + loopy_bounds = loopy_redn_bounds[var_name] new_bounds_list: list[ArithmeticExpression] = [] - for bound_prefix, bound in zip(bound_prefixes, bounds, strict=True): - if subscript_detector(bound): + for bound_prefix, bound, loopy_bound in zip( + bound_prefixes, bounds, loopy_bounds, strict=True): + if not is_quasi_affine(loopy_bound): unique_name = var_to_reduction_unique_name[var_name] bound_name = f"{unique_name}_{bound_prefix}bound" - loopy_bound = self.exprgen_mapper( - bound, prstnt_ctx, local_ctx) bound_result: ImplementedResult = InlinedResult( loopy_bound, expr.ndim, prstnt_ctx.depends_on) bound_result = StoredResult( @@ -793,25 +802,6 @@ def map_call(self, expr: Call, state: CodeGenState) -> None: } -class SubscriptDetector(scalar_expr.CombineMapper[bool, []]): - """Returns *True* if a scalar expression contains any subscripts.""" - @override - def combine(self, values: Iterable[bool]) -> bool: - return any(values) - - @override - def map_algebraic_leaf(self, expr: prim.AlgebraicLeaf) -> bool: - return False - - @override - def map_subscript(self, expr: prim.Subscript) -> bool: - return True - - @override - def map_constant(self, expr: object) -> bool: - return False - - class ReductionCollector(scalar_expr.CombineMapper[frozenset[scalar_expr.Reduce], []]): """ Constructs a :class:`frozenset` containing all instances of