Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,288 changes: 96 additions & 1,192 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions doc/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ Cross-References for Other Documentation
The type of the Python-builtin :data:`Ellipsis` object. (not otherwise
documented)

.. currentmodule:: prim

.. class:: NaN

See :class:`pymbolic.primitives.NaN`.

.. currentmodule:: loopy.kernel

.. class:: LoopKernel
Expand Down
89 changes: 70 additions & 19 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,15 @@
.. autoclass:: DictOfNamedArrays
.. autoclass:: AbstractResultWithNamedArrays

.. currentmodule:: pytato.array

.. autoclass:: ArrayOrScalarT

NumPy-Like Interface
--------------------

.. currentmodule:: pytato

These functions generally follow the interface of the corresponding functions in
:mod:`numpy`, but not all NumPy features may be supported.

Expand Down Expand Up @@ -216,25 +222,34 @@
INT_CLASSES,
SCALAR_CLASSES,
ScalarExpression,
TypeCast,
get_reduction_induction_variables,
)


if TYPE_CHECKING:
from numpy.typing import DTypeLike
_dtype_any = np.dtype[Any]
else:
_dtype_any = np.dtype

# {{{ typing helpers

AxesT = tuple["Axis", ...]
ArrayT = TypeVar("ArrayT", bound="Array")

ArrayOrScalar: TypeAlias = "Array | Scalar"

# {{{ shape
ArrayOrScalarT = TypeVar("ArrayOrScalarT", "Array", Scalar, ArrayOrScalar)

ShapeComponent = Union[Integer, "Array"]
ShapeType = tuple[ShapeComponent, ...]
ConvertibleToShape = ShapeComponent | Sequence[ShapeComponent]
ShapeComponent: TypeAlias = "Integer | Array"
ShapeType: TypeAlias = tuple[ShapeComponent, ...]
ConvertibleToShape: TypeAlias = "ShapeComponent | Sequence[ShapeComponent]"

# }}}


# {{{ shape

def _check_identifier(s: str | None, optional: bool) -> bool:
if s is None:
Expand Down Expand Up @@ -898,7 +913,7 @@ def __xor__(self, other: ArrayOrScalar) -> Array:
def __rxor__(self, other: ArrayOrScalar) -> Array:
return self._binary_op(operator.xor, other, reverse=True)

def conj(self) -> ArrayOrScalar:
def conj(self) -> Array:
import pytato as pt
return pt.conj(self)

Expand All @@ -913,15 +928,32 @@ def __bool__(self) -> None:
raise ValueError("The truth value of an array expression is undefined.")

@property
def real(self) -> ArrayOrScalar:
def real(self) -> Array:
import pytato as pt
return pt.real(self)

@property
def imag(self) -> ArrayOrScalar:
def imag(self) -> Array:
import pytato as pt
return pt.imag(self)

def astype(self, dtype: DTypeLike) -> Array:
dtype = np.dtype(dtype)
if self.dtype.kind in ["f", "c"] and dtype.kind in ["i", "u"]:
raise NotImplementedError("numpy-like overflow behavior in float-to-int")
if self.dtype.kind == "c" and dtype.kind in ["i", "u", "f"]:
raise NotImplementedError("complex-to-real casts fail in loopy")

from pymbolic import var
return make_index_lambda(
TypeCast(dtype, var("in_0")[
tuple(var(f"_{i}") for i in range(self.ndim))
]),
bindings={"in_0": self},
shape=self.shape,
dtype=dtype,
)

def reshape(self, *shape: int | Sequence[int], order: str = "C") -> Array:
import pytato as pt
if len(shape) == 0:
Expand All @@ -934,14 +966,18 @@ def reshape(self, *shape: int | Sequence[int], order: str = "C") -> Array:
# expected "Union[int, Sequence[int]]"
return pt.reshape(self, shape, order=order) # type: ignore[arg-type]

def all(self, axis: int = 0) -> ArrayOrScalar:
def all(self,
axis: int | tuple[int, ...] | None = None,
) -> ArrayOrScalar:
"""
Equivalent to :func:`pytato.all`.
"""
import pytato as pt
return pt.all(self, axis)

def any(self, axis: int = 0) -> ArrayOrScalar:
def any(self,
axis: int | tuple[int, ...] | None = None,
) -> ArrayOrScalar:
"""
Equivalent to :func:`pytato.any`.
"""
Expand All @@ -965,9 +1001,6 @@ def __repr__(self) -> str:
from pytato.stringifier import Reprifier
return Reprifier()(self)


ArrayOrScalar: TypeAlias = Array | Scalar

# }}}


Expand Down Expand Up @@ -2760,7 +2793,17 @@ def greater_equal(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:

# {{{ logical operations

def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
@overload
def logical_or(x1: Scalar, x2: Scalar, /) -> bool: ...

@overload
def logical_or(x1: ArrayOrScalar, x2: Array, /) -> Array: ...

@overload
def logical_or(x1: Array, x2: ArrayOrScalar, /) -> Array: ...


def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar, /) -> Array | bool:
"""
Returns the element-wise logical OR of *x1* and *x2*.
"""
Expand All @@ -2778,6 +2821,16 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
) # type: ignore[return-value]


@overload
def logical_and(x1: Scalar, x2: Scalar, /) -> bool: ...

@overload
def logical_and(x1: ArrayOrScalar, x2: Array, /) -> Array: ...

@overload
def logical_and(x1: Array, x2: ArrayOrScalar, /) -> Array: ...


def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
"""
Returns the element-wise logical AND of *x1* and *x2*.
Expand Down Expand Up @@ -2890,8 +2943,7 @@ def maximum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar:
or np.issubdtype(common_dtype, np.complexfloating)):
from pytato.cmath import isnan
return where(logical_or(isnan(x1), isnan(x2)),
# I don't know why pylint thinks common_dtype is a tuple.
common_dtype.type(np.nan), # pylint: disable=no-member
common_dtype.type(np.nan),
where(greater(x1, x2), x1, x2))
else:
return where(greater(x1, x2), x1, x2)
Expand All @@ -2909,8 +2961,7 @@ def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar:
or np.issubdtype(common_dtype, np.complexfloating)):
from pytato.cmath import isnan
return where(logical_or(isnan(x1), isnan(x2)),
# I don't know why pylint thinks common_dtype is a tuple.
common_dtype.type(np.nan), # pylint: disable=no-member
common_dtype.type(np.nan),
where(less(x1, x2), x1, x2))
else:
return where(less(x1, x2), x1, x2)
Expand All @@ -2924,7 +2975,7 @@ def make_index_lambda(
expression: str | ScalarExpression,
bindings: Mapping[str, Array],
shape: ShapeType,
dtype: Any,
dtype: DTypeLike,
var_to_reduction_descr: Mapping[str, ReductionDescriptor] | None = None
) -> IndexLambda:
if isinstance(expression, str):
Expand Down Expand Up @@ -2966,7 +3017,7 @@ def make_index_lambda(
return IndexLambda(expr=expression,
bindings=immutabledict(bindings),
shape=shape,
dtype=dtype,
dtype=np.dtype(dtype),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(len(shape)),
Expand Down
55 changes: 28 additions & 27 deletions pytato/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@
from immutabledict import immutabledict

import pymbolic.primitives as prim
from pymbolic import Scalar, var
from pymbolic import var

from pytato.array import (
Array,
ArrayOrScalar,
ArrayOrScalarT,
IndexLambda,
_dtype_any,
_get_created_at_tag,
Expand All @@ -81,17 +82,17 @@
from pymbolic.typing import Expression


def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],
def _apply_elem_wise_func(inputs: tuple[ArrayOrScalarT, ...],
func_name: str,
ret_dtype: _dtype_any | None = None,
np_func_name: str | None = None
) -> ArrayOrScalar:
) -> ArrayOrScalarT:
if all(isinstance(x, SCALAR_CLASSES) for x in inputs):
if np_func_name is None:
np_func_name = func_name

np_func = getattr(np, np_func_name)
return cast("ArrayOrScalar", np_func(*inputs))
return cast("ArrayOrScalarT", np_func(*inputs))

if not inputs:
raise ValueError("at least one argument must be present")
Expand Down Expand Up @@ -126,15 +127,15 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],
assert shape is not None
assert ret_dtype is not None

return IndexLambda(
return cast("ArrayOrScalarT", IndexLambda(
expr=prim.Call(var(f"pytato.c99.{func_name}"),
tuple(sym_args)),
shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(stacklevel=2),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=immutabledict(),
)
))


def _get_dtype(x: ArrayOrScalar) -> _dtype_any:
Expand All @@ -147,7 +148,7 @@ def _get_dtype(x: ArrayOrScalar) -> _dtype_any:

# FIXME: Overload these instead of returning union type?

def abs(x: ArrayOrScalar) -> ArrayOrScalar:
def abs(x: ArrayOrScalarT) -> ArrayOrScalarT:
x_dtype = _get_dtype(x)
if x_dtype.kind == "c":
result_dtype = np.empty(0, dtype=x_dtype).real.dtype
Expand All @@ -157,73 +158,73 @@ def abs(x: ArrayOrScalar) -> ArrayOrScalar:
return _apply_elem_wise_func((x,), "abs", ret_dtype=result_dtype)


def sqrt(x: ArrayOrScalar) -> ArrayOrScalar:
def sqrt(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "sqrt")


def sin(x: ArrayOrScalar) -> ArrayOrScalar:
def sin(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "sin")


def cos(x: ArrayOrScalar) -> ArrayOrScalar:
def cos(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "cos")


def tan(x: ArrayOrScalar) -> ArrayOrScalar:
def tan(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "tan")


def arcsin(x: ArrayOrScalar) -> ArrayOrScalar:
def arcsin(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "asin", np_func_name="arcsin")


def arccos(x: ArrayOrScalar) -> ArrayOrScalar:
def arccos(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "acos", np_func_name="arccos")


def arctan(x: ArrayOrScalar) -> ArrayOrScalar:
def arctan(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "atan", np_func_name="arctan")


def conj(x: ArrayOrScalar) -> ArrayOrScalar:
def conj(x: ArrayOrScalarT) -> ArrayOrScalarT:
if _get_dtype(x).kind != "c":
return x
return _apply_elem_wise_func((x,), "conj")


def arctan2(y: ArrayOrScalar, x: ArrayOrScalar) -> ArrayOrScalar:
def arctan2(y: ArrayOrScalarT, x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((y, x), "atan2", np_func_name="arctan2")


def sinh(x: ArrayOrScalar) -> ArrayOrScalar:
def sinh(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "sinh")


def cosh(x: ArrayOrScalar) -> ArrayOrScalar:
def cosh(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "cosh")


def tanh(x: ArrayOrScalar) -> ArrayOrScalar:
def tanh(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "tanh")


def exp(x: ArrayOrScalar) -> ArrayOrScalar:
def exp(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "exp")


def log(x: ArrayOrScalar) -> ArrayOrScalar:
def log(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "log")


def log10(x: ArrayOrScalar) -> ArrayOrScalar:
def log10(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "log10")


def isnan(x: ArrayOrScalar) -> ArrayOrScalar:
def isnan(x: ArrayOrScalarT) -> ArrayOrScalarT:
return _apply_elem_wise_func((x,), "isnan", np.dtype(np.int32))


def real(x: ArrayOrScalar) -> ArrayOrScalar:
def real(x: ArrayOrScalarT) -> ArrayOrScalarT:
x_dtype = _get_dtype(x)
if x_dtype.kind == "c":
result_dtype = np.empty(0, dtype=x_dtype).real.dtype
Expand All @@ -232,17 +233,17 @@ def real(x: ArrayOrScalar) -> ArrayOrScalar:
return _apply_elem_wise_func((x,), "real", ret_dtype=result_dtype)


def imag(x: ArrayOrScalar) -> ArrayOrScalar:
def imag(x: ArrayOrScalarT) -> ArrayOrScalarT:
x_dtype = _get_dtype(x)
if x_dtype.kind == "c":
result_dtype = np.empty(0, dtype=x_dtype).real.dtype
else:
if np.isscalar(x):
return cast("Scalar", x_dtype.type(0))
return cast("ArrayOrScalarT", x_dtype.type(0))
else:
assert isinstance(x, Array)
import pytato as pt
return pt.zeros(x.shape, dtype=x_dtype)
return cast("ArrayOrScalarT", pt.zeros(x.shape, dtype=x_dtype))
return _apply_elem_wise_func((x,), "imag", ret_dtype=result_dtype)

# vim: fdm=marker
Loading
Loading