Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions formulaic/transforms/contrasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ class TreatmentContrasts(Contrasts):
FACTOR_FORMAT_REDUCED = "{name}[T.{field}]"

base: Hashable = UNSET
drop: bool = False

@Contrasts.override
def _apply(
Expand All @@ -503,7 +504,7 @@ def _apply(
reduced_rank: bool = True,
sparse: bool = False,
) -> Union[pandas.DataFrame, numpy.ndarray, spsparse.spmatrix]:
if reduced_rank:
if reduced_rank or self.drop:
drop_index = self._find_base_index(levels)
mask = numpy.ones(len(levels), dtype=bool)
mask[drop_index] = False
Expand Down Expand Up @@ -536,7 +537,7 @@ def _get_coding_matrix(
matrix = spsparse.eye(n).tocsc()
else:
matrix = numpy.eye(n)
if reduced_rank:
if reduced_rank or self.drop:
drop_level = self._find_base_index(levels)
matrix = matrix[:, [i for i in range(matrix.shape[1]) if i != drop_level]]
return matrix
Expand All @@ -546,7 +547,7 @@ def get_coding_column_names(
self, levels: Sequence[Hashable], reduced_rank: bool = True
) -> Sequence[Hashable]:
base_index = self._find_base_index(levels)
if reduced_rank:
if reduced_rank or self.drop:
return [level for i, level in enumerate(levels) if i != base_index]
return levels

Expand All @@ -555,7 +556,7 @@ def get_coefficient_row_names(
self, levels: Sequence[Hashable], reduced_rank: bool = True
) -> Sequence[Hashable]:
base = levels[self._find_base_index(levels)]
if reduced_rank:
if reduced_rank or self.drop:
return [base, *(f"{level}-{base}" for level in levels if level != base)]
return levels

Expand Down