Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 26 additions & 62 deletions formulaic/formula.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import sys
from abc import ABCMeta, abstractmethod
from collections.abc import Generator, Iterable, Mapping, MutableSequence
from collections import Counter
from collections.abc import Generator, Iterable, Mapping
from enum import Enum
from typing import (
Any,
Expand All @@ -11,7 +11,6 @@
TypeVar,
Union,
cast,
overload,
)

from typing_extensions import Self, TypeAlias
Expand All @@ -21,9 +20,10 @@
from .errors import FormulaInvalidError
from .model_matrix import ModelMatrix
from .parser import DefaultFormulaParser
from .parser.types import FormulaParser, OrderedSet, Term
from .parser.types import FormulaParser, Term
from .utils.calculus import differentiate_term
from .utils.deprecations import deprecated
from .utils.ordered_set import OrderedSet
from .utils.structured import Structured
from .utils.variables import Variable, get_expression_variables

Expand Down Expand Up @@ -329,12 +329,12 @@ def differentiate(


class SimpleFormula(
MutableSequence[Term] if sys.version_info >= (3, 9) else MutableSequence, # type: ignore
OrderedSet[Term],
Formula,
):
"""
The atomic component of all formulae represented by Formulaic, which in turn
is a mutable sequence of `Term` instances. `StructuredFormula` uses
is a mutable ordered set of `Term` instances. `StructuredFormula` uses
`SimpleFormula` as its nodes.

Instances of this class can be used directly as a mutable sequence of
Expand All @@ -354,7 +354,9 @@ class SimpleFormula(

def __init__(
self,
root: Union[Iterable[Term], MissingType] = MISSING,
root: Union[
Iterable[Term], Mapping[Term, int], OrderedSet[Term], MissingType
] = MISSING,
*,
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
_parser: Optional[FormulaParser] = None,
Expand All @@ -375,13 +377,17 @@ def __init__(
"`SimpleFormula` does not support nested structure. To create a "
"structured formula, use `StructuredFormula` instead."
)
self.__terms = list(root)
self.ordering = OrderingMethod(_ordering)

self.__validate_terms(self.__terms)
self.__validate_terms(root)
super().__init__(root)

self._reorder()

def _prepare_item(self, item) -> Term:
self.__validate_terms([item])
return item

@classmethod
def __validate_terms(cls, terms: Any) -> None:
"""
Expand All @@ -392,6 +398,9 @@ def __validate_terms(cls, terms: Any) -> None:
f"All components of a `SimpleFormula` should be `Term` instances. Found: {repr(terms)}. To use formula strings, please use `Formula` or `StructuredFormula` instead."
)

def _post_update(self) -> None:
self._reorder()

def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None:
"""
Reorder the terms in this container in-place according to the specified
Expand All @@ -404,6 +413,8 @@ def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None:
"""
ordering = OrderingMethod(ordering if ordering is not None else self.ordering)
orderer = None

print(ordering)
if ordering is OrderingMethod.DEGREE:
orderer = lambda terms: sorted(terms, key=lambda term: term.degree)
elif ordering is OrderingMethod.SORT:
Expand All @@ -412,56 +423,9 @@ def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None:
)

if orderer is not None:
self.__terms = orderer(self.__terms)

# MutableSequence implementation

@overload
def __getitem__(self, key: int) -> Term: ...

@overload
def __getitem__(self, key: slice) -> SimpleFormula: ...

def __getitem__(self, key: Union[int, slice]) -> Union[Term, SimpleFormula]:
if isinstance(key, slice):
return self.__class__(self.__terms[key], _ordering=self.ordering)
else:
return self.__terms[key]

@overload
def __setitem__(self, key: int, value: Term) -> None: ...

@overload
def __setitem__(self, key: slice, value: Iterable[Term]) -> None: ...

def __setitem__(self, key, value): # type: ignore
self.__validate_terms([value])
self.__terms[key] = value
self._reorder()

@overload
def __delitem__(self, key: int) -> None: ...

@overload
def __delitem__(self, key: slice) -> None: ...

def __delitem__(self, key): # type: ignore
del self.__terms[key]

def __len__(self) -> int:
return len(self.__terms)

def insert(self, index: int, value: Term) -> None:
self.__validate_terms([value])
self.__terms.insert(index, value)
self._reorder()

def __eq__(self, other: Any) -> bool:
if isinstance(other, SimpleFormula):
other = list(other)
if isinstance(other, list):
return self.__terms == other
return NotImplemented
self._values = Counter(
{item: self._values[item] for item in orderer(self._values)}
)

# Transforms

Expand All @@ -486,7 +450,7 @@ def differentiate( # pylint: disable=redefined-builtin
return SimpleFormula(
[
differentiate_term(term, wrt, use_sympy=use_sympy)
for term in self.__terms
for term in self._values
],
# Preserve term ordering even if differentiation modifies degrees/etc.
_ordering=OrderingMethod.NONE,
Expand Down Expand Up @@ -536,7 +500,7 @@ def required_variables(self) -> set[Variable]:

variables: list[Variable] = [
variable
for term in self.__terms
for term in self._values
for factor in term.factors
for variable in get_expression_variables(factor.expr, {})
if "value" in variable.roles
Expand All @@ -554,7 +518,7 @@ def required_variables(self) -> set[Variable]:
)

def __repr__(self) -> str:
return " + ".join([str(t) for t in self.__terms])
return " + ".join([str(t) for t in self])

# Deprecated shims for legacy `Structured`-like behaviour (previously there
# was no distinction between `SimpleFormula` and `StructuredFormula`, and
Expand Down
2 changes: 1 addition & 1 deletion formulaic/materializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
from formulaic.materializers.types.factor_values import FactorValuesMetadata
from formulaic.model_matrix import ModelMatrices, ModelMatrix
from formulaic.parser.types import Factor, Term
from formulaic.parser.types.ordered_set import OrderedSet
from formulaic.transforms import TRANSFORMS
from formulaic.utils.cast import as_columns
from formulaic.utils.layered_mapping import LayeredMapping
from formulaic.utils.null_handling import find_nulls
from formulaic.utils.ordered_set import OrderedSet
from formulaic.utils.stateful_transforms import stateful_eval
from formulaic.utils.variables import Variable

Expand Down
2 changes: 1 addition & 1 deletion formulaic/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from formulaic.errors import FormulaParsingError
from formulaic.utils.layered_mapping import LayeredMapping
from formulaic.utils.ordered_set import OrderedSet
from formulaic.utils.structured import Structured

from .algos.sanitize_tokens import sanitize_tokens
Expand All @@ -27,7 +28,6 @@
FormulaParser,
Operator,
OperatorResolver,
OrderedSet,
Term,
Token,
)
Expand Down
2 changes: 0 additions & 2 deletions formulaic/parser/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .formula_parser import FormulaParser
from .operator import Operator
from .operator_resolver import OperatorResolver
from .ordered_set import OrderedSet
from .term import Term
from .token import Token

Expand All @@ -13,7 +12,6 @@
"FormulaParser",
"Operator",
"OperatorResolver",
"OrderedSet",
"Term",
"Token",
]
2 changes: 1 addition & 1 deletion formulaic/parser/types/ast_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
Union,
)

from formulaic.utils.ordered_set import OrderedSet
from formulaic.utils.structured import Structured

from .operator import Operator
from .ordered_set import OrderedSet
from .term import Term

ItemType = TypeVar("ItemType")
Expand Down
3 changes: 2 additions & 1 deletion formulaic/parser/types/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union

from .ordered_set import OrderedSet
from formulaic.utils.ordered_set import OrderedSet

from .term import Term

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion formulaic/parser/types/formula_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
overload,
)

from formulaic.parser.types.ordered_set import OrderedSet
from formulaic.utils.layered_mapping import LayeredMapping
from formulaic.utils.ordered_set import OrderedSet
from formulaic.utils.structured import Structured

from .ast_node import ASTNode
Expand Down
28 changes: 0 additions & 28 deletions formulaic/parser/types/ordered_set.py

This file was deleted.

2 changes: 1 addition & 1 deletion formulaic/parser/types/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any, Optional

from .ordered_set import OrderedSet
from formulaic.utils.ordered_set import OrderedSet

if TYPE_CHECKING:
from .factor import Factor # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion formulaic/parser/types/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from enum import Enum
from typing import Any, Optional, Union

from formulaic.utils.ordered_set import OrderedSet
from formulaic.utils.variables import Variable, get_expression_variables

from .factor import Factor
from .ordered_set import OrderedSet
from .term import Term


Expand Down
2 changes: 1 addition & 1 deletion formulaic/utils/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import cast

from formulaic.parser.types import Factor, Term
from formulaic.parser.types.ordered_set import OrderedSet
from formulaic.utils.ordered_set import OrderedSet


def differentiate_term(
Expand Down
Loading