Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
Changelog
=========

3.2.1 - unreleased
3.2.1 - 2026-03-16
------------------

**Bug fix:**

- Fixed an error when predicting at a specific ``alpha`` with categorical features.

**Other changes:**

- Downgraded log messages in ``align_df_categories`` and ``add_missing_categories`` from INFO to DEBUG, and deduplicated them so they are emitted only once per column per fitted model.
Expand Down
28 changes: 15 additions & 13 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,20 +889,22 @@ def _compute_linear_predictor(
)

if alpha_index is None:
coef = coef_path
intercept = intercept_path
xb = X @ coef_path + intercept_path
if offset is not None:
xb += offset
elif np.isscalar(alpha_index): # `None` doesn't qualify
xb = X @ coef_path[alpha_index] + intercept_path[alpha_index] # type: ignore
if offset is not None:
xb += offset
else:
scalar = np.isscalar(alpha_index)
alpha_index = np.atleast_1d(alpha_index) # type: ignore[assignment]
coef = coef_path[alpha_index] # type: ignore
intercept = intercept_path[alpha_index] # type: ignore

xb = X @ coef.T + intercept
if offset is not None:
offset = np.asanyarray(offset)
xb += offset if xb.ndim == 1 else offset[:, np.newaxis] # type: ignore[call-overload]

return xb.squeeze() if alpha_index is None or scalar else xb
_xb = []
for idx in alpha_index: # type: ignore
_xb.append(X @ coef_path[idx] + intercept_path[idx]) # type: ignore
xb = np.stack(_xb, axis=1)
if offset is not None:
xb += np.asanyarray(offset)[:, np.newaxis]

return xb

def predict(
self,
Expand Down
18 changes: 18 additions & 0 deletions tests/glm/test_glm_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,24 @@ def test_predict_list(regression_data, alpha, alpha_index):
np.testing.assert_allclose(candidate, target + 1)


def test_predict_list_categorical():

letters = ["a", "b", "c", "d", "e", "f"]
rng = np.random.default_rng(42)

df = pd.DataFrame({"x": rng.choice(letters, size=100)})

df["x"] = df["x"].astype("category")
df["y"] = df["x"].map({v: k + 1 for k, v in enumerate(letters)})

regressor = GeneralizedLinearRegressor(alpha=[0, 2], alpha_search=True)
regressor = regressor.fit(df[["x"]], df["y"])

candidate = regressor.predict(df[["x"]], alpha=0)

np.testing.assert_allclose(candidate, df["y"])


def test_predict_error(regression_data):
X, y = regression_data

Expand Down
Loading