From c31d2f1c4087f467904634397933af8bf6ba049b Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 31 Aug 2025 15:03:27 +0300 Subject: [PATCH 1/3] feat: add ObjectArray.__matmul__ with ndarray --- pytools/obj_array.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytools/obj_array.py b/pytools/obj_array.py index 29069647..85d9bcbd 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -273,7 +273,11 @@ def __matmul__( self: ObjectArray2D[T], other: ObjectArray2D[T], /) -> ObjectArray2D[T]: ... - + @overload + def __matmul__( + self: ObjectArray2D[T], + other: np.ndarray[tuple[int, int], np.dtype[Any]], + /) -> ObjectArray2D[T]: ... @overload def __matmul__( self: ObjectArrayND[T], From 4bcba4b680e86a208624b16d176f4fb60bde27b9 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 31 Aug 2025 15:03:41 +0300 Subject: [PATCH 2/3] feat: add default axis for obj_array.sum --- pytools/obj_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytools/obj_array.py b/pytools/obj_array.py index 85d9bcbd..433109d9 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -443,7 +443,7 @@ def sum( def sum( array: ObjectArrayND[T], - axis: int | None, + axis: int | None = None, ) -> ObjectArrayND[T] | T: import numpy as np return cast("ObjectArrayND[T] | T", np.sum( From a08639e98ce97f14301c325e3a5dd09210ae0944 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 31 Aug 2025 15:40:57 +0300 Subject: [PATCH 3/3] feat: add ObjectArray.flatten --- pytools/obj_array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytools/obj_array.py b/pytools/obj_array.py index 433109d9..c81e1185 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -285,7 +285,9 @@ def __matmul__( /) -> ObjectArrayND[T] | T: ... @property - def flat(self) -> ObjectArray1D[T]: ... + def flat(self) -> Iterator[T]: ... + + def flatten(self) -> ObjectArray1D[T]: ... @overload def tolist(self: ObjectArray0D[T]) -> T: ... @@ -455,7 +457,7 @@ def sum( def to_hashable(ary: ObjectArray[ShapeT, T] | Hashable, /) -> Hashable: if isinstance(ary, ObjectArray): ary = cast("ObjectArray[ShapeT, T]", ary) - return tuple(ary.flat.tolist()) + return tuple(ary.flat) return ary