diff --git a/README.md b/README.md index 733906e..7c927ef 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,11 @@ The package provides multiple classes and functions to wrap the tree ensemble mo ```python from sklearn.ensemble import RandomForestClassifier -from ocean import MixedIntegerProgramExplainer, ConstraintProgrammingExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MaxSATExplainer, + MixedIntegerProgramExplainer, +) from ocean.datasets import load_adult # Load the adult dataset @@ -47,20 +51,28 @@ rf.fit(data, target) # Predict the class of the random instance y = int(rf.predict(x).item()) +x = x.to_numpy().flatten() # Explain the prediction using MIPEXplainer mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper) -x = x.to_numpy().flatten() mip_explanation = mip_model.explain(x, y=1 - y, norm=1) # Explain the prediction using CPEExplainer cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) -x = x.to_numpy().flatten() cp_explanation = cp_model.explain(x, y=1 - y, norm=1) -# Show the explanation -print("MIP: ",mip_explanation, "\n") -print("CP : ",cp_explanation) +maxsat_model = MaxSATExplainer(rf, mapper=mapper) +maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1) + +# Show the explanations and their objective values +print("MIP objective value:", mip_model.get_objective_value()) +print("MIP", mip_explanation, "\n") + +print("CP objective value:", cp_model.get_objective_value()) +print("CP", cp_explanation, "\n") + +print("MaxSAT objective value:", maxsat_model.get_objective_value()) +print("MaxSAT", maxsat_explanation, "\n") ``` @@ -94,6 +106,20 @@ Occupation : 1 Relationship : 0 Sex : 0 WorkClass : 4 + +MaxSAT objective value: 3.0 +MaxSAT Explanation: +Age : 39.0 +CapitalGain : 2174.0 +CapitalLoss : 0.0 +EducationNumber : 13.0 +HoursPerWeek : 40.0 +MaritalStatus : 3 +NativeCountry : 0 +Occupation : 1 +Relationship : 0 +Sex : 0 +WorkClass : 4 ``` @@ -106,7 +132,7 @@ See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) fo | ------------------------------- | ---------- | ------------------------------------------ | | **MIP formulation** | ✅ Done | Based on Parmentier & Vidal (2020/2021). | | **Constraint Programming (CP)** | ✅ Done | Based on an upcoming paper. | -| **MaxSAT formulation** | ⏳ Upcoming | Planned addition to the toolbox. | +| **MaxSAT formulation** | ✅ Done | Planned addition to the toolbox. | | **Heuristics** | ⏳ Upcoming | Fast approximate methods. | | **Other methods** | ⏳ Upcoming | Additional formulations under exploration. | | **AdaBoost support** | ✅ Ready | Fully supported in ocean. | diff --git a/examples/maxsat_example.py b/examples/maxsat_example.py new file mode 100644 index 0000000..23b53a8 --- /dev/null +++ b/examples/maxsat_example.py @@ -0,0 +1,78 @@ +from sklearn.ensemble import RandomForestClassifier + +from ocean import ( + ConstraintProgrammingExplainer, + MaxSATExplainer, + MixedIntegerProgramExplainer, +) +from ocean.datasets import load_adult + +# Load the adult dataset +(data, target), mapper = load_adult(scale=True) + +# Train a random forest classifier +rf = RandomForestClassifier(n_estimators=20, max_depth=3, random_state=42) +rf.fit(data, target) + +# Select an instance to explain from the dataset +x = data.iloc[0].to_frame().T +x_np = x.to_numpy().flatten() + +# Predict the class of the instance +y_pred = int(rf.predict(x).item()) +target_class = 1 - y_pred # Binary classification - choose opposite class + +print(f"Instance shape: {x_np.shape}") +print(f"Original prediction: {y_pred}") +print(f"Target counterfactual class: {target_class}") + +# Explain the prediction using MaxSATExplainer +print("\n--- MaxSAT Explainer ---") +try: + maxsat_model = MaxSATExplainer(rf, mapper=mapper) + maxsat_explanation = maxsat_model.explain(x_np, y=target_class, norm=1) + if maxsat_explanation is not None: + cf_np = maxsat_explanation.to_numpy() + print("MaxSAT CF:", cf_np) + print("MaxSAT CF prediction:", rf.predict([cf_np])[0]) + print("Objective value:", maxsat_model.get_objective_value()) + print("Status:", maxsat_model.get_solving_status()) + else: + print("MaxSAT: No counterfactual found.") +except (ValueError, RuntimeError, ImportError) as e: + import traceback + + print(f"MaxSAT Error: {e}") + traceback.print_exc() + +# Explain the prediction using MIPExplainer +print("\n--- MIP Explainer ---") +try: + mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper) + mip_explanation = mip_model.explain(x_np, y=target_class, norm=1) + if mip_explanation is not None: + cf_np = mip_explanation.to_numpy() + print("MIP CF:", cf_np) + print("MIP CF prediction:", rf.predict([cf_np])[0]) + print("Objective value:", mip_model.get_objective_value()) + print("Status:", mip_model.get_solving_status()) + else: + print("MIP: No counterfactual found.") +except (ValueError, RuntimeError, ImportError) as e: + print(f"MIP Error: {e}") + +# Explain the prediction using CPExplainer +print("\n--- CP Explainer ---") +try: + cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) + cp_explanation = cp_model.explain(x_np, y=target_class, norm=1) + if cp_explanation is not None: + cf_np = cp_explanation.to_numpy() + print("CP CF:", cf_np) + print("CP CF prediction:", rf.predict([cf_np])[0]) + print("Objective value:", cp_model.get_objective_value()) + print("Status:", cp_model.get_solving_status()) + else: + print("CP: No counterfactual found.") +except (ValueError, RuntimeError, ImportError) as e: + print(f"CP Error: {e}") diff --git a/examples/query.py b/examples/query.py index 53284b5..5cb6531 100644 --- a/examples/query.py +++ b/examples/query.py @@ -13,6 +13,7 @@ from ocean import ( ConstraintProgrammingExplainer, + MaxSATExplainer, MixedIntegerProgramExplainer, ) from ocean.abc import Mapper @@ -28,6 +29,7 @@ EXPLAINERS = { "mip": MixedIntegerProgramExplainer, "cp": ConstraintProgrammingExplainer, + "maxsat": MaxSATExplainer, } MODELS = { "rf": RandomForestClassifier, @@ -75,8 +77,8 @@ def create_argument_parser() -> ArgumentParser: help="List of explainers to use", type=str, nargs="+", - choices=["mip", "cp"], - default=["mip", "cp"], + choices=["mip", "cp", "maxsat"], + default=["mip", "cp", "maxsat"], ) parser.add_argument( "-m", @@ -153,7 +155,8 @@ def fit_model_with_console( def build_explainer( explainer_name: str, explainer_class: type[MixedIntegerProgramExplainer] - | type[ConstraintProgrammingExplainer], + | type[ConstraintProgrammingExplainer] + | type[MaxSATExplainer], args: Args, model: RandomForestClassifier | XGBClassifier, mapper: Mapper[Feature], @@ -163,7 +166,10 @@ def build_explainer( if explainer_class is MixedIntegerProgramExplainer: ENV.setParam("Seed", args.seed) exp = explainer_class(model, mapper=mapper, env=ENV) - elif explainer_class is ConstraintProgrammingExplainer: + elif ( + explainer_class is ConstraintProgrammingExplainer + or explainer_class is MaxSATExplainer + ): exp = explainer_class(model, mapper=mapper) else: msg = f"Unknown explainer type: {explainer_class}" diff --git a/examples/readme.py b/examples/readme.py index 410b8b3..00ad19e 100644 --- a/examples/readme.py +++ b/examples/readme.py @@ -1,6 +1,10 @@ from sklearn.ensemble import RandomForestClassifier -from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MaxSATExplainer, + MixedIntegerProgramExplainer, +) from ocean.datasets import load_adult # Load the adult dataset @@ -15,19 +19,25 @@ # Predict the class of the random instance y = int(rf.predict(x).item()) +x = x.to_numpy().flatten() # Explain the prediction using MIPEXplainer mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper) -x = x.to_numpy().flatten() mip_explanation = mip_model.explain(x, y=1 - y, norm=1) # Explain the prediction using CPEExplainer cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) cp_explanation = cp_model.explain(x, y=1 - y, norm=1) +maxsat_model = MaxSATExplainer(rf, mapper=mapper) +maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1) + # Show the explanations and their objective values print("MIP objective value:", mip_model.get_objective_value()) print("MIP", mip_explanation, "\n") print("CP objective value:", cp_model.get_objective_value()) print("CP", cp_explanation, "\n") + +print("MaxSAT objective value:", maxsat_model.get_objective_value()) +print("MaxSAT", maxsat_explanation, "\n") diff --git a/ocean/cp/_explainer.py b/ocean/cp/_explainer.py index 4adc628..bd04a16 100644 --- a/ocean/cp/_explainer.py +++ b/ocean/cp/_explainer.py @@ -3,6 +3,7 @@ import warnings from ortools.sat.python import cp_model as cp +from sklearn.ensemble import AdaBoostClassifier from ..abc import Mapper from ..feature import Feature @@ -11,6 +12,7 @@ Array1D, BaseExplainableEnsemble, BaseExplainer, + NonNegativeArray1D, NonNegativeInt, PositiveInt, ) @@ -25,13 +27,13 @@ def __init__( ensemble: BaseExplainableEnsemble, *, mapper: Mapper[Feature], - weights: Array1D | None = None, + weights: NonNegativeArray1D | None = None, epsilon: int = Model.DEFAULT_EPSILON, model_type: Model.Type = Model.Type.CP, ) -> None: ensembles = (ensemble,) trees = parse_ensembles(*ensembles, mapper=mapper) - if trees[0].adaboost: + if isinstance(ensemble, AdaBoostClassifier): weights = ensemble.estimator_weights_ Model.__init__( self, diff --git a/ocean/cp/_managers/_tree.py b/ocean/cp/_managers/_tree.py index dbfef77..ed22ac3 100644 --- a/ocean/cp/_managers/_tree.py +++ b/ocean/cp/_managers/_tree.py @@ -140,10 +140,7 @@ def weighted_function( if self._adaboost: # no need to scale since values are 0/1 # scaling is done later for tree weights - if np.argmax(leaf.value[op, :]) == c: - coeff = 1 - else: - coeff = 0 + coeff = int(np.argmax(leaf.value[op, :]) == c) else: coeff = int(leaf.value[op, c] * scale) coefs.append(coeff) @@ -151,7 +148,7 @@ def weighted_function( tree_expr = cp.LinearExpr.WeightedSum(variables, coefs) tree_exprs.append(tree_expr) if self._adaboost: - tree_weights.append(int(weight * scale)) + tree_weights.append(int(weight * scale)) else: tree_weights.append(int(weight)) expr = cp.LinearExpr.WeightedSum(tree_exprs, tree_weights) diff --git a/ocean/maxsat/__init__.py b/ocean/maxsat/__init__.py index 6915f01..c311836 100644 --- a/ocean/maxsat/__init__.py +++ b/ocean/maxsat/__init__.py @@ -1,3 +1,19 @@ +from ._base import BaseModel +from ._env import ENV from ._explainer import Explainer +from ._explanation import Explanation +from ._managers import FeatureManager, TreeManager +from ._model import Model +from ._variables import FeatureVar, TreeVar -__all__ = ["Explainer"] +__all__ = [ + "ENV", + "BaseModel", + "Explainer", + "Explanation", + "FeatureManager", + "FeatureVar", + "Model", + "TreeManager", + "TreeVar", +] diff --git a/ocean/maxsat/_base.py b/ocean/maxsat/_base.py index 4a17993..d79d4f5 100644 --- a/ocean/maxsat/_base.py +++ b/ocean/maxsat/_base.py @@ -1,12 +1,15 @@ from abc import ABC from typing import Any, Protocol -from pysat.formula import WCNF +from pysat.formula import WCNF, IDPool class BaseModel(ABC, WCNF): + vpool: IDPool + def __init__(self) -> None: WCNF.__init__(self) + self.vpool = IDPool() # Create new pool for each instance def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401 object.__setattr__(self, name, value) @@ -15,6 +18,49 @@ def build_vars(self, *variables: "Var") -> None: for variable in variables: variable.build(model=self) + def add_var(self, name: str) -> int: + if name in self.vpool.obj2id: # var has been already created + msg = f"Variable with name '{name}' already exists." + raise ValueError(msg) + return self.vpool.id(f"{name}") # type: ignore[no-any-return] + + def get_var(self, name: str) -> int: + if name not in self.vpool.obj2id: # var has not been created + msg = f"Variable with name '{name}' does not exist." + raise ValueError(msg) + return self.vpool.obj2id[name] # type: ignore[no-any-return] + + def add_hard(self, lits: list[int], return_id: bool = False) -> int: # noqa: FBT001, FBT002 + """ + Add a hard clause (must be satisfied). + + Returns: + The clause ID if return_id is True, otherwise -1. + + """ + # weight=None => hard clause in WCNF + self.append(lits) + if return_id: + return len(self.hard) - 1 # pyright: ignore[reportUnknownArgumentType] + return -1 + + def add_soft(self, lits: list[int], weight: int = 1) -> None: + """Add a soft clause with a given weight.""" + self.append(lits, weight=weight) + + def add_exactly_one(self, lits: list[int]) -> None: + """Add constraint that exactly one path is selected.""" + self.add_hard(lits) # at least one + for i in range(len(lits)): + for j in range(i + 1, len(lits)): + self.add_hard([-lits[i], -lits[j]]) # at most one + + def _clean_soft(self) -> None: + """Reset the model to only contain hard constraints.""" + self.soft.clear() + self.wght.clear() + self.topw = 1 + class Var(Protocol): _name: str diff --git a/ocean/maxsat/_builder/model.py b/ocean/maxsat/_builder/model.py index 35100fe..12a66ff 100644 --- a/ocean/maxsat/_builder/model.py +++ b/ocean/maxsat/_builder/model.py @@ -1,6 +1,8 @@ from collections.abc import Iterable from typing import Protocol +import numpy as np + from ...abc import Mapper from ...tree._node import Node from .._base import BaseModel @@ -69,7 +71,7 @@ def _propagate( *, node: Node, mapper: Mapper[FeatureVar], - y: object, + y: int, ) -> None: parent = node.parent if parent is None: @@ -83,7 +85,7 @@ def _expand( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: @@ -100,45 +102,95 @@ def _expand( def _bset( model: BaseModel, *, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - msg = "Raise NotImplementedError" - raise NotImplementedError(msg) + # sigma=True => left child (x <= 0.5, i.e., x=0) + # sigma=False => right child (x > 0.5, i.e., x=1) + if sigma: + model.add_hard([-y, -v.xget()]) + else: + model.add_hard([-y, v.xget()]) @staticmethod def _cset( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - raise NotImplementedError + # For continuous features: + # j = searchsorted(levels, threshold) gives index where threshold fits + # sigma=True => left child (x <= threshold) + # sigma=False => right child (x > threshold) + threshold = node.threshold + j = int(np.searchsorted(v.levels, threshold, side="left")) + n_intervals = len(v.levels) - 1 + + if sigma: + # Left branch: x <= threshold, so x is in interval 0, 1, ..., j-1 + # Forbid intervals j, j+1, ..., n-2 + for i in range(j, n_intervals): + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) + else: + # Right branch: x > threshold, so x is in interval j, j+1, ..., n-2 + # Forbid intervals 0, 1, ..., j-1 + for i in range(j): + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) @staticmethod def _dset( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - raise NotImplementedError + # For discrete features: + # sigma=True => left child (x <= threshold) + # sigma=False => right child (x > threshold) + # + # mu[i] => value == levels[i] + threshold = node.threshold + n_values = len(v.levels) + + if sigma: + # Left branch: x <= threshold + # Forbid values where levels[i] > threshold + for i in range(n_values): + if v.levels[i] > threshold: + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) + else: + # Right branch: x > threshold + # Forbid values where levels[i] <= threshold + for i in range(n_values): + if v.levels[i] <= threshold: + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) @staticmethod def _eset( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - raise NotImplementedError + # sigma=True (left child): category != code, so u[code] = False + # sigma=False (right child): category == code, so u[code] = True + x = v.xget(code=node.code) + if sigma: + model.add_hard([-y, -x]) + else: + model.add_hard([-y, x]) class ModelBuilderFactory: diff --git a/ocean/maxsat/_env.py b/ocean/maxsat/_env.py new file mode 100644 index 0000000..f3567ac --- /dev/null +++ b/ocean/maxsat/_env.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from pysat.examples.rc2 import RC2 + +if TYPE_CHECKING: + from pysat.formula import WCNF + + +class Env: + _solver: MaxSATSolver + + def __init__(self) -> None: + self._solver = MaxSATSolver() + + @property + def solver(self) -> MaxSATSolver: + return self._solver + + @solver.setter + def solver(self, solver: MaxSATSolver) -> None: + self._solver = solver + + +class MaxSATSolver: + """Thin RC2 wrapper to keep a stable interface.""" + + _model: list[int] | None = None + _cost: float = float("inf") + + def __init__( + self, + solver_name: str = "cadical195", + TimeLimit: int = 60, + n_threads: int = 1, + ) -> None: + self.solver_name = solver_name + self.TimeLimit = TimeLimit + self.n_threads = n_threads + + def solve(self, w: WCNF) -> list[int]: + with RC2( + w, + solver=self.solver_name, + adapt=True, + exhaust=False, + minz=True, + ) as rc2: + model = cast("list[int] | None", rc2.compute()) + if model is None: + msg = "UNSAT: no counterfactual found." + raise RuntimeError(msg) + self._model = model + self._cost = rc2.cost + return model + + def model(self, v: int) -> float: + """ + Return 1.0 if variable v is true in the model, 0.0 otherwise. + + Args: + v: Variable to check in the model. + + Returns: + 1.0 if variable v is true in the model, 0.0 otherwise. + + Raises: + ValueError: If no model has been found solve() must be called first. + + """ + if self._model is None: + msg = "No model found, please run 'solve' first." + raise ValueError(msg) + # The model is a list of signed literals. + # Variable v is true if v is in the model, false if -v is in the model. + if v in self._model: + return 1.0 + return 0.0 + + @property + def cost(self) -> float: + return self._cost + + +ENV = Env() diff --git a/ocean/maxsat/_explainer.py b/ocean/maxsat/_explainer.py index fcec0b8..e395fec 100644 --- a/ocean/maxsat/_explainer.py +++ b/ocean/maxsat/_explainer.py @@ -1,7 +1,10 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING +from sklearn.ensemble import AdaBoostClassifier + from ..tree import parse_ensembles from ..typing import ( Array1D, @@ -10,8 +13,8 @@ NonNegativeInt, PositiveInt, ) +from ._env import ENV from ._model import Model -from ._solver import MaxSATSolver if TYPE_CHECKING: from ..abc import Mapper @@ -20,6 +23,10 @@ class Explainer(Model, BaseExplainer): + """MaxSAT-based explainer for tree ensemble classifiers.""" + + Status: str = "UNKNOWN" + def __init__( self, ensemble: BaseExplainableEnsemble, @@ -31,6 +38,8 @@ def __init__( ) -> None: ensembles = (ensemble,) trees = parse_ensembles(*ensembles, mapper=mapper) + if isinstance(ensemble, AdaBoostClassifier): + weights = ensemble.estimator_weights_ Model.__init__( self, trees, @@ -40,13 +49,13 @@ def __init__( model_type=model_type, ) self.build() - self.solver = MaxSATSolver + self.solver = ENV.solver def get_objective_value(self) -> float: - raise NotImplementedError + return self.solver.cost / self._obj_scale def get_solving_status(self) -> str: - raise NotImplementedError + return self.Status def get_anytime_solutions(self) -> list[dict[str, float]] | None: raise NotImplementedError @@ -57,10 +66,35 @@ def explain( *, y: NonNegativeInt, norm: PositiveInt, - return_callback: bool = False, - verbose: bool = False, - max_time: int = 60, - num_workers: int | None = None, - random_seed: int = 42, + verbose: bool = False, # noqa: ARG002 + max_time: int = 60, # noqa: ARG002 + random_seed: int = 42, # noqa: ARG002 ) -> Explanation | None: - raise NotImplementedError + # Add objective soft clauses + self.add_objective(x, norm=norm) + + # Add hard constraints for target class + self.set_majority_class(y=y) + + try: + # Solve the MaxSAT problem + self.solver.solve(self) + self.Status = "OPTIMAL" + except RuntimeError as e: + if "UNSAT" in str(e): + self.Status = "INFEASIBLE" + msg = "There are no feasible counterfactuals for this query." + msg += " If there should be one, please check the model " + msg += "constraints or report this issue to the developers." + warnings.warn(msg, category=UserWarning, stacklevel=2) + self.cleanup() + return None + raise + else: + # Store the query in the explanation + self.explanation.query = x + + # Clean up for next solve + self.cleanup() + + return self.explanation diff --git a/ocean/maxsat/_explanation.py b/ocean/maxsat/_explanation.py index b3da47d..48c508f 100644 --- a/ocean/maxsat/_explanation.py +++ b/ocean/maxsat/_explanation.py @@ -1,27 +1,86 @@ from collections.abc import Mapping -from typing import TYPE_CHECKING import numpy as np +import pandas as pd from ..abc import Mapper from ..typing import Array1D, BaseExplanation, Key, Number +from ._env import ENV from ._variables import FeatureVar -if TYPE_CHECKING: - import pandas as pd - class Explanation(Mapper[FeatureVar], BaseExplanation): _epsilon: float = float(np.finfo(np.float32).eps) _x: Array1D = np.zeros((0,), dtype=int) def vget(self, i: int) -> int: - msg = "Not implemented." - raise NotImplementedError(msg) + name = self.names[i] + if self[name].is_one_hot_encoded: + code = self.codes[i] + return self[name].xget(code=code) + if self[name].is_numeric: + j: int = int( + np.searchsorted(self[name].levels, self._x[i], side="left") # pyright: ignore[reportUnknownArgumentType] + ) + return self[name].xget(mu=j) + return self[name].xget() + + def _get_active_mu_index( + self, + name: Key, + for_discrete: bool = False, # noqa: FBT001, FBT002 + ) -> int: + """ + Find which mu variable is set to true for a numeric feature. + + Returns: + Index of the active mu variable, or 0 if none found. + + """ + if for_discrete: + # For discrete: one mu per level + n_vars = len(self[name].levels) + else: + # For continuous: one mu per interval + n_vars = len(self[name].levels) - 1 + for mu_idx in range(n_vars): + var = self[name].xget(mu=mu_idx) + if ENV.solver.model(var) > 0: + return mu_idx + return 0 # Default to first if none found def to_series(self) -> "pd.Series[float]": - msg = "Not implemented." - raise NotImplementedError(msg) + values: list[float] = [] + for f in range(self.n_columns): + name = self.names[f] + if self[name].is_one_hot_encoded: + code = self.codes[f] + var = self[name].xget(code=code) + values.append(ENV.solver.model(var)) + elif self[name].is_continuous: + mu_idx = self._get_active_mu_index(name, for_discrete=False) + values.append( + self.format_continuous_value( + f, mu_idx, list(self[name].levels) + ) + ) + elif self[name].is_discrete: + # For discrete features, mu[i] means value == levels[i] + mu_idx = self._get_active_mu_index(name, for_discrete=True) + levels = list(self[name].levels) + discrete_val = int(levels[mu_idx]) + values.append( + self.format_discrete_value( + f, discrete_val, self[name].levels + ) + ) + elif self[name].is_binary: + var = self[name].xget() + values.append(ENV.solver.model(var)) + else: + var = self[name].xget() + values.append(ENV.solver.model(var)) + return pd.Series(values, index=self.columns) def to_numpy(self) -> Array1D: return ( @@ -39,10 +98,33 @@ def x(self) -> Array1D: @property def value(self) -> Mapping[Key, Key | Number]: - msg = "Not implemented." - raise NotImplementedError(msg) - - def format_value( + def get(v: FeatureVar) -> Key | Number: + if v.is_one_hot_encoded: + for code in v.codes: + if ENV.solver.model(v.xget(code)) > 0: + return code + if v.is_numeric: + f = [val for _, val in self.items()].index(v) + if v.is_discrete: + idx = self._get_active_mu_index( + self.names[f], for_discrete=True + ) + val = int(v.levels[idx]) + return self.format_discrete_value(f, val, v.levels) + idx = self._get_active_mu_index( + self.names[f], for_discrete=False + ) + return self.format_continuous_value( + f, + idx, + list(v.levels), + ) + x = v.xget() + return int(ENV.solver.model(x)) + + return self.reduce(get) + + def format_continuous_value( self, f: int, idx: int, diff --git a/ocean/maxsat/_managers/_feature.py b/ocean/maxsat/_managers/_feature.py index 9d57ec1..6b8d7af 100644 --- a/ocean/maxsat/_managers/_feature.py +++ b/ocean/maxsat/_managers/_feature.py @@ -37,7 +37,7 @@ def explanation(self) -> Explanation: return self.mapper def vget(self, i: int) -> int: - raise NotImplementedError + return self.mapper.vget(i) def _set_mapper(self, mapper: Mapper[Feature]) -> None: def create(key: Key, feature: Feature) -> FeatureVar: diff --git a/ocean/maxsat/_managers/_garbage.py b/ocean/maxsat/_managers/_garbage.py index b8b21c0..5d3e21c 100644 --- a/ocean/maxsat/_managers/_garbage.py +++ b/ocean/maxsat/_managers/_garbage.py @@ -1,5 +1,5 @@ class GarbageManager: - type GarbageObject = object + type GarbageObject = int # Garbage collector for the model. # - Used to keep track of the variables and constraints created, @@ -13,4 +13,7 @@ def add_garbage(self, *args: GarbageObject) -> None: self._garbage.extend(args) def remove_garbage(self) -> None: - raise NotImplementedError + self._garbage.clear() + + def garbage_list(self) -> list[GarbageObject]: + return self._garbage diff --git a/ocean/maxsat/_managers/_tree.py b/ocean/maxsat/_managers/_tree.py index e3c1d74..7bd4940 100644 --- a/ocean/maxsat/_managers/_tree.py +++ b/ocean/maxsat/_managers/_tree.py @@ -90,11 +90,28 @@ def _set_weights(self, weights: NonNegativeArray1D | None = None) -> None: def weighted_function( self, - weights: NonNegativeArray1D, - ) -> dict[tuple[NonNegativeInt, NonNegativeInt], object]: - raise NotImplementedError + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], list[int]]: + func: dict[tuple[NonNegativeInt, NonNegativeInt], list[int]] = {} + n_classes = self.n_classes + n_outputs = self.shape[-2] + for op in range(n_outputs): + for c in range(n_classes): + leaf_vars: list[int] = [] + for tree in self.estimators: + for leaf in tree.leaves: + leaf_class = int(np.argmax(leaf.value[op, :])) + if leaf_class == c: + leaf_vars.append(tree[leaf.node_id]) + func[op, c] = leaf_vars + return func def _get_function( self, - ) -> dict[tuple[NonNegativeInt, NonNegativeInt], object]: - return self.weighted_function(weights=self.weights) + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], list[int]]: + return self.weighted_function() + + @property + def function( + self, + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], list[int]]: + return self._function diff --git a/ocean/maxsat/_model.py b/ocean/maxsat/_model.py index 9af8be6..c595b8d 100644 --- a/ocean/maxsat/_model.py +++ b/ocean/maxsat/_model.py @@ -1,9 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING +import numpy as np +from pydantic import validate_call +from pysat.pb import PBEnc + +from ..typing import NonNegativeInt from ._base import BaseModel from ._builder.model import ModelBuilder, ModelBuilderFactory from ._managers import FeatureManager, GarbageManager, TreeManager @@ -14,15 +18,15 @@ from ..abc import Mapper from ..feature import Feature from ..tree import Tree - from ..typing import NonNegativeArray1D, NonNegativeInt - + from ..typing import Array1D, Key, NonNegativeArray1D, NonNegativeInt + from ._variables import FeatureVar -@dataclass -class Model(BaseModel, FeatureManager, TreeManager, GarbageManager): - DEFAULT_EPSILON: int = 1 +class Model(BaseModel, FeatureManager, GarbageManager, TreeManager): # Model builder for the ensemble. - _builder: ModelBuilder | None = None + _builder: ModelBuilder + DEFAULT_EPSILON: int = 1 + _obj_scale: int = int(1e8) class Type(Enum): MAXSAT = "MAXSAT" @@ -32,10 +36,10 @@ def __init__( trees: Iterable[Tree], mapper: Mapper[Feature], *, - model_type: Type = Type.MAXSAT, weights: NonNegativeArray1D | None = None, max_samples: NonNegativeInt = 0, epsilon: int = DEFAULT_EPSILON, + model_type: Type = Type.MAXSAT, ) -> None: BaseModel.__init__(self) TreeManager.__init__( @@ -52,7 +56,247 @@ def __init__( self._set_builder(model_type=model_type) def build(self) -> None: - raise NotImplementedError + self.build_features(self) + self.build_trees(self) + self._builder.build(self, trees=self.trees, mapper=self.mapper) + + def add_objective( + self, + x: Array1D, + *, + norm: int = 1, + ) -> None: + if x.size != self.mapper.n_columns: + msg = f"Expected {self.mapper.n_columns} values, got {x.size}" + raise ValueError(msg) + if norm != 1: + msg = f"Unsupported norm: {norm}" + raise ValueError(msg) + + x_arr = np.asarray(x, dtype=float).ravel() + variables = self.mapper.values() + names = [n for n, _ in self.mapper.items()] + k = 0 + indexer = self.mapper.idx + + for v, name in zip(variables, names, strict=True): + if v.is_one_hot_encoded: + for code in v.codes: + idx = indexer.get(name, code) + self._add_soft_l1_ohe(x_arr[idx], v, code=code) + k += 1 + elif v.is_continuous: + self._add_soft_l1_continuous(x_arr[k], v) + k += 1 + elif v.is_discrete: + self._add_soft_l1_discrete(x_arr[k], v) + k += 1 + elif v.is_binary: + self._add_soft_l1_binary(x_arr[k], v) + k += 1 + else: + k += 1 + + def _add_soft_l1_binary(self, x_val: float, v: FeatureVar) -> None: + """Add soft clause for binary feature.""" + weight = int(self._obj_scale) + x_var = v.xget() + binary_threshold = 0.5 + if x_val > binary_threshold: + # If x=1, penalize flipping to 0 + self.add_soft([x_var], weight=weight) + else: + # If x=0, penalize flipping to 1 + self.add_soft([-x_var], weight=weight) + + def _add_soft_l1_ohe( + self, + x_val: float, + v: FeatureVar, + code: Key, + ) -> None: + """Add soft clause for one-hot encoded feature.""" + weight = int(self._obj_scale / 2) # OHE uses half weight + x_var = v.xget(code=code) + binary_threshold = 0.5 + if x_val > binary_threshold: + self.add_soft([x_var], weight=weight) + else: + self.add_soft([-x_var], weight=weight) + + def _add_soft_l1_continuous(self, x_val: float, v: FeatureVar) -> None: + """Add soft clauses for continuous feature intervals.""" + levels = v.levels + intervals_cost = self._get_intervals_cost(levels, x_val) + + for i in range(len(levels) - 1): + cost = intervals_cost[i] + if cost > 0: + mu_var = v.xget(mu=i) + self.add_soft([-mu_var], weight=cost) + + def _add_soft_l1_discrete(self, x_val: float, v: FeatureVar) -> None: + """ + Add soft clauses for discrete feature. + + For discrete features, mu[i] means value == levels[i]. + Penalize each level based on distance from x_val. + """ + levels = v.levels + + for i in range(len(levels)): + level_val = levels[i] + if level_val == x_val: + # No cost if this is the same value + continue + cost = int(abs(x_val - level_val) * self._obj_scale) + if cost > 0: + mu_var = v.xget(mu=i) + self.add_soft([-mu_var], weight=cost) + + def _get_intervals_cost(self, levels: Array1D, x: float) -> list[int]: + """ + Compute cost for each interval based on distance from x. + + Returns: + List of integer costs for each interval based on distance from x. + + """ + intervals_cost = np.zeros(len(levels) - 1, dtype=int) + for i in range(len(intervals_cost)): + if levels[i] < x <= levels[i + 1]: + continue + if levels[i] > x: + intervals_cost[i] = int(abs(x - levels[i]) * self._obj_scale) + elif levels[i + 1] < x: + intervals_cost[i] = int( + abs(x - levels[i + 1]) * self._obj_scale + ) + return intervals_cost.tolist() + + @validate_call + def set_majority_class( + self, + y: NonNegativeInt, + *, + op: NonNegativeInt = 0, + ) -> None: + """ + Set hard constraints to enforce majority vote for class y. + + Raises: + ValueError: If y is greater than or equal to the number of classes. + + """ + if y >= self.n_classes: + msg = f"Expected class < {self.n_classes}, got {y}" + raise ValueError(msg) + + self._set_majority_class(y, op=op) + + def _set_majority_class( + self, + y: NonNegativeInt, + *, + op: NonNegativeInt = 0, + ) -> None: + """ + Add hard constraints to enforce class y gets majority vote. + + For sklearn's RandomForestClassifier, the predicted class is the one + with the highest mean probability across all trees (soft voting). + + We encode this as: for each class c != y, + sum(prob_y - prob_c) >= epsilon + where epsilon > 0 if c < y (for tie-breaking), else epsilon >= 0. + + Since MaxSAT doesn't directly support weighted sums, we use a + discretized approach with auxiliary variables. + """ + scale = 10000 # Scale factor for probabilities + + for class_ in range(self.n_classes): + if class_ == y: + continue + + # Compute the score difference for each leaf in each tree + # We need: sum over trees of (prob_y - prob_c) >= epsilon + + # For each tree, compute min and max possible contributions + tree_contributions: list[list[tuple[int, int]]] = [] + for tree in self.trees: + contribs: list[tuple[int, int]] = [] + for leaf in tree.leaves: + prob_y = leaf.value[op, y] + prob_c = leaf.value[op, class_] + diff = int((prob_y - prob_c) * scale) + leaf_var = tree[leaf.node_id] + contribs.append((leaf_var, diff)) + tree_contributions.append(contribs) + + # Threshold for comparison + epsilon = self._epsilon if class_ < y else 0 + + # Use iterative bounds propagation to encode the constraint + # For each tree, we track the range of possible partial sums + self._encode_weighted_sum_constraint(tree_contributions, epsilon) + + def _encode_weighted_sum_constraint( + self, + tree_contributions: list[list[tuple[int, int]]], + threshold: int, + ) -> None: + """ + Encode: sum of contributions >= threshold using pseudo-Boolean encoding. + + This approach avoids exponential enumeration. + """ + lits: list[int] = [] + weights: list[int] = [] + shift = 0 # sum over |negative weights| + + for contribs in tree_contributions: + for leaf_var, diff in contribs: + if diff == 0: + continue # contributes nothing, can be ignored + + if diff > 0: + # positive coefficient: weight * x + lits.append(leaf_var) # x + weights.append(diff) + else: + # negative coefficient: -a * x + a = -diff + # transform -a*x into a*(-x) and shift the bound by +a + lits.append(-leaf_var) # -x + weights.append(a) + shift += a + + effective_bound = threshold + shift + + if not lits: # degenerate case + if effective_bound > 0: + self.add_hard([]) # UNSAT + return + + # Encode sum(weights_i * lits_i) >= effective_bound + pb = PBEnc.atleast( + lits=lits, + weights=weights, + bound=effective_bound, + vpool=self.vpool, + ) + + for clause in pb.clauses: + self.add_garbage( + self.add_hard(clause, return_id=True) # pyright: ignore[reportUnknownArgumentType] + ) + + def cleanup(self) -> None: + self._clean_soft() + for idx in sorted(self.garbage_list(), reverse=True): + self.hard.pop(idx) + self.remove_garbage() def _set_builder(self, model_type: Type) -> None: match model_type: diff --git a/ocean/maxsat/_solver.py b/ocean/maxsat/_solver.py deleted file mode 100644 index e86b7f4..0000000 --- a/ocean/maxsat/_solver.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from typing import cast - -from pysat.examples.rc2 import RC2 -from pysat.formula import WCNF - - -class MaxSATSolver: - """Thin RC2 wrapper to keep a stable interface.""" - - def __init__(self, solver_name: str = "glucose3") -> None: - self.solver_name = solver_name - - @staticmethod - def new_wcnf() -> WCNF: - return WCNF() - - @staticmethod - def add_hard(w: WCNF, clause: list[int]) -> None: - w.append(clause, weight=-1) - - @staticmethod - def add_soft(w: WCNF, clause: list[int], weight: int) -> None: - w.append(clause, weight=weight) - - def solve(self, w: WCNF) -> list[int]: - with RC2(w, solver=self.solver_name, adapt=True) as rc2: - model = cast("list[int] | None", rc2.compute()) - if model is None: - msg = "UNSAT: no counterfactual found." - raise RuntimeError(msg) - return model diff --git a/ocean/maxsat/_variables/_feature.py b/ocean/maxsat/_variables/_feature.py index 3d4dc7d..0986a15 100644 --- a/ocean/maxsat/_variables/_feature.py +++ b/ocean/maxsat/_variables/_feature.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + from ...feature import Feature from ...feature._keeper import FeatureKeeper from ...typing import Key @@ -7,50 +9,106 @@ class FeatureVar(Var, FeatureKeeper): X_VAR_NAME_FMT: str = "x[{name}]" + _x: int + _u: Mapping[Key, int] + _mu: Mapping[Key, int] + def __init__(self, feature: Feature, name: str) -> None: Var.__init__(self, name=name) FeatureKeeper.__init__(self, feature=feature) def build(self, model: BaseModel) -> None: - raise NotImplementedError - - def xget(self, code: Key | None = None) -> None: - raise NotImplementedError - - def mget(self, key: int) -> None: - raise NotImplementedError - - def objvarget(self) -> None: - raise NotImplementedError - - def _add_x(self, model: BaseModel) -> None: - raise NotImplementedError - - def _add_u(self, model: BaseModel) -> None: - raise NotImplementedError - - def _set_mu(self, model: BaseModel, m: int) -> None: - raise NotImplementedError - - def _add_mu(self, model: BaseModel) -> None: - raise NotImplementedError + if self.is_binary: + self._x = self._add_x(model) + if self.is_numeric: + self._mu = self._add_mu(model) + # Add exactly-one constraint for mu variables (exactly one interval) + model.add_exactly_one(list(self._mu.values())) + if self.is_one_hot_encoded: + self._u = self._add_u(model) + + def xget(self, code: Key | None = None, mu: Key | None = None) -> int: + if mu is not None and code is not None: + msg = "Cannot get both 'mu' and 'code' at the same time" + raise ValueError(msg) + if self.is_one_hot_encoded: + return self._xget_one_hot_encoded(code) + if code is not None: + msg = "Get by code is only supported for one-hot encoded features" + raise ValueError(msg) + if self.is_numeric: + return self._xget_numeric(mu) + if mu is not None: + msg = "Get by 'mu' is only supported for numeric features" + raise ValueError(msg) + return self._x + + def _add_x(self, model: BaseModel) -> int: + if not self.is_binary: + msg = "The '_add_x' method is only supported for binary features" + raise ValueError(msg) + name = self.X_VAR_NAME_FMT.format(name=self._name) + return self._add_binary(model, name) + + def _add_u(self, model: BaseModel) -> Mapping[Key, int]: + name = self._name.format(name=self._name) + u = self._add_one_hot_encoded(model=model, name=name) + model.add_exactly_one(list(u.values())) + return u def _add_one_hot_encoded( self, model: BaseModel, name: str, - ) -> None: - raise NotImplementedError + ) -> Mapping[Key, int]: + return { + code: model.add_var(name=f"{name}[{code}]") for code in self.codes + } + + def _add_mu(self, model: BaseModel) -> Mapping[Key, int]: + name = self._name.format(name=self._name) + if self.is_discrete: + # For discrete features: one mu variable per level (value) + # mu[i] means value == levels[i] + n_values = len(self.levels) + return { + lv: model.add_var(name=f"{name}[{lv}]") + for lv in range(n_values) + } + # For continuous features: n-1 mu variables for n levels (intervals) + # mu[i] means value in interval (levels[i], levels[i+1]] + n_intervals = len(self.levels) - 1 + return { + lv: model.add_var(name=f"{name}[{lv}]") for lv in range(n_intervals) + } @staticmethod - def _add_binary(model: BaseModel, name: str) -> None: - raise NotImplementedError - - def _add_continuous(self, model: BaseModel, name: str) -> None: - raise NotImplementedError - - def _add_discrete(self, model: BaseModel, name: str) -> None: - raise NotImplementedError - - def _xget_one_hot_encoded(self, code: Key | None) -> None: - raise NotImplementedError + def _add_binary(model: BaseModel, name: str) -> int: + return model.add_var(name=name) + + def _xget_one_hot_encoded(self, code: Key | None) -> int: + if code is None: + msg = "Code is required for one-hot encoded features get" + raise ValueError(msg) + if code not in self.codes: + msg = f"Code '{code}' not found in the feature codes" + raise ValueError(msg) + return self._u[code] + + def _xget_numeric(self, mu: Key | None) -> int: + if mu is None: + msg = "mu is required to get numeric features" + raise ValueError(msg) + if self.is_discrete: + # For discrete: mu[i] represents value levels[i] + n_values = len(self.levels) + if mu not in range(n_values): + msg = f"mu '{mu}' not in values (0 to {n_values - 1})" + raise ValueError(msg) + else: + # For continuous: mu[i] represents interval (levels[i], levels[i+1]] + n_intervals = len(self.levels) - 1 + if mu not in range(n_intervals): + msg = f"mu '{mu}' not in intervals (0 to {n_intervals - 1})" + raise ValueError(msg) + return self._mu[mu] diff --git a/ocean/maxsat/_variables/_tree.py b/ocean/maxsat/_variables/_tree.py index 2ee57c8..1393681 100644 --- a/ocean/maxsat/_variables/_tree.py +++ b/ocean/maxsat/_variables/_tree.py @@ -10,6 +10,8 @@ class TreeVar(Var, TreeKeeper, Mapping[NonNegativeInt, object]): PATH_VAR_NAME_FMT: str = "{name}_path" + _path: Mapping[NonNegativeInt, int] + def __init__( self, tree: TreeLike, @@ -19,7 +21,9 @@ def __init__( TreeKeeper.__init__(self, tree=tree) def build(self, model: BaseModel) -> None: - raise NotImplementedError + name = self.PATH_VAR_NAME_FMT.format(name=self._name) + self._path = self._add_path(model=model, name=name) + model.add_exactly_one(list(self._path.values())) def __len__(self) -> int: return self.n_nodes @@ -28,12 +32,15 @@ def __iter__(self) -> Iterator[NonNegativeInt]: return iter(range(self.n_nodes)) @validate_call - def __getitem__(self, node_id: NonNegativeInt) -> None: - raise NotImplementedError + def __getitem__(self, node_id: NonNegativeInt) -> int: + return self._path[node_id] def _add_path( self, model: BaseModel, name: str, - ) -> None: - raise NotImplementedError + ) -> Mapping[NonNegativeInt, int]: + return { + leaf.node_id: model.add_var(name=f"{name}[{leaf.node_id}]") + for leaf in self.leaves + } diff --git a/ocean/mip/_explainer.py b/ocean/mip/_explainer.py index b420b56..690d62b 100644 --- a/ocean/mip/_explainer.py +++ b/ocean/mip/_explainer.py @@ -2,7 +2,7 @@ import warnings import gurobipy as gp -from sklearn.ensemble import IsolationForest +from sklearn.ensemble import AdaBoostClassifier, IsolationForest from ..abc import Mapper from ..feature import Feature @@ -37,7 +37,7 @@ def __init__( ensembles = (ensemble,) if isolation is None else (ensemble, isolation) n_isolators, max_samples = self._get_isolation_params(isolation) trees = parse_ensembles(*ensembles, mapper=mapper) - if trees[0].adaboost: + if isinstance(ensemble, AdaBoostClassifier): weights = ensemble.estimator_weights_ Model.__init__( self, diff --git a/ocean/mip/_managers/_tree.py b/ocean/mip/_managers/_tree.py index e70d26f..ed80eb5 100644 --- a/ocean/mip/_managers/_tree.py +++ b/ocean/mip/_managers/_tree.py @@ -160,7 +160,12 @@ def create(item: tuple[int, Tree]) -> TreeVar: if tree.adaboost: self._adaboost = tree.adaboost name = self.TREE_VAR_FMT.format(t=t) - return TreeVar(tree, name=name, flow_type=flow_type, _adaboost=self._adaboost) + return TreeVar( + tree, + name=name, + flow_type=flow_type, + _adaboost=self._adaboost, + ) tree_vars = tuple(map(create, enumerate(trees))) if len(tree_vars) == 0: diff --git a/ocean/mip/_variables/_tree.py b/ocean/mip/_variables/_tree.py index 4bbc2c2..fee7908 100644 --- a/ocean/mip/_variables/_tree.py +++ b/ocean/mip/_variables/_tree.py @@ -1,8 +1,8 @@ from collections.abc import Iterator, Mapping from enum import Enum -import numpy as np import gurobipy as gp +import numpy as np from pydantic import validate_call from ...tree._keeper import TreeKeeper, TreeLike @@ -36,6 +36,7 @@ def __init__( TreeKeeper.__init__(self, tree=tree) self._set_builder(flow_type=flow_type) self._adaboost = _adaboost + @property def value(self) -> gp.MLinExpr: return self._value @@ -91,12 +92,14 @@ def _get_value(self) -> gp.MLinExpr: for leaf in self.leaves: if self._adaboost: # one-hot encode the confidence vector - val = np.zeros(leaf.value.shape[1]) + val: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = ( + np.zeros(leaf.value.shape[1]) + ) val[np.argmax(leaf.value)] = 1 else: val = leaf.value value += self._flow[leaf.node_id] * val - + return value def _get_length(self) -> gp.LinExpr: diff --git a/ocean/tree/_parse.py b/ocean/tree/_parse.py index eed4542..ef36693 100644 --- a/ocean/tree/_parse.py +++ b/ocean/tree/_parse.py @@ -4,8 +4,8 @@ from itertools import chain import xgboost as xgb -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.ensemble import AdaBoostClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ..abc import Mapper from ..feature import Feature @@ -70,22 +70,30 @@ def _parse_node( return _build_node(tree, node_id, mapper=mapper) -def _parse_tree(sklearn_tree: SKLearnTree, *, - mapper: Mapper[Feature], - is_adaboost: bool = False) -> Tree: - tree = SKLearnTreeProtocol(sklearn_tree) - root = _parse_node(tree, 0, mapper=mapper) +def _parse_tree( + sklearn_tree: SKLearnTree, + *, + mapper: Mapper[Feature], + is_adaboost: bool = False, +) -> Tree: + sk_tree = SKLearnTreeProtocol(sklearn_tree) + root = _parse_node(sk_tree, 0, mapper=mapper) tree = Tree(root=root) if is_adaboost: tree.adaboost = True return tree -def parse_tree(tree: SKLearnDecisionTree, *, - mapper: Mapper[Feature], - is_adaboost: bool = False) -> Tree: + +def parse_tree( + tree: SKLearnDecisionTree, + *, + mapper: Mapper[Feature], + is_adaboost: bool = False, +) -> Tree: getter = operator.attrgetter("tree_") return _parse_tree(getter(tree), mapper=mapper, is_adaboost=is_adaboost) + def parse_trees( trees: Iterable[SKLearnDecisionTree], *, diff --git a/ocean/typing/__init__.py b/ocean/typing/__init__.py index 04fe37d..3c2e20b 100644 --- a/ocean/typing/__init__.py +++ b/ocean/typing/__init__.py @@ -5,9 +5,15 @@ import pandas as pd import xgboost as xgb from pydantic import Field -from sklearn.ensemble import IsolationForest, RandomForestClassifier, AdaBoostClassifier - -type BaseExplainableEnsemble = RandomForestClassifier | xgb.XGBClassifier | AdaBoostClassifier +from sklearn.ensemble import ( + AdaBoostClassifier, + IsolationForest, + RandomForestClassifier, +) + +type BaseExplainableEnsemble = ( + RandomForestClassifier | xgb.XGBClassifier | AdaBoostClassifier +) type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest | xgb.Booster type Number = float diff --git a/pyproject.toml b/pyproject.toml index cb2defe..d450c12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ dependencies = [ "pydantic", "scikit-learn", "xgboost", - "python-sat", ] optional-dependencies.dev = [ @@ -60,6 +59,7 @@ optional-dependencies.test = [ "pyright", "pytest", "pytest-cov", + "python-sat", "ruff", "scipy-stubs", "tox", @@ -70,6 +70,10 @@ optional-dependencies.example = [ "matplotlib", ] +optional-dependencies.maxsat = [ + "python-sat[pblib,aiger]; platform_system != 'Windows'", +] + urls.Homepage = "https://github.com/vidalt/OCEAN" urls.Repository = "https://github.com/vidalt/OCEAN" urls.Issues = "https://github.com/vidalt/OCEAN/issues" diff --git a/tests/maxsat/__init__.py b/tests/maxsat/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/maxsat/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/maxsat/conftest.py b/tests/maxsat/conftest.py new file mode 100644 index 0000000..7d4d8b6 --- /dev/null +++ b/tests/maxsat/conftest.py @@ -0,0 +1,8 @@ +import sys + +import pytest + +pytestmark = pytest.mark.skipif( + sys.platform.startswith("win"), + reason="MaxSAT tests are disabled on Windows.", +) diff --git a/tests/maxsat/model/__init__.py b/tests/maxsat/model/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/maxsat/model/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/maxsat/model/test_feasible.py b/tests/maxsat/model/test_feasible.py new file mode 100644 index 0000000..489c88a --- /dev/null +++ b/tests/maxsat/model/test_feasible.py @@ -0,0 +1,99 @@ +import numpy as np +import pytest + +from ocean.maxsat import ENV, Model +from ocean.tree import parse_trees + +from ..utils import ( + MAX_DEPTH, + N_CLASSES, + N_ESTIMATORS, + N_SAMPLES, + SEEDS, + train_rf, + validate_paths, + validate_sklearn_pred, + validate_solution, +) + + +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("n_estimators", N_ESTIMATORS) +@pytest.mark.parametrize("max_depth", MAX_DEPTH) +@pytest.mark.parametrize("n_samples", N_SAMPLES) +@pytest.mark.parametrize("n_classes", N_CLASSES) +class TestNoIsolation: + @staticmethod + def test_build( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + # Get predictions and pick a class to target + predictions = np.array(clf.predict(data), dtype=np.int64) + target_class = int(predictions[0]) + + model = Model(trees=trees, mapper=mapper) + model.build() + model.set_majority_class(y=target_class) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + validate_sklearn_pred(clf, explanation, m_class=target_class) + + @staticmethod + def test_set_majority_class( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + predictions = np.array(clf.predict(data), dtype=np.int64) + classes = set(map(int, predictions.flatten())) + + for class_ in classes: + model = Model(trees=trees, mapper=mapper) + model.build() + model.set_majority_class(y=class_) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + validate_sklearn_pred(clf, explanation, m_class=class_) + model.cleanup() diff --git a/tests/maxsat/model/test_init.py b/tests/maxsat/model/test_init.py new file mode 100644 index 0000000..9cd5692 --- /dev/null +++ b/tests/maxsat/model/test_init.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest + +from ocean.abc import Mapper +from ocean.maxsat import Model +from ocean.tree import parse_trees + +from ..utils import ( + MAX_DEPTH, + N_CLASSES, + N_ESTIMATORS, + N_SAMPLES, + SEEDS, + train_rf, +) + + +def test_no_trees() -> None: + msg = r"At least one tree is required." + with pytest.raises(ValueError, match=msg): + Model(trees=[], mapper=Mapper()) + + +def test_no_features() -> None: + msg = r"At least one feature is required." + rf, mapper = train_rf(42, 2, 2, 100, 2) + trees = tuple(parse_trees(rf, mapper=mapper)) + with pytest.raises(ValueError, match=msg): + Model(trees=trees, mapper=Mapper()) + + +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("n_estimators", N_ESTIMATORS) +@pytest.mark.parametrize("max_depth", MAX_DEPTH) +@pytest.mark.parametrize("n_samples", N_SAMPLES) +@pytest.mark.parametrize("n_classes", N_CLASSES) +class TestNoIsolation: + @staticmethod + def test_no_weights( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + ) + trees = parse_trees(clf, mapper=mapper) + model = Model(trees=trees, mapper=mapper) + expected_weights = np.ones(n_estimators, dtype=float) + assert model is not None + assert model.n_estimators == n_estimators + assert model.n_classes == n_classes + assert model.weights.shape == expected_weights.shape + assert np.isclose(model.weights, expected_weights).all() + + @staticmethod + def test_weights( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + ) + trees = parse_trees(clf, mapper=mapper) + generator = np.random.default_rng(seed) + weights = generator.random(n_estimators).flatten() + model = Model(trees=trees, mapper=mapper, weights=weights) + assert model is not None + assert model.n_estimators == n_estimators + assert model.n_classes == n_classes + assert model.weights.shape == weights.shape + assert np.isclose(model.weights, weights).all() + + @staticmethod + def test_invalid_weights( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + generator = np.random.default_rng(seed) + shapes = [generator.integers(n_estimators + 1, 2 * n_estimators + 1)] + if n_estimators > 2: + shapes += [generator.integers(1, n_estimators - 1)] + for shape in shapes: + weights = generator.random(shape).flatten() + msg = r"The number of weights must match the number of trees." + with pytest.raises(ValueError, match=msg): + Model(trees=trees, mapper=mapper, weights=weights) diff --git a/tests/maxsat/model/test_objective.py b/tests/maxsat/model/test_objective.py new file mode 100644 index 0000000..e1daeb6 --- /dev/null +++ b/tests/maxsat/model/test_objective.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from ocean.maxsat import ENV, Model +from ocean.tree import parse_trees + +from ..utils import ( + MAX_DEPTH, + N_CLASSES, + N_ESTIMATORS, + N_SAMPLES, + SEEDS, + check_solution, + train_rf, + validate_paths, + validate_sklearn_pred, + validate_solution, +) + +P_QUERIES = 0.2 + + +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("n_estimators", N_ESTIMATORS) +@pytest.mark.parametrize("max_depth", MAX_DEPTH) +@pytest.mark.parametrize("n_samples", N_SAMPLES) +@pytest.mark.parametrize("n_classes", N_CLASSES) +class TestNoIsolation: + @staticmethod + def test_build( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + n_queries = int(data.shape[0] * P_QUERIES) + generator = np.random.default_rng(seed) + queries = generator.choice( + range(len(data)), size=n_queries, replace=False + ) + + for query in queries: + x = np.array(data.to_numpy()[query], dtype=np.float64).flatten() + + model = Model(trees=trees, mapper=mapper) + model.build() + model.add_objective(x=x) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + check_solution(x, explanation) + model.cleanup() + + @staticmethod + def test_set_majority_class( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + predictions = np.array(clf.predict(data.to_numpy()), dtype=np.int64) + classes = set(map(int, predictions.flatten())) + + generator = np.random.default_rng(seed) + query = generator.integers(len(data)) + + x = np.array(data.to_numpy()[query], dtype=np.float64).flatten() + y = int(predictions[query]) + + for class_ in classes: + model = Model(trees=trees, mapper=mapper) + model.build() + model.set_majority_class(y=class_) + model.add_objective(x=x) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + validate_sklearn_pred(clf, explanation, m_class=class_) + if class_ == y: + check_solution(x, explanation) + + model.cleanup() diff --git a/tests/maxsat/test_explainer.py b/tests/maxsat/test_explainer.py new file mode 100644 index 0000000..ae9ffb7 --- /dev/null +++ b/tests/maxsat/test_explainer.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest +from sklearn.ensemble import RandomForestClassifier + +from ocean import MaxSATExplainer + +from .utils import generate_data + + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_estimators", [5]) +@pytest.mark.parametrize("max_depth", [2, 3]) +@pytest.mark.parametrize("n_classes", [2]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +def test_maxsat_explain( + seed: int, + n_estimators: int, + max_depth: int, + n_classes: int, + n_samples: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = RandomForestClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + model = MaxSATExplainer(clf, mapper=mapper) + + x = data.iloc[0, :].to_numpy().astype(float).flatten() + # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] + for target in classes[classes != y]: + exp = model.explain( + x, + y=target, + norm=1, + random_seed=seed, + ) + assert model.Status == "OPTIMAL" + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() diff --git a/tests/maxsat/utils.py b/tests/maxsat/utils.py new file mode 100644 index 0000000..0b8463d --- /dev/null +++ b/tests/maxsat/utils.py @@ -0,0 +1,169 @@ +from collections import defaultdict +from typing import TYPE_CHECKING, Literal, overload + +import numpy as np +import pandas as pd +from sklearn.ensemble import RandomForestClassifier + +from ocean.abc import Mapper +from ocean.feature import Feature +from ocean.maxsat import Explanation, TreeVar +from ocean.typing import Array1D, NonNegativeInt + +from ..utils import generate_data + +if TYPE_CHECKING: + from ocean.typing import Key + + +def check_solution(x: Array1D, explanation: Explanation) -> None: + n = explanation.n_columns + x_sol = explanation.x + for i in range(n): + name = explanation.names[i] + # For now we only check the non continuous features + # as the continuous features are epsilon away from + # the explanation + if not explanation[name].is_continuous: + assert np.isclose(x[i], x_sol[i]) + + +def validate_solution(explanation: Explanation) -> None: + x = explanation.x + n = explanation.n_columns + codes: dict[Key, float] = defaultdict(float) + for i in range(n): + name = explanation.names[i] + feature = explanation[name] + value = x[i] + if feature.is_one_hot_encoded: + assert np.any(np.isclose(value, [0.0, 1.0])) + codes[name] += value + + if feature.is_binary: + assert np.any(np.isclose(value, [0.0, 1.0])) + elif feature.is_numeric: + assert feature.levels[0] <= value <= feature.levels[-1] + + for value in codes.values(): + assert np.isclose(value, 1.0) + + +def find_leaf(tree: TreeVar, explanation: Explanation) -> NonNegativeInt: + node = tree.root + x = explanation.x + while not node.is_leaf: + name = node.feature + if explanation[name].is_one_hot_encoded: + code = node.code + i = explanation.idx.get(name, code) + value = x[i] + else: + i = explanation.idx.get(name) + value = x[i] + + if explanation[name].is_numeric: + threshold = node.threshold + node = node.left if value <= threshold else node.right + elif np.isclose(value, 0.0): + node = node.left + else: + node = node.right + return node.node_id + + +def validate_path( + tree: TreeVar, explanation: Explanation, solver_model: list[int] +) -> None: + """Check that exactly one leaf is selected and matches the solution.""" + n_active = 0 + active_leaf = tree.root.node_id + for node in tree.leaves: + assert node.is_leaf + leaf_var = tree[node.node_id] + is_active = leaf_var in solver_model + if is_active: + n_active += 1 + active_leaf = node.node_id + + assert n_active == 1, ( + f"Expected one leaf to be active, but {n_active} were found." + ) + x_id_leaf = find_leaf(tree, explanation) + assert active_leaf == x_id_leaf, ( + f"Expected leaf {active_leaf}, but found {x_id_leaf}." + ) + + +def validate_paths( + *trees: TreeVar, explanation: Explanation, solver_model: list[int] +) -> None: + for tree in trees: + validate_path(tree, explanation, solver_model) + + +def validate_sklearn_pred( + clf: RandomForestClassifier, + explanation: Explanation, + m_class: NonNegativeInt, +) -> None: + x = explanation.x.reshape(1, -1) + prediction = np.asarray(clf.predict(x), dtype=np.int64) + assert (prediction == m_class).all(), ( + f"Expected class {m_class}, got {prediction[0]}" + ) + + +@overload +def train_rf( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + *, + return_data: Literal[False] = False, +) -> tuple[RandomForestClassifier, Mapper[Feature]]: ... + + +@overload +def train_rf( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + *, + return_data: Literal[True], +) -> tuple[RandomForestClassifier, Mapper[Feature], pd.DataFrame]: ... + + +def train_rf( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + *, + return_data: bool = False, +) -> ( + tuple[RandomForestClassifier, Mapper[Feature]] + | tuple[RandomForestClassifier, Mapper[Feature], pd.DataFrame] +): + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = RandomForestClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + if return_data: + return clf, mapper, data + return clf, mapper + + +SEEDS = [43, 44, 45] +N_ESTIMATORS = [1, 4] +MAX_DEPTH = [3] +N_CLASSES = [2, 4] +N_SAMPLES = [100, 200] diff --git a/tests/test_explainer.py b/tests/test_explainer.py index 0362c0d..54d0daf 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -4,7 +4,10 @@ from sklearn.ensemble import RandomForestClassifier from xgboost import XGBClassifier -from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MixedIntegerProgramExplainer, +) from .utils import ENV, generate_data diff --git a/tox.ini b/tox.ini index 0119c38..4bf587b 100644 --- a/tox.ini +++ b/tox.ini @@ -9,8 +9,9 @@ basepython = python3 [testenv] deps = pytest +extras = maxsat commands = - pip install .[test] + pip install .[test,maxsat] pytest -vv [testenv:py312] @@ -37,6 +38,7 @@ deps = pandas-stubs pydantic pytest + python-sat scikit-learn scipy scipy-stubs