diff --git a/docsite/docs/guides/transforms.ipynb b/docsite/docs/guides/transforms.ipynb index 7b1f6dd..c41b7ba 100644 --- a/docsite/docs/guides/transforms.ipynb +++ b/docsite/docs/guides/transforms.ipynb @@ -235,8 +235,8 @@ "\n", "In Formulaic, a stateful transform is just a regular callable object (typically\n", "a function) that has an attribute `__is_stateful_transform__` that is set to\n", - "`True`. Such callables will be passed up to three additional arguments by\n", - "formulaic if they are present in the callable signature:\n", + "`True`. Such callables will be passed additional arguments by formulaic if they\n", + "are present in the callable signature:\n", "\n", "* `_state`: The existing state or an empty dictionary that should be mutated\n", " to record any additional state.\n", @@ -245,9 +245,12 @@ " populated.\n", "* `_spec`: The current model spec being evaluated (or an empty `ModelSpec` if\n", " being called outside of Formulaic's materialization routines).\n", + "* `_materializer`: The `FormulaMaterializer` instance for which the expression\n", + " is being evaluated.\n", + "* `_context`: A mapping of the name to value for all the variables available\n", + " in the formula evaluation context (including data column names).\n", "\n", - "Only `_state` is required, `_metadata` and `_spec` will only be passed in by \n", - "Formulaic if they are present in the callable signature." + "Typically all stateful transforms will use `_state`, but all are optional." ] }, { diff --git a/formulaic/materializers/base.py b/formulaic/materializers/base.py index 9b6e4e4..9959c3b 100644 --- a/formulaic/materializers/base.py +++ b/formulaic/materializers/base.py @@ -641,6 +641,7 @@ def _evaluate( {expr: metadata}, spec.transform_state, spec, + materializer=self, variables=variables, ), variables, diff --git a/formulaic/transforms/__init__.py b/formulaic/transforms/__init__.py index 8c59f32..7533a54 100644 --- a/formulaic/transforms/__init__.py +++ b/formulaic/transforms/__init__.py @@ -7,6 +7,7 @@ from .cubic_spline import cyclic_cubic_spline, natural_cubic_spline from .hashed import hashed from .identity import identity +from .interactions import i from .lag import lag from .patsy_compat import PATSY_COMPAT_TRANSFORMS from .poly import poly @@ -21,6 +22,7 @@ "C", "encode_contrasts", "ContrastsRegistry", + "i", "lag", "poly", "center", @@ -51,6 +53,7 @@ "contr": ContrastsRegistry, "I": identity, "hashed": hashed, + "i": i, # Patsy compatibility shims **PATSY_COMPAT_TRANSFORMS, } diff --git a/formulaic/transforms/contrasts.py b/formulaic/transforms/contrasts.py index e3ae0a4..f4a3cbe 100644 --- a/formulaic/transforms/contrasts.py +++ b/formulaic/transforms/contrasts.py @@ -40,6 +40,7 @@ def C( *, levels: Optional[Iterable[str]] = None, spans_intercept: bool = True, + reduce_rank: Optional[bool] = None, ) -> FactorValues: """ Mark data as being categorical, and optionally specify the contrasts to be @@ -78,7 +79,7 @@ def encoder( values, contrasts=contrasts, levels=levels, - reduced_rank=reduced_rank, + reduced_rank=reduce_rank if reduce_rank is not None else reduced_rank, _state=encoder_state, _spec=model_spec, ) @@ -86,7 +87,7 @@ def encoder( return FactorValues( data, kind="categorical", - spans_intercept=spans_intercept, + spans_intercept=not reduce_rank and spans_intercept, encoder=encoder, ) diff --git a/formulaic/transforms/interactions.py b/formulaic/transforms/interactions.py new file mode 100644 index 0000000..b6beeeb --- /dev/null +++ b/formulaic/transforms/interactions.py @@ -0,0 +1,101 @@ +from typing import TYPE_CHECKING, Any + +import narwhals.stable.v1 as narwhals + +from formulaic.materializers.types import FactorValues +from formulaic.transforms import stateful_transform +from formulaic.transforms.contrasts import C + +if TYPE_CHECKING: + from formulaic.model_spec import ModelSpec # pragma: no cover + + +@stateful_transform +def i(*args, _spec=None, _materializer=None): + """The 'interaction' transform, which creates interaction terms between non-null + combinations of the input arguments""" + + # TODO: Keep track of encoder state. + + if len(args) == 0: + return {} + + if not _materializer: + raise RuntimeError("The 'i' transform requires a materializer context") + + def encoder( + values: dict[str, Any], + reduced_rank: bool, + drop_rows: list[int], + encoder_state: dict[str, Any], + model_spec: ModelSpec, + ) -> FactorValues: + required_terms = narwhals.DataFrame.from_dict( + { + arg.name: narwhals.from_native( + arg, + series_only=True, + ) + for arg in args + } + ).unique().sort(list(values.keys())) + + encoded = {} + categorical_factors = set() + for name, arg in values.items(): + if isinstance(arg, FactorValues): + if arg.__formulaic_metadata__.encoder: + encoded[name] = arg.__formulaic_metadata__.encoder( + values=arg, reduced_rank=False, drop_rows=drop_rows, encoder_state={}, model_spec=_spec + ) + if not narwhals.dependencies.is_into_series(encoded[name]): + categorical_factors.add(name) + else: + encoded[name] = arg + elif _materializer._is_categorical(arg): + categorical_factors.add(name) + encoded[name] = dict( + C(arg).__formulaic_metadata__.encoder( + arg, reduced_rank=False, drop_rows=drop_rows, encoder_state={}, model_spec=_spec + ) + ) + else: + encoded[name] = arg + + out = {} + for row in required_terms.iter_rows(named=True): + factors = [] + for name, value in row.items(): + if name in categorical_factors: + if value not in encoded[name]: + break + factors.append( + { + getattr(values[name], "format", "{name}[{field}]").format(name=name, field=value): encoded[name][value] + } + ) + else: + factors.append({name: encoded[name]}) + else: + out.update( + _materializer._get_columns_for_term( + factors=factors, + spec=_spec, + ) + ) + + return FactorValues( + out, + format="{field}", + spans_intercept=False, + ) + + return FactorValues( + { + arg.name: arg + for arg in args + }, + kind="categorical", + spans_intercept=False, + encoder=encoder, + ) diff --git a/formulaic/utils/stateful_transforms.py b/formulaic/utils/stateful_transforms.py index bbbe5d2..0022b37 100644 --- a/formulaic/utils/stateful_transforms.py +++ b/formulaic/utils/stateful_transforms.py @@ -15,6 +15,7 @@ from .variables import Variable, get_expression_variables if TYPE_CHECKING: + from formulaic.materializers import FormulaMaterializer # pragma: no cover from formulaic.model_spec import ModelSpec # pragma: no cover @@ -26,6 +27,7 @@ def stateful_transform(func: Callable) -> Callable: - _state: The existing state or an empty dictionary. - _metadata: Any extra metadata passed about the factor being evaluated. - _spec: The `ModelSpec` instance being evaluated (or an empty `ModelSpec`). + - _materializer: The `FormulaMaterializer` instance being evaluated (or `None`). - _context: A mapping of the name to value for all the variables available in the formula evaluation context (including data column names). If the callable has any of these in its signature, these will be passed onto @@ -46,7 +48,7 @@ def stateful_transform(func: Callable) -> Callable: @functools.wraps(func) def wrapper( # type: ignore[no-untyped-def] - data, *args, _metadata=None, _state=None, _spec=None, _context=None, **kwargs + data, *args, _metadata=None, _state=None, _spec=None, _materializer=None, _context=None, **kwargs ): from formulaic.model_spec import ModelSpec @@ -56,8 +58,10 @@ def wrapper( # type: ignore[no-untyped-def] extra_params["_metadata"] = _metadata if "_spec" in params: extra_params["_spec"] = _spec or ModelSpec(formula=[]) + if "_materializer" in params: + extra_params["_materializer"] = _materializer if "_context" in params: - extra_params["_context"] = _context + extra_params["_context"] = _context or {} if isinstance(data, dict): results = {} @@ -91,6 +95,7 @@ def stateful_eval( metadata: Optional[Mapping], state: Optional[MutableMapping], spec: Optional["ModelSpec"], + materializer: Optional["FormulaMaterializer"] = None, variables: Optional[set[Variable]] = None, ) -> Any: """ @@ -111,6 +116,8 @@ def stateful_eval( stateful transforms). spec: The current `ModelSpec` instance being evaluated (passed through to stateful transforms). + materializer: The `FormulaMaterializer` instance for which the + expression is being evaluated. variables: A (optional) set of variables to update with the variables used in this stateful evaluation. @@ -170,6 +177,9 @@ def stateful_eval( node.keywords.append( ast.keyword("_spec", ast.parse("__FORMULAIC_SPEC__", mode="eval").body) ) + node.keywords.append( + ast.keyword("_materializer", ast.parse("__FORMULAIC_MATERIALIZER__", mode="eval").body) + ) # Compile mutated AST compiled = compile(ast.fix_missing_locations(code), "", "eval") @@ -179,6 +189,7 @@ def stateful_eval( "__FORMULAIC_METADATA__", "__FORMULAIC_STATE__", "__FORMULAIC_SPEC__", + "__FORMULAIC_MATERIALIZER__", }.intersection(env) if used_reserved: raise RuntimeError( @@ -196,6 +207,7 @@ def stateful_eval( "__FORMULAIC_METADATA__": metadata, "__FORMULAIC_SPEC__": spec, "__FORMULAIC_STATE__": state, + "__FORMULAIC_MATERIALIZER__": materializer, }, env, ),