Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32,890 changes: 11,277 additions & 21,613 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

19 changes: 10 additions & 9 deletions loopy/expression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from pymbolic import ArithmeticExpression


__copyright__ = "Copyright (C) 2012-15 Andreas Kloeckner"

Expand All @@ -24,11 +22,12 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Literal, TypeAlias, cast

import numpy as np

import pymbolic.primitives as p
from pymbolic import ArithmeticExpression
from pymbolic.mapper import Mapper

from loopy.codegen import UnvectorizableError
Expand All @@ -41,13 +40,15 @@
from loopy.symbolic import LinearSubscript, Reduction


# type_context may be:
# - "i" for integer -
# - "f" for single-precision floating point
# - "d" for double-precision floating point
# or None for 'no known context'.
TypeContext: TypeAlias = Literal[
"f", # single-precision floating point
"d", # double-precision floating point
"i", # integer
"b", # boolean
] | None # "no known context"


def dtype_to_type_context(target, dtype) -> str | None:
def dtype_to_type_context(target, dtype) -> TypeContext:
from loopy.types import NumpyType

if dtype.is_integral():
Expand Down
16 changes: 16 additions & 0 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,22 @@ def tup_to_exprs(tup):

return tup_to_exprs(grid_size), tup_to_exprs(group_size)

@overload
def get_grid_size_upper_bounds(self,
callables_table: CallablesTable,
*,
ignore_auto: bool = ...,
return_dict: Literal[False] = ...
) -> tuple[tuple[isl.PwAff, ...], tuple[isl.PwAff, ...]]: ...

@overload
def get_grid_size_upper_bounds(self,
callables_table: CallablesTable,
*,
ignore_auto: bool = ...,
return_dict: Literal[True]
) -> tuple[dict[int, isl.PwAff], dict[int, isl.PwAff]]: ...

def get_grid_size_upper_bounds(self,
callables_table: CallablesTable,
ignore_auto: bool = False,
Expand Down
Loading
Loading