Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@
import operator
import re
import sys
from collections.abc import Callable, Collection, Hashable, Iterable, Mapping, Sequence
from collections.abc import (
Callable,
Collection,
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
)
from functools import reduce, wraps
from sys import intern
from typing import (
Expand All @@ -41,10 +49,11 @@
Concatenate,
Generic,
ParamSpec,
Protocol,
TypeVar,
)

from typing_extensions import dataclass_transform
from typing_extensions import Self, dataclass_transform


# These are deprecated and will go away in 2022.
Expand Down Expand Up @@ -102,6 +111,7 @@
.. autofunction:: generate_all_integer_tuples_below
.. autofunction:: generate_permutations
.. autofunction:: generate_unique_permutations
.. autoclass:: _ConcatenableSequence

Formatting
----------
Expand Down Expand Up @@ -1347,7 +1357,9 @@ def std_deviation(iterable, finite_pop):

# {{{ permutations, tuples, integer sequences

def wandering_element(length, wanderer=1, landscape=0):
def wandering_element(length: int,
wanderer: int = 1,
landscape: int = 0) -> Iterator[tuple[int, ...]]:
for i in range(length):
yield i*(landscape,) + (wanderer,) + (length-1-i)*(landscape,)

Expand All @@ -1372,9 +1384,12 @@ def indices_in_shape(shape):
yield (i, *rest)


def generate_nonnegative_integer_tuples_below(n, length=None, least=0):
def generate_nonnegative_integer_tuples_below(
n: tuple[int, ...] | int, length: int | None = None, least: int = 0
) -> Iterator[tuple[int, ...]]:
"""n may be a sequence, in which case length must be None."""
if length is None:
assert not isinstance(n, int)
if not n:
yield ()
return
Expand All @@ -1383,6 +1398,7 @@ def generate_nonnegative_integer_tuples_below(n, length=None, least=0):
n = n[1:]
next_length = None
else:
assert isinstance(n, int)
my_n = n

assert length >= 0
Expand All @@ -1399,27 +1415,27 @@ def generate_nonnegative_integer_tuples_below(n, length=None, least=0):


def generate_decreasing_nonnegative_tuples_summing_to(
n, length, min_value=0, max_value=None):
n: int, length: int, min_value: int = 0, max_value: int | None = None
) -> Iterator[tuple[int, ...]]:
if length == 0:
yield ()
elif length == 1:
if n <= max_value:
# print "MX", n, max_value
if max_value is None or n <= max_value:
yield (n,)
else:
return
else:
if max_value is None or n < max_value:
max_value = n

for i in range(min_value, max_value+1):
# print "SIG", sig, i
for i in range(min_value, max_value + 1):
for remainder in generate_decreasing_nonnegative_tuples_summing_to(
n-i, length-1, min_value, i):
n - i, length - 1, min_value=min_value, max_value=i):
yield (i, *remainder)


def generate_nonnegative_integer_tuples_summing_to_at_most(n, length):
def generate_nonnegative_integer_tuples_summing_to_at_most(
n: int, length: int) -> Iterator[tuple[int, ...]]:
"""Enumerate all non-negative integer tuples summing to at most n,
exhausting the search space by varying the first entry fastest,
and the last entry the slowest.
Expand All @@ -1438,24 +1454,50 @@ def generate_nonnegative_integer_tuples_summing_to_at_most(n, length):
generate_positive_integer_tuples_below = generate_nonnegative_integer_tuples_below


def _pos_and_neg_adaptor(tuple_iter):
def _pos_and_neg_adaptor(
tuple_iter: Iterator[tuple[int, ...]]
) -> Iterator[tuple[int, ...]]:
for tup in tuple_iter:
nonzero_indices = [i for i in range(len(tup)) if tup[i] != 0]
for do_neg_tup in generate_nonnegative_integer_tuples_below(
2, len(nonzero_indices)):

this_result = list(tup)
for index, do_neg in enumerate(do_neg_tup):
if do_neg:
this_result[nonzero_indices[index]] *= -1

yield tuple(this_result)


def generate_all_integer_tuples_below(n, length, least_abs=0):
def generate_all_integer_tuples_below(
n: int, length: int, least_abs: int = 0
) -> Iterator[tuple[int, ...]]:
return _pos_and_neg_adaptor(generate_nonnegative_integer_tuples_below(
n, length, least_abs))


def generate_permutations(original):
class _ConcatenableSequence(Protocol):
"""
A protocol that supports the following:

.. automethod:: __getitem__
.. automethod:: __add__
.. automethod:: __len__
"""
def __getitem__(self, slice) -> Self:
...

def __add__(self, other: Self) -> Self:
...

def __len__(self) -> int:
...


def generate_permutations(
original: _ConcatenableSequence
) -> Iterator[_ConcatenableSequence]:
"""Generate all permutations of the list *original*.

Nicked from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/252178
Expand All @@ -1465,12 +1507,17 @@ def generate_permutations(original):
else:
for perm_ in generate_permutations(original[1:]):
for i in range(len(perm_)+1):
# nb str[0:1] works in both string and list contexts
# NOTE: ary[0:1] works in both string and list contexts
yield perm_[:i] + original[0:1] + perm_[i:]


def generate_unique_permutations(original):
def generate_unique_permutations(
original: _ConcatenableSequence
) -> Iterator[_ConcatenableSequence]:
"""Generate all unique permutations of the list *original*.

Note that, unlike for :func:`generate_permutations`, *original* must be a
hashable object.
"""

had_those = set()
Expand Down
16 changes: 16 additions & 0 deletions pytools/test/test_pytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,22 @@ class ImmutableRecordWithUnsetSlots(ImmutableRecord):
# }}}


def test_permutations():
from math import factorial

from pytools import generate_permutations, generate_unique_permutations

perm = list(generate_permutations([1, 2, 3, 4]))
assert len(perm) == factorial(4)
perm = list(generate_unique_permutations((1, 3, 3, 4)))
assert len(perm) == 12

perms = list(generate_permutations("1234"))
assert len(perms) == factorial(4)
perms = list(generate_unique_permutations("1334"))
assert len(perms) == 12


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
Loading