From a82fd1905f66c9d4f7cc27b33eef2707dc8c8888 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 12 Jan 2025 13:56:18 -0600 Subject: [PATCH] Improve typing for permutations and gnitb --- pytools/__init__.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 6eb69004..77fca6bf 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -1385,7 +1385,7 @@ def indices_in_shape(shape): def generate_nonnegative_integer_tuples_below( - n: tuple[int, ...] | int, length: int | None = None, least: int = 0 + n: Sequence[int] | int, length: int | None = None, least: int = 0 ) -> Iterator[tuple[int, ...]]: """n may be a sequence, in which case length must be None.""" if length is None: @@ -1477,7 +1477,10 @@ def generate_all_integer_tuples_below( n, length, least_abs)) -class _ConcatenableSequence(Protocol): +T_co = TypeVar("T_co", covariant=True) + + +class _ConcatenableSequence(Generic[T_co], Protocol): """ A protocol that supports the following: @@ -1494,10 +1497,13 @@ def __add__(self, other: Self) -> Self: def __len__(self) -> int: ... + def __iter__(self) -> Iterator[T_co]: + ... + def generate_permutations( - original: _ConcatenableSequence - ) -> Iterator[_ConcatenableSequence]: + original: _ConcatenableSequence[T] + ) -> Iterator[_ConcatenableSequence[T]]: """Generate all permutations of the list *original*. Nicked from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/252178 @@ -1512,8 +1518,8 @@ def generate_permutations( def generate_unique_permutations( - original: _ConcatenableSequence - ) -> Iterator[_ConcatenableSequence]: + original: _ConcatenableSequence[T] + ) -> Iterator[_ConcatenableSequence[T]]: """Generate all unique permutations of the list *original*. Note that, unlike for :func:`generate_permutations`, *original* must be a