diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 230365f..0e21ca7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,10 @@ jobs: python-version: ${{ matrix.python-version }} cache: pip + - name: Install OpenMP (macOS) + run: brew install libomp + if: runner.os == 'macOS' + - name: Install tox run: python -m pip install --upgrade pip tox @@ -76,4 +80,4 @@ jobs: path: | coverage.xml htmlcov/ - retention-days: 30 + retention-days: 30 \ No newline at end of file diff --git a/README.md b/README.md index c692122..adfbab4 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,21 @@ # Optimal Counterfactual Explanations in Tree Ensembles +[![Maintained](https://img.shields.io/badge/Maintained-YES-14b8a6?style=for-the-badge&logo=github)](https://github.com/vidalt/OCEAN/graphs/commit-activity) +[![License](https://img.shields.io/github/license/vidalt/OCEAN?style=for-the-badge&color=0ea5e9&logo=unlicense&logoColor=white)](https://github.com/vidalt/OCEAN/blob/main/LICENSE) +[![Contributors](https://img.shields.io/github/contributors/vidalt/OCEAN?style=for-the-badge&color=38bdf8&logo=github)](https://github.com/vidalt/OCEAN/graphs/contributors) +[![Stars](https://img.shields.io/github/stars/vidalt/OCEAN?style=for-the-badge&color=0284c7&logo=github)](https://github.com/vidalt/OCEAN/stargazers) +[![Watchers](https://img.shields.io/github/watchers/vidalt/OCEAN?style=for-the-badge&color=2563eb&logo=github)](https://github.com/vidalt/OCEAN/watchers) +[![Forks](https://img.shields.io/github/forks/vidalt/OCEAN?style=for-the-badge&color=1d4ed8&logo=github)](https://github.com/vidalt/OCEAN/network/members) +[![PRs](https://img.shields.io/github/issues-pr/vidalt/OCEAN?style=for-the-badge&color=22c55e&logo=github)](https://github.com/vidalt/OCEAN/pulls) + + + ![Logo](https://github.com/eminyous/ocean/blob/main/logo.svg?raw=True) **ocean** is a full package dedicated to counterfactual explanations for **tree ensembles**. It builds on the paper *Optimal Counterfactual Explanations in Tree Ensemble* by Axel Parmentier and Thibaut Vidal in the *Proceedings of the thirty-eighth International Conference on Machine Learning*, 2021, in press. The article is [available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf). Beyond the original MIP approach, ocean includes a new **constraint programming (CP)** method and will grow to cover additional formulations and heuristics. - ## Installation You can install the package with the following command: @@ -88,7 +97,7 @@ WorkClass : 6 ``` - +See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) for more usage examples. ## Feature Preview & Roadmap @@ -101,6 +110,10 @@ WorkClass : 6 | **Heuristics** | ⏳ Upcoming | Fast approximate methods. | | **Other methods** | ⏳ Upcoming | Additional formulations under exploration. | | **Random Forest support** | ✅ Ready | Fully supported in ocean. | -| **XGBoost support** | ⏳ Upcoming | Implementation planned. | +| **XGBoost support** | ✅ Ready | Fully supported in ocean. | + +> Legend: ✅ available · ⏳ upcoming + +## Stargazers over time -> Legend: ✅ available · ⏳ upcoming \ No newline at end of file +[![Stargazers over time](https://starchart.cc/vidalt/OCEAN.svg)](https://starchart.cc/vidalt/OCEAN) \ No newline at end of file diff --git a/examples/query.py b/examples/query.py new file mode 100644 index 0000000..53284b5 --- /dev/null +++ b/examples/query.py @@ -0,0 +1,294 @@ +import time +from argparse import ArgumentParser +from dataclasses import dataclass + +import gurobipy as gp +import pandas as pd +from rich.console import Console +from rich.progress import track +from rich.table import Table +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from xgboost import XGBClassifier + +from ocean import ( + ConstraintProgrammingExplainer, + MixedIntegerProgramExplainer, +) +from ocean.abc import Mapper +from ocean.datasets import load_adult, load_compas, load_credit +from ocean.feature import Feature +from ocean.typing import Array1D, BaseExplainer + +# Global constants +ENV = gp.Env(empty=True) +ENV.setParam("OutputFlag", 0) +ENV.start() +CONSOLE = Console() +EXPLAINERS = { + "mip": MixedIntegerProgramExplainer, + "cp": ConstraintProgrammingExplainer, +} +MODELS = { + "rf": RandomForestClassifier, + "xgb": XGBClassifier, +} + + +@dataclass +class Args: + seed: int + n_estimators: int + max_depth: int + n_examples: int + dataset: str + explainers: list[str] + models: list[str] + + +def create_argument_parser() -> ArgumentParser: + parser = ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--n-estimators", + type=int, + default=100, + dest="n_estimators", + ) + parser.add_argument("--max-depth", type=int, default=5, dest="max_depth") + parser.add_argument( + "--n-examples", + type=int, + default=100, + dest="n_examples", + ) + parser.add_argument( + "--dataset", + type=str, + choices=["adult", "compas", "credit"], + default="compas", + ) + parser.add_argument( + "-e", + "--exp", + "--explainer", + help="List of explainers to use", + type=str, + nargs="+", + choices=["mip", "cp"], + default=["mip", "cp"], + ) + parser.add_argument( + "-m", + "--model", + help="List of models to use", + type=str, + nargs="+", + choices=["rf", "xgb"], + default=["rf"], + ) + return parser + + +def parse_args() -> Args: + parser = create_argument_parser() + args = parser.parse_args() + return Args( + seed=args.seed, + n_estimators=args.n_estimators, + max_depth=args.max_depth, + n_examples=args.n_examples, + dataset=args.dataset, + explainers=args.exp, + models=args.model, + ) + + +def load_dataset( + dataset: str, +) -> tuple[tuple[pd.DataFrame, pd.Series], Mapper[Feature]]: + if dataset == "credit": + return load_credit() + if dataset == "adult": + return load_adult() + if dataset == "compas": + return load_compas() + msg = f"Unknown dataset: {dataset}" + raise ValueError(msg) + + +def load_data(args: Args) -> tuple[pd.DataFrame, pd.Series, Mapper[Feature]]: + with CONSOLE.status("[bold blue]Loading the data[/bold blue]"): + (data, target), mapper = load_dataset(args.dataset) + CONSOLE.print("[bold green]Data loaded[/bold green]") + return data, target, mapper + + +def fit_model_with_console( + args: Args, + data: pd.DataFrame, + target: pd.Series, + model_class: type[RandomForestClassifier] | type[XGBClassifier], + model_name: str, + **model_kwargs: str | float | bool | None, +) -> RandomForestClassifier | XGBClassifier: + X_train, _, y_train, _ = train_test_split( + data, + target, + test_size=0.2, + random_state=args.seed, + ) + with CONSOLE.status(f"[bold blue]Fitting a {model_name} model[/bold blue]"): + model = model_class( + n_estimators=args.n_estimators, + random_state=args.seed, + max_depth=args.max_depth, + **model_kwargs, + ) + model.fit(X_train, y_train) + CONSOLE.print("[bold green]Model fitted[/bold green]") + return model + + +def build_explainer( + explainer_name: str, + explainer_class: type[MixedIntegerProgramExplainer] + | type[ConstraintProgrammingExplainer], + args: Args, + model: RandomForestClassifier | XGBClassifier, + mapper: Mapper[Feature], +) -> BaseExplainer: + with CONSOLE.status("[bold blue]Building the Explainer[/bold blue]"): + start = time.time() + if explainer_class is MixedIntegerProgramExplainer: + ENV.setParam("Seed", args.seed) + exp = explainer_class(model, mapper=mapper, env=ENV) + elif explainer_class is ConstraintProgrammingExplainer: + exp = explainer_class(model, mapper=mapper) + else: + msg = f"Unknown explainer type: {explainer_class}" + raise ValueError(msg) + end = time.time() + CONSOLE.print( + f"[bold green]{explainer_name.upper()} Explainer built[/bold green]" + ) + msg = f"Build time: {end - start:.2f} seconds" + CONSOLE.print(f"\t[bold yellow]{msg}[/bold yellow]") + return exp + + +def generate_queries( + args: Args, + model: RandomForestClassifier | XGBClassifier, + data: pd.DataFrame, +) -> list[tuple[Array1D, int]]: + _, X_test = train_test_split( + data, + test_size=0.2, + random_state=args.seed, + ) + X_test = pd.DataFrame(X_test) + y_pred = model.predict(X_test) + with CONSOLE.status("[bold blue]Generating queries[/bold blue]"): + queries: list[tuple[Array1D, int]] = [ + (X_test.iloc[i].to_numpy().flatten(), 1 - y_pred[i]) + for i in range(min(args.n_examples, len(X_test))) + ] + CONSOLE.print("[bold green]Queries generated[/bold green]") + return queries + + +def run_queries_verbose( + explainer: BaseExplainer, queries: list[tuple[Array1D, int]] +) -> "pd.Series[float]": + times: pd.Series[float] = pd.Series() + for i, (x, y) in track( + enumerate(queries), + total=len(queries), + description="[bold blue]Running queries[/bold blue]", + ): + start = time.time() + explainer.explain(x, y=y, norm=1) + explainer.cleanup() + end = time.time() + times[i] = end - start + return times + + +def create_table_row( + metric: str, times: dict[str, "pd.Series[float]"] +) -> list[str]: + row = [metric] + for t in times.values(): + if metric == "Number of queries": + row.append(str(len(t))) + elif metric == "Total time (seconds)": + row.append(f"{t.sum():.2f}") + elif metric == "Mean time per query (seconds)": + row.append(f"{t.mean():.2f}") + elif metric == "Std of time per query (seconds)": + row.append(f"{t.std():.2f}") + elif metric == "Maximum time per query (seconds)": + row.append(f"{t.max():.2f}") + elif metric == "Minimum time per query (seconds)": + row.append(f"{t.min():.2f}") + else: + row.append("N/A") + return row + + +def display_statistics(times: dict[str, "pd.Series[float]"]) -> None: + """Display timing statistics in a table.""" + CONSOLE.print("[bold blue]Statistics:[/bold blue]") + table = Table(show_header=True, header_style="bold magenta") + table.add_column("Metric", style="dim", width=30) + names = list(times.keys()) + for name in names: + table.add_column(name.upper()) + metrics = [ + "Number of queries", + "Total time (seconds)", + "Mean time per query (seconds)", + "Std of time per query (seconds)", + "Maximum time per query (seconds)", + "Minimum time per query (seconds)", + ] + for metric in metrics: + row = create_table_row(metric, times) + table.add_row(*row) + CONSOLE.print(table) + CONSOLE.print("[bold green]Done[/bold green]") + + +def main() -> None: + args = parse_args() + data, target, mapper = load_data(args) + explainers = { + name: explainer + for name, explainer in EXPLAINERS.items() + if name in args.explainers + } + models = { + name: model for name, model in MODELS.items() if name in args.models + } + for model_name, model_class in models.items(): + CONSOLE.print( + f"[bold blue]Running experiment with {model_name}: [/bold blue]" + ) + model = fit_model_with_console( + args, data, target, model_class, model_name + ) + for explainer_name, explainer_class in explainers.items(): + CONSOLE.print( + f"[bold blue]Running for {explainer_name}[/bold blue]" + ) + exp = build_explainer( + explainer_name, explainer_class, args, model, mapper + ) + queries = generate_queries(args, model, data) + times = run_queries_verbose(exp, queries) + display_statistics({explainer_name: times}) + + +if __name__ == "__main__": + main() diff --git a/examples/random_forest.py b/examples/random_forest.py deleted file mode 100644 index e817bed..0000000 --- a/examples/random_forest.py +++ /dev/null @@ -1,191 +0,0 @@ -import time -from argparse import ArgumentParser -from dataclasses import dataclass - -import gurobipy as gp -import pandas as pd -from rich.console import Console -from rich.progress import track -from rich.table import Table -from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import train_test_split - -from ocean import MixedIntegerProgramExplainer -from ocean.abc import Mapper -from ocean.datasets import load_adult, load_compas, load_credit -from ocean.feature import Feature -from ocean.typing import Array1D - -Loaded = tuple[tuple[pd.DataFrame, "pd.Series[int]"], Mapper[Feature]] - - -@dataclass -class Args: - seed: int - n_estimators: int - max_depth: int - n_examples: int - dataset: str - - -def parse_args() -> Args: - parser = ArgumentParser() - parser.add_argument("--seed", type=int, default=42) - parser.add_argument( - "--n-estimators", - type=int, - default=100, - dest="n_estimators", - ) - parser.add_argument("--max-depth", type=int, default=5, dest="max_depth") - parser.add_argument( - "--n-examples", - type=int, - default=100, - dest="n_examples", - ) - parser.add_argument( - "--dataset", - type=str, - choices=["adult", "compas", "credit"], - default="compas", - ) - args = parser.parse_args() - return Args( - seed=args.seed, - n_estimators=args.n_estimators, - max_depth=args.max_depth, - n_examples=args.n_examples, - dataset=args.dataset, - ) - - -def load(dataset: str) -> Loaded: - if dataset == "credit": - return load_credit() - if dataset == "adult": - return load_adult() - if dataset == "compas": - return load_compas() - msg = f"Unknown dataset: {dataset}" - raise ValueError(msg) - - -ENV = gp.Env(empty=True) -ENV.setParam("OutputFlag", 0) -ENV.start() -CONSOLE = Console() - - -def main() -> None: - args = parse_args() - data, target, mapper = load_data(args) - rf = fit_model(args, data, target) - mip = build_explainer(args, rf, mapper) - queries = generate_queries(args, rf, data) - times = run_queries(mip, queries) - display_statistics(times) - - -def load_data( - args: Args, -) -> tuple[pd.DataFrame, "pd.Series[int]", Mapper[Feature]]: - with CONSOLE.status("[bold blue]Loading the data[/bold blue]"): - (data, target), mapper = load(args.dataset) - CONSOLE.print("[bold green]Data loaded[/bold green]") - return data, target, mapper - - -def fit_model( - args: Args, - data: pd.DataFrame, - target: "pd.Series[int]", -) -> RandomForestClassifier: - X_train, _, y_train, _ = train_test_split( - data, - target, - test_size=0.2, - random_state=args.seed, - ) - with CONSOLE.status("[bold blue]Fitting a Random Forest model[/bold blue]"): - rf = RandomForestClassifier( - n_estimators=args.n_estimators, - random_state=args.seed, - max_depth=args.max_depth, - ) - rf.fit(X_train, y_train) - CONSOLE.print("[bold green]Model fitted[/bold green]") - return rf - - -def build_explainer( - args: Args, - rf: RandomForestClassifier, - mapper: Mapper[Feature], -) -> MixedIntegerProgramExplainer: - with CONSOLE.status("[bold blue]Building the Explainer[/bold blue]"): - ENV.setParam("Seed", args.seed) - start = time.time() - mip = MixedIntegerProgramExplainer(rf, mapper=mapper, env=ENV) - end = time.time() - CONSOLE.print("[bold green]Explainer built[/bold green]") - msg = f"Build time: {end - start:.2f} seconds" - CONSOLE.print(f"[bold yellow]{msg}[/bold yellow]") - return mip - - -def generate_queries( - args: Args, - rf: RandomForestClassifier, - data: pd.DataFrame, -) -> list[tuple[Array1D, int]]: - _, X_test = train_test_split( - data, - test_size=0.2, - random_state=args.seed, - ) - X_test = pd.DataFrame(X_test) - y_pred = rf.predict(X_test) - with CONSOLE.status("[bold blue]Generating queries[/bold blue]"): - queries: list[tuple[Array1D, int]] = [ - (X_test.iloc[i].to_numpy().flatten(), 1 - y_pred[i]) - for i in range(min(args.n_examples, len(X_test))) - ] - CONSOLE.print("[bold green]Queries generated[/bold green]") - return queries - - -def run_queries( - mip: MixedIntegerProgramExplainer, queries: list[tuple[Array1D, int]] -) -> "pd.Series[float]": - times: pd.Series[float] = pd.Series() - for i, (x, y) in track( - enumerate(queries), - total=len(queries), - description="[bold blue]Running queries[/bold blue]", - ): - start = time.time() - mip.explain(x, y=y, norm=1) - mip.cleanup() - end = time.time() - times[i] = end - start - return times - - -def display_statistics(times: "pd.Series[int]") -> None: - CONSOLE.print("[bold blue]Statistics:[/bold blue]") - table = Table(show_header=True, header_style="bold magenta") - table.add_column("Metric", style="dim", width=30) - table.add_column("Value") - table.add_row("Number of queries", str(len(times))) - table.add_row("Total time (seconds)", f"{times.sum():.2f}") - table.add_row("Mean time per query (seconds)", f"{times.mean():.2f}") - table.add_row("Std of time per query (seconds)", f"{times.std():.2f}") - table.add_row("Maximum time per query (seconds)", f"{times.max():.2f}") - table.add_row("Minimum time per query (seconds)", f"{times.min():.2f}") - CONSOLE.print(table) - CONSOLE.print("[bold green]Done[/bold green]") - - -if __name__ == "__main__": - main() diff --git a/examples/random_forest_cp.py b/examples/random_forest_cp.py deleted file mode 100644 index 7477ca6..0000000 --- a/examples/random_forest_cp.py +++ /dev/null @@ -1,186 +0,0 @@ -import time -from argparse import ArgumentParser -from dataclasses import dataclass - -import pandas as pd -from rich.console import Console -from rich.progress import track -from rich.table import Table -from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import train_test_split - -from ocean import ConstraintProgrammingExplainer -from ocean.abc import Mapper -from ocean.datasets import load_adult, load_compas, load_credit -from ocean.feature import Feature -from ocean.typing import Array1D - -Loaded = tuple[tuple[pd.DataFrame, "pd.Series[int]"], Mapper[Feature]] - - -@dataclass -class Args: - """Command line arguments.""" - - seed: int - n_estimators: int - max_depth: int - n_examples: int - dataset: str - - -def parse_args() -> Args: - parser = ArgumentParser() - parser.add_argument("--seed", type=int, default=42) - parser.add_argument( - "--n-estimators", - type=int, - default=100, - dest="n_estimators", - ) - parser.add_argument("--max-depth", type=int, default=5, dest="max_depth") - parser.add_argument( - "--n-examples", - type=int, - default=100, - dest="n_examples", - ) - parser.add_argument( - "--dataset", - type=str, - choices=["adult", "compas", "credit"], - default="compas", - ) - args = parser.parse_args() - return Args( - seed=args.seed, - n_estimators=args.n_estimators, - max_depth=args.max_depth, - n_examples=args.n_examples, - dataset=args.dataset, - ) - - -def load(dataset: str) -> Loaded: - if dataset == "credit": - return load_credit() - if dataset == "adult": - return load_adult() - if dataset == "compas": - return load_compas() - msg = f"Unknown dataset: {dataset}" - raise ValueError(msg) - - -CONSOLE = Console() - - -def main() -> None: - args = parse_args() - data, target, mapper = load_data(args) - rf = fit_model(args, data, target) - cp = build_explainer(rf, mapper) - queries = generate_queries(args, rf, data) - times = run_queries(cp, queries) - display_statistics(times) - - -def load_data( - args: Args, -) -> tuple[pd.DataFrame, "pd.Series[int]", Mapper[Feature]]: - with CONSOLE.status("[bold blue]Loading the data[/bold blue]"): - (data, target), mapper = load(args.dataset) - CONSOLE.print("[bold green]Data loaded[/bold green]") - return data, target, mapper - - -def fit_model( - args: Args, - data: pd.DataFrame, - target: "pd.Series[int]", -) -> RandomForestClassifier: - X_train, _, y_train, _ = train_test_split( - data, - target, - test_size=0.2, - random_state=args.seed, - ) - with CONSOLE.status("[bold blue]Fitting a Random Forest model[/bold blue]"): - rf = RandomForestClassifier( - n_estimators=args.n_estimators, - random_state=args.seed, - max_depth=args.max_depth, - ) - rf.fit(X_train, y_train) - CONSOLE.print("[bold green]Model fitted[/bold green]") - return rf - - -def build_explainer( - rf: RandomForestClassifier, mapper: Mapper[Feature] -) -> ConstraintProgrammingExplainer: - with CONSOLE.status("[bold blue]Building the Explainer[/bold blue]"): - start = time.time() - cp = ConstraintProgrammingExplainer(rf, mapper=mapper) - end = time.time() - CONSOLE.print("[bold green]Explainer built[/bold green]") - msg = f"Build time: {end - start:.2f} seconds" - CONSOLE.print(f"[bold yellow]{msg}[/bold yellow]") - return cp - - -def generate_queries( - args: Args, - rf: RandomForestClassifier, - data: pd.DataFrame, -) -> list[tuple[Array1D, int]]: - _, X_test = train_test_split( - data, - test_size=0.2, - random_state=args.seed, - ) - X_test = pd.DataFrame(X_test) - y_pred = rf.predict(X_test) - with CONSOLE.status("[bold blue]Generating queries[/bold blue]"): - queries: list[tuple[Array1D, int]] = [ - (X_test.iloc[i].to_numpy().flatten(), 1 - y_pred[i]) - for i in range(min(args.n_examples, len(X_test))) - ] - CONSOLE.print("[bold green]Queries generated[/bold green]") - return queries - - -def run_queries( - cp: ConstraintProgrammingExplainer, queries: list[tuple[Array1D, int]] -) -> "pd.Series[float]": - times: pd.Series[float] = pd.Series() - for i, (x, y) in track( - enumerate(queries), - total=len(queries), - description="[bold blue]Running queries[/bold blue]", - ): - start = time.time() - cp.explain(x, y=y, norm=1) - cp.cleanup() - end = time.time() - times[i] = end - start - return times - - -def display_statistics(times: "pd.Series[int]") -> None: - CONSOLE.print("[bold blue]Statistics:[/bold blue]") - table = Table(show_header=True, header_style="bold magenta") - table.add_column("Metric", style="dim", width=30) - table.add_column("Value") - table.add_row("Number of queries", str(len(times))) - table.add_row("Total time (seconds)", f"{times.sum():.2f}") - table.add_row("Mean time per query (seconds)", f"{times.mean():.2f}") - table.add_row("Std of time per query (seconds)", f"{times.std():.2f}") - table.add_row("Maximum time per query (seconds)", f"{times.max():.2f}") - table.add_row("Minimum time per query (seconds)", f"{times.min():.2f}") - CONSOLE.print(table) - CONSOLE.print("[bold green]Done[/bold green]") - - -if __name__ == "__main__": - main() diff --git a/examples/simple_example_both.py b/examples/simple_example_both.py index b7c3f74..4d2a316 100644 --- a/examples/simple_example_both.py +++ b/examples/simple_example_both.py @@ -29,8 +29,9 @@ from sklearn.tree import plot_tree import matplotlib.pyplot as plt + # Plot the first tree of the forest -plt.figure(figsize=(20,10)) +plt.figure(figsize=(20, 10)) plot_tree(rf.estimators_[0], filled=True) plt.title("First tree of the Random Forest") plt.savefig("./first_tree_rf.png") @@ -39,7 +40,7 @@ liste_thresholds = [] for tree in rf.estimators_: liste_thresholds.extend(tree.tree_.threshold[tree.tree_.feature == 25]) -print("Tree thresholds for feature 25:", sorted(liste_thresholds) ) +print("Tree thresholds for feature 25:", sorted(liste_thresholds)) print("RF train acc= ", rf.score(X_train, y_train)) print("RF test acc= ", rf.score(X_test, y_test)) @@ -50,9 +51,10 @@ ) # Define a CF query using the qid-th element of the test set -#qid = 1 -#query = X_test.iloc[qid] -import numpy as np +# qid = 1 +# query = X_test.iloc[qid] +import numpy as np + qid = 10 query = X_test.iloc[qid] query_pred = rf.predict([np.asarray(query)])[0] @@ -60,7 +62,7 @@ # Use the MILP formulation to generate a CF milp_model = MixedIntegerProgramExplainer(rf, mapper=mapper) -#print("milp_model._num_epsilon", milp_model._num_epsilon) +# print("milp_model._num_epsilon", milp_model._num_epsilon) start_ = time.time() explanation_ocean = milp_model.explain( query, @@ -73,7 +75,7 @@ ) milp_time = time.time() - start_ cf = explanation_ocean -#cf[4] += 0.0001 +# cf[4] += 0.0001 if explanation_ocean is not None: print( "MILP : ", @@ -82,7 +84,7 @@ rf.predict([explanation_ocean.to_numpy()])[0], ")", ) - #print("MILP Sollist = ", milp_model.get_anytime_solutions()) + # print("MILP Sollist = ", milp_model.get_anytime_solutions()) else: print("MILP: No CF found.") @@ -107,10 +109,16 @@ sample_id = 0 # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id` node_index = node_indicator.indices[ - node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1] + node_indicator.indptr[ + sample_id + ] : node_indicator.indptr[sample_id + 1] ] - print("[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format(i=i, id=sample_id)) + print( + "[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format( + i=i, id=sample_id + ) + ) for node_id in node_index: # continue to the next node if it is a leaf node if leaf_id[sample_id] == node_id: @@ -121,7 +129,10 @@ threshold_sign = "<=" else: threshold_sign = ">" - if np.abs(cf[feature[node_id]] - threshold[node_id]) < 1e-3: + if ( + np.abs(cf[feature[node_id]] - threshold[node_id]) + < 1e-3 + ): print( "decision node {node} : (cf[{feature}] = {value}) " "{inequality} {threshold})".format( @@ -161,7 +172,7 @@ rf.predict([explanation_oceancp.to_numpy()])[0], ")", ) - #print("CP Sollist = ", cp_model.get_anytime_solutions()) + # print("CP Sollist = ", cp_model.get_anytime_solutions()) else: print("CP: No CF found.") @@ -185,10 +196,16 @@ sample_id = 0 # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id` node_index = node_indicator.indices[ - node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1] + node_indicator.indptr[ + sample_id + ] : node_indicator.indptr[sample_id + 1] ] print(node_index) - print("[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format(i=i, id=sample_id)) + print( + "[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format( + i=i, id=sample_id + ) + ) for node_id in node_index: # continue to the next node if it is a leaf node if leaf_id[sample_id] == node_id: @@ -199,7 +216,10 @@ threshold_sign = "<=" else: threshold_sign = ">" - if np.abs(cf[feature[node_id]] - threshold[node_id]) < 1e-3: + if ( + np.abs(cf[feature[node_id]] - threshold[node_id]) + < 1e-3 + ): print( "decision node {node} : (cf[{feature}] = {value}) " "{inequality} {threshold})".format( diff --git a/ocean/tree/_parse.py b/ocean/tree/_parse.py index c42b0dc..3d3b090 100644 --- a/ocean/tree/_parse.py +++ b/ocean/tree/_parse.py @@ -3,16 +3,21 @@ from functools import partial from itertools import chain +import xgboost as xgb from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ..abc import Mapper from ..feature import Feature -from ..typing import NonNegativeInt, ParsableEnsemble +from ..typing import NonNegativeInt, ParsableEnsemble, SKLearnTree from ._node import Node -from ._protocol import SKLearnTree, SKLearnTreeProtocol, TreeProtocol +from ._parse_xgb import parse_xgb_ensemble +from ._protocol import ( + SKLearnTreeProtocol, + TreeProtocol, +) from ._tree import Tree -type DecisionTree = DecisionTreeClassifier | DecisionTreeRegressor +type SKLearnDecisionTree = DecisionTreeClassifier | DecisionTreeRegressor def _build_leaf(tree: TreeProtocol, node_id: NonNegativeInt) -> Node: @@ -70,13 +75,13 @@ def _parse_tree(sklearn_tree: SKLearnTree, *, mapper: Mapper[Feature]) -> Tree: return Tree(root=root) -def parse_tree(tree: DecisionTree, *, mapper: Mapper[Feature]) -> Tree: +def parse_tree(tree: SKLearnDecisionTree, *, mapper: Mapper[Feature]) -> Tree: getter = operator.attrgetter("tree_") return _parse_tree(getter(tree), mapper=mapper) def parse_trees( - trees: Iterable[DecisionTree], + trees: Iterable[SKLearnDecisionTree], *, mapper: Mapper[Feature], ) -> tuple[Tree, ...]: @@ -84,9 +89,21 @@ def parse_trees( return tuple(map(parser, trees)) +def parse_ensemble( + ensemble: ParsableEnsemble, + *, + mapper: Mapper[Feature], +) -> tuple[Tree, ...]: + if isinstance(ensemble, xgb.Booster): + return parse_xgb_ensemble(ensemble, mapper=mapper) + if isinstance(ensemble, xgb.XGBClassifier): + return parse_xgb_ensemble(ensemble.get_booster(), mapper=mapper) + return parse_trees(ensemble, mapper=mapper) + + def parse_ensembles( *ensembles: ParsableEnsemble, mapper: Mapper[Feature], ) -> tuple[Tree, ...]: - parser = partial(parse_trees, mapper=mapper) + parser = partial(parse_ensemble, mapper=mapper) return tuple(chain.from_iterable(map(parser, ensembles))) diff --git a/ocean/tree/_parse_xgb.py b/ocean/tree/_parse_xgb.py new file mode 100644 index 0000000..4761bab --- /dev/null +++ b/ocean/tree/_parse_xgb.py @@ -0,0 +1,211 @@ +from collections.abc import Iterable + +import numpy as np +import xgboost as xgb + +from ..abc import Mapper +from ..feature import Feature +from ..typing import NonNegativeInt, XGBTree +from ._node import Node +from ._tree import Tree + + +def _get_column_value( + xgb_tree: XGBTree, node_id: NonNegativeInt, column: str +) -> str | float | int: + mask = xgb_tree["Node"] == node_id + return xgb_tree.loc[mask, column].values[0] # type: ignore[no-any-return] + + +def _build_xgb_leaf( + xgb_tree: XGBTree, + *, + node_id: NonNegativeInt, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, +) -> Node: + weight = float(_get_column_value(xgb_tree, node_id, "Gain")) + + if num_trees_per_round == 1: + value = np.array([[weight, -weight]]) + else: + k = int(tree_id % num_trees_per_round) + value = np.zeros((1, int(num_trees_per_round)), dtype=float) + value[0, k] = weight + + return Node(node_id, n_samples=0, value=value) + + +def _parse_feature_info( + feature_name: str, mapper: Mapper[Feature] +) -> tuple[str, str | None]: + words = feature_name.split(" ") + name = words[0] if words else feature_name + code = words[1] if len(words) > 1 and words[1] else None + + if name not in mapper.names: + msg = f"feature '{name}' not found in mapper '{mapper.names}'" + raise KeyError(msg) + + return name, code + + +def _validate_feature_format( + name: str, + code: str | None, + mapper: Mapper[Feature], + node_id: NonNegativeInt, +) -> None: + if mapper[name].is_numeric and code: + msg = f"invalid numeric feature {name} for node {node_id}" + raise ValueError(msg) + + if mapper[name].is_one_hot_encoded: + if not code: + msg = f"invalid one-hot encoded feature {name} for node {node_id}" + raise ValueError(msg) + if code not in mapper.codes: + msg = f"code '{code}' not found in mapper codes '{mapper.codes}'" + raise KeyError(msg) + + +def _get_child_id( + xgb_tree: XGBTree, node_id: NonNegativeInt, column: str +) -> int: + raw = str(_get_column_value(xgb_tree, node_id, column)) + return int(raw.rsplit("-", 1)[-1]) + + +def _build_xgb_node( + xgb_tree: XGBTree, + *, + node_id: NonNegativeInt, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Node: + feature_name = str(_get_column_value(xgb_tree, node_id, "Feature")) + name, code = _parse_feature_info(feature_name, mapper) + _validate_feature_format(name, code, mapper, node_id) + + threshold = None + if mapper[name].is_numeric: + threshold = float(_get_column_value(xgb_tree, node_id, "Split")) + mapper[name].add(threshold) + + left_id = _get_child_id(xgb_tree, node_id, "Yes") + right_id = _get_child_id(xgb_tree, node_id, "No") + + node = Node( + node_id, feature=name, threshold=threshold, code=code, n_samples=0 + ) + node.left = _parse_xgb_node( + xgb_tree, + node_id=left_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + node.right = _parse_xgb_node( + xgb_tree, + node_id=right_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + return node + + +def _parse_xgb_node( + xgb_tree: XGBTree, + node_id: NonNegativeInt, + *, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Node: + mask = xgb_tree["Node"] == node_id + feature_val = str(xgb_tree.loc[mask, "Feature"].to_numpy().item()) + + if feature_val == "Leaf": + return _build_xgb_leaf( + xgb_tree, + node_id=node_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + ) + + return _build_xgb_node( + xgb_tree, + node_id=node_id, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + + +def _parse_xgb_tree( + xgb_tree: XGBTree, + *, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Tree: + root = _parse_xgb_node( + xgb_tree, + node_id=0, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + return Tree(root=root) + + +def parse_xgb_tree( + xgb_tree: XGBTree, + *, + tree_id: NonNegativeInt, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> Tree: + return _parse_xgb_tree( + xgb_tree, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + + +def parse_xgb_trees( + trees: Iterable[XGBTree], + *, + num_trees_per_round: NonNegativeInt, + mapper: Mapper[Feature], +) -> tuple[Tree, ...]: + return tuple( + parse_xgb_tree( + tree, + tree_id=tree_id, + num_trees_per_round=num_trees_per_round, + mapper=mapper, + ) + for tree_id, tree in enumerate(trees) + ) + + +def parse_xgb_ensemble( + ensemble: xgb.Booster, *, mapper: Mapper[Feature] +) -> tuple[Tree, ...]: + df = ensemble.trees_to_dataframe() + groups = df.groupby("Tree") + trees = tuple( + groups.get_group(tree_id).reset_index(drop=True) + for tree_id in groups.groups + ) + + num_rounds = ensemble.num_boosted_rounds() or 1 + num_trees_per_round = max(1, len(trees) // num_rounds) + + return parse_xgb_trees( + trees, num_trees_per_round=num_trees_per_round, mapper=mapper + ) diff --git a/ocean/tree/_protocol.py b/ocean/tree/_protocol.py index c664024..38f7e10 100644 --- a/ocean/tree/_protocol.py +++ b/ocean/tree/_protocol.py @@ -10,6 +10,7 @@ NonNegativeIntArray1D, PositiveInt, SKLearnTree, + XGBTree, ) @@ -40,4 +41,5 @@ def __init__(self, tree: SKLearnTree) -> None: "SKLearnTree", "SKLearnTreeProtocol", "TreeProtocol", + "XGBTree", ] diff --git a/ocean/typing/__init__.py b/ocean/typing/__init__.py index e8a2954..b736354 100644 --- a/ocean/typing/__init__.py +++ b/ocean/typing/__init__.py @@ -3,11 +3,12 @@ import numpy as np import pandas as pd +import xgboost as xgb from pydantic import Field from sklearn.ensemble import IsolationForest, RandomForestClassifier -type BaseExplainableEnsemble = RandomForestClassifier -type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest +type BaseExplainableEnsemble = RandomForestClassifier | xgb.XGBClassifier +type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest | xgb.Booster type Number = float type NonNegativeNumber = Annotated[Number, Field(ge=0.0)] @@ -75,6 +76,9 @@ class SKLearnTree(Protocol): value: Array +type XGBTree = pd.DataFrame + + class BaseExplanation(Protocol): @property def x(self) -> Array1D: ... @@ -93,6 +97,8 @@ def explain( norm: PositiveInt, ) -> BaseExplanation | None: ... + def cleanup(self) -> None: ... + __all__ = [ "Array", diff --git a/pyproject.toml b/pyproject.toml index e6d3e4c..a930ea3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "pandas", "pydantic", "scikit-learn", + "xgboost", ] optional-dependencies.dev = [ diff --git a/pyrightconfig.json b/pyrightconfig.json index 587aebb..1b26508 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -6,16 +6,5 @@ "ignore": [ "sklearn", "anytree" - ], - "executionEnvironments": [ - { - "root": "./examples", - "extraPaths": [ - "../ocean" - ], - "reportArgumentType": "none", - "reportUnknownArgumentType": "none", - "reportUnknownVariableType": "none" - } ] } \ No newline at end of file diff --git a/tests/tree/test_parse.py b/tests/tree/test_parse.py index 9a00b39..4b5a9b0 100644 --- a/tests/tree/test_parse.py +++ b/tests/tree/test_parse.py @@ -1,10 +1,11 @@ import pytest +import xgboost as xgb from pydantic import ValidationError from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ocean.abc import Mapper from ocean.feature import Feature -from ocean.tree import Node, parse_tree +from ocean.tree import Node, parse_ensembles, parse_tree from ocean.typing import SKLearnTree from ..utils import generate_data @@ -46,6 +47,50 @@ def _dfs(node: Node) -> None: _dfs(root) +def _check_xgb_tree( + root: Node, + booster: xgb.Booster, + *, + tree_id: int, + mapper: Mapper[Feature], +) -> None: + df = booster.trees_to_dataframe() + tree_df = df[df["Tree"] == tree_id].reset_index(drop=True) + + def _dfs(node: Node) -> None: + row = tree_df[tree_df["Node"] == node.node_id] + assert not row.empty, f"node {node.node_id} not found in tree {tree_id}" + + if node.is_leaf: + assert row["Feature"].values[0] == "Leaf" + assert (node.value == row["Gain"].values[0]).any() + else: + assert node.feature is not None + assert node.feature in mapper + assert node.left is not None + assert node.right is not None + assert len(node.children) == 2 + + feature = mapper[node.feature] + feature_name = str(row["Feature"].values[0]).strip() + if feature.is_numeric: + assert feature_name == node.feature + assert node.threshold == float(row["Split"].values[0]) + if feature.is_one_hot_encoded: + assert feature_name == f"{node.feature} {node.code}" + assert node.code in feature.codes + + left_id = int(str(row["Yes"].values[0]).split("-")[-1]) + right_id = int(str(row["No"].values[0]).split("-")[-1]) + assert node.left.node_id == left_id + assert node.right.node_id == right_id + + _dfs(node.left) + _dfs(node.right) + + _dfs(root) + + @pytest.mark.parametrize("seed", [42, 43, 44]) @pytest.mark.parametrize("max_depth", [2, 3, 4]) @pytest.mark.parametrize("n_classes", [2, 3, 4]) @@ -87,3 +132,39 @@ def test_parse_regressor(seed: int, n_samples: int, max_depth: int) -> None: assert tree.max_depth == dt.tree_.max_depth # pyright: ignore[reportAttributeAccessIssue] assert tree.shape == (1, 1) _check_tree(tree.root, dt.tree_, mapper=mapper) # pyright: ignore[reportArgumentType, reportUnknownArgumentType] + + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_classes", [2, 3, 4]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +@pytest.mark.parametrize("n_estimators", [3, 5, 4]) +def test_parse_xgb_classifier( + seed: int, + n_classes: int, + n_samples: int, + n_estimators: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + model = xgb.XGBClassifier( + n_estimators=n_estimators, + max_depth=3, + eval_metric="logloss", + random_state=seed, + ) + model.fit(data, y) + assert model is not None + booster = model.get_booster() + assert booster is not None + trees = parse_ensembles(model, mapper=mapper) + assert len(trees) == n_estimators * (1 if n_classes == 2 else n_classes) + for i, tree in enumerate(trees): + assert tree.root is not None + assert tree.root.node_id == 0 + assert tree.max_depth >= 1 + + _check_xgb_tree( + tree.root, + booster, + tree_id=i, + mapper=mapper, + ) diff --git a/tox.ini b/tox.ini index 7e37c88..0119c38 100644 --- a/tox.ini +++ b/tox.ini @@ -40,6 +40,7 @@ deps = scikit-learn scipy scipy-stubs + xgboost commands = pip install --upgrade pyright mypy ocean tests