diff --git a/src/pydsl/memref.py b/src/pydsl/memref.py index f264a51..f5c2594 100644 --- a/src/pydsl/memref.py +++ b/src/pydsl/memref.py @@ -6,7 +6,7 @@ from ctypes import POINTER, c_void_p from dataclasses import dataclass from enum import Enum -from functools import cache +from functools import cache, reduce from typing import TYPE_CHECKING, Final import mlir.ir as mlir @@ -1007,6 +1007,45 @@ def subtree_to_slices( raise TypeError(f"{type(key)} cannot be used as a subscript") +def calc_shape(memref_shape: tuple, assoc: list[list[int]]): + # We need to make sure that assoc is valid + # grouping of the dimensions 0 to n-1. + # examples, [[0,1], [2,3], [4]] + # notice the order of the elements must be correct + flattened = [e for a in assoc for e in a] + assert flattened == list(range(len(flattened))) + assert len(flattened) == len(memref_shape) + + output = [] + + for group in assoc: + res = 1 + for i in group: + dim = memref_shape[i] + if (dim == DYNAMIC or res == DYNAMIC): + res = DYNAMIC + else: + res *= dim + output.append(res) + + return tuple(output) + + +@CallMacro.generate() +def collapse_shape( + visitor: ToMLIRBase, + mem: Compiled, + assoc: Evaluated +): + shpe = calc_shape(mem.shape, assoc) + result_type = MemRef[mem.element_type, *shpe] + return result_type( + memref.CollapseShapeOp( + lower_single(result_type), + lower_single(mem), + assoc + ) + ) def split_static_dynamic_dims( shape: Iterable[Number | SupportsIndex], ) -> tuple[list[int], list[Index]]: diff --git a/tests/e2e/test_memref.py b/tests/e2e/test_memref.py index 8b0b94f..7a1f136 100644 --- a/tests/e2e/test_memref.py +++ b/tests/e2e/test_memref.py @@ -9,7 +9,7 @@ from pydsl.gpu import GPU_AddrSpace import pydsl.linalg as linalg import pydsl.memref as memref -from pydsl.memref import alloc, alloca, dealloc, DYNAMIC, MemRef, MemRefFactory +from pydsl.memref import alloc, alloca, dealloc, DYNAMIC, MemRef, MemRefFactory, collapse_shape from pydsl.type import Bool, F32, F64, Index, SInt16, Tuple, UInt32 from helper import compilation_failed_from, failed_from, multi_arange, run @@ -480,6 +480,14 @@ def f(m1: MemRef[UInt32]) -> Tuple[UInt32, MemRef[UInt32]]: assert res2.shape == () +def test_collapse_shape(): + @compile() + def my_func(a: MemRef[F32, 1, 3]) -> MemRef[F32, 3]: + return collapse_shape(a, [[0, 1]]) + + n1 = np.array([[1.0, 2.0, 3.0]], dtype=np.float32) + assert all([a == b for a, b in zip(my_func(n1), [1.0, 2.0, 3.0])]) + def test_cast_basic(): @compile() def f( @@ -612,6 +620,7 @@ def f(m1: MemRef[F32, 5], m2: MemRef[F64, 5]): run(test_link_ndarray) run(test_chain_link_ndarray) run(test_zero_d) + run(test_collapse_shape) run(test_cast_basic) run(test_cast_strided) run(test_cast_strided)