Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions docsite/docs/guides/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@
"\n",
"In Formulaic, a stateful transform is just a regular callable object (typically\n",
"a function) that has an attribute `__is_stateful_transform__` that is set to\n",
"`True`. Such callables will be passed up to three additional arguments by\n",
"formulaic if they are present in the callable signature:\n",
"`True`. Such callables will be passed additional arguments by formulaic if they\n",
"are present in the callable signature:\n",
"\n",
"* `_state`: The existing state or an empty dictionary that should be mutated\n",
" to record any additional state.\n",
Expand All @@ -245,9 +245,12 @@
" populated.\n",
"* `_spec`: The current model spec being evaluated (or an empty `ModelSpec` if\n",
" being called outside of Formulaic's materialization routines).\n",
"* `_materializer`: The `FormulaMaterializer` instance for which the expression\n",
" is being evaluated.\n",
"* `_context`: A mapping of the name to value for all the variables available\n",
" in the formula evaluation context (including data column names).\n",
"\n",
"Only `_state` is required, `_metadata` and `_spec` will only be passed in by \n",
"Formulaic if they are present in the callable signature."
"Typically all stateful transforms will use `_state`, but all are optional."
]
},
{
Expand Down
1 change: 1 addition & 0 deletions formulaic/materializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def _evaluate(
{expr: metadata},
spec.transform_state,
spec,
materializer=self,
variables=variables,
),
variables,
Expand Down
3 changes: 3 additions & 0 deletions formulaic/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .cubic_spline import cyclic_cubic_spline, natural_cubic_spline
from .hashed import hashed
from .identity import identity
from .interactions import i
from .lag import lag
from .patsy_compat import PATSY_COMPAT_TRANSFORMS
from .poly import poly
Expand All @@ -21,6 +22,7 @@
"C",
"encode_contrasts",
"ContrastsRegistry",
"i",
"lag",
"poly",
"center",
Expand Down Expand Up @@ -51,6 +53,7 @@
"contr": ContrastsRegistry,
"I": identity,
"hashed": hashed,
"i": i,
# Patsy compatibility shims
**PATSY_COMPAT_TRANSFORMS,
}
5 changes: 3 additions & 2 deletions formulaic/transforms/contrasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def C(
*,
levels: Optional[Iterable[str]] = None,
spans_intercept: bool = True,
reduce_rank: Optional[bool] = None,
) -> FactorValues:
"""
Mark data as being categorical, and optionally specify the contrasts to be
Expand Down Expand Up @@ -78,15 +79,15 @@ def encoder(
values,
contrasts=contrasts,
levels=levels,
reduced_rank=reduced_rank,
reduced_rank=reduce_rank if reduce_rank is not None else reduced_rank,
_state=encoder_state,
_spec=model_spec,
)

return FactorValues(
data,
kind="categorical",
spans_intercept=spans_intercept,
spans_intercept=not reduce_rank and spans_intercept,
encoder=encoder,
)

Expand Down
101 changes: 101 additions & 0 deletions formulaic/transforms/interactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import TYPE_CHECKING, Any

import narwhals.stable.v1 as narwhals

from formulaic.materializers.types import FactorValues
from formulaic.transforms import stateful_transform
from formulaic.transforms.contrasts import C

if TYPE_CHECKING:
from formulaic.model_spec import ModelSpec # pragma: no cover


@stateful_transform
def i(*args, _spec=None, _materializer=None):
"""The 'interaction' transform, which creates interaction terms between non-null
combinations of the input arguments"""

# TODO: Keep track of encoder state.

if len(args) == 0:
return {}

if not _materializer:
raise RuntimeError("The 'i' transform requires a materializer context")

def encoder(
values: dict[str, Any],
reduced_rank: bool,
drop_rows: list[int],
encoder_state: dict[str, Any],
model_spec: ModelSpec,
) -> FactorValues:
required_terms = narwhals.DataFrame.from_dict(
{
arg.name: narwhals.from_native(
arg,
series_only=True,
)
for arg in args
}
).unique().sort(list(values.keys()))

encoded = {}
categorical_factors = set()
for name, arg in values.items():
if isinstance(arg, FactorValues):
if arg.__formulaic_metadata__.encoder:
encoded[name] = arg.__formulaic_metadata__.encoder(
values=arg, reduced_rank=False, drop_rows=drop_rows, encoder_state={}, model_spec=_spec
)
if not narwhals.dependencies.is_into_series(encoded[name]):
categorical_factors.add(name)
else:
encoded[name] = arg
elif _materializer._is_categorical(arg):
categorical_factors.add(name)
encoded[name] = dict(
C(arg).__formulaic_metadata__.encoder(
arg, reduced_rank=False, drop_rows=drop_rows, encoder_state={}, model_spec=_spec
)
)
else:
encoded[name] = arg

out = {}
for row in required_terms.iter_rows(named=True):
factors = []
for name, value in row.items():
if name in categorical_factors:
if value not in encoded[name]:
break
factors.append(
{
getattr(values[name], "format", "{name}[{field}]").format(name=name, field=value): encoded[name][value]
}
)
else:
factors.append({name: encoded[name]})
else:
out.update(
_materializer._get_columns_for_term(
factors=factors,
spec=_spec,
)
)

return FactorValues(
out,
format="{field}",
spans_intercept=False,
)

return FactorValues(
{
arg.name: arg
for arg in args
},
kind="categorical",
spans_intercept=False,
encoder=encoder,
)
16 changes: 14 additions & 2 deletions formulaic/utils/stateful_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .variables import Variable, get_expression_variables

if TYPE_CHECKING:
from formulaic.materializers import FormulaMaterializer # pragma: no cover
from formulaic.model_spec import ModelSpec # pragma: no cover


Expand All @@ -26,6 +27,7 @@ def stateful_transform(func: Callable) -> Callable:
- _state: The existing state or an empty dictionary.
- _metadata: Any extra metadata passed about the factor being evaluated.
- _spec: The `ModelSpec` instance being evaluated (or an empty `ModelSpec`).
- _materializer: The `FormulaMaterializer` instance being evaluated (or `None`).
- _context: A mapping of the name to value for all the variables available
in the formula evaluation context (including data column names).
If the callable has any of these in its signature, these will be passed onto
Expand All @@ -46,7 +48,7 @@ def stateful_transform(func: Callable) -> Callable:

@functools.wraps(func)
def wrapper( # type: ignore[no-untyped-def]
data, *args, _metadata=None, _state=None, _spec=None, _context=None, **kwargs
data, *args, _metadata=None, _state=None, _spec=None, _materializer=None, _context=None, **kwargs
):
from formulaic.model_spec import ModelSpec

Expand All @@ -56,8 +58,10 @@ def wrapper( # type: ignore[no-untyped-def]
extra_params["_metadata"] = _metadata
if "_spec" in params:
extra_params["_spec"] = _spec or ModelSpec(formula=[])
if "_materializer" in params:
extra_params["_materializer"] = _materializer
if "_context" in params:
extra_params["_context"] = _context
extra_params["_context"] = _context or {}

if isinstance(data, dict):
results = {}
Expand Down Expand Up @@ -91,6 +95,7 @@ def stateful_eval(
metadata: Optional[Mapping],
state: Optional[MutableMapping],
spec: Optional["ModelSpec"],
materializer: Optional["FormulaMaterializer"] = None,
variables: Optional[set[Variable]] = None,
) -> Any:
"""
Expand All @@ -111,6 +116,8 @@ def stateful_eval(
stateful transforms).
spec: The current `ModelSpec` instance being evaluated (passed through
to stateful transforms).
materializer: The `FormulaMaterializer` instance for which the
expression is being evaluated.
variables: A (optional) set of variables to update with the variables
used in this stateful evaluation.

Expand Down Expand Up @@ -170,6 +177,9 @@ def stateful_eval(
node.keywords.append(
ast.keyword("_spec", ast.parse("__FORMULAIC_SPEC__", mode="eval").body)
)
node.keywords.append(
ast.keyword("_materializer", ast.parse("__FORMULAIC_MATERIALIZER__", mode="eval").body)
)

# Compile mutated AST
compiled = compile(ast.fix_missing_locations(code), "", "eval")
Expand All @@ -179,6 +189,7 @@ def stateful_eval(
"__FORMULAIC_METADATA__",
"__FORMULAIC_STATE__",
"__FORMULAIC_SPEC__",
"__FORMULAIC_MATERIALIZER__",
}.intersection(env)
if used_reserved:
raise RuntimeError(
Expand All @@ -196,6 +207,7 @@ def stateful_eval(
"__FORMULAIC_METADATA__": metadata,
"__FORMULAIC_SPEC__": spec,
"__FORMULAIC_STATE__": state,
"__FORMULAIC_MATERIALIZER__": materializer,
},
env,
),
Expand Down
Loading