From 5ab764aee2c4b534a7a7beceb69f866baba3ffe6 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 29 Dec 2024 11:10:12 +0200 Subject: [PATCH] mypy: type gnitsam functions --- pytools/__init__.py | 79 ++++++++++++++++++++++++++++-------- pytools/test/test_pytools.py | 16 ++++++++ 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 06887a49..6eb69004 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -32,7 +32,15 @@ import operator import re import sys -from collections.abc import Callable, Collection, Hashable, Iterable, Mapping, Sequence +from collections.abc import ( + Callable, + Collection, + Hashable, + Iterable, + Iterator, + Mapping, + Sequence, +) from functools import reduce, wraps from sys import intern from typing import ( @@ -41,10 +49,11 @@ Concatenate, Generic, ParamSpec, + Protocol, TypeVar, ) -from typing_extensions import dataclass_transform +from typing_extensions import Self, dataclass_transform # These are deprecated and will go away in 2022. @@ -102,6 +111,7 @@ .. autofunction:: generate_all_integer_tuples_below .. autofunction:: generate_permutations .. autofunction:: generate_unique_permutations +.. autoclass:: _ConcatenableSequence Formatting ---------- @@ -1347,7 +1357,9 @@ def std_deviation(iterable, finite_pop): # {{{ permutations, tuples, integer sequences -def wandering_element(length, wanderer=1, landscape=0): +def wandering_element(length: int, + wanderer: int = 1, + landscape: int = 0) -> Iterator[tuple[int, ...]]: for i in range(length): yield i*(landscape,) + (wanderer,) + (length-1-i)*(landscape,) @@ -1372,9 +1384,12 @@ def indices_in_shape(shape): yield (i, *rest) -def generate_nonnegative_integer_tuples_below(n, length=None, least=0): +def generate_nonnegative_integer_tuples_below( + n: tuple[int, ...] | int, length: int | None = None, least: int = 0 + ) -> Iterator[tuple[int, ...]]: """n may be a sequence, in which case length must be None.""" if length is None: + assert not isinstance(n, int) if not n: yield () return @@ -1383,6 +1398,7 @@ def generate_nonnegative_integer_tuples_below(n, length=None, least=0): n = n[1:] next_length = None else: + assert isinstance(n, int) my_n = n assert length >= 0 @@ -1399,12 +1415,12 @@ def generate_nonnegative_integer_tuples_below(n, length=None, least=0): def generate_decreasing_nonnegative_tuples_summing_to( - n, length, min_value=0, max_value=None): + n: int, length: int, min_value: int = 0, max_value: int | None = None + ) -> Iterator[tuple[int, ...]]: if length == 0: yield () elif length == 1: - if n <= max_value: - # print "MX", n, max_value + if max_value is None or n <= max_value: yield (n,) else: return @@ -1412,14 +1428,14 @@ def generate_decreasing_nonnegative_tuples_summing_to( if max_value is None or n < max_value: max_value = n - for i in range(min_value, max_value+1): - # print "SIG", sig, i + for i in range(min_value, max_value + 1): for remainder in generate_decreasing_nonnegative_tuples_summing_to( - n-i, length-1, min_value, i): + n - i, length - 1, min_value=min_value, max_value=i): yield (i, *remainder) -def generate_nonnegative_integer_tuples_summing_to_at_most(n, length): +def generate_nonnegative_integer_tuples_summing_to_at_most( + n: int, length: int) -> Iterator[tuple[int, ...]]: """Enumerate all non-negative integer tuples summing to at most n, exhausting the search space by varying the first entry fastest, and the last entry the slowest. @@ -1438,24 +1454,50 @@ def generate_nonnegative_integer_tuples_summing_to_at_most(n, length): generate_positive_integer_tuples_below = generate_nonnegative_integer_tuples_below -def _pos_and_neg_adaptor(tuple_iter): +def _pos_and_neg_adaptor( + tuple_iter: Iterator[tuple[int, ...]] + ) -> Iterator[tuple[int, ...]]: for tup in tuple_iter: nonzero_indices = [i for i in range(len(tup)) if tup[i] != 0] for do_neg_tup in generate_nonnegative_integer_tuples_below( 2, len(nonzero_indices)): + this_result = list(tup) for index, do_neg in enumerate(do_neg_tup): if do_neg: this_result[nonzero_indices[index]] *= -1 + yield tuple(this_result) -def generate_all_integer_tuples_below(n, length, least_abs=0): +def generate_all_integer_tuples_below( + n: int, length: int, least_abs: int = 0 + ) -> Iterator[tuple[int, ...]]: return _pos_and_neg_adaptor(generate_nonnegative_integer_tuples_below( n, length, least_abs)) -def generate_permutations(original): +class _ConcatenableSequence(Protocol): + """ + A protocol that supports the following: + + .. automethod:: __getitem__ + .. automethod:: __add__ + .. automethod:: __len__ + """ + def __getitem__(self, slice) -> Self: + ... + + def __add__(self, other: Self) -> Self: + ... + + def __len__(self) -> int: + ... + + +def generate_permutations( + original: _ConcatenableSequence + ) -> Iterator[_ConcatenableSequence]: """Generate all permutations of the list *original*. Nicked from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/252178 @@ -1465,12 +1507,17 @@ def generate_permutations(original): else: for perm_ in generate_permutations(original[1:]): for i in range(len(perm_)+1): - # nb str[0:1] works in both string and list contexts + # NOTE: ary[0:1] works in both string and list contexts yield perm_[:i] + original[0:1] + perm_[i:] -def generate_unique_permutations(original): +def generate_unique_permutations( + original: _ConcatenableSequence + ) -> Iterator[_ConcatenableSequence]: """Generate all unique permutations of the list *original*. + + Note that, unlike for :func:`generate_permutations`, *original* must be a + hashable object. """ had_those = set() diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index 81247220..a7e0b4f1 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -871,6 +871,22 @@ class ImmutableRecordWithUnsetSlots(ImmutableRecord): # }}} +def test_permutations(): + from math import factorial + + from pytools import generate_permutations, generate_unique_permutations + + perm = list(generate_permutations([1, 2, 3, 4])) + assert len(perm) == factorial(4) + perm = list(generate_unique_permutations((1, 3, 3, 4))) + assert len(perm) == 12 + + perms = list(generate_permutations("1234")) + assert len(perms) == factorial(4) + perms = list(generate_unique_permutations("1334")) + assert len(perms) == 12 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])