diff --git a/ocean/cp/_managers/_tree.py b/ocean/cp/_managers/_tree.py index 28d4397..322a57b 100644 --- a/ocean/cp/_managers/_tree.py +++ b/ocean/cp/_managers/_tree.py @@ -4,11 +4,7 @@ from ortools.sat.python import cp_model as cp from ...tree import Tree -from ...typing import ( - NonNegativeArray1D, - NonNegativeInt, - PositiveInt, -) +from ...typing import Array1D, NonNegativeArray1D, NonNegativeInt, PositiveInt from .._base import BaseModel from .._variables import TreeVar @@ -16,6 +12,8 @@ class TreeManager: TREE_VAR_FMT: str = "tree[{t}]" DEFAULT_SCORE_SCALE: int = int(1e10) + XGBOOST_DEFAULT_CLASS: int = 1 + NUM_BINARY_CLASS: int = 2 # Tree variables in the ensemble. _trees: tuple[TreeVar, *tuple[TreeVar, ...]] @@ -29,6 +27,12 @@ class TreeManager: # Scale for the scores. _score_scale: int = DEFAULT_SCORE_SCALE + # Base score for the ensemble. + _logit: Array1D + + # Flag to indicate if the model is using XGBoost trees. + _xgboost: bool = False + def __init__( self, trees: Iterable[Tree], @@ -89,6 +93,9 @@ def _set_trees( ) -> None: def create(item: tuple[int, Tree]) -> TreeVar: t, tree = item + if tree.xgboost: + self._logit = tree.logit + self._xgboost = tree.xgboost name = self.TREE_VAR_FMT.format(t=t) return TreeVar(tree, name=name) @@ -131,6 +138,14 @@ def weighted_function( tree_exprs.append(tree_expr) tree_weights.append(int(weight)) expr = cp.LinearExpr.WeightedSum(tree_exprs, tree_weights) + if self._xgboost: + if ( + n_classes == self.NUM_BINARY_CLASS + and c == self.XGBOOST_DEFAULT_CLASS + ): + expr += int(self._logit[0] * scale) + elif n_classes > self.NUM_BINARY_CLASS: + expr += int(self._logit[c] * scale) exprs[op, c] = expr return exprs diff --git a/ocean/mip/_managers/_tree.py b/ocean/mip/_managers/_tree.py index 462c917..c0a5c49 100644 --- a/ocean/mip/_managers/_tree.py +++ b/ocean/mip/_managers/_tree.py @@ -6,6 +6,7 @@ from ...tree import Tree from ...tree._utils import average_length from ...typing import ( + Array1D, NonNegativeArray1D, NonNegativeInt, NonNegativeNumber, @@ -17,6 +18,7 @@ class TreeManager: TREE_VAR_FMT: str = "tree[{t}]" + NUM_BINARY_CLASS: int = 2 # Tree variables in the ensemble. _trees: tuple[TreeVar, *tuple[TreeVar, ...]] @@ -36,6 +38,12 @@ class TreeManager: # Function of the ensemble. _function: gp.MLinExpr + # Base score for the ensemble. + _logit: Array1D + + # Flag to indicate if the model is using XGBoost trees. + _xgboost: bool = False + def __init__( self, trees: Iterable[Tree], @@ -120,6 +128,21 @@ def weighted(tree: TreeVar, weight: float) -> gp.MLinExpr: zeros = gp.MLinExpr.zeros(self.shape) return sum(map(weighted, self.estimators, weights), zeros) + def xgb_margin_function( + self, + weights: NonNegativeArray1D, + ) -> gp.MLinExpr: + margin_values = gp.MLinExpr.zeros(self.shape) + self._logit + if self.n_classes == self.NUM_BINARY_CLASS: + margin_values += ( + weights[0] * self._logit[0] * np.array([[0.0, 1.0]]) + ) + + for tree, weight in zip(self.estimators, weights, strict=True): + margin_values += weight * tree.value + + return margin_values + def _set_trees( self, trees: Iterable[Tree], @@ -128,6 +151,9 @@ def _set_trees( ) -> None: def create(item: tuple[int, Tree]) -> TreeVar: t, tree = item + if tree.xgboost: + self._logit = tree.logit + self._xgboost = tree.xgboost name = self.TREE_VAR_FMT.format(t=t) return TreeVar(tree, name=name, flow_type=flow_type) @@ -152,4 +178,6 @@ def _get_length(self) -> gp.LinExpr: return sum((tree.length for tree in self.isolators), gp.LinExpr()) def _get_function(self) -> gp.MLinExpr: + if self._xgboost: + return self.xgb_margin_function(weights=self.weights) return self.weighted_function(weights=self.weights) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index 4761bab..8969898 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -1,15 +1,31 @@ +import json from collections.abc import Iterable +from typing import Any import numpy as np import xgboost as xgb from ..abc import Mapper from ..feature import Feature -from ..typing import NonNegativeInt, XGBTree +from ..typing import Array1D, NonNegativeInt, XGBTree from ._node import Node from ._tree import Tree +def _parse_base_score(cfg: dict[str, Any]) -> Array1D: + base_score_values = json.loads( + cfg["learner"]["learner_model_param"]["base_score"] + ) + if isinstance(base_score_values, float): + return np.array([float(base_score_values)]) + return np.array([float(s) for s in base_score_values]) + + +def _logit(p: Array1D) -> Array1D: + p = np.clip(p, 1e-12, 1 - 1e-12) + return np.log(p / (1 - p)) + + def _get_column_value( xgb_tree: XGBTree, node_id: NonNegativeInt, column: str ) -> str | float | int: @@ -27,12 +43,11 @@ def _build_xgb_leaf( weight = float(_get_column_value(xgb_tree, node_id, "Gain")) if num_trees_per_round == 1: - value = np.array([[weight, -weight]]) + value = np.array([[0.0, weight]]) else: k = int(tree_id % num_trees_per_round) value = np.zeros((1, int(num_trees_per_round)), dtype=float) - value[0, k] = weight - + value[0, k] += weight return Node(node_id, n_samples=0, value=value) @@ -90,7 +105,7 @@ def _build_xgb_node( threshold = None if mapper[name].is_numeric: - threshold = float(_get_column_value(xgb_tree, node_id, "Split")) + threshold = float(_get_column_value(xgb_tree, node_id, "Split")) - 1e-8 mapper[name].add(threshold) left_id = _get_child_id(xgb_tree, node_id, "Yes") @@ -150,6 +165,7 @@ def _parse_xgb_tree( tree_id: NonNegativeInt, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], + base_score_margin: Array1D, ) -> Tree: root = _parse_xgb_node( xgb_tree, @@ -158,7 +174,14 @@ def _parse_xgb_tree( num_trees_per_round=num_trees_per_round, mapper=mapper, ) - return Tree(root=root) + tree = Tree(root=root) + tree.logit = ( + base_score_margin + if len(base_score_margin) > 1 + else _logit(base_score_margin) + ) + tree.xgboost = True + return tree def parse_xgb_tree( @@ -167,12 +190,14 @@ def parse_xgb_tree( tree_id: NonNegativeInt, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], + base_score_margin: Array1D, ) -> Tree: return _parse_xgb_tree( xgb_tree, tree_id=tree_id, num_trees_per_round=num_trees_per_round, mapper=mapper, + base_score_margin=base_score_margin, ) @@ -181,6 +206,7 @@ def parse_xgb_trees( *, num_trees_per_round: NonNegativeInt, mapper: Mapper[Feature], + base_score_margin: Array1D, ) -> tuple[Tree, ...]: return tuple( parse_xgb_tree( @@ -188,6 +214,7 @@ def parse_xgb_trees( tree_id=tree_id, num_trees_per_round=num_trees_per_round, mapper=mapper, + base_score_margin=base_score_margin, ) for tree_id, tree in enumerate(trees) ) @@ -197,6 +224,7 @@ def parse_xgb_ensemble( ensemble: xgb.Booster, *, mapper: Mapper[Feature] ) -> tuple[Tree, ...]: df = ensemble.trees_to_dataframe() + cfg = json.loads(ensemble.save_config()) groups = df.groupby("Tree") trees = tuple( groups.get_group(tree_id).reset_index(drop=True) @@ -205,7 +233,10 @@ def parse_xgb_ensemble( num_rounds = ensemble.num_boosted_rounds() or 1 num_trees_per_round = max(1, len(trees) // num_rounds) - + base_score_margin = _parse_base_score(cfg) return parse_xgb_trees( - trees, num_trees_per_round=num_trees_per_round, mapper=mapper + trees, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + base_score_margin=base_score_margin, ) diff --git a/ocean/tree/_tree.py b/ocean/tree/_tree.py index 649d2f8..0fa6c35 100644 --- a/ocean/tree/_tree.py +++ b/ocean/tree/_tree.py @@ -2,13 +2,14 @@ from pydantic import validate_call -from ..typing import NonNegativeInt, PositiveInt +from ..typing import Array1D, NonNegativeInt, PositiveInt from ._node import Node class Tree: root: Node _shape: tuple[NonNegativeInt, ...] + _xgboost: bool = False def __init__(self, root: Node) -> None: self.root = root @@ -30,6 +31,22 @@ def leaves(self) -> tuple[Node, *tuple[Node, ...]]: def shape(self) -> tuple[NonNegativeInt, ...]: return self._shape + @property + def logit(self) -> Array1D: + return self._base_score_prob + + @logit.setter + def logit(self, value: Array1D) -> None: + self._base_score_prob = value + + @property + def xgboost(self) -> bool: + return self._xgboost + + @xgboost.setter + def xgboost(self, value: bool) -> None: + self._xgboost = value + @validate_call def nodes_at(self, depth: NonNegativeInt) -> Iterator[Node]: return self._nodes_at(self.root, depth=depth) diff --git a/tests/test_explainer.py b/tests/test_explainer.py index a34bd19..0362c0d 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -1,6 +1,8 @@ import gurobipy as gp +import numpy as np import pytest from sklearn.ensemble import RandomForestClassifier +from xgboost import XGBClassifier from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer @@ -32,21 +34,73 @@ def test_mip_explain( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + assert model.Status == gp.GRB.OPTIMAL + assert len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() - try: - model.explain(x, y=0, norm=1, - num_workers=num_workers, - random_seed=seed) - assert model.Status == gp.GRB.OPTIMAL - model.cleanup() - model.explain(x, y=0, norm=1, - return_callback=True, - num_workers=num_workers, - random_seed=seed) - assert len(model.callback.sollist) != 0 + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") - except gp.GurobiError as e: - pytest.skip(f"Skipping test due to {e}") + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_estimators", [5]) +@pytest.mark.parametrize("max_depth", [2, 3]) +@pytest.mark.parametrize("n_classes", [2, 3, 4]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_mip_explain_xgb( + seed: int, + n_estimators: int, + max_depth: int, + n_classes: int, + n_samples: int, + num_workers: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = XGBClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + model = MixedIntegerProgramExplainer(clf, mapper=mapper, env=ENV) + + x = data.iloc[0, :].to_numpy().astype(float).flatten() + # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + model.cleanup() + assert model.Status == gp.GRB.OPTIMAL + assert len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") @pytest.mark.parametrize("seed", [42, 43, 44]) @@ -74,18 +128,69 @@ def test_cp_explain( x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + assert model.get_solving_status() == "OPTIMAL" + assert model.callback is None or len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") + + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_estimators", [5]) +@pytest.mark.parametrize("max_depth", [2, 3]) +@pytest.mark.parametrize("n_classes", [2, 3, 4]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_cp_explain_xgb( + seed: int, + n_estimators: int, + max_depth: int, + n_classes: int, + n_samples: int, + num_workers: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = XGBClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + model = ConstraintProgrammingExplainer(clf, mapper=mapper) + + x = data.iloc[0, :].to_numpy().astype(float).flatten() + # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] + for target in classes[classes != y]: + try: + exp = model.explain( + x, + y=target, + norm=1, + return_callback=True, + num_workers=num_workers, + random_seed=seed, + ) + assert model.get_solving_status() == "OPTIMAL" + assert model.callback is None or len(model.callback.sollist) != 0 + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() - try: - _ = model.explain(x, y=0, norm=1, - return_callback=False, - num_workers=num_workers, - random_seed=seed) - assert model.callback is None or len(model.callback.sollist) == 0 - model.cleanup() - _ = model.explain(x, y=0, norm=1, - return_callback=True, - num_workers=num_workers, - random_seed=seed) - assert model.callback is None or len(model.callback.sollist) != 0 - except gp.GurobiError as e: - pytest.skip(f"Skipping test due to {e}") + except gp.GurobiError as e: + pytest.skip(f"Skipping test due to {e}") diff --git a/tests/tree/test_parse.py b/tests/tree/test_parse.py index 4b5a9b0..ec68d51 100644 --- a/tests/tree/test_parse.py +++ b/tests/tree/test_parse.py @@ -75,7 +75,7 @@ def _dfs(node: Node) -> None: feature_name = str(row["Feature"].values[0]).strip() if feature.is_numeric: assert feature_name == node.feature - assert node.threshold == float(row["Split"].values[0]) + assert node.threshold == float(row["Split"].values[0] - 1e-8) if feature.is_one_hot_encoded: assert feature_name == f"{node.feature} {node.code}" assert node.code in feature.codes