From 56d302c9e5b5948f0eac47893551f4a29be5b367 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 29 Aug 2025 14:04:23 -0500 Subject: [PATCH 1/2] Add slice/int overloads for ObjectArray2D.__getitem__ --- pytools/obj_array.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/pytools/obj_array.py b/pytools/obj_array.py index c8b48e5f..8f15ad4a 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -162,19 +162,29 @@ def __getitem__(self: ObjectArrayND[T], x: int, /) -> ObjectArrayND[T] | T: ... @overload def __getitem__( - self: ObjectArrayND[T], - x: tuple[int | slice, ...], - /) -> ObjectArrayND[T] | T: ... + self: ObjectArray1D[T], + x: slice, /) -> ObjectArray1D[T]: ... @overload def __getitem__( - self: ObjectArray1D[T], - x: slice) -> ObjectArray1D[T]: ... + self: ObjectArray2D[T], + x: slice, /) -> ObjectArray2D[T]: ... + + @overload + def __getitem__( + self: ObjectArray2D[T], + x: tuple[slice, int], /) -> ObjectArray1D[T]: ... @overload def __getitem__( self: ObjectArray2D[T], - x: slice) -> ObjectArray2D[T]: ... + x: tuple[int, slice], /) -> ObjectArray1D[T]: ... + + @overload + def __getitem__( + self: ObjectArrayND[T], + x: tuple[int | slice, ...], + /) -> ObjectArrayND[T] | T: ... @overload def __iter__(self: ObjectArray1D[T]) -> Iterator[T]: ... From 4ba8438173d1da14d46fd531bced277f6d18231b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 29 Aug 2025 14:04:45 -0500 Subject: [PATCH 2/2] Add obj_array.sum --- pytools/obj_array.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/pytools/obj_array.py b/pytools/obj_array.py index 8f15ad4a..29069647 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -416,6 +416,38 @@ def trace( return cast("T", np.trace(cast("NDArray[Any]", cast("object", array)))) +@overload +def sum( + array: ObjectArrayND[T], + axis: None, + ) -> T: ... + + +@overload +def sum( + array: ObjectArray1D[T], + axis: int, + ) -> T: ... + + +@overload +def sum( + array: ObjectArray2D[T], + axis: int, + ) -> ObjectArray1D[T]: ... + + +def sum( + array: ObjectArrayND[T], + axis: int | None, + ) -> ObjectArrayND[T] | T: + import numpy as np + return cast("ObjectArrayND[T] | T", np.sum( + cast("NDArray[Any]", cast("object", array)), + axis=axis, + )) + + def to_hashable(ary: ObjectArray[ShapeT, T] | Hashable, /) -> Hashable: if isinstance(ary, ObjectArray): ary = cast("ObjectArray[ShapeT, T]", ary)