diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 5076a8e2..06a3fbbb 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2931,326 +2931,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 12, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 22, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 4, - "endColumn": 6, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 14, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 24, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 8, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 12, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 22, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 4, - "endColumn": 6, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 14, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 24, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 8, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 10, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 11, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 11, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 29, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 10, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 11, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 11, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 29, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -4627,70 +4307,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 14, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 14, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 20, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 27, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 4 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -5123,94 +4739,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 17, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 17, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 25, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 4, - "endColumn": 10, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 18, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 62, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 31, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 46, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 22, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 830585e9..1226c756 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,7 @@ jobs: python-version: '3.x' - name: "Main Script" run: | - EXTRA_INSTALL="numpy pymbolic orderedsets" + EXTRA_INSTALL="numpy pymbolic orderedsets optype" curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-pylint.sh . ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" @@ -68,13 +68,15 @@ jobs: python-version: '3.x' - name: "Main Script" run: | + EXTRA_INSTALL="basedpyright numpy attrs orderedsets pytest mpi4py matplotlib optype" + + sudo apt update + sudo apt -y install libopenmpi-dev + curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 + build_py_project_in_venv - sudo apt update - sudo apt -y install libopenmpi-dev - pip install numpy attrs orderedsets pytest mpi4py matplotlib - pip install basedpyright basedpyright pytest: @@ -163,7 +165,7 @@ jobs: python-version: '3.x' - name: "Main Script" run: | - EXTRA_INSTALL="numpy" + EXTRA_INSTALL="numpy optype" curl -L -O https://tiker.net/ci-support-v0 . ci-support-v0 build_py_project_in_venv diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ece4d8db..fa8b961e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -42,30 +42,30 @@ Pytest without Numpy: # - tags Ruff: - script: - - pipx install ruff - - ruff check + script: | + pipx install ruff + ruff check tags: - docker-runner except: - tags Pylint: - script: - - EXTRA_INSTALL="numpy pymbolic orderedsets siphash24" - - py_version=3 - - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-pylint.sh - - . ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" + script: | + EXTRA_INSTALL="numpy pymbolic orderedsets siphash24 optype" + py_version=3 + curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-pylint.sh + . ./prepare-and-run-pylint.sh "$CI_PROJECT_NAME" tags: - python3 except: - tags Documentation: - script: - - EXTRA_INSTALL="numpy siphash24" - - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh - - ". ./build-docs.sh" + script: | + EXTRA_INSTALL="numpy siphash24 optype" + curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh + . ./build-docs.sh tags: - python3 diff --git a/doc/conf.py b/doc/conf.py index e06e9ddc..676da1b9 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -35,6 +35,10 @@ "NodeT": "pytools.graph.NodeT", } +nitpick_ignore_regex = [ + ["py:class", r"optype.*"], +] + sphinxconfig_missing_reference_aliases = { # numpy typing "NDArray": "obj:numpy.typing.NDArray", @@ -42,10 +46,7 @@ "np.ndarray": "class:numpy.ndarray", "np.floating": "class:numpy.floating", # pytools typing - "BoundingBox": "obj:pytools.spatial_btree.BoundingBox", - "Element": "obj:pytools.spatial_btree.Element", "ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D", - "Point": "obj:pytools.spatial_btree.Point", "ReadableBuffer": "data:pytools.ReadableBuffer", } diff --git a/pyproject.toml b/pyproject.toml index cf3afcb5..8823af4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ numpy = [ ] test = [ "basedpyright", + "optype", "pytest", "ruff", ] diff --git a/pytools/__init__.py b/pytools/__init__.py index cdd5a2af..1f3035fb 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -51,10 +51,12 @@ ClassVar, Concatenate, Generic, + Literal, ParamSpec, Protocol, TypeVar, cast, + overload, ) from warnings import warn @@ -69,7 +71,8 @@ import numpy as np from _typeshed import ReadableBuffer - from numpy.typing import NDArray + from numpy.typing import DTypeLike, NDArray + from optype import CanLt from typing_extensions import Self @@ -260,6 +263,10 @@ .. class:: ReadableBuffer Anything that implements the read-write buffer interface. + +.. class:: SupportsLessThanT + + An invariant :class:`typing.TypeVar` bound to :class:`optype.CanLt`. """ # {{{ type variables @@ -273,6 +280,7 @@ K = TypeVar("K") V = TypeVar("V") EmptyT = TypeVar("EmptyT") +SupportsLessThanT = TypeVar("SupportsLessThanT", bound="CanLt[Any]") # }}} @@ -1293,12 +1301,26 @@ def find_max_where(predicate, prec=1e-5, initial_guess=1, fail_bound=1e38): # {{{ argmin, argmax -def argmin2(iterable, return_value=False): +class SupportsLessThan(Protocol[T_contra]): + def __lt__(self, other: T_contra, /) -> bool: ... + + +@overload +def argmin2(iterable: Iterable[tuple[T, SupportsLessThanT]], + return_value: Literal[True]) -> tuple[T, SupportsLessThanT]: ... + +@overload +def argmin2(iterable: Iterable[tuple[T, SupportsLessThanT]], + return_value: Literal[False] = False) -> T: ... + + +def argmin2(iterable: Iterable[tuple[T, SupportsLessThanT]], + return_value: bool = False) -> T | tuple[T, SupportsLessThanT]: it = iter(iterable) try: current_argmin, current_min = next(it) except StopIteration: - raise ValueError("argmin of empty iterable") from None + raise ValueError("argmin() iterable argument is empty") from None for arg, item in it: if item < current_min: @@ -1307,15 +1329,26 @@ def argmin2(iterable, return_value=False): if return_value: return current_argmin, current_min + return current_argmin -def argmax2(iterable, return_value=False): +@overload +def argmax2(iterable: Iterable[tuple[T, SupportsLessThanT]], + return_value: Literal[True]) -> tuple[T, SupportsLessThanT]: ... + +@overload +def argmax2(iterable: Iterable[tuple[T, SupportsLessThanT]], + return_value: Literal[False] = False) -> T: ... + + +def argmax2(iterable: Iterable[tuple[T, SupportsLessThanT]], + return_value: bool = False) -> T | tuple[T, SupportsLessThanT]: it = iter(iterable) try: current_argmax, current_max = next(it) except StopIteration: - raise ValueError("argmax of empty iterable") from None + raise ValueError("argmax() iterable argument is empty") from None for arg, item in it: if item > current_max: @@ -1324,15 +1357,16 @@ def argmax2(iterable, return_value=False): if return_value: return current_argmax, current_max + return current_argmax -def argmin(iterable): - return argmin2(enumerate(iterable)) +def argmin(iterable: Iterable[SupportsLessThanT]) -> int: + return argmin2(enumerate(iterable), return_value=False) -def argmax(iterable): - return argmax2(enumerate(iterable)) +def argmax(iterable: Iterable[SupportsLessThanT]) -> int: + return argmax2(enumerate(iterable), return_value=False) # }}} @@ -2011,12 +2045,12 @@ def format_bar(cnt): # }}} -def word_wrap(text, width, wrap_using="\n"): +def word_wrap(text: str, width: int, wrap_using: str = "\n") -> str: # http://code.activestate.com/recipes/148061-one-liner-word-wrap-function/ r""" A word-wrap function that preserves existing line breaks and most spaces in the text. Expects that existing line - breaks are posix newlines (``\n``). + breaks are POSIX newlines (``\n``). """ space_or_break = [" ", wrap_using] return reduce(lambda line, word: "{}{}{}".format( @@ -2239,14 +2273,18 @@ def add_python_path_relative_to_script(rel_path: str) -> None: # {{{ numpy dtype mangling -def common_dtype(dtypes, default=None): - dtypes = list(dtypes) - if dtypes: - return argmax2((dtype, dtype.num) for dtype in dtypes) +def common_dtype(dtypes: Iterator[DTypeLike], + default: DTypeLike = None) -> np.dtype[Any]: + import numpy as np + + ddtypes = [np.dtype(dtype) for dtype in dtypes] + if ddtypes: + return argmax2((dtype, dtype.num) for dtype in ddtypes) + if default is not None: - return default - raise ValueError( - "cannot find common dtype of empty dtype list") + return np.dtype(default) + + raise ValueError("cannot find common dtype of empty dtype list") def to_uncomplex_dtype(dtype):