diff --git a/pytools/obj_array.py b/pytools/obj_array.py index c8b48e5f..29069647 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -162,19 +162,29 @@ def __getitem__(self: ObjectArrayND[T], x: int, /) -> ObjectArrayND[T] | T: ... @overload def __getitem__( - self: ObjectArrayND[T], - x: tuple[int | slice, ...], - /) -> ObjectArrayND[T] | T: ... + self: ObjectArray1D[T], + x: slice, /) -> ObjectArray1D[T]: ... @overload def __getitem__( - self: ObjectArray1D[T], - x: slice) -> ObjectArray1D[T]: ... + self: ObjectArray2D[T], + x: slice, /) -> ObjectArray2D[T]: ... + + @overload + def __getitem__( + self: ObjectArray2D[T], + x: tuple[slice, int], /) -> ObjectArray1D[T]: ... @overload def __getitem__( self: ObjectArray2D[T], - x: slice) -> ObjectArray2D[T]: ... + x: tuple[int, slice], /) -> ObjectArray1D[T]: ... + + @overload + def __getitem__( + self: ObjectArrayND[T], + x: tuple[int | slice, ...], + /) -> ObjectArrayND[T] | T: ... @overload def __iter__(self: ObjectArray1D[T]) -> Iterator[T]: ... @@ -406,6 +416,38 @@ def trace( return cast("T", np.trace(cast("NDArray[Any]", cast("object", array)))) +@overload +def sum( + array: ObjectArrayND[T], + axis: None, + ) -> T: ... + + +@overload +def sum( + array: ObjectArray1D[T], + axis: int, + ) -> T: ... + + +@overload +def sum( + array: ObjectArray2D[T], + axis: int, + ) -> ObjectArray1D[T]: ... + + +def sum( + array: ObjectArrayND[T], + axis: int | None, + ) -> ObjectArrayND[T] | T: + import numpy as np + return cast("ObjectArrayND[T] | T", np.sum( + cast("NDArray[Any]", cast("object", array)), + axis=axis, + )) + + def to_hashable(ary: ObjectArray[ShapeT, T] | Hashable, /) -> Hashable: if isinstance(ary, ObjectArray): ary = cast("ObjectArray[ShapeT, T]", ary)