diff --git a/pytools/__init__.py b/pytools/__init__.py index 6eb69004..77fca6bf 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -1385,7 +1385,7 @@ def indices_in_shape(shape): def generate_nonnegative_integer_tuples_below( - n: tuple[int, ...] | int, length: int | None = None, least: int = 0 + n: Sequence[int] | int, length: int | None = None, least: int = 0 ) -> Iterator[tuple[int, ...]]: """n may be a sequence, in which case length must be None.""" if length is None: @@ -1477,7 +1477,10 @@ def generate_all_integer_tuples_below( n, length, least_abs)) -class _ConcatenableSequence(Protocol): +T_co = TypeVar("T_co", covariant=True) + + +class _ConcatenableSequence(Generic[T_co], Protocol): """ A protocol that supports the following: @@ -1494,10 +1497,13 @@ def __add__(self, other: Self) -> Self: def __len__(self) -> int: ... + def __iter__(self) -> Iterator[T_co]: + ... + def generate_permutations( - original: _ConcatenableSequence - ) -> Iterator[_ConcatenableSequence]: + original: _ConcatenableSequence[T] + ) -> Iterator[_ConcatenableSequence[T]]: """Generate all permutations of the list *original*. Nicked from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/252178 @@ -1512,8 +1518,8 @@ def generate_permutations( def generate_unique_permutations( - original: _ConcatenableSequence - ) -> Iterator[_ConcatenableSequence]: + original: _ConcatenableSequence[T] + ) -> Iterator[_ConcatenableSequence[T]]: """Generate all unique permutations of the list *original*. Note that, unlike for :func:`generate_permutations`, *original* must be a