From 41df8db036e4c07820a3699894d18b3ca3348bf9 Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 18:22:36 -0400 Subject: [PATCH 01/14] add xgboost to deps. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e6d3e4c..a930ea3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "pandas", "pydantic", "scikit-learn", + "xgboost", ] optional-dependencies.dev = [ From 6461340a3943572e752e895eb57c2737693df183 Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 18:22:57 -0400 Subject: [PATCH 02/14] add support for xgboost --- ocean/tree/_parse.py | 27 +++-- ocean/tree/_parse_xgb.py | 225 +++++++++++++++++++++++++++++++++++++++ ocean/tree/_protocol.py | 2 + ocean/typing/__init__.py | 6 +- tests/tree/test_parse.py | 83 ++++++++++++++- 5 files changed, 335 insertions(+), 8 deletions(-) create mode 100644 ocean/tree/_parse_xgb.py diff --git a/ocean/tree/_parse.py b/ocean/tree/_parse.py index c42b0dc..239ae79 100644 --- a/ocean/tree/_parse.py +++ b/ocean/tree/_parse.py @@ -3,16 +3,21 @@ from functools import partial from itertools import chain +import xgboost as xgb from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ..abc import Mapper from ..feature import Feature -from ..typing import NonNegativeInt, ParsableEnsemble +from ..typing import NonNegativeInt, ParsableEnsemble, SKLearnTree from ._node import Node -from ._protocol import SKLearnTree, SKLearnTreeProtocol, TreeProtocol +from ._parse_xgb import parse_xgb_ensemble +from ._protocol import ( + SKLearnTreeProtocol, + TreeProtocol, +) from ._tree import Tree -type DecisionTree = DecisionTreeClassifier | DecisionTreeRegressor +type SKLearnDecisionTree = DecisionTreeClassifier | DecisionTreeRegressor def _build_leaf(tree: TreeProtocol, node_id: NonNegativeInt) -> Node: @@ -70,13 +75,13 @@ def _parse_tree(sklearn_tree: SKLearnTree, *, mapper: Mapper[Feature]) -> Tree: return Tree(root=root) -def parse_tree(tree: DecisionTree, *, mapper: Mapper[Feature]) -> Tree: +def parse_tree(tree: SKLearnDecisionTree, *, mapper: Mapper[Feature]) -> Tree: getter = operator.attrgetter("tree_") return _parse_tree(getter(tree), mapper=mapper) def parse_trees( - trees: Iterable[DecisionTree], + trees: Iterable[SKLearnDecisionTree], *, mapper: Mapper[Feature], ) -> tuple[Tree, ...]: @@ -84,9 +89,19 @@ def parse_trees( return tuple(map(parser, trees)) +def parse_ensemble( + ensemble: ParsableEnsemble, + *, + mapper: Mapper[Feature], +) -> tuple[Tree, ...]: + if isinstance(ensemble, xgb.Booster): + return parse_xgb_ensemble(ensemble, mapper=mapper) + return parse_trees(ensemble, mapper=mapper) + + def parse_ensembles( *ensembles: ParsableEnsemble, mapper: Mapper[Feature], ) -> tuple[Tree, ...]: - parser = partial(parse_trees, mapper=mapper) + parser = partial(parse_ensemble, mapper=mapper) return tuple(chain.from_iterable(map(parser, ensembles))) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py new file mode 100644 index 0000000..93fee4e --- /dev/null +++ b/ocean/tree/_parse_xgb.py @@ -0,0 +1,225 @@ +from collections.abc import Iterable + +import numpy as np +import xgboost as xgb + +from ..abc import Mapper +from ..feature import Feature +from ..typing import NonNegativeInt, XGBTree +from ._node import Node +from ._tree import Tree + + +def _get_column_value( + xgb_tree: XGBTree, node_id: NonNegativeInt, column: str +) -> str | float | int: + mask = xgb_tree["Node"] == node_id + try: + return xgb_tree.loc[mask, column].to_numpy().item() + except Exception as exc: # pragma: no cover - defensive + msg = f"unable to read {column} for node {node_id}: {exc}" + raise ValueError(msg) from exc + + +def _build_xgb_leaf( + xgb_tree: XGBTree, + *, + node_id: NonNegativeInt, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, +) -> Node: + weight = float(_get_column_value(xgb_tree, node_id, "Gain")) + + if num_trees_per_round == 1: + value = np.array([weight, -weight]) + else: + k = int(tree_id % num_trees_per_round) + value = np.zeros(int(num_trees_per_round), dtype=float) + value[k] = weight + + return Node(node_id, n_samples=0, value=value) + + +def _parse_feature_info( + feature_name: str, mapper: Mapper[Feature] +) -> tuple[str, str | None]: + words = feature_name.split(" ") + name = words[0] if words else feature_name + code = words[1] if len(words) > 1 and words[1] else None + + if name not in mapper.names: + msg = f"feature '{name}' not found in mapper '{mapper.names}'" + raise KeyError(msg) + + return name, code + + +def _validate_feature_format( + name: str, + code: str | None, + mapper: Mapper[Feature], + node_id: NonNegativeInt, +) -> None: + if mapper[name].is_numeric and code: + msg = f"invalid numeric feature {name} for node {node_id}" + raise ValueError(msg) + + if mapper[name].is_one_hot_encoded: + if not code: + msg = f"invalid one-hot encoded feature {name} for node {node_id}" + raise ValueError(msg) + if code not in mapper.codes: + msg = f"code '{code}' not found in mapper codes '{mapper.codes}'" + raise KeyError(msg) + + +def _get_child_id( + xgb_tree: XGBTree, node_id: NonNegativeInt, column: str +) -> int: + raw = str(_get_column_value(xgb_tree, node_id, column)) + try: + return int(raw.rsplit("-", 1)[-1]) + except Exception as exc: # pragma: no cover - defensive + msg = ( + f"unable to parse child id from {column} for node {node_id}: {exc}" + ) + raise ValueError(msg) from exc + + +def _build_xgb_node( + xgb_tree: XGBTree, + *, + node_id: NonNegativeInt, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Node: + feature_name = str(_get_column_value(xgb_tree, node_id, "Feature")) + name, code = _parse_feature_info(feature_name, mapper) + _validate_feature_format(name, code, mapper, node_id) + + threshold = None + if mapper[name].is_numeric: + threshold = float(_get_column_value(xgb_tree, node_id, "Split")) + mapper[name].add(threshold) + + left_id = _get_child_id(xgb_tree, node_id, "Yes") + right_id = _get_child_id(xgb_tree, node_id, "No") + + node = Node( + node_id, feature=name, threshold=threshold, code=code, n_samples=0 + ) + node.left = _parse_xgb_node( + xgb_tree, + node_id=left_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + node.right = _parse_xgb_node( + xgb_tree, + node_id=right_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + return node + + +def _parse_xgb_node( + xgb_tree: XGBTree, + node_id: NonNegativeInt, + *, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Node: + mask = xgb_tree["Node"] == node_id + try: + feature_val = str(xgb_tree.loc[mask, "Feature"].to_numpy().item()) + except Exception as exc: # pragma: no cover - defensive + msg = f"unable to read Feature for node {node_id}: {exc}" + raise ValueError(msg) from exc + + if feature_val == "Leaf": + return _build_xgb_leaf( + xgb_tree, + node_id=node_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + ) + + return _build_xgb_node( + xgb_tree, + node_id=node_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + + +def _parse_xgb_tree( + xgb_tree: XGBTree, + *, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Tree: + root = _parse_xgb_node( + xgb_tree, + node_id=0, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + return Tree(root=root) + + +def parse_xgb_tree( + xgb_tree: XGBTree, + *, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Tree: + return _parse_xgb_tree( + xgb_tree, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + + +def parse_xgb_trees( + trees: Iterable[XGBTree], + *, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> tuple[Tree, ...]: + return tuple( + parse_xgb_tree( + tree, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + for tree_id, tree in enumerate(trees) + ) + + +def parse_xgb_ensemble( + ensemble: xgb.Booster, *, mapper: Mapper[Feature] +) -> tuple[Tree, ...]: + df = ensemble.trees_to_dataframe() + groups = df.groupby("Tree") + trees = tuple( + groups.get_group(tree_id).reset_index(drop=True) + for tree_id in groups.groups + ) + + num_rounds = ensemble.num_boosted_rounds() or 1 + num_trees_per_round = max(1, len(trees) // num_rounds) + + return parse_xgb_trees( + trees, num_trees_per_round=num_trees_per_round, mapper=mapper + ) diff --git a/ocean/tree/_protocol.py b/ocean/tree/_protocol.py index c664024..38f7e10 100644 --- a/ocean/tree/_protocol.py +++ b/ocean/tree/_protocol.py @@ -10,6 +10,7 @@ NonNegativeIntArray1D, PositiveInt, SKLearnTree, + XGBTree, ) @@ -40,4 +41,5 @@ def __init__(self, tree: SKLearnTree) -> None: "SKLearnTree", "SKLearnTreeProtocol", "TreeProtocol", + "XGBTree", ] diff --git a/ocean/typing/__init__.py b/ocean/typing/__init__.py index e8a2954..ebd69fc 100644 --- a/ocean/typing/__init__.py +++ b/ocean/typing/__init__.py @@ -3,11 +3,12 @@ import numpy as np import pandas as pd +import xgboost as xgb from pydantic import Field from sklearn.ensemble import IsolationForest, RandomForestClassifier type BaseExplainableEnsemble = RandomForestClassifier -type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest +type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest | xgb.Booster type Number = float type NonNegativeNumber = Annotated[Number, Field(ge=0.0)] @@ -75,6 +76,9 @@ class SKLearnTree(Protocol): value: Array +type XGBTree = pd.DataFrame + + class BaseExplanation(Protocol): @property def x(self) -> Array1D: ... diff --git a/tests/tree/test_parse.py b/tests/tree/test_parse.py index 9a00b39..2205203 100644 --- a/tests/tree/test_parse.py +++ b/tests/tree/test_parse.py @@ -1,10 +1,11 @@ import pytest +import xgboost as xgb from pydantic import ValidationError from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ocean.abc import Mapper from ocean.feature import Feature -from ocean.tree import Node, parse_tree +from ocean.tree import Node, parse_ensembles, parse_tree from ocean.typing import SKLearnTree from ..utils import generate_data @@ -46,6 +47,50 @@ def _dfs(node: Node) -> None: _dfs(root) +def _check_xgb_tree( + root: Node, + booster: xgb.Booster, + *, + tree_id: int, + mapper: Mapper[Feature], +) -> None: + df = booster.trees_to_dataframe() + tree_df = df[df["Tree"] == tree_id].reset_index(drop=True) + + def _dfs(node: Node) -> None: + row = tree_df[tree_df["Node"] == node.node_id] + assert not row.empty, f"node {node.node_id} not found in tree {tree_id}" + + if node.is_leaf: + assert row["Feature"].values[0] == "Leaf" + assert (node.value == row["Gain"].values[0]).any() + else: + assert node.feature is not None + assert node.feature in mapper + assert node.left is not None + assert node.right is not None + assert len(node.children) == 2 + + feature = mapper[node.feature] + feature_name = str(row["Feature"].values[0]).strip() + if feature.is_numeric: + assert feature_name == node.feature + assert node.threshold == float(row["Split"].values[0]) + if feature.is_one_hot_encoded: + assert feature_name == f"{node.feature} {node.code}" + assert node.code in feature.codes + + left_id = int(str(row["Yes"].values[0]).split("-")[-1]) + right_id = int(str(row["No"].values[0]).split("-")[-1]) + assert node.left.node_id == left_id + assert node.right.node_id == right_id + + _dfs(node.left) + _dfs(node.right) + + _dfs(root) + + @pytest.mark.parametrize("seed", [42, 43, 44]) @pytest.mark.parametrize("max_depth", [2, 3, 4]) @pytest.mark.parametrize("n_classes", [2, 3, 4]) @@ -87,3 +132,39 @@ def test_parse_regressor(seed: int, n_samples: int, max_depth: int) -> None: assert tree.max_depth == dt.tree_.max_depth # pyright: ignore[reportAttributeAccessIssue] assert tree.shape == (1, 1) _check_tree(tree.root, dt.tree_, mapper=mapper) # pyright: ignore[reportArgumentType, reportUnknownArgumentType] + + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_classes", [2, 3, 4]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +@pytest.mark.parametrize("n_estimators", [3, 5, 4]) +def test_parse_xgb_classifier( + seed: int, + n_classes: int, + n_samples: int, + n_estimators: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + model = xgb.XGBClassifier( + n_estimators=n_estimators, + max_depth=3, + eval_metric="logloss", + random_state=seed, + ) + model.fit(data, y) + assert model is not None + booster = model.get_booster() + assert booster is not None + trees = parse_ensembles(booster, mapper=mapper) + assert len(trees) == n_estimators * (1 if n_classes == 2 else n_classes) + for i, tree in enumerate(trees): + assert tree.root is not None + assert tree.root.node_id == 0 + assert tree.max_depth >= 1 + + _check_xgb_tree( + tree.root, + booster, + tree_id=i, + mapper=mapper, + ) From 2c1d1e9dd43dea87a6e238d78eaf7e405340aff9 Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 18:28:22 -0400 Subject: [PATCH 03/14] ignore mypy --- ocean/tree/_parse_xgb.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index 93fee4e..839ab94 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -14,11 +14,7 @@ def _get_column_value( xgb_tree: XGBTree, node_id: NonNegativeInt, column: str ) -> str | float | int: mask = xgb_tree["Node"] == node_id - try: - return xgb_tree.loc[mask, column].to_numpy().item() - except Exception as exc: # pragma: no cover - defensive - msg = f"unable to read {column} for node {node_id}: {exc}" - raise ValueError(msg) from exc + return xgb_tree.loc[mask, column].to_numpy().item() # type: ignore[no-any-return] def _build_xgb_leaf( From 0ea6d86d880df93afe4e907efbe61a1ed30c8947 Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 18:35:40 -0400 Subject: [PATCH 04/14] install openmp for mac CI --- .github/workflows/test.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 230365f..0e21ca7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,10 @@ jobs: python-version: ${{ matrix.python-version }} cache: pip + - name: Install OpenMP (macOS) + run: brew install libomp + if: runner.os == 'macOS' + - name: Install tox run: python -m pip install --upgrade pip tox @@ -76,4 +80,4 @@ jobs: path: | coverage.xml htmlcov/ - retention-days: 30 + retention-days: 30 \ No newline at end of file From 05d1e92358b0f159a4a276acd6d9ea8f1de0d52c Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 18:48:12 -0400 Subject: [PATCH 05/14] add xgboost to check environment dependencies --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 7e37c88..0119c38 100644 --- a/tox.ini +++ b/tox.ini @@ -40,6 +40,7 @@ deps = scikit-learn scipy scipy-stubs + xgboost commands = pip install --upgrade pyright mypy ocean tests From 1b143770a970a0d898a6a34af437bfa40a0842aa Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 18:50:30 -0400 Subject: [PATCH 06/14] diable mypy --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 0119c38..887623d 100644 --- a/tox.ini +++ b/tox.ini @@ -43,7 +43,7 @@ deps = xgboost commands = pip install --upgrade pyright - mypy ocean tests + # mypy ocean tests pyright ocean tests [testenv:coverage] From 49f0a818c6cc13510757629d5cb2adb0a7544dd9 Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 18:57:16 -0400 Subject: [PATCH 07/14] enable mypy --- ocean/tree/_parse_xgb.py | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index 839ab94..881aa26 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -14,7 +14,7 @@ def _get_column_value( xgb_tree: XGBTree, node_id: NonNegativeInt, column: str ) -> str | float | int: mask = xgb_tree["Node"] == node_id - return xgb_tree.loc[mask, column].to_numpy().item() # type: ignore[no-any-return] + return xgb_tree.loc[mask, column].values[0] def _build_xgb_leaf( diff --git a/tox.ini b/tox.ini index 887623d..0119c38 100644 --- a/tox.ini +++ b/tox.ini @@ -43,7 +43,7 @@ deps = xgboost commands = pip install --upgrade pyright - # mypy ocean tests + mypy ocean tests pyright ocean tests [testenv:coverage] From 13be6bd7ab437c267f1302268447e3340217e459 Mon Sep 17 00:00:00 2001 From: eminyous Date: Sun, 12 Oct 2025 19:00:23 -0400 Subject: [PATCH 08/14] fix mypy. --- ocean/tree/_parse_xgb.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index 881aa26..c694750 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -14,7 +14,7 @@ def _get_column_value( xgb_tree: XGBTree, node_id: NonNegativeInt, column: str ) -> str | float | int: mask = xgb_tree["Node"] == node_id - return xgb_tree.loc[mask, column].values[0] + return xgb_tree.loc[mask, column].values[0] # type: ignore[no-any-return] def _build_xgb_leaf( @@ -73,13 +73,7 @@ def _get_child_id( xgb_tree: XGBTree, node_id: NonNegativeInt, column: str ) -> int: raw = str(_get_column_value(xgb_tree, node_id, column)) - try: - return int(raw.rsplit("-", 1)[-1]) - except Exception as exc: # pragma: no cover - defensive - msg = ( - f"unable to parse child id from {column} for node {node_id}: {exc}" - ) - raise ValueError(msg) from exc + return int(raw.rsplit("-", 1)[-1]) def _build_xgb_node( @@ -131,11 +125,7 @@ def _parse_xgb_node( mapper: Mapper[Feature], ) -> Node: mask = xgb_tree["Node"] == node_id - try: - feature_val = str(xgb_tree.loc[mask, "Feature"].to_numpy().item()) - except Exception as exc: # pragma: no cover - defensive - msg = f"unable to read Feature for node {node_id}: {exc}" - raise ValueError(msg) from exc + feature_val = str(xgb_tree.loc[mask, "Feature"].to_numpy().item()) if feature_val == "Leaf": return _build_xgb_leaf( From a5cf40d92399e53340d81c285f33fe0d1b3faaf5 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 13 Oct 2025 14:08:15 -0400 Subject: [PATCH 09/14] fix: update parse_ensembles to handle xgb classifier correctly --- ocean/tree/_parse.py | 7 ++++++- tests/tree/test_parse.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ocean/tree/_parse.py b/ocean/tree/_parse.py index 239ae79..71501d6 100644 --- a/ocean/tree/_parse.py +++ b/ocean/tree/_parse.py @@ -96,7 +96,7 @@ def parse_ensemble( ) -> tuple[Tree, ...]: if isinstance(ensemble, xgb.Booster): return parse_xgb_ensemble(ensemble, mapper=mapper) - return parse_trees(ensemble, mapper=mapper) + return parse_trees(ensemble, mapper=mapper) # type: ignore[reportArgumentType] def parse_ensembles( @@ -104,4 +104,9 @@ def parse_ensembles( mapper: Mapper[Feature], ) -> tuple[Tree, ...]: parser = partial(parse_ensemble, mapper=mapper) + if all(isinstance(e, xgb.XGBClassifier) for e in ensembles): + xgb_ensembles = tuple( + e for e in ensembles if isinstance(e, xgb.XGBClassifier) + ) + ensembles = tuple(e.get_booster() for e in xgb_ensembles) return tuple(chain.from_iterable(map(parser, ensembles))) diff --git a/tests/tree/test_parse.py b/tests/tree/test_parse.py index 2205203..4b5a9b0 100644 --- a/tests/tree/test_parse.py +++ b/tests/tree/test_parse.py @@ -155,7 +155,7 @@ def test_parse_xgb_classifier( assert model is not None booster = model.get_booster() assert booster is not None - trees = parse_ensembles(booster, mapper=mapper) + trees = parse_ensembles(model, mapper=mapper) assert len(trees) == n_estimators * (1 if n_classes == 2 else n_classes) for i, tree in enumerate(trees): assert tree.root is not None From e422d7ec6fcab0981c998a48a7ffdba18abfcb60 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 13 Oct 2025 14:08:43 -0400 Subject: [PATCH 10/14] feat: add XGBoost example and update typing for explainable ensembles --- examples/random_forest.py | 91 ++++++++++++----- examples/{random_forest_cp.py => xgb.py} | 120 ++++++++++++++++------- ocean/typing/__init__.py | 4 +- 3 files changed, 155 insertions(+), 60 deletions(-) rename examples/{random_forest_cp.py => xgb.py} (55%) diff --git a/examples/random_forest.py b/examples/random_forest.py index e817bed..01fc2be 100644 --- a/examples/random_forest.py +++ b/examples/random_forest.py @@ -10,11 +10,14 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split -from ocean import MixedIntegerProgramExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MixedIntegerProgramExplainer, +) from ocean.abc import Mapper from ocean.datasets import load_adult, load_compas, load_credit from ocean.feature import Feature -from ocean.typing import Array1D +from ocean.typing import Array1D, BaseExplainer Loaded = tuple[tuple[pd.DataFrame, "pd.Series[int]"], Mapper[Feature]] @@ -77,13 +80,15 @@ def load(dataset: str) -> Loaded: CONSOLE = Console() -def main() -> None: +def main(explainers: dict[str, type[BaseExplainer]]) -> None: args = parse_args() data, target, mapper = load_data(args) rf = fit_model(args, data, target) - mip = build_explainer(args, rf, mapper) - queries = generate_queries(args, rf, data) - times = run_queries(mip, queries) + times: dict[str, pd.Series[float]] = {} + for name, explainer in explainers.items(): + exp = build_explainer(name, explainer, args, rf, mapper) + queries = generate_queries(args, rf, data) + times[name] = run_queries(exp, queries) display_statistics(times) @@ -119,19 +124,24 @@ def fit_model( def build_explainer( + name: str, + explainer: type[BaseExplainer], args: Args, rf: RandomForestClassifier, mapper: Mapper[Feature], -) -> MixedIntegerProgramExplainer: +) -> BaseExplainer: with CONSOLE.status("[bold blue]Building the Explainer[/bold blue]"): - ENV.setParam("Seed", args.seed) start = time.time() - mip = MixedIntegerProgramExplainer(rf, mapper=mapper, env=ENV) + if name == "mip": + ENV.setParam("Seed", args.seed) + exp = explainer(rf, mapper=mapper, env=ENV) # type: ignore # noqa: PGH003 + else: + exp = explainer(rf, mapper=mapper) # type: ignore # noqa: PGH003 end = time.time() - CONSOLE.print("[bold green]Explainer built[/bold green]") + CONSOLE.print(f"[bold green]{name.upper()} Explainer built[/bold green]") msg = f"Build time: {end - start:.2f} seconds" - CONSOLE.print(f"[bold yellow]{msg}[/bold yellow]") - return mip + CONSOLE.print(f"\t[bold yellow]{msg}[/bold yellow]") + return exp def generate_queries( @@ -156,7 +166,7 @@ def generate_queries( def run_queries( - mip: MixedIntegerProgramExplainer, queries: list[tuple[Array1D, int]] + explainer: BaseExplainer, queries: list[tuple[Array1D, int]] ) -> "pd.Series[float]": times: pd.Series[float] = pd.Series() for i, (x, y) in track( @@ -165,27 +175,60 @@ def run_queries( description="[bold blue]Running queries[/bold blue]", ): start = time.time() - mip.explain(x, y=y, norm=1) - mip.cleanup() + explainer.explain(x, y=y, norm=1) + explainer.cleanup() end = time.time() times[i] = end - start return times -def display_statistics(times: "pd.Series[int]") -> None: +def create_table_row( + metric: str, times: dict[str, "pd.Series[float]"] +) -> list[str]: + row = [metric] + for t in times.values(): + if metric == "Number of queries": + row.append(str(len(t))) + elif metric == "Total time (seconds)": + row.append(f"{t.sum():.2f}") + elif metric == "Mean time per query (seconds)": + row.append(f"{t.mean():.2f}") + elif metric == "Std of time per query (seconds)": + row.append(f"{t.std():.2f}") + elif metric == "Maximum time per query (seconds)": + row.append(f"{t.max():.2f}") + elif metric == "Minimum time per query (seconds)": + row.append(f"{t.min():.2f}") + else: + row.append("N/A") + return row + + +def display_statistics(times: dict[str, "pd.Series[float]"]) -> None: CONSOLE.print("[bold blue]Statistics:[/bold blue]") table = Table(show_header=True, header_style="bold magenta") table.add_column("Metric", style="dim", width=30) - table.add_column("Value") - table.add_row("Number of queries", str(len(times))) - table.add_row("Total time (seconds)", f"{times.sum():.2f}") - table.add_row("Mean time per query (seconds)", f"{times.mean():.2f}") - table.add_row("Std of time per query (seconds)", f"{times.std():.2f}") - table.add_row("Maximum time per query (seconds)", f"{times.max():.2f}") - table.add_row("Minimum time per query (seconds)", f"{times.min():.2f}") + names = list(times.keys()) + for name in names: + table.add_column(name.upper()) + metrics = [ + "Number of queries", + "Total time (seconds)", + "Mean time per query (seconds)", + "Std of time per query (seconds)", + "Maximum time per query (seconds)", + "Minimum time per query (seconds)", + ] + for metric in metrics: + row = create_table_row(metric, times) + table.add_row(*row) CONSOLE.print(table) CONSOLE.print("[bold green]Done[/bold green]") if __name__ == "__main__": - main() + explainers = { + "mip": MixedIntegerProgramExplainer, + "cp": ConstraintProgrammingExplainer, + } + main(explainers) diff --git a/examples/random_forest_cp.py b/examples/xgb.py similarity index 55% rename from examples/random_forest_cp.py rename to examples/xgb.py index 7477ca6..282b380 100644 --- a/examples/random_forest_cp.py +++ b/examples/xgb.py @@ -2,26 +2,28 @@ from argparse import ArgumentParser from dataclasses import dataclass +import gurobipy as gp import pandas as pd from rich.console import Console from rich.progress import track from rich.table import Table -from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split +from xgboost import XGBClassifier -from ocean import ConstraintProgrammingExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MixedIntegerProgramExplainer, +) from ocean.abc import Mapper from ocean.datasets import load_adult, load_compas, load_credit from ocean.feature import Feature -from ocean.typing import Array1D +from ocean.typing import Array1D, BaseExplainer Loaded = tuple[tuple[pd.DataFrame, "pd.Series[int]"], Mapper[Feature]] @dataclass class Args: - """Command line arguments.""" - seed: int n_estimators: int max_depth: int @@ -72,16 +74,21 @@ def load(dataset: str) -> Loaded: raise ValueError(msg) +ENV = gp.Env(empty=True) +ENV.setParam("OutputFlag", 0) +ENV.start() CONSOLE = Console() -def main() -> None: +def main(explainers: dict[str, type[BaseExplainer]]) -> None: args = parse_args() data, target, mapper = load_data(args) - rf = fit_model(args, data, target) - cp = build_explainer(rf, mapper) - queries = generate_queries(args, rf, data) - times = run_queries(cp, queries) + clf = fit_model(args, data, target) + times: dict[str, pd.Series[float]] = {} + for name, explainer in explainers.items(): + exp = build_explainer(name, explainer, args, clf, mapper) + queries = generate_queries(args, clf, data) + times[name] = run_queries(exp, queries) display_statistics(times) @@ -98,40 +105,49 @@ def fit_model( args: Args, data: pd.DataFrame, target: "pd.Series[int]", -) -> RandomForestClassifier: +) -> XGBClassifier: X_train, _, y_train, _ = train_test_split( data, target, test_size=0.2, random_state=args.seed, ) - with CONSOLE.status("[bold blue]Fitting a Random Forest model[/bold blue]"): - rf = RandomForestClassifier( + with CONSOLE.status("[bold blue]Fitting a XGBoost model[/bold blue]"): + clf = XGBClassifier( n_estimators=args.n_estimators, random_state=args.seed, max_depth=args.max_depth, + eval_metric="logloss", ) - rf.fit(X_train, y_train) + clf.fit(X_train, y_train) CONSOLE.print("[bold green]Model fitted[/bold green]") - return rf + return clf def build_explainer( - rf: RandomForestClassifier, mapper: Mapper[Feature] -) -> ConstraintProgrammingExplainer: + name: str, + explainer: type[BaseExplainer], + args: Args, + clf: XGBClassifier, + mapper: Mapper[Feature], +) -> BaseExplainer: with CONSOLE.status("[bold blue]Building the Explainer[/bold blue]"): start = time.time() - cp = ConstraintProgrammingExplainer(rf, mapper=mapper) + if name == "mip": + ENV.setParam("Seed", args.seed) + exp = explainer(clf, mapper=mapper, env=ENV) # type: ignore # noqa: PGH003 + else: + exp = explainer(clf, mapper=mapper) # type: ignore # noqa: PGH003 end = time.time() - CONSOLE.print("[bold green]Explainer built[/bold green]") + CONSOLE.print(f"[bold green]{name.upper()} Explainer built[/bold green]") msg = f"Build time: {end - start:.2f} seconds" - CONSOLE.print(f"[bold yellow]{msg}[/bold yellow]") - return cp + CONSOLE.print(f"\t[bold yellow]{msg}[/bold yellow]") + return exp def generate_queries( args: Args, - rf: RandomForestClassifier, + clf: XGBClassifier, data: pd.DataFrame, ) -> list[tuple[Array1D, int]]: _, X_test = train_test_split( @@ -140,7 +156,7 @@ def generate_queries( random_state=args.seed, ) X_test = pd.DataFrame(X_test) - y_pred = rf.predict(X_test) + y_pred = clf.predict(X_test) with CONSOLE.status("[bold blue]Generating queries[/bold blue]"): queries: list[tuple[Array1D, int]] = [ (X_test.iloc[i].to_numpy().flatten(), 1 - y_pred[i]) @@ -151,7 +167,7 @@ def generate_queries( def run_queries( - cp: ConstraintProgrammingExplainer, queries: list[tuple[Array1D, int]] + explainer: BaseExplainer, queries: list[tuple[Array1D, int]] ) -> "pd.Series[float]": times: pd.Series[float] = pd.Series() for i, (x, y) in track( @@ -159,28 +175,62 @@ def run_queries( total=len(queries), description="[bold blue]Running queries[/bold blue]", ): + print(f"Running query {i + 1}/{len(queries)}, x= {x}, target={y}") start = time.time() - cp.explain(x, y=y, norm=1) - cp.cleanup() + explainer.explain(x, y=y, norm=1) + explainer.cleanup() end = time.time() times[i] = end - start return times -def display_statistics(times: "pd.Series[int]") -> None: +def create_table_row( + metric: str, times: dict[str, "pd.Series[float]"] +) -> list[str]: + row = [metric] + for t in times.values(): + if metric == "Number of queries": + row.append(str(len(t))) + elif metric == "Total time (seconds)": + row.append(f"{t.sum():.2f}") + elif metric == "Mean time per query (seconds)": + row.append(f"{t.mean():.2f}") + elif metric == "Std of time per query (seconds)": + row.append(f"{t.std():.2f}") + elif metric == "Maximum time per query (seconds)": + row.append(f"{t.max():.2f}") + elif metric == "Minimum time per query (seconds)": + row.append(f"{t.min():.2f}") + else: + row.append("N/A") + return row + + +def display_statistics(times: dict[str, "pd.Series[float]"]) -> None: CONSOLE.print("[bold blue]Statistics:[/bold blue]") table = Table(show_header=True, header_style="bold magenta") table.add_column("Metric", style="dim", width=30) - table.add_column("Value") - table.add_row("Number of queries", str(len(times))) - table.add_row("Total time (seconds)", f"{times.sum():.2f}") - table.add_row("Mean time per query (seconds)", f"{times.mean():.2f}") - table.add_row("Std of time per query (seconds)", f"{times.std():.2f}") - table.add_row("Maximum time per query (seconds)", f"{times.max():.2f}") - table.add_row("Minimum time per query (seconds)", f"{times.min():.2f}") + names = list(times.keys()) + for name in names: + table.add_column(name.upper()) + metrics = [ + "Number of queries", + "Total time (seconds)", + "Mean time per query (seconds)", + "Std of time per query (seconds)", + "Maximum time per query (seconds)", + "Minimum time per query (seconds)", + ] + for metric in metrics: + row = create_table_row(metric, times) + table.add_row(*row) CONSOLE.print(table) CONSOLE.print("[bold green]Done[/bold green]") if __name__ == "__main__": - main() + explainers = { + "mip": MixedIntegerProgramExplainer, + "cp": ConstraintProgrammingExplainer, + } + main(explainers) diff --git a/ocean/typing/__init__.py b/ocean/typing/__init__.py index ebd69fc..b736354 100644 --- a/ocean/typing/__init__.py +++ b/ocean/typing/__init__.py @@ -7,7 +7,7 @@ from pydantic import Field from sklearn.ensemble import IsolationForest, RandomForestClassifier -type BaseExplainableEnsemble = RandomForestClassifier +type BaseExplainableEnsemble = RandomForestClassifier | xgb.XGBClassifier type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest | xgb.Booster type Number = float @@ -97,6 +97,8 @@ def explain( norm: PositiveInt, ) -> BaseExplanation | None: ... + def cleanup(self) -> None: ... + __all__ = [ "Array", From 328231f989da770590fdc0cfa64a5262e0488fd1 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 13 Oct 2025 14:27:12 -0400 Subject: [PATCH 11/14] fix: update parse_ensemble to handle XGBClassifier and improve ensemble parsing --- ocean/tree/_parse.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ocean/tree/_parse.py b/ocean/tree/_parse.py index 71501d6..3d3b090 100644 --- a/ocean/tree/_parse.py +++ b/ocean/tree/_parse.py @@ -96,7 +96,9 @@ def parse_ensemble( ) -> tuple[Tree, ...]: if isinstance(ensemble, xgb.Booster): return parse_xgb_ensemble(ensemble, mapper=mapper) - return parse_trees(ensemble, mapper=mapper) # type: ignore[reportArgumentType] + if isinstance(ensemble, xgb.XGBClassifier): + return parse_xgb_ensemble(ensemble.get_booster(), mapper=mapper) + return parse_trees(ensemble, mapper=mapper) def parse_ensembles( @@ -104,9 +106,4 @@ def parse_ensembles( mapper: Mapper[Feature], ) -> tuple[Tree, ...]: parser = partial(parse_ensemble, mapper=mapper) - if all(isinstance(e, xgb.XGBClassifier) for e in ensembles): - xgb_ensembles = tuple( - e for e in ensembles if isinstance(e, xgb.XGBClassifier) - ) - ensembles = tuple(e.get_booster() for e in xgb_ensembles) return tuple(chain.from_iterable(map(parser, ensembles))) From 1abfe081201df9d4343be029b597e0438e067e52 Mon Sep 17 00:00:00 2001 From: eminyous Date: Mon, 13 Oct 2025 17:48:22 -0400 Subject: [PATCH 12/14] refactor examples --- examples/{xgb.py => query.py} | 153 ++++++++++++++------- examples/random_forest.py | 234 -------------------------------- examples/simple_example_both.py | 50 +++++-- pyrightconfig.json | 11 -- 4 files changed, 141 insertions(+), 307 deletions(-) rename examples/{xgb.py => query.py} (62%) delete mode 100644 examples/random_forest.py diff --git a/examples/xgb.py b/examples/query.py similarity index 62% rename from examples/xgb.py rename to examples/query.py index 282b380..3b06b99 100644 --- a/examples/xgb.py +++ b/examples/query.py @@ -7,6 +7,7 @@ from rich.console import Console from rich.progress import track from rich.table import Table +from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from xgboost import XGBClassifier @@ -19,7 +20,19 @@ from ocean.feature import Feature from ocean.typing import Array1D, BaseExplainer -Loaded = tuple[tuple[pd.DataFrame, "pd.Series[int]"], Mapper[Feature]] +# Global constants +ENV = gp.Env(empty=True) +ENV.setParam("OutputFlag", 0) +ENV.start() +CONSOLE = Console() +EXPLAINERS = { + "mip": MixedIntegerProgramExplainer, + "cp": ConstraintProgrammingExplainer, +} +MODELS = { + "rf": RandomForestClassifier, + "xgb": XGBClassifier, +} @dataclass @@ -29,9 +42,11 @@ class Args: max_depth: int n_examples: int dataset: str + explainers: list[str] + models: list[str] -def parse_args() -> Args: +def create_argument_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument("--seed", type=int, default=42) parser.add_argument( @@ -53,6 +68,30 @@ def parse_args() -> Args: choices=["adult", "compas", "credit"], default="compas", ) + parser.add_argument( + "-e", + "--exp", + "--explainer", + help="List of explainers to use", + type=str, + nargs="+", + choices=["mip", "cp"], + default=["mip", "cp"], + ) + parser.add_argument( + "-m", + "--model", + help="List of models to use", + type=str, + nargs="+", + choices=["rf", "xgb"], + default=["rf"], + ) + return parser + + +def parse_args() -> Args: + parser = create_argument_parser() args = parser.parse_args() return Args( seed=args.seed, @@ -60,10 +99,14 @@ def parse_args() -> Args: max_depth=args.max_depth, n_examples=args.n_examples, dataset=args.dataset, + explainers=args.exp, + models=args.model, ) -def load(dataset: str) -> Loaded: +def load_dataset( + dataset: str, +) -> tuple[tuple[pd.DataFrame, pd.Series], Mapper[Feature]]: if dataset == "credit": return load_credit() if dataset == "adult": @@ -74,72 +117,61 @@ def load(dataset: str) -> Loaded: raise ValueError(msg) -ENV = gp.Env(empty=True) -ENV.setParam("OutputFlag", 0) -ENV.start() -CONSOLE = Console() - - -def main(explainers: dict[str, type[BaseExplainer]]) -> None: - args = parse_args() - data, target, mapper = load_data(args) - clf = fit_model(args, data, target) - times: dict[str, pd.Series[float]] = {} - for name, explainer in explainers.items(): - exp = build_explainer(name, explainer, args, clf, mapper) - queries = generate_queries(args, clf, data) - times[name] = run_queries(exp, queries) - display_statistics(times) - - -def load_data( - args: Args, -) -> tuple[pd.DataFrame, "pd.Series[int]", Mapper[Feature]]: +def load_data(args: Args) -> tuple[pd.DataFrame, pd.Series, Mapper[Feature]]: with CONSOLE.status("[bold blue]Loading the data[/bold blue]"): - (data, target), mapper = load(args.dataset) + (data, target), mapper = load_dataset(args.dataset) CONSOLE.print("[bold green]Data loaded[/bold green]") return data, target, mapper -def fit_model( +def fit_model_with_console( args: Args, data: pd.DataFrame, - target: "pd.Series[int]", -) -> XGBClassifier: + target: pd.Series, + model_class: type[RandomForestClassifier] | type[XGBClassifier], + model_name: str, + **model_kwargs: str | float | bool | None, +) -> RandomForestClassifier | XGBClassifier: X_train, _, y_train, _ = train_test_split( data, target, test_size=0.2, random_state=args.seed, ) - with CONSOLE.status("[bold blue]Fitting a XGBoost model[/bold blue]"): - clf = XGBClassifier( + with CONSOLE.status(f"[bold blue]Fitting a {model_name} model[/bold blue]"): + model = model_class( n_estimators=args.n_estimators, random_state=args.seed, max_depth=args.max_depth, - eval_metric="logloss", + **model_kwargs, ) - clf.fit(X_train, y_train) + model.fit(X_train, y_train) CONSOLE.print("[bold green]Model fitted[/bold green]") - return clf + return model def build_explainer( - name: str, - explainer: type[BaseExplainer], + explainer_name: str, + explainer_class: type[MixedIntegerProgramExplainer] + | type[ConstraintProgrammingExplainer], args: Args, - clf: XGBClassifier, + model: RandomForestClassifier | XGBClassifier, mapper: Mapper[Feature], ) -> BaseExplainer: with CONSOLE.status("[bold blue]Building the Explainer[/bold blue]"): start = time.time() - if name == "mip": + if explainer_class is MixedIntegerProgramExplainer: ENV.setParam("Seed", args.seed) - exp = explainer(clf, mapper=mapper, env=ENV) # type: ignore # noqa: PGH003 + exp = explainer_class(model, mapper=mapper, env=ENV) + elif explainer_class is ConstraintProgrammingExplainer: + exp = explainer_class(model, mapper=mapper) else: - exp = explainer(clf, mapper=mapper) # type: ignore # noqa: PGH003 + msg = f"Unknown explainer type: {explainer_class}" + raise ValueError(msg) end = time.time() - CONSOLE.print(f"[bold green]{name.upper()} Explainer built[/bold green]") + CONSOLE.print( + f"[bold green]{explainer_name.upper()} Explainer built[/bold green]" + ) msg = f"Build time: {end - start:.2f} seconds" CONSOLE.print(f"\t[bold yellow]{msg}[/bold yellow]") return exp @@ -147,7 +179,7 @@ def build_explainer( def generate_queries( args: Args, - clf: XGBClassifier, + model: RandomForestClassifier | XGBClassifier, data: pd.DataFrame, ) -> list[tuple[Array1D, int]]: _, X_test = train_test_split( @@ -156,7 +188,7 @@ def generate_queries( random_state=args.seed, ) X_test = pd.DataFrame(X_test) - y_pred = clf.predict(X_test) + y_pred = model.predict(X_test) with CONSOLE.status("[bold blue]Generating queries[/bold blue]"): queries: list[tuple[Array1D, int]] = [ (X_test.iloc[i].to_numpy().flatten(), 1 - y_pred[i]) @@ -166,7 +198,7 @@ def generate_queries( return queries -def run_queries( +def run_queries_verbose( explainer: BaseExplainer, queries: list[tuple[Array1D, int]] ) -> "pd.Series[float]": times: pd.Series[float] = pd.Series() @@ -207,6 +239,7 @@ def create_table_row( def display_statistics(times: dict[str, "pd.Series[float]"]) -> None: + """Display timing statistics in a table.""" CONSOLE.print("[bold blue]Statistics:[/bold blue]") table = Table(show_header=True, header_style="bold magenta") table.add_column("Metric", style="dim", width=30) @@ -228,9 +261,35 @@ def display_statistics(times: dict[str, "pd.Series[float]"]) -> None: CONSOLE.print("[bold green]Done[/bold green]") -if __name__ == "__main__": +def main() -> None: + args = parse_args() + data, target, mapper = load_data(args) explainers = { - "mip": MixedIntegerProgramExplainer, - "cp": ConstraintProgrammingExplainer, + name: explainer + for name, explainer in EXPLAINERS.items() + if name in args.explainers } - main(explainers) + models = { + name: model for name, model in MODELS.items() if name in args.models + } + for model_name, model_class in models.items(): + CONSOLE.print( + f"[bold blue]Running experiment with {model_name}: [/bold blue]" + ) + model = fit_model_with_console( + args, data, target, model_class, model_name + ) + for explainer_name, explainer_class in explainers.items(): + CONSOLE.print( + f"[bold blue]Running for {explainer_name}[/bold blue]" + ) + exp = build_explainer( + explainer_name, explainer_class, args, model, mapper + ) + queries = generate_queries(args, model, data) + times = run_queries_verbose(exp, queries) + display_statistics({explainer_name: times}) + + +if __name__ == "__main__": + main() diff --git a/examples/random_forest.py b/examples/random_forest.py deleted file mode 100644 index 01fc2be..0000000 --- a/examples/random_forest.py +++ /dev/null @@ -1,234 +0,0 @@ -import time -from argparse import ArgumentParser -from dataclasses import dataclass - -import gurobipy as gp -import pandas as pd -from rich.console import Console -from rich.progress import track -from rich.table import Table -from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import train_test_split - -from ocean import ( - ConstraintProgrammingExplainer, - MixedIntegerProgramExplainer, -) -from ocean.abc import Mapper -from ocean.datasets import load_adult, load_compas, load_credit -from ocean.feature import Feature -from ocean.typing import Array1D, BaseExplainer - -Loaded = tuple[tuple[pd.DataFrame, "pd.Series[int]"], Mapper[Feature]] - - -@dataclass -class Args: - seed: int - n_estimators: int - max_depth: int - n_examples: int - dataset: str - - -def parse_args() -> Args: - parser = ArgumentParser() - parser.add_argument("--seed", type=int, default=42) - parser.add_argument( - "--n-estimators", - type=int, - default=100, - dest="n_estimators", - ) - parser.add_argument("--max-depth", type=int, default=5, dest="max_depth") - parser.add_argument( - "--n-examples", - type=int, - default=100, - dest="n_examples", - ) - parser.add_argument( - "--dataset", - type=str, - choices=["adult", "compas", "credit"], - default="compas", - ) - args = parser.parse_args() - return Args( - seed=args.seed, - n_estimators=args.n_estimators, - max_depth=args.max_depth, - n_examples=args.n_examples, - dataset=args.dataset, - ) - - -def load(dataset: str) -> Loaded: - if dataset == "credit": - return load_credit() - if dataset == "adult": - return load_adult() - if dataset == "compas": - return load_compas() - msg = f"Unknown dataset: {dataset}" - raise ValueError(msg) - - -ENV = gp.Env(empty=True) -ENV.setParam("OutputFlag", 0) -ENV.start() -CONSOLE = Console() - - -def main(explainers: dict[str, type[BaseExplainer]]) -> None: - args = parse_args() - data, target, mapper = load_data(args) - rf = fit_model(args, data, target) - times: dict[str, pd.Series[float]] = {} - for name, explainer in explainers.items(): - exp = build_explainer(name, explainer, args, rf, mapper) - queries = generate_queries(args, rf, data) - times[name] = run_queries(exp, queries) - display_statistics(times) - - -def load_data( - args: Args, -) -> tuple[pd.DataFrame, "pd.Series[int]", Mapper[Feature]]: - with CONSOLE.status("[bold blue]Loading the data[/bold blue]"): - (data, target), mapper = load(args.dataset) - CONSOLE.print("[bold green]Data loaded[/bold green]") - return data, target, mapper - - -def fit_model( - args: Args, - data: pd.DataFrame, - target: "pd.Series[int]", -) -> RandomForestClassifier: - X_train, _, y_train, _ = train_test_split( - data, - target, - test_size=0.2, - random_state=args.seed, - ) - with CONSOLE.status("[bold blue]Fitting a Random Forest model[/bold blue]"): - rf = RandomForestClassifier( - n_estimators=args.n_estimators, - random_state=args.seed, - max_depth=args.max_depth, - ) - rf.fit(X_train, y_train) - CONSOLE.print("[bold green]Model fitted[/bold green]") - return rf - - -def build_explainer( - name: str, - explainer: type[BaseExplainer], - args: Args, - rf: RandomForestClassifier, - mapper: Mapper[Feature], -) -> BaseExplainer: - with CONSOLE.status("[bold blue]Building the Explainer[/bold blue]"): - start = time.time() - if name == "mip": - ENV.setParam("Seed", args.seed) - exp = explainer(rf, mapper=mapper, env=ENV) # type: ignore # noqa: PGH003 - else: - exp = explainer(rf, mapper=mapper) # type: ignore # noqa: PGH003 - end = time.time() - CONSOLE.print(f"[bold green]{name.upper()} Explainer built[/bold green]") - msg = f"Build time: {end - start:.2f} seconds" - CONSOLE.print(f"\t[bold yellow]{msg}[/bold yellow]") - return exp - - -def generate_queries( - args: Args, - rf: RandomForestClassifier, - data: pd.DataFrame, -) -> list[tuple[Array1D, int]]: - _, X_test = train_test_split( - data, - test_size=0.2, - random_state=args.seed, - ) - X_test = pd.DataFrame(X_test) - y_pred = rf.predict(X_test) - with CONSOLE.status("[bold blue]Generating queries[/bold blue]"): - queries: list[tuple[Array1D, int]] = [ - (X_test.iloc[i].to_numpy().flatten(), 1 - y_pred[i]) - for i in range(min(args.n_examples, len(X_test))) - ] - CONSOLE.print("[bold green]Queries generated[/bold green]") - return queries - - -def run_queries( - explainer: BaseExplainer, queries: list[tuple[Array1D, int]] -) -> "pd.Series[float]": - times: pd.Series[float] = pd.Series() - for i, (x, y) in track( - enumerate(queries), - total=len(queries), - description="[bold blue]Running queries[/bold blue]", - ): - start = time.time() - explainer.explain(x, y=y, norm=1) - explainer.cleanup() - end = time.time() - times[i] = end - start - return times - - -def create_table_row( - metric: str, times: dict[str, "pd.Series[float]"] -) -> list[str]: - row = [metric] - for t in times.values(): - if metric == "Number of queries": - row.append(str(len(t))) - elif metric == "Total time (seconds)": - row.append(f"{t.sum():.2f}") - elif metric == "Mean time per query (seconds)": - row.append(f"{t.mean():.2f}") - elif metric == "Std of time per query (seconds)": - row.append(f"{t.std():.2f}") - elif metric == "Maximum time per query (seconds)": - row.append(f"{t.max():.2f}") - elif metric == "Minimum time per query (seconds)": - row.append(f"{t.min():.2f}") - else: - row.append("N/A") - return row - - -def display_statistics(times: dict[str, "pd.Series[float]"]) -> None: - CONSOLE.print("[bold blue]Statistics:[/bold blue]") - table = Table(show_header=True, header_style="bold magenta") - table.add_column("Metric", style="dim", width=30) - names = list(times.keys()) - for name in names: - table.add_column(name.upper()) - metrics = [ - "Number of queries", - "Total time (seconds)", - "Mean time per query (seconds)", - "Std of time per query (seconds)", - "Maximum time per query (seconds)", - "Minimum time per query (seconds)", - ] - for metric in metrics: - row = create_table_row(metric, times) - table.add_row(*row) - CONSOLE.print(table) - CONSOLE.print("[bold green]Done[/bold green]") - - -if __name__ == "__main__": - explainers = { - "mip": MixedIntegerProgramExplainer, - "cp": ConstraintProgrammingExplainer, - } - main(explainers) diff --git a/examples/simple_example_both.py b/examples/simple_example_both.py index b7c3f74..4d2a316 100644 --- a/examples/simple_example_both.py +++ b/examples/simple_example_both.py @@ -29,8 +29,9 @@ from sklearn.tree import plot_tree import matplotlib.pyplot as plt + # Plot the first tree of the forest -plt.figure(figsize=(20,10)) +plt.figure(figsize=(20, 10)) plot_tree(rf.estimators_[0], filled=True) plt.title("First tree of the Random Forest") plt.savefig("./first_tree_rf.png") @@ -39,7 +40,7 @@ liste_thresholds = [] for tree in rf.estimators_: liste_thresholds.extend(tree.tree_.threshold[tree.tree_.feature == 25]) -print("Tree thresholds for feature 25:", sorted(liste_thresholds) ) +print("Tree thresholds for feature 25:", sorted(liste_thresholds)) print("RF train acc= ", rf.score(X_train, y_train)) print("RF test acc= ", rf.score(X_test, y_test)) @@ -50,9 +51,10 @@ ) # Define a CF query using the qid-th element of the test set -#qid = 1 -#query = X_test.iloc[qid] -import numpy as np +# qid = 1 +# query = X_test.iloc[qid] +import numpy as np + qid = 10 query = X_test.iloc[qid] query_pred = rf.predict([np.asarray(query)])[0] @@ -60,7 +62,7 @@ # Use the MILP formulation to generate a CF milp_model = MixedIntegerProgramExplainer(rf, mapper=mapper) -#print("milp_model._num_epsilon", milp_model._num_epsilon) +# print("milp_model._num_epsilon", milp_model._num_epsilon) start_ = time.time() explanation_ocean = milp_model.explain( query, @@ -73,7 +75,7 @@ ) milp_time = time.time() - start_ cf = explanation_ocean -#cf[4] += 0.0001 +# cf[4] += 0.0001 if explanation_ocean is not None: print( "MILP : ", @@ -82,7 +84,7 @@ rf.predict([explanation_ocean.to_numpy()])[0], ")", ) - #print("MILP Sollist = ", milp_model.get_anytime_solutions()) + # print("MILP Sollist = ", milp_model.get_anytime_solutions()) else: print("MILP: No CF found.") @@ -107,10 +109,16 @@ sample_id = 0 # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id` node_index = node_indicator.indices[ - node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1] + node_indicator.indptr[ + sample_id + ] : node_indicator.indptr[sample_id + 1] ] - print("[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format(i=i, id=sample_id)) + print( + "[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format( + i=i, id=sample_id + ) + ) for node_id in node_index: # continue to the next node if it is a leaf node if leaf_id[sample_id] == node_id: @@ -121,7 +129,10 @@ threshold_sign = "<=" else: threshold_sign = ">" - if np.abs(cf[feature[node_id]] - threshold[node_id]) < 1e-3: + if ( + np.abs(cf[feature[node_id]] - threshold[node_id]) + < 1e-3 + ): print( "decision node {node} : (cf[{feature}] = {value}) " "{inequality} {threshold})".format( @@ -161,7 +172,7 @@ rf.predict([explanation_oceancp.to_numpy()])[0], ")", ) - #print("CP Sollist = ", cp_model.get_anytime_solutions()) + # print("CP Sollist = ", cp_model.get_anytime_solutions()) else: print("CP: No CF found.") @@ -185,10 +196,16 @@ sample_id = 0 # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id` node_index = node_indicator.indices[ - node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1] + node_indicator.indptr[ + sample_id + ] : node_indicator.indptr[sample_id + 1] ] print(node_index) - print("[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format(i=i, id=sample_id)) + print( + "[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format( + i=i, id=sample_id + ) + ) for node_id in node_index: # continue to the next node if it is a leaf node if leaf_id[sample_id] == node_id: @@ -199,7 +216,10 @@ threshold_sign = "<=" else: threshold_sign = ">" - if np.abs(cf[feature[node_id]] - threshold[node_id]) < 1e-3: + if ( + np.abs(cf[feature[node_id]] - threshold[node_id]) + < 1e-3 + ): print( "decision node {node} : (cf[{feature}] = {value}) " "{inequality} {threshold})".format( diff --git a/pyrightconfig.json b/pyrightconfig.json index 587aebb..1b26508 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -6,16 +6,5 @@ "ignore": [ "sklearn", "anytree" - ], - "executionEnvironments": [ - { - "root": "./examples", - "extraPaths": [ - "../ocean" - ], - "reportArgumentType": "none", - "reportUnknownArgumentType": "none", - "reportUnknownVariableType": "none" - } ] } \ No newline at end of file From a0e6956eb021612ff66943b0322977f826939903 Mon Sep 17 00:00:00 2001 From: eminyous Date: Mon, 13 Oct 2025 17:48:31 -0400 Subject: [PATCH 13/14] fix xgb parsing. --- ocean/tree/_parse_xgb.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py index c694750..4761bab 100644 --- a/ocean/tree/_parse_xgb.py +++ b/ocean/tree/_parse_xgb.py @@ -27,11 +27,11 @@ def _build_xgb_leaf( weight = float(_get_column_value(xgb_tree, node_id, "Gain")) if num_trees_per_round == 1: - value = np.array([weight, -weight]) + value = np.array([[weight, -weight]]) else: k = int(tree_id % num_trees_per_round) - value = np.zeros(int(num_trees_per_round), dtype=float) - value[k] = weight + value = np.zeros((1, int(num_trees_per_round)), dtype=float) + value[0, k] = weight return Node(node_id, n_samples=0, value=value) From f7917cd44f6141a7249adc10d071961c8c3a5dc6 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 13 Oct 2025 18:21:30 -0400 Subject: [PATCH 14/14] fix: remove verbose output during query execution and update README file --- README.md | 21 +++++++++++++++++---- examples/query.py | 1 - 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c692122..adfbab4 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,21 @@ # Optimal Counterfactual Explanations in Tree Ensembles +[![Maintained](https://img.shields.io/badge/Maintained-YES-14b8a6?style=for-the-badge&logo=github)](https://github.com/vidalt/OCEAN/graphs/commit-activity) +[![License](https://img.shields.io/github/license/vidalt/OCEAN?style=for-the-badge&color=0ea5e9&logo=unlicense&logoColor=white)](https://github.com/vidalt/OCEAN/blob/main/LICENSE) +[![Contributors](https://img.shields.io/github/contributors/vidalt/OCEAN?style=for-the-badge&color=38bdf8&logo=github)](https://github.com/vidalt/OCEAN/graphs/contributors) +[![Stars](https://img.shields.io/github/stars/vidalt/OCEAN?style=for-the-badge&color=0284c7&logo=github)](https://github.com/vidalt/OCEAN/stargazers) +[![Watchers](https://img.shields.io/github/watchers/vidalt/OCEAN?style=for-the-badge&color=2563eb&logo=github)](https://github.com/vidalt/OCEAN/watchers) +[![Forks](https://img.shields.io/github/forks/vidalt/OCEAN?style=for-the-badge&color=1d4ed8&logo=github)](https://github.com/vidalt/OCEAN/network/members) +[![PRs](https://img.shields.io/github/issues-pr/vidalt/OCEAN?style=for-the-badge&color=22c55e&logo=github)](https://github.com/vidalt/OCEAN/pulls) + + + ![Logo](https://github.com/eminyous/ocean/blob/main/logo.svg?raw=True) **ocean** is a full package dedicated to counterfactual explanations for **tree ensembles**. It builds on the paper *Optimal Counterfactual Explanations in Tree Ensemble* by Axel Parmentier and Thibaut Vidal in the *Proceedings of the thirty-eighth International Conference on Machine Learning*, 2021, in press. The article is [available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf). Beyond the original MIP approach, ocean includes a new **constraint programming (CP)** method and will grow to cover additional formulations and heuristics. - ## Installation You can install the package with the following command: @@ -88,7 +97,7 @@ WorkClass : 6 ``` - +See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) for more usage examples. ## Feature Preview & Roadmap @@ -101,6 +110,10 @@ WorkClass : 6 | **Heuristics** | ⏳ Upcoming | Fast approximate methods. | | **Other methods** | ⏳ Upcoming | Additional formulations under exploration. | | **Random Forest support** | ✅ Ready | Fully supported in ocean. | -| **XGBoost support** | ⏳ Upcoming | Implementation planned. | +| **XGBoost support** | ✅ Ready | Fully supported in ocean. | + +> Legend: ✅ available · ⏳ upcoming + +## Stargazers over time -> Legend: ✅ available · ⏳ upcoming \ No newline at end of file +[![Stargazers over time](https://starchart.cc/vidalt/OCEAN.svg)](https://starchart.cc/vidalt/OCEAN) \ No newline at end of file diff --git a/examples/query.py b/examples/query.py index 3b06b99..53284b5 100644 --- a/examples/query.py +++ b/examples/query.py @@ -207,7 +207,6 @@ def run_queries_verbose( total=len(queries), description="[bold blue]Running queries[/bold blue]", ): - print(f"Running query {i + 1}/{len(queries)}, x= {x}, target={y}") start = time.time() explainer.explain(x, y=y, norm=1) explainer.cleanup()