diff --git a/pytools/obj_array.py b/pytools/obj_array.py index 29069647..c81e1185 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -273,7 +273,11 @@ def __matmul__( self: ObjectArray2D[T], other: ObjectArray2D[T], /) -> ObjectArray2D[T]: ... - + @overload + def __matmul__( + self: ObjectArray2D[T], + other: np.ndarray[tuple[int, int], np.dtype[Any]], + /) -> ObjectArray2D[T]: ... @overload def __matmul__( self: ObjectArrayND[T], @@ -281,7 +285,9 @@ def __matmul__( /) -> ObjectArrayND[T] | T: ... @property - def flat(self) -> ObjectArray1D[T]: ... + def flat(self) -> Iterator[T]: ... + + def flatten(self) -> ObjectArray1D[T]: ... @overload def tolist(self: ObjectArray0D[T]) -> T: ... @@ -439,7 +445,7 @@ def sum( def sum( array: ObjectArrayND[T], - axis: int | None, + axis: int | None = None, ) -> ObjectArrayND[T] | T: import numpy as np return cast("ObjectArrayND[T] | T", np.sum( @@ -451,7 +457,7 @@ def sum( def to_hashable(ary: ObjectArray[ShapeT, T] | Hashable, /) -> Hashable: if isinstance(ary, ObjectArray): ary = cast("ObjectArray[ShapeT, T]", ary) - return tuple(ary.flat.tolist()) + return tuple(ary.flat) return ary