From b092d6eaad43effcd39351ee9a449080399ee543 Mon Sep 17 00:00:00 2001 From: Kavin Satheeskumar Date: Thu, 28 Aug 2025 16:47:00 -0400 Subject: [PATCH] add support for collapse_shape op on memrefs --- src/pydsl/memref.py | 46 ++++++++++++++++++++++++++++++++++++++-- tests/e2e/test_memref.py | 12 ++++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/pydsl/memref.py b/src/pydsl/memref.py index 45f92bb..8d07bfb 100644 --- a/src/pydsl/memref.py +++ b/src/pydsl/memref.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable from ctypes import POINTER, c_void_p from dataclasses import dataclass -from functools import cache +from functools import cache, reduce from typing import TYPE_CHECKING, Final import mlir.ir as mlir @@ -467,7 +467,8 @@ def __init__(self, rep: OpView | Value) -> None: ]): raise TypeError( f"expected shape {'x'.join([str(sh) for sh in self.shape])}" - f"x{lower_single(self.element_type)}, got representation with shape " + f"x{lower_single(self.element_type) + }, got representation with shape " f"{'x'.join([str(sh) for sh in rep.type.shape])}" f"x{rep.type.element_type}" ) @@ -841,3 +842,44 @@ def subtree_to_slices( return [key] case _: raise TypeError(f"{type(key)} cannot be used as a subscript") + + +def calc_shape(memref_shape: tuple, assoc: list[list[int]]): + # We need to make sure that assoc is valid + # grouping of the dimensions 0 to n-1. + # examples, [[0,1], [2,3], [4]] + # notice the order of the elements must be correct + flattened = [e for a in assoc for e in a] + assert flattened == list(range(len(flattened))) + assert len(flattened) == len(memref_shape) + + output = [] + + for group in assoc: + res = 1 + for i in group: + dim = memref_shape[i] + if (dim == DYNAMIC or res == DYNAMIC): + res = DYNAMIC + else: + res *= dim + output.append(res) + + return tuple(output) + + +@CallMacro.generate() +def collapse_shape( + visitor: ToMLIRBase, + mem: Compiled, + assoc: Evaluated +): + shpe = calc_shape(mem.shape, assoc) + result_type = MemRef[mem.element_type, *shpe] + return result_type( + memref.CollapseShapeOp( + lower_single(result_type), + lower_single(mem), + assoc + ) + ) diff --git a/tests/e2e/test_memref.py b/tests/e2e/test_memref.py index f11a90b..f7d2cc4 100644 --- a/tests/e2e/test_memref.py +++ b/tests/e2e/test_memref.py @@ -7,7 +7,7 @@ from pydsl.affine import affine_range as arange from pydsl.frontend import compile import pydsl.linalg as linalg -from pydsl.memref import alloc, alloca, DYNAMIC, Dynamic, MemRef, MemRefFactory +from pydsl.memref import alloc, alloca, DYNAMIC, Dynamic, MemRef, MemRefFactory, collapse_shape from pydsl.type import Bool, F32, F64, Index, SInt16, Tuple, UInt32 from helper import compilation_failed_from, failed_from, multi_arange, run @@ -449,6 +449,15 @@ def f(t1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]: assert res2.shape == () +def test_collapse_shape(): + @compile() + def my_func(a: MemRef[F32, 1, 3]) -> MemRef[F32, 3]: + return collapse_shape(a, [[0, 1]]) + + n1 = np.array([[1.0, 2.0, 3.0]], dtype=np.float32) + assert all([a == b for a, b in zip(my_func(n1), [1.0, 2.0, 3.0])]) + + if __name__ == "__main__": run(test_load_implicit_index_uint32) run(test_load_implicit_index_f64) @@ -472,3 +481,4 @@ def f(t1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]: run(test_link_ndarray) run(test_chain_link_ndarray) run(test_zero_d) + run(test_collapse_shape)