From 0a6f128c88b3b020424823a6babac7a8ec56f4a5 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 11:41:54 -0500 Subject: [PATCH 01/10] feat: Enhance tests for MixedIntegerProgramExplainer and ConstraintProgrammingExplainer with XGBoost integration --- tests/test_explainer.py | 159 +++++++++++++++++++++++++++++++++------- 1 file changed, 132 insertions(+), 27 deletions(-) diff --git a/tests/test_explainer.py b/tests/test_explainer.py index a34bd19..97f5a3b 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -1,6 +1,8 @@ import gurobipy as gp +import numpy as np import pytest from sklearn.ensemble import RandomForestClassifier +from xgboost import XGBClassifier from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer @@ -32,21 +34,73 @@ def test_mip_explain( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())) + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + assert model.Status == gp.GRB.OPTIMAL + assert len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() - try: - model.explain(x, y=0, norm=1, - num_workers=num_workers, - random_seed=seed) - assert model.Status == gp.GRB.OPTIMAL - model.cleanup() - model.explain(x, y=0, norm=1, - return_callback=True, - num_workers=num_workers, - random_seed=seed) - assert len(model.callback.sollist) != 0 + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") - except gp.GurobiError as e: - pytest.skip(f"Skipping test due to {e}") + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_estimators", [5]) +@pytest.mark.parametrize("max_depth", [2, 3]) +@pytest.mark.parametrize("n_classes", [2, 3, 4]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_mip_explain_xgb( + seed: int, + n_estimators: int, + max_depth: int, + n_classes: int, + n_samples: int, + num_workers: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = XGBClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + model = MixedIntegerProgramExplainer(clf, mapper=mapper, env=ENV) + + x = data.iloc[0, :].to_numpy().astype(float).flatten() + # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())) + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + model.cleanup() + assert model.Status == gp.GRB.OPTIMAL + assert len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") @pytest.mark.parametrize("seed", [42, 43, 44]) @@ -74,18 +128,69 @@ def test_cp_explain( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())) + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + assert model.get_solving_status() == "OPTIMAL" + assert model.callback is None or len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") + + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_estimators", [5]) +@pytest.mark.parametrize("max_depth", [2, 3]) +@pytest.mark.parametrize("n_classes", [2, 3, 4]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_cp_explain_xgb( + seed: int, + n_estimators: int, + max_depth: int, + n_classes: int, + n_samples: int, + num_workers: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = XGBClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + model = ConstraintProgrammingExplainer(clf, mapper=mapper) + + x = data.iloc[0, :].to_numpy().astype(float).flatten() + # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())) + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + assert model.get_solving_status() == "OPTIMAL" + assert model.callback is None or len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() - try: - _ = model.explain(x, y=0, norm=1, - return_callback=False, - num_workers=num_workers, - random_seed=seed) - assert model.callback is None or len(model.callback.sollist) == 0 - model.cleanup() - _ = model.explain(x, y=0, norm=1, - return_callback=True, - num_workers=num_workers, - random_seed=seed) - assert model.callback is None or len(model.callback.sollist) != 0 - except gp.GurobiError as e: - pytest.skip(f"Skipping test due to {e}") + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") From 8edd1a83b2d63288abbaa5b2058b106c908440a2 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 11:42:47 -0500 Subject: [PATCH 02/10] feat: Add base score probability handling and improve XGBoost tree parsing --- ocean/tree/_parse_xgb.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index 4761bab..65ad63d 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -1,3 +1,4 @@ +import json from collections.abc import Iterable import numpy as np @@ -27,12 +28,11 @@ def _build_xgb_leaf( weight = float(_get_column_value(xgb_tree, node_id, "Gain")) if num_trees_per_round == 1: - value = np.array([[weight, -weight]]) + value = np.array([[0.0, weight]]) else: k = int(tree_id % num_trees_per_round) value = np.zeros((1, int(num_trees_per_round)), dtype=float) - value[0, k] = weight - + value[0, k] += weight return Node(node_id, n_samples=0, value=value) @@ -90,7 +90,7 @@ def _build_xgb_node( threshold = None if mapper[name].is_numeric: - threshold = float(_get_column_value(xgb_tree, node_id, "Split")) + threshold = float(_get_column_value(xgb_tree, node_id, "Split")) - 1e-8 mapper[name].add(threshold) left_id = _get_child_id(xgb_tree, node_id, "Yes") @@ -150,6 +150,7 @@ def _parse_xgb_tree( tree_id: NonNegativeInt, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], + base_score_prob: float = 0.0, ) -> Tree: root = _parse_xgb_node( xgb_tree, @@ -158,7 +159,10 @@ def _parse_xgb_tree( num_trees_per_round=num_trees_per_round, mapper=mapper, ) - return Tree(root=root) + tree = Tree(root=root) + tree.logit = np.log(base_score_prob / (1 - base_score_prob)) + tree.xgboost = True + return tree def parse_xgb_tree( @@ -167,12 +171,14 @@ def parse_xgb_tree( tree_id: NonNegativeInt, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], + base_score_prob: float = 0.0, ) -> Tree: return _parse_xgb_tree( xgb_tree, tree_id=tree_id, num_trees_per_round=num_trees_per_round, mapper=mapper, + base_score_prob=base_score_prob, ) @@ -181,6 +187,7 @@ def parse_xgb_trees( *, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], + base_score_prob: float = 0.0, ) -> tuple[Tree, ...]: return tuple( parse_xgb_tree( @@ -188,6 +195,7 @@ def parse_xgb_trees( tree_id=tree_id, num_trees_per_round=num_trees_per_round, mapper=mapper, + base_score_prob=base_score_prob, ) for tree_id, tree in enumerate(trees) ) @@ -197,6 +205,7 @@ def parse_xgb_ensemble( ensemble: xgb.Booster, *, mapper: Mapper[Feature] ) -> tuple[Tree, ...]: df = ensemble.trees_to_dataframe() + cfg = json.loads(ensemble.save_config()) groups = df.groupby("Tree") trees = tuple( groups.get_group(tree_id).reset_index(drop=True) @@ -205,7 +214,10 @@ def parse_xgb_ensemble( num_rounds = ensemble.num_boosted_rounds() or 1 num_trees_per_round = max(1, len(trees) // num_rounds) - + base_score_prob = float(cfg["learner"]["learner_model_param"]["base_score"]) return parse_xgb_trees( - trees, num_trees_per_round=num_trees_per_round, mapper=mapper + trees, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + base_score_prob=base_score_prob, ) From 19c88748891f60fb9d5f7fa530f601b54d4631fa Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 11:43:30 -0500 Subject: [PATCH 03/10] feat: Add logit management and xgboost property to the Tree class --- ocean/tree/_tree.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/ocean/tree/_tree.py b/ocean/tree/_tree.py index 649d2f8..ae3d71b 100644 --- a/ocean/tree/_tree.py +++ b/ocean/tree/_tree.py @@ -9,6 +9,7 @@ class Tree: root: Node _shape: tuple[NonNegativeInt, ...] + _xgboost: bool = False def __init__(self, root: Node) -> None: self.root = root @@ -30,6 +31,22 @@ def leaves(self) -> tuple[Node, *tuple[Node, ...]]: def shape(self) -> tuple[NonNegativeInt, ...]: return self._shape + @property + def logit(self) -> float: + return self._base_score_prob + + @logit.setter + def logit(self, value: float) -> None: + self._base_score_prob = value + + @property + def xgboost(self) -> bool: + return self._xgboost + + @xgboost.setter + def xgboost(self, value: bool) -> None: + self._xgboost = value + @validate_call def nodes_at(self, depth: NonNegativeInt) -> Iterator[Node]: return self._nodes_at(self.root, depth=depth) From 4cf89189a79df2c9feb63fb6dba57e329cd6bffb Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 11:46:41 -0500 Subject: [PATCH 04/10] feat: Introduce basic score management and XGBoost tree support in CP TreeManager --- ocean/cp/_managers/_tree.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ocean/cp/_managers/_tree.py b/ocean/cp/_managers/_tree.py index 28d4397..27bdeea 100644 --- a/ocean/cp/_managers/_tree.py +++ b/ocean/cp/_managers/_tree.py @@ -29,6 +29,12 @@ class TreeManager: # Scale for the scores. _score_scale: int = DEFAULT_SCORE_SCALE + # Base score for the ensemble. + _logit: float = 0.0 + + # Flag to indicate if the model is using XGBoost trees. + _xgboost: bool = False + def __init__( self, trees: Iterable[Tree], @@ -89,6 +95,9 @@ def _set_trees( ) -> None: def create(item: tuple[int, Tree]) -> TreeVar: t, tree = item + if tree.xgboost: + self._logit = tree.logit + self._xgboost = tree.xgboost name = self.TREE_VAR_FMT.format(t=t) return TreeVar(tree, name=name) @@ -131,6 +140,8 @@ def weighted_function( tree_exprs.append(tree_expr) tree_weights.append(int(weight)) expr = cp.LinearExpr.WeightedSum(tree_exprs, tree_weights) + if self._xgboost and n_classes == 2 and c == 1: # noqa: PLR2004 + expr += int(self._logit * scale) exprs[op, c] = expr return exprs From 48186565a14797d7c6599d7b042a3aef7525d279 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 11:47:03 -0500 Subject: [PATCH 05/10] feat: Introduce basic score management and XGBoost tree support in MIP TreeManager --- ocean/mip/_managers/_tree.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/ocean/mip/_managers/_tree.py b/ocean/mip/_managers/_tree.py index 462c917..fc927a4 100644 --- a/ocean/mip/_managers/_tree.py +++ b/ocean/mip/_managers/_tree.py @@ -36,6 +36,12 @@ class TreeManager: # Function of the ensemble. _function: gp.MLinExpr + # Base score for the ensemble. + _logit: float = 0.0 + + # Flag to indicate if the model is using XGBoost trees. + _xgboost: bool = False + def __init__( self, trees: Iterable[Tree], @@ -120,6 +126,19 @@ def weighted(tree: TreeVar, weight: float) -> gp.MLinExpr: zeros = gp.MLinExpr.zeros(self.shape) return sum(map(weighted, self.estimators, weights), zeros) + def xgb_margin_function( + self, + weights: NonNegativeArray1D, + ) -> gp.MLinExpr: + margin_values = gp.MLinExpr.zeros(self.shape) + if self.n_classes == 2: # noqa: PLR2004 + margin_values += weights[0] * self._logit * np.array([[0.0, 1.0]]) + + for tree, weight in zip(self.estimators, weights, strict=True): + margin_values += weight * tree.value + + return margin_values + def _set_trees( self, trees: Iterable[Tree], @@ -128,6 +147,9 @@ def _set_trees( ) -> None: def create(item: tuple[int, Tree]) -> TreeVar: t, tree = item + if tree.xgboost: + self._logit = tree.logit + self._xgboost = tree.xgboost name = self.TREE_VAR_FMT.format(t=t) return TreeVar(tree, name=name, flow_type=flow_type) @@ -152,4 +174,6 @@ def _get_length(self) -> gp.LinExpr: return sum((tree.length for tree in self.isolators), gp.LinExpr()) def _get_function(self) -> gp.MLinExpr: + if self._xgboost: + return self.xgb_margin_function(weights=self.weights) return self.weighted_function(weights=self.weights) From 271500d53de2b9980eb69077ba3a493a5db72521 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 11:53:11 -0500 Subject: [PATCH 06/10] =?UTF-8?q?[XGBoost]=20OCEAN=20returns=20invalid=20c?= =?UTF-8?q?ounterfactuals=20(prediction=20doesn=E2=80=99t=20flip)=20Fixes?= =?UTF-8?q?=20#30?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/tree/test_parse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tree/test_parse.py b/tests/tree/test_parse.py index 4b5a9b0..817e467 100644 --- a/tests/tree/test_parse.py +++ b/tests/tree/test_parse.py @@ -75,7 +75,7 @@ def _dfs(node: Node) -> None: feature_name = str(row["Feature"].values[0]).strip() if feature.is_numeric: assert feature_name == node.feature - assert node.threshold == float(row["Split"].values[0]) + assert node.threshold == float(row["Split"].values[0]) - 1e-8 if feature.is_one_hot_encoded: assert feature_name == f"{node.feature} {node.code}" assert node.code in feature.codes From 7bfbaab852ca14963941947ef6aee880cae9e89b Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 12:02:09 -0500 Subject: [PATCH 07/10] fix: Correct threshold comparison in XGBoost tree parsing test and bypass pyright reportUnknownArgumentType --- tests/test_explainer.py | 8 ++++---- tests/tree/test_parse.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_explainer.py b/tests/test_explainer.py index 97f5a3b..0362c0d 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -35,7 +35,7 @@ def test_mip_explain( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] y = clf.predict([x])[0] - classes = np.unique(clf.predict(data.to_numpy())) + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] for target in classes[classes != y]: try: exp = model.explain( @@ -82,7 +82,7 @@ def test_mip_explain_xgb( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] y = clf.predict([x])[0] - classes = np.unique(clf.predict(data.to_numpy())) + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] for target in classes[classes != y]: try: exp = model.explain( @@ -129,7 +129,7 @@ def test_cp_explain( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] y = clf.predict([x])[0] - classes = np.unique(clf.predict(data.to_numpy())) + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] for target in classes[classes != y]: try: exp = model.explain( @@ -175,7 +175,7 @@ def test_cp_explain_xgb( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] y = clf.predict([x])[0] - classes = np.unique(clf.predict(data.to_numpy())) + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] for target in classes[classes != y]: try: exp = model.explain( diff --git a/tests/tree/test_parse.py b/tests/tree/test_parse.py index 817e467..ec68d51 100644 --- a/tests/tree/test_parse.py +++ b/tests/tree/test_parse.py @@ -75,7 +75,7 @@ def _dfs(node: Node) -> None: feature_name = str(row["Feature"].values[0]).strip() if feature.is_numeric: assert feature_name == node.feature - assert node.threshold == float(row["Split"].values[0]) - 1e-8 + assert node.threshold == float(row["Split"].values[0] - 1e-8) if feature.is_one_hot_encoded: assert feature_name == f"{node.feature} {node.code}" assert node.code in feature.codes From dbaf25443a1baa3322f055f9cf43a44e559896f8 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 15:28:33 -0500 Subject: [PATCH 08/10] feat: Refactor TreeManager to use Array1D for logit and improve XGBoost handling --- ocean/cp/_managers/_tree.py | 20 ++++++++++++-------- ocean/mip/_managers/_tree.py | 12 ++++++++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/ocean/cp/_managers/_tree.py b/ocean/cp/_managers/_tree.py index 27bdeea..322a57b 100644 --- a/ocean/cp/_managers/_tree.py +++ b/ocean/cp/_managers/_tree.py @@ -4,11 +4,7 @@ from ortools.sat.python import cp_model as cp from ...tree import Tree -from ...typing import ( - NonNegativeArray1D, - NonNegativeInt, - PositiveInt, -) +from ...typing import Array1D, NonNegativeArray1D, NonNegativeInt, PositiveInt from .._base import BaseModel from .._variables import TreeVar @@ -16,6 +12,8 @@ class TreeManager: TREE_VAR_FMT: str = "tree[{t}]" DEFAULT_SCORE_SCALE: int = int(1e10) + XGBOOST_DEFAULT_CLASS: int = 1 + NUM_BINARY_CLASS: int = 2 # Tree variables in the ensemble. _trees: tuple[TreeVar, *tuple[TreeVar, ...]] @@ -30,7 +28,7 @@ class TreeManager: _score_scale: int = DEFAULT_SCORE_SCALE # Base score for the ensemble. - _logit: float = 0.0 + _logit: Array1D # Flag to indicate if the model is using XGBoost trees. _xgboost: bool = False @@ -140,8 +138,14 @@ def weighted_function( tree_exprs.append(tree_expr) tree_weights.append(int(weight)) expr = cp.LinearExpr.WeightedSum(tree_exprs, tree_weights) - if self._xgboost and n_classes == 2 and c == 1: # noqa: PLR2004 - expr += int(self._logit * scale) + if self._xgboost: + if ( + n_classes == self.NUM_BINARY_CLASS + and c == self.XGBOOST_DEFAULT_CLASS + ): + expr += int(self._logit[0] * scale) + elif n_classes > self.NUM_BINARY_CLASS: + expr += int(self._logit[c] * scale) exprs[op, c] = expr return exprs diff --git a/ocean/mip/_managers/_tree.py b/ocean/mip/_managers/_tree.py index fc927a4..c0a5c49 100644 --- a/ocean/mip/_managers/_tree.py +++ b/ocean/mip/_managers/_tree.py @@ -6,6 +6,7 @@ from ...tree import Tree from ...tree._utils import average_length from ...typing import ( + Array1D, NonNegativeArray1D, NonNegativeInt, NonNegativeNumber, @@ -17,6 +18,7 @@ class TreeManager: TREE_VAR_FMT: str = "tree[{t}]" + NUM_BINARY_CLASS: int = 2 # Tree variables in the ensemble. _trees: tuple[TreeVar, *tuple[TreeVar, ...]] @@ -37,7 +39,7 @@ class TreeManager: _function: gp.MLinExpr # Base score for the ensemble. - _logit: float = 0.0 + _logit: Array1D # Flag to indicate if the model is using XGBoost trees. _xgboost: bool = False @@ -130,9 +132,11 @@ def xgb_margin_function( self, weights: NonNegativeArray1D, ) -> gp.MLinExpr: - margin_values = gp.MLinExpr.zeros(self.shape) - if self.n_classes == 2: # noqa: PLR2004 - margin_values += weights[0] * self._logit * np.array([[0.0, 1.0]]) + margin_values = gp.MLinExpr.zeros(self.shape) + self._logit + if self.n_classes == self.NUM_BINARY_CLASS: + margin_values += ( + weights[0] * self._logit[0] * np.array([[0.0, 1.0]]) + ) for tree, weight in zip(self.estimators, weights, strict=True): margin_values += weight * tree.value From e6c18917c9da3d7fb6b91c57b0f6eb74b8aea7ce Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 15:29:50 -0500 Subject: [PATCH 09/10] feat: Add basic score management and improve logit calculation in XGBoost tree parsing to comply with the xgboost version. --- ocean/tree/_parse_xgb.py | 35 ++++++++++++++++++++++++++--------- ocean/tree/_tree.py | 6 +++--- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index 65ad63d..6f74283 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -1,16 +1,29 @@ import json from collections.abc import Iterable +from typing import Any import numpy as np import xgboost as xgb from ..abc import Mapper from ..feature import Feature -from ..typing import NonNegativeInt, XGBTree +from ..typing import Array1D, NonNegativeInt, XGBTree from ._node import Node from ._tree import Tree +def _parse_base_score(cfg: dict[str, Any]) -> Array1D: + base_score_str = cfg["learner"]["learner_model_param"]["base_score"] + if isinstance(base_score_str, float): + return np.array([float(base_score_str)]) + return np.array([float(s) for s in json.loads(base_score_str)]) + + +def _logit(p: Array1D) -> Array1D: + p = np.clip(p, 1e-12, 1 - 1e-12) + return np.log(p / (1 - p)) + + def _get_column_value( xgb_tree: XGBTree, node_id: NonNegativeInt, column: str ) -> str | float | int: @@ -150,7 +163,7 @@ def _parse_xgb_tree( tree_id: NonNegativeInt, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], - base_score_prob: float = 0.0, + base_score_margin: Array1D, ) -> Tree: root = _parse_xgb_node( xgb_tree, @@ -160,7 +173,11 @@ def _parse_xgb_tree( mapper=mapper, ) tree = Tree(root=root) - tree.logit = np.log(base_score_prob / (1 - base_score_prob)) + tree.logit = ( + base_score_margin + if len(base_score_margin) > 1 + else _logit(base_score_margin) + ) tree.xgboost = True return tree @@ -171,14 +188,14 @@ def parse_xgb_tree( tree_id: NonNegativeInt, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], - base_score_prob: float = 0.0, + base_score_margin: Array1D, ) -> Tree: return _parse_xgb_tree( xgb_tree, tree_id=tree_id, num_trees_per_round=num_trees_per_round, mapper=mapper, - base_score_prob=base_score_prob, + base_score_margin=base_score_margin, ) @@ -187,7 +204,7 @@ def parse_xgb_trees( *, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], - base_score_prob: float = 0.0, + base_score_margin: Array1D, ) -> tuple[Tree, ...]: return tuple( parse_xgb_tree( @@ -195,7 +212,7 @@ def parse_xgb_trees( tree_id=tree_id, num_trees_per_round=num_trees_per_round, mapper=mapper, - base_score_prob=base_score_prob, + base_score_margin=base_score_margin, ) for tree_id, tree in enumerate(trees) ) @@ -214,10 +231,10 @@ def parse_xgb_ensemble( num_rounds = ensemble.num_boosted_rounds() or 1 num_trees_per_round = max(1, len(trees) // num_rounds) - base_score_prob = float(cfg["learner"]["learner_model_param"]["base_score"]) + base_score_margin = _parse_base_score(cfg) return parse_xgb_trees( trees, num_trees_per_round=num_trees_per_round, mapper=mapper, - base_score_prob=base_score_prob, + base_score_margin=base_score_margin, ) diff --git a/ocean/tree/_tree.py b/ocean/tree/_tree.py index ae3d71b..0fa6c35 100644 --- a/ocean/tree/_tree.py +++ b/ocean/tree/_tree.py @@ -2,7 +2,7 @@ from pydantic import validate_call -from ..typing import NonNegativeInt, PositiveInt +from ..typing import Array1D, NonNegativeInt, PositiveInt from ._node import Node @@ -32,11 +32,11 @@ def shape(self) -> tuple[NonNegativeInt, ...]: return self._shape @property - def logit(self) -> float: + def logit(self) -> Array1D: return self._base_score_prob @logit.setter - def logit(self, value: float) -> None: + def logit(self, value: Array1D) -> None: self._base_score_prob = value @property From b21aa2233e11997e0e9a319eee2410e9f6ea8838 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 11 Nov 2025 15:33:57 -0500 Subject: [PATCH 10/10] fix: Correctly parse base score values in XGBoost configuration --- ocean/tree/_parse_xgb.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index 6f74283..8969898 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -13,10 +13,12 @@ def _parse_base_score(cfg: dict[str, Any]) -> Array1D: - base_score_str = cfg["learner"]["learner_model_param"]["base_score"] - if isinstance(base_score_str, float): - return np.array([float(base_score_str)]) - return np.array([float(s) for s in json.loads(base_score_str)]) + base_score_values = json.loads( + cfg["learner"]["learner_model_param"]["base_score"] + ) + if isinstance(base_score_values, float): + return np.array([float(base_score_values)]) + return np.array([float(s) for s in base_score_values]) def _logit(p: Array1D) -> Array1D: