diff --git a/formulaic/formula.py b/formulaic/formula.py index 939ce58..c3f1bbf 100644 --- a/formulaic/formula.py +++ b/formulaic/formula.py @@ -1,8 +1,8 @@ from __future__ import annotations -import sys from abc import ABCMeta, abstractmethod -from collections.abc import Generator, Iterable, Mapping, MutableSequence +from collections import Counter +from collections.abc import Generator, Iterable, Mapping from enum import Enum from typing import ( Any, @@ -11,7 +11,6 @@ TypeVar, Union, cast, - overload, ) from typing_extensions import Self, TypeAlias @@ -21,9 +20,10 @@ from .errors import FormulaInvalidError from .model_matrix import ModelMatrix from .parser import DefaultFormulaParser -from .parser.types import FormulaParser, OrderedSet, Term +from .parser.types import FormulaParser, Term from .utils.calculus import differentiate_term from .utils.deprecations import deprecated +from .utils.ordered_set import OrderedSet from .utils.structured import Structured from .utils.variables import Variable, get_expression_variables @@ -329,12 +329,12 @@ def differentiate( class SimpleFormula( - MutableSequence[Term] if sys.version_info >= (3, 9) else MutableSequence, # type: ignore + OrderedSet[Term], Formula, ): """ The atomic component of all formulae represented by Formulaic, which in turn - is a mutable sequence of `Term` instances. `StructuredFormula` uses + is a mutable ordered set of `Term` instances. `StructuredFormula` uses `SimpleFormula` as its nodes. Instances of this class can be used directly as a mutable sequence of @@ -354,7 +354,9 @@ class SimpleFormula( def __init__( self, - root: Union[Iterable[Term], MissingType] = MISSING, + root: Union[ + Iterable[Term], Mapping[Term, int], OrderedSet[Term], MissingType + ] = MISSING, *, _ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, _parser: Optional[FormulaParser] = None, @@ -375,13 +377,17 @@ def __init__( "`SimpleFormula` does not support nested structure. To create a " "structured formula, use `StructuredFormula` instead." ) - self.__terms = list(root) self.ordering = OrderingMethod(_ordering) - self.__validate_terms(self.__terms) + self.__validate_terms(root) + super().__init__(root) self._reorder() + def _prepare_item(self, item) -> Term: + self.__validate_terms([item]) + return item + @classmethod def __validate_terms(cls, terms: Any) -> None: """ @@ -392,6 +398,9 @@ def __validate_terms(cls, terms: Any) -> None: f"All components of a `SimpleFormula` should be `Term` instances. Found: {repr(terms)}. To use formula strings, please use `Formula` or `StructuredFormula` instead." ) + def _post_update(self) -> None: + self._reorder() + def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None: """ Reorder the terms in this container in-place according to the specified @@ -404,6 +413,8 @@ def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None: """ ordering = OrderingMethod(ordering if ordering is not None else self.ordering) orderer = None + + print(ordering) if ordering is OrderingMethod.DEGREE: orderer = lambda terms: sorted(terms, key=lambda term: term.degree) elif ordering is OrderingMethod.SORT: @@ -412,56 +423,9 @@ def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None: ) if orderer is not None: - self.__terms = orderer(self.__terms) - - # MutableSequence implementation - - @overload - def __getitem__(self, key: int) -> Term: ... - - @overload - def __getitem__(self, key: slice) -> SimpleFormula: ... - - def __getitem__(self, key: Union[int, slice]) -> Union[Term, SimpleFormula]: - if isinstance(key, slice): - return self.__class__(self.__terms[key], _ordering=self.ordering) - else: - return self.__terms[key] - - @overload - def __setitem__(self, key: int, value: Term) -> None: ... - - @overload - def __setitem__(self, key: slice, value: Iterable[Term]) -> None: ... - - def __setitem__(self, key, value): # type: ignore - self.__validate_terms([value]) - self.__terms[key] = value - self._reorder() - - @overload - def __delitem__(self, key: int) -> None: ... - - @overload - def __delitem__(self, key: slice) -> None: ... - - def __delitem__(self, key): # type: ignore - del self.__terms[key] - - def __len__(self) -> int: - return len(self.__terms) - - def insert(self, index: int, value: Term) -> None: - self.__validate_terms([value]) - self.__terms.insert(index, value) - self._reorder() - - def __eq__(self, other: Any) -> bool: - if isinstance(other, SimpleFormula): - other = list(other) - if isinstance(other, list): - return self.__terms == other - return NotImplemented + self._values = Counter( + {item: self._values[item] for item in orderer(self._values)} + ) # Transforms @@ -486,7 +450,7 @@ def differentiate( # pylint: disable=redefined-builtin return SimpleFormula( [ differentiate_term(term, wrt, use_sympy=use_sympy) - for term in self.__terms + for term in self._values ], # Preserve term ordering even if differentiation modifies degrees/etc. _ordering=OrderingMethod.NONE, @@ -536,7 +500,7 @@ def required_variables(self) -> set[Variable]: variables: list[Variable] = [ variable - for term in self.__terms + for term in self._values for factor in term.factors for variable in get_expression_variables(factor.expr, {}) if "value" in variable.roles @@ -554,7 +518,7 @@ def required_variables(self) -> set[Variable]: ) def __repr__(self) -> str: - return " + ".join([str(t) for t in self.__terms]) + return " + ".join([str(t) for t in self]) # Deprecated shims for legacy `Structured`-like behaviour (previously there # was no distinction between `SimpleFormula` and `StructuredFormula`, and diff --git a/formulaic/materializers/base.py b/formulaic/materializers/base.py index a59b8d4..0f1d6c9 100644 --- a/formulaic/materializers/base.py +++ b/formulaic/materializers/base.py @@ -29,11 +29,11 @@ from formulaic.materializers.types.factor_values import FactorValuesMetadata from formulaic.model_matrix import ModelMatrices, ModelMatrix from formulaic.parser.types import Factor, Term -from formulaic.parser.types.ordered_set import OrderedSet from formulaic.transforms import TRANSFORMS from formulaic.utils.cast import as_columns from formulaic.utils.layered_mapping import LayeredMapping from formulaic.utils.null_handling import find_nulls +from formulaic.utils.ordered_set import OrderedSet from formulaic.utils.stateful_transforms import stateful_eval from formulaic.utils.variables import Variable diff --git a/formulaic/parser/parser.py b/formulaic/parser/parser.py index c87dd4c..0b31c07 100644 --- a/formulaic/parser/parser.py +++ b/formulaic/parser/parser.py @@ -17,6 +17,7 @@ from formulaic.errors import FormulaParsingError from formulaic.utils.layered_mapping import LayeredMapping +from formulaic.utils.ordered_set import OrderedSet from formulaic.utils.structured import Structured from .algos.sanitize_tokens import sanitize_tokens @@ -27,7 +28,6 @@ FormulaParser, Operator, OperatorResolver, - OrderedSet, Term, Token, ) diff --git a/formulaic/parser/types/__init__.py b/formulaic/parser/types/__init__.py index e9747ad..e333f80 100644 --- a/formulaic/parser/types/__init__.py +++ b/formulaic/parser/types/__init__.py @@ -3,7 +3,6 @@ from .formula_parser import FormulaParser from .operator import Operator from .operator_resolver import OperatorResolver -from .ordered_set import OrderedSet from .term import Term from .token import Token @@ -13,7 +12,6 @@ "FormulaParser", "Operator", "OperatorResolver", - "OrderedSet", "Term", "Token", ] diff --git a/formulaic/parser/types/ast_node.py b/formulaic/parser/types/ast_node.py index 6299a0c..a6cf31f 100644 --- a/formulaic/parser/types/ast_node.py +++ b/formulaic/parser/types/ast_node.py @@ -11,10 +11,10 @@ Union, ) +from formulaic.utils.ordered_set import OrderedSet from formulaic.utils.structured import Structured from .operator import Operator -from .ordered_set import OrderedSet from .term import Term ItemType = TypeVar("ItemType") diff --git a/formulaic/parser/types/factor.py b/formulaic/parser/types/factor.py index 1a5e619..d847652 100644 --- a/formulaic/parser/types/factor.py +++ b/formulaic/parser/types/factor.py @@ -4,7 +4,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union -from .ordered_set import OrderedSet +from formulaic.utils.ordered_set import OrderedSet + from .term import Term if TYPE_CHECKING: diff --git a/formulaic/parser/types/formula_parser.py b/formulaic/parser/types/formula_parser.py index 2197672..07c6799 100644 --- a/formulaic/parser/types/formula_parser.py +++ b/formulaic/parser/types/formula_parser.py @@ -11,8 +11,8 @@ overload, ) -from formulaic.parser.types.ordered_set import OrderedSet from formulaic.utils.layered_mapping import LayeredMapping +from formulaic.utils.ordered_set import OrderedSet from formulaic.utils.structured import Structured from .ast_node import ASTNode diff --git a/formulaic/parser/types/ordered_set.py b/formulaic/parser/types/ordered_set.py deleted file mode 100644 index fe246f1..0000000 --- a/formulaic/parser/types/ordered_set.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, Iterator, Set -from typing import Any, Generic, TypeVar - -ItemType = TypeVar("ItemType") - - -class OrderedSet(Set, Generic[ItemType]): - """ - A set-like container that retains the order in which item were added to the - set. - """ - - def __init__(self, values: Iterable[ItemType] = ()) -> None: - self.values = dict.fromkeys(values) - - def __contains__(self, item: Any) -> bool: - return item in self.values - - def __iter__(self) -> Iterator[ItemType]: - return iter(self.values) - - def __len__(self) -> int: - return len(self.values) - - def __repr__(self) -> str: - return f"{{{', '.join(repr(v) for v in self.values)}}}" diff --git a/formulaic/parser/types/term.py b/formulaic/parser/types/term.py index 08eed27..6a871d4 100644 --- a/formulaic/parser/types/term.py +++ b/formulaic/parser/types/term.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any, Optional -from .ordered_set import OrderedSet +from formulaic.utils.ordered_set import OrderedSet if TYPE_CHECKING: from .factor import Factor # pragma: no cover diff --git a/formulaic/parser/types/token.py b/formulaic/parser/types/token.py index 4afd6a9..5d30a3e 100644 --- a/formulaic/parser/types/token.py +++ b/formulaic/parser/types/token.py @@ -6,10 +6,10 @@ from enum import Enum from typing import Any, Optional, Union +from formulaic.utils.ordered_set import OrderedSet from formulaic.utils.variables import Variable, get_expression_variables from .factor import Factor -from .ordered_set import OrderedSet from .term import Term diff --git a/formulaic/utils/calculus.py b/formulaic/utils/calculus.py index b245c2d..06eb737 100644 --- a/formulaic/utils/calculus.py +++ b/formulaic/utils/calculus.py @@ -2,7 +2,7 @@ from typing import cast from formulaic.parser.types import Factor, Term -from formulaic.parser.types.ordered_set import OrderedSet +from formulaic.utils.ordered_set import OrderedSet def differentiate_term( diff --git a/formulaic/utils/ordered_set.py b/formulaic/utils/ordered_set.py new file mode 100644 index 0000000..c253f01 --- /dev/null +++ b/formulaic/utils/ordered_set.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +from collections import Counter +from collections.abc import Iterable, Iterator, Mapping, MutableSequence, MutableSet +from itertools import chain, islice +from typing import Any, Generic, Optional, TypeVar, Union, overload + +_ItemType = TypeVar("_ItemType") +_SelfType = TypeVar("_SelfType", bound="OrderedSet") + + +class OrderedSet(MutableSet, MutableSequence, Generic[_ItemType]): + """ + A mutable set-like sequence container that retains the order in which item + were added to the set, keeps track of multiplicities (how many times an item + was added), and provides both set and list-like indexing and mutations. This + container keeps track of how many times an item was added to the set, which + can be checked using the `.get_multiplicity()` method. + + This class is optimised for set-like operations, but also provides O(n) + lookups by index, insertions, deletions, and updates. We may optimise this + in the future based on need by maintaining index tables. + + Note: Indexed mutations like `collection[] = ` do not maintain + order, and are equivalent to `collection.remove(collection[])` + followed by `collection.add()`. + """ + + def __init__( + self, + values: Union[ + Iterable[_ItemType], Mapping[_ItemType, int], OrderedSet[_ItemType] + ] = (), + ) -> None: + self._values: Counter = Counter( + values._values if isinstance(values, OrderedSet) else values + ) + + def get_multiplicity(self, item: _ItemType) -> int: + """ + Identify how many times this item was added to the set. If the item was + never added, return 0. This is mainly useful if you later need to expand + an item into multiple items and need to keep track of the original + interaction order. + """ + return self._values[item] + + def _prepare_item(self, item: Any) -> _ItemType: + """ + Prepare an item for insertion into this ordered set. This method is + called whenever an item is added to the set. It is *not* called for + discard operations. + """ + return item + + def _post_update(self) -> None: + """ + Perform any post-update operations. This is called after every mutation + to the ordered set. + """ + pass + + def __repr__(self) -> str: + return f"{{{', '.join(repr(v) for v in self._values)}}}" + + # MutableSet interface + + def __contains__(self, item: Any) -> bool: + return item in self._values + + def __iter__(self) -> Iterator[_ItemType]: + return iter(self._values) + + def __len__(self) -> int: + return len(self._values) + + def add(self, item: _ItemType, count: int = 1) -> None: + item = self._prepare_item(item) + self._values[item] = self._values.get(item, 0) + count + self._post_update() + + def discard(self, item: _ItemType, count: Optional[int] = None) -> None: + if item in self._values: + final_count = 0 if count is None else self._values.get(item) - count + if final_count <= 0: + del self._values[item] + else: + self._values[item] = final_count + self._post_update() + + # MutableSet order preservation + + def __and__(self, other: Any) -> OrderedSet[_ItemType]: + out = OrderedSet[_ItemType]() + other = OrderedSet(other) + + for value, count in self._values.items(): + if value in other: + out.add(value, count) + for value, count in other._values.items(): + if value in self: + out.add(value, count) + return out + + def __ror__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) | self + + def __rxor__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) ^ self + + def __rand__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) & self + + def __rsub__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) - self + + # Additional methods for MutableSequence interface (O(n) lookups by index) + + @overload + def __getitem__(self, index: int) -> _ItemType: ... + + @overload + def __getitem__(self: _SelfType, index: slice) -> _SelfType: ... + + def __getitem__( + self: _SelfType, index: Union[int, slice] + ) -> Union[_ItemType, _SelfType]: + if isinstance(index, slice): + return self.__class__( + { + item: self._values[item] + for item in islice( + self._values, + index.start % len(self) if index.start is not None else None, + index.stop % len(self) if index.stop is not None else None, + index.step, + ) + } + ) + else: + return next(islice(self._values, index % len(self), None)) + + @overload + def __setitem__(self, key: int, value: _ItemType) -> None: ... + + @overload + def __setitem__(self, key: slice, value: Iterable[_ItemType]) -> None: ... + + def __setitem__(self, key, value): # type: ignore + self.__insert_or_replace( + key, value, replace=True + ) + + @overload + def __delitem__(self, key: int) -> None: ... + + @overload + def __delitem__(self, key: slice) -> None: ... + + def __delitem__(self, key): # type: ignore + if isinstance(key, slice): + for item in self[key]: + del self._values[item] + else: + del self._values[self[key]] + self._post_update() + + def insert(self, index: int, value: _ItemType, count: int = 1) -> None: + self.__insert_or_replace(index, value, count=count, replace=False) + + def __insert_or_replace(self, indices: Union[int, slice], values: Union[_ItemType, Iterable[_ItemType]], count: int = 1, replace=False) -> None: + if isinstance(indices, int): + indices = range(indices, indices+1) + values = [values] + else: + indices = range(len(self._values))[indices] + if len(indices) == 0: + indices = range(len(self._values), len(self._values) + 1) + + values_to_insert = { + v: self._values.get(v, 0) + count + for value in values + if (v := self._prepare_item(value)) or True + } + _values_new = Counter() + for i, (item, count) in enumerate(self._values.items()): + if i in indices: + if i > min(indices): + continue + _values_new.update(values_to_insert) + if item not in values_to_insert and (not replace or i not in indices): + _values_new.update({item: count}) + if min(indices) >= len(self._values): + _values_new.update(values_to_insert) + self._values = _values_new + self._post_update() + + # Other data model methods + + def __eq__(self, other: Any) -> bool: + if isinstance(other, (OrderedSet, list, tuple)): + return tuple(self) == tuple(other) + return NotImplemented + + # Convenience methods + + def update( + self, + items: Union[ + Iterable[_ItemType], Mapping[_ItemType, int], OrderedSet[_ItemType] + ], + ) -> None: + """ + Update this ordered set with the items from another iterable or mapping + from items to observed counts. + + Args: + items: The items to add to this ordered set. If an iterable is + is provided, the items will be added with a count of 1. + Otherwise the counts will be aggregated from the mapping and/or + ordered set instances. + """ + self._values.update(items.values if isinstance(items, OrderedSet) else items) + self._post_update() diff --git a/tests/parser/types/test_ordered_set.py b/tests/parser/types/test_ordered_set.py deleted file mode 100644 index 25d3900..0000000 --- a/tests/parser/types/test_ordered_set.py +++ /dev/null @@ -1,13 +0,0 @@ -from formulaic.parser.types import OrderedSet - - -def test_ordered_set(): - assert OrderedSet() == OrderedSet() - assert len(OrderedSet()) == 0 - - assert list(OrderedSet(["a", "a", "z", "b"])) == ["a", "z", "b"] - assert repr(OrderedSet(["z", "b", "c"])) == "{'z', 'b', 'c'}" - - assert OrderedSet(["z", "k"]) | ["a", "b"] == OrderedSet(["z", "k", "a", "b"]) - assert OrderedSet(("z", "k")) - ("z",) == OrderedSet("k") - assert ["b"] | OrderedSet("a") == OrderedSet("ba") diff --git a/tests/parser/types/test_term.py b/tests/parser/types/test_term.py index 3631be9..f80ccac 100644 --- a/tests/parser/types/test_term.py +++ b/tests/parser/types/test_term.py @@ -1,7 +1,7 @@ import pytest from formulaic.parser.types import Factor, Term -from formulaic.parser.types.ordered_set import OrderedSet +from formulaic.utils.ordered_set import OrderedSet class TestTerm: diff --git a/tests/parser/types/test_token.py b/tests/parser/types/test_token.py index 244d386..73a9a88 100644 --- a/tests/parser/types/test_token.py +++ b/tests/parser/types/test_token.py @@ -63,7 +63,7 @@ def test_to_factor(self, token_a, token_b, token_c): token_c.to_factor() def test_to_terms(self, token_a): - assert token_a.to_terms() == {Term([token_a.to_factor()])} + assert token_a.to_terms() == [Term([token_a.to_factor()])] def test_flatten(self, token_a): assert token_a.flatten(str_args=False) is token_a diff --git a/tests/utils/test_ordered_set.py b/tests/utils/test_ordered_set.py new file mode 100644 index 0000000..ee575cc --- /dev/null +++ b/tests/utils/test_ordered_set.py @@ -0,0 +1,76 @@ +from formulaic.utils.ordered_set import OrderedSet + + +class TestOrderedSet: + def test_constructor(self): + assert OrderedSet() == OrderedSet() + assert len(OrderedSet()) == 0 + assert OrderedSet() != "a" + + def test_multiplicity(self): + assert list(OrderedSet(["a", "a", "z", "b"])) == ["a", "z", "b"] + assert OrderedSet(["a", "a", "z", "b"]).get_multiplicity("a") == 2 + assert OrderedSet(["a", "a", "z", "b"]).get_multiplicity("b") == 1 + assert OrderedSet(["a", "a", "z", "b"]).get_multiplicity("missing") == 0 + + def test_repr(self): + assert repr(OrderedSet(["z", "b", "c"])) == "{'z', 'b', 'c'}" + + def test_set_operations(self): + assert OrderedSet(["z", "k"]) | ["a", "b"] == OrderedSet(["z", "k", "a", "b"]) + assert OrderedSet(("z", "k")) - ("z",) == OrderedSet("k") + assert OrderedSet(["z", "a", "b"]) & OrderedSet(["a", "c", "z"]) == OrderedSet(["z", "a"]) + assert ["b"] | OrderedSet("a") == OrderedSet("ba") + assert ["a", "b"] ^ OrderedSet(["a", "c"]) == OrderedSet(["b", "c"]) + assert ["z", "a", "b"] & OrderedSet(["a", "c", "z"]) == OrderedSet(["z", "a"]) + assert ["z", "a", "b"] - OrderedSet(["a", "c"]) == OrderedSet(["z", "b"]) + + s = OrderedSet(["a"]) + s.add("b") + assert list(s) == ["a", "b"] + assert s.get_multiplicity("a") == 1 + s.discard("a") + assert list(s) == ["b"] + assert s.get_multiplicity("a") == 0 + s.add("b", count=2) + assert s.get_multiplicity("b") == 3 + s.discard("b", count=1) + assert s.get_multiplicity("b") == 2 + s.discard("b") + assert list(s) == [] + + s.update(["c", "d"]) + assert list(s) == ["c", "d"] + s.update({"e": 3}) + assert list(s) == ["c", "d", "e"] + assert s.get_multiplicity("e") == 3 + + def test_sequence_operations(self): + s = OrderedSet(["z", "k"]) + assert s[0] == "z" + assert isinstance(s[1:], OrderedSet) + assert s[1:] == OrderedSet(["k"]) + assert len(s) == 2 + + s[0] = "a" + assert s[0] == "a" + assert s == OrderedSet(["a", "k"]) + + del s[0] + assert s[0] == "k" + assert s == OrderedSet(["k"]) + del s[0:] + assert len(s) == 0 + + s.insert(0, "a") + s.insert(0, "b") + assert s == ["b", "a"] + s.insert(0, "a") + assert s == ["a", "b"] + assert s.get_multiplicity("a") == 2 + s[10:12] = ["c", "d"] + assert s == ["a", "b", "c", "d"] + s[1:3] = ["d", "a"] + assert s == ["d", "a"] + assert s.get_multiplicity("a") == 3 + assert s.get_multiplicity("d") == 2 \ No newline at end of file