From 9fc2d0cf799fc8a25a5ff10365adb867cfb61e90 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Sun, 23 Nov 2025 13:05:23 -0500 Subject: [PATCH 01/18] fix: Add variable management methods and clause handling in BaseModel --- ocean/maxsat/_base.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/ocean/maxsat/_base.py b/ocean/maxsat/_base.py index 4a17993..7a07668 100644 --- a/ocean/maxsat/_base.py +++ b/ocean/maxsat/_base.py @@ -1,10 +1,12 @@ from abc import ABC from typing import Any, Protocol -from pysat.formula import WCNF +from pysat.formula import WCNF, IDPool class BaseModel(ABC, WCNF): + vpool: IDPool = IDPool() + def __init__(self) -> None: WCNF.__init__(self) @@ -15,6 +17,34 @@ def build_vars(self, *variables: "Var") -> None: for variable in variables: variable.build(model=self) + def add_var(self, name: str) -> int: + if name in self.vpool.obj2id: # var has been already created + msg = f"Variable with name '{name}' already exists." + raise ValueError(msg) + return self.vpool.id(f"{name}") + + def get_var(self, name: str) -> int: + if name not in self.vpool.obj2id: # var has not been created + msg = f"Variable with name '{name}' does not exist." + raise ValueError(msg) + return self.vpool.obj2id[name] + + def add_hard(self, lits: list[int]) -> None: + """Add a hard clause (must be satisfied).""" + # weight=None => hard clause in WCNF + self.append(lits) + + def add_soft(self, lits: list[int], weight: int = 1) -> None: + """Add a soft clause with a given weight.""" + self.append(lits, weight=weight) + + def add_exactly_one(self, lits: list[int]) -> None: + """Add constraint that exactly one path is selected.""" + self.add_hard(lits) # at least one + for i in range(len(lits)): + for j in range(i + 1, len(lits)): + self.add_hard([-lits[i], -lits[j]]) # at most one + class Var(Protocol): _name: str From 57fa4c015a3a6e17d6721a7970990c974e770f02 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Sun, 23 Nov 2025 13:05:34 -0500 Subject: [PATCH 02/18] fix: Implement vget method to handle one-hot encoded and numeric values in Explanation class --- ocean/maxsat/_explanation.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ocean/maxsat/_explanation.py b/ocean/maxsat/_explanation.py index b3da47d..886aa05 100644 --- a/ocean/maxsat/_explanation.py +++ b/ocean/maxsat/_explanation.py @@ -16,8 +16,16 @@ class Explanation(Mapper[FeatureVar], BaseExplanation): _x: Array1D = np.zeros((0,), dtype=int) def vget(self, i: int) -> int: - msg = "Not implemented." - raise NotImplementedError(msg) + name = self.names[i] + if self[name].is_one_hot_encoded: + code = self.codes[i] + return self[name].xget(code=code) + if self[name].is_numeric: + j = int( + np.searchsorted(self[name].levels, self._x[i], side="left") # type: ignore[arg-type] + ) + return self[name].xget(mu=j) + return self[name].xget() def to_series(self) -> "pd.Series[float]": msg = "Not implemented." From 42bd50f81d703cc3f7acc6675329a9fdd0fd5b80 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Sun, 23 Nov 2025 13:05:48 -0500 Subject: [PATCH 03/18] fix: Refactor Model class to ensure _builder is initialized and build method implementation --- ocean/maxsat/_model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ocean/maxsat/_model.py b/ocean/maxsat/_model.py index 9af8be6..0aeab76 100644 --- a/ocean/maxsat/_model.py +++ b/ocean/maxsat/_model.py @@ -19,10 +19,9 @@ @dataclass class Model(BaseModel, FeatureManager, TreeManager, GarbageManager): - DEFAULT_EPSILON: int = 1 - # Model builder for the ensemble. - _builder: ModelBuilder | None = None + _builder: ModelBuilder + DEFAULT_EPSILON: int = 1 class Type(Enum): MAXSAT = "MAXSAT" @@ -32,10 +31,10 @@ def __init__( trees: Iterable[Tree], mapper: Mapper[Feature], *, - model_type: Type = Type.MAXSAT, weights: NonNegativeArray1D | None = None, max_samples: NonNegativeInt = 0, epsilon: int = DEFAULT_EPSILON, + model_type: Type = Type.MAXSAT, ) -> None: BaseModel.__init__(self) TreeManager.__init__( @@ -52,7 +51,9 @@ def __init__( self._set_builder(model_type=model_type) def build(self) -> None: - raise NotImplementedError + self.build_features(self) + self.build_trees(self) + self._builder.build(self, trees=self.trees, mapper=self.mapper) def _set_builder(self, model_type: Type) -> None: match model_type: From 3c2b8eb359732703a8664c6faef98e163a40b14b Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Sun, 23 Nov 2025 13:06:08 -0500 Subject: [PATCH 04/18] fix: Update MaxSATBuilder methods to enforce integer type for 'y' parameter and implement missing logic in _bset, _cset, _dset, and _eset methods --- ocean/maxsat/_builder/model.py | 42 +++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/ocean/maxsat/_builder/model.py b/ocean/maxsat/_builder/model.py index 35100fe..c716e44 100644 --- a/ocean/maxsat/_builder/model.py +++ b/ocean/maxsat/_builder/model.py @@ -1,6 +1,8 @@ from collections.abc import Iterable from typing import Protocol +import numpy as np + from ...abc import Mapper from ...tree._node import Node from .._base import BaseModel @@ -69,7 +71,7 @@ def _propagate( *, node: Node, mapper: Mapper[FeatureVar], - y: object, + y: int, ) -> None: parent = node.parent if parent is None: @@ -83,7 +85,7 @@ def _expand( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: @@ -100,45 +102,63 @@ def _expand( def _bset( model: BaseModel, *, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - msg = "Raise NotImplementedError" - raise NotImplementedError(msg) + if sigma: + model.add_hard([-y, v.xget()]) + else: + model.add_hard([-y, -v.xget()]) @staticmethod def _cset( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - raise NotImplementedError + threshold = node.threshold + j = int(np.searchsorted(v.levels, threshold, side="left")) + mu = v.xget(mu=j - 1) + if sigma: + model.add_hard([-y, mu]) + else: + model.add_hard([-y, -mu]) @staticmethod def _dset( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - raise NotImplementedError + threshold = node.threshold + j = int(np.searchsorted(v.levels, threshold, side="left")) + x = v.xget(mu=j) + if sigma: + model.add_hard([-y, x]) + else: + model.add_hard([-y, x]) @staticmethod def _eset( model: BaseModel, *, node: Node, - y: object, + y: int, v: FeatureVar, sigma: bool, ) -> None: - raise NotImplementedError + x = v.xget(code=node.code) + if sigma: + model.add_hard([-y, x]) + else: + model.add_hard([-y, -x]) class ModelBuilderFactory: From fa28247e09dfe510fd070b5d32279d6d987044fb Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Sun, 23 Nov 2025 13:06:32 -0500 Subject: [PATCH 05/18] fix: Implement vget method in FeatureManager and enhance build method in FeatureVar and TreeVar classes --- ocean/maxsat/_managers/_feature.py | 2 +- ocean/maxsat/_variables/_feature.py | 107 +++++++++++++++++++--------- ocean/maxsat/_variables/_tree.py | 17 +++-- 3 files changed, 85 insertions(+), 41 deletions(-) diff --git a/ocean/maxsat/_managers/_feature.py b/ocean/maxsat/_managers/_feature.py index 9d57ec1..6b8d7af 100644 --- a/ocean/maxsat/_managers/_feature.py +++ b/ocean/maxsat/_managers/_feature.py @@ -37,7 +37,7 @@ def explanation(self) -> Explanation: return self.mapper def vget(self, i: int) -> int: - raise NotImplementedError + return self.mapper.vget(i) def _set_mapper(self, mapper: Mapper[Feature]) -> None: def create(key: Key, feature: Feature) -> FeatureVar: diff --git a/ocean/maxsat/_variables/_feature.py b/ocean/maxsat/_variables/_feature.py index 3d4dc7d..329fd25 100644 --- a/ocean/maxsat/_variables/_feature.py +++ b/ocean/maxsat/_variables/_feature.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + from ...feature import Feature from ...feature._keeper import FeatureKeeper from ...typing import Key @@ -7,50 +9,85 @@ class FeatureVar(Var, FeatureKeeper): X_VAR_NAME_FMT: str = "x[{name}]" + _x: int + _u: Mapping[Key, int] + _mu: Mapping[Key, int] + def __init__(self, feature: Feature, name: str) -> None: Var.__init__(self, name=name) FeatureKeeper.__init__(self, feature=feature) def build(self, model: BaseModel) -> None: - raise NotImplementedError - - def xget(self, code: Key | None = None) -> None: - raise NotImplementedError - - def mget(self, key: int) -> None: - raise NotImplementedError - - def objvarget(self) -> None: - raise NotImplementedError - - def _add_x(self, model: BaseModel) -> None: - raise NotImplementedError - - def _add_u(self, model: BaseModel) -> None: - raise NotImplementedError - - def _set_mu(self, model: BaseModel, m: int) -> None: - raise NotImplementedError - - def _add_mu(self, model: BaseModel) -> None: - raise NotImplementedError + if self.is_binary: + self._x = self._add_x(model) + if self.is_numeric: + self._mu = self._add_mu(model) + if self.is_one_hot_encoded: + self._u = self._add_u(model) + + def xget(self, code: Key | None = None, mu: Key | None = None) -> int: + if mu is not None and code is not None: + msg = "Cannot get both 'mu' and 'code' at the same time" + raise ValueError(msg) + if self.is_one_hot_encoded: + return self._xget_one_hot_encoded(code) + if code is not None: + msg = "Get by code is only supported for one-hot encoded features" + raise ValueError(msg) + if self.is_numeric: + return self._xget_numeric(mu) + if mu is not None: + msg = "Get by 'mu' is only supported for numeric features" + raise ValueError(msg) + return self._x + + def _add_x(self, model: BaseModel) -> int: + if not self.is_binary: + msg = "The '_add_x' method is only supported for binary features" + raise ValueError(msg) + name = self.X_VAR_NAME_FMT.format(name=self._name) + return self._add_binary(model, name) + + def _add_u(self, model: BaseModel) -> Mapping[Key, int]: + name = self._name.format(name=self._name) + u = self._add_one_hot_encoded(model=model, name=name) + model.add_exactly_one(list(u.values())) + return u def _add_one_hot_encoded( self, model: BaseModel, name: str, - ) -> None: - raise NotImplementedError + ) -> Mapping[Key, int]: + return { + code: model.add_var(name=f"{name}[{code}]") for code in self.codes + } + + def _add_mu(self, model: BaseModel) -> Mapping[Key, int]: + name = self._name.format(name=self._name) + return { + lv: model.add_var(name=f"{name}[{lv}]") + for lv in range(len(self.levels)) + } @staticmethod - def _add_binary(model: BaseModel, name: str) -> None: - raise NotImplementedError - - def _add_continuous(self, model: BaseModel, name: str) -> None: - raise NotImplementedError - - def _add_discrete(self, model: BaseModel, name: str) -> None: - raise NotImplementedError - - def _xget_one_hot_encoded(self, code: Key | None) -> None: - raise NotImplementedError + def _add_binary(model: BaseModel, name: str) -> int: + return model.add_var(name=name) + + def _xget_one_hot_encoded(self, code: Key | None) -> int: + if code is None: + msg = "Code is required for one-hot encoded features get" + raise ValueError(msg) + if code not in self.codes: + msg = f"Code '{code}' not found in the feature codes" + raise ValueError(msg) + return self._u[code] + + def _xget_numeric(self, mu: Key | None) -> int: + if mu is None: + msg = "Mu is required for numeric features get" + raise ValueError(msg) + if mu not in range(len(self.levels)): + msg = f"Mu '{mu}' not found in the feature levels" + raise ValueError(msg) + return self._mu[mu] diff --git a/ocean/maxsat/_variables/_tree.py b/ocean/maxsat/_variables/_tree.py index 2ee57c8..1393681 100644 --- a/ocean/maxsat/_variables/_tree.py +++ b/ocean/maxsat/_variables/_tree.py @@ -10,6 +10,8 @@ class TreeVar(Var, TreeKeeper, Mapping[NonNegativeInt, object]): PATH_VAR_NAME_FMT: str = "{name}_path" + _path: Mapping[NonNegativeInt, int] + def __init__( self, tree: TreeLike, @@ -19,7 +21,9 @@ def __init__( TreeKeeper.__init__(self, tree=tree) def build(self, model: BaseModel) -> None: - raise NotImplementedError + name = self.PATH_VAR_NAME_FMT.format(name=self._name) + self._path = self._add_path(model=model, name=name) + model.add_exactly_one(list(self._path.values())) def __len__(self) -> int: return self.n_nodes @@ -28,12 +32,15 @@ def __iter__(self) -> Iterator[NonNegativeInt]: return iter(range(self.n_nodes)) @validate_call - def __getitem__(self, node_id: NonNegativeInt) -> None: - raise NotImplementedError + def __getitem__(self, node_id: NonNegativeInt) -> int: + return self._path[node_id] def _add_path( self, model: BaseModel, name: str, - ) -> None: - raise NotImplementedError + ) -> Mapping[NonNegativeInt, int]: + return { + leaf.node_id: model.add_var(name=f"{name}[{leaf.node_id}]") + for leaf in self.leaves + } From 5422b8b3950752bf75ef40b33aef5d499470ed22 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 25 Nov 2025 11:57:36 -0500 Subject: [PATCH 06/18] fix: Refactor tree and explainer modules to improve AdaBoost handling and type annotations --- ocean/cp/_explainer.py | 6 ++++-- ocean/cp/_managers/_tree.py | 7 ++----- ocean/mip/_explainer.py | 4 ++-- ocean/mip/_managers/_tree.py | 7 ++++++- ocean/mip/_variables/_tree.py | 9 ++++++--- ocean/tree/_parse.py | 26 +++++++++++++++++--------- ocean/typing/__init__.py | 12 +++++++++--- 7 files changed, 46 insertions(+), 25 deletions(-) diff --git a/ocean/cp/_explainer.py b/ocean/cp/_explainer.py index 4adc628..bd04a16 100644 --- a/ocean/cp/_explainer.py +++ b/ocean/cp/_explainer.py @@ -3,6 +3,7 @@ import warnings from ortools.sat.python import cp_model as cp +from sklearn.ensemble import AdaBoostClassifier from ..abc import Mapper from ..feature import Feature @@ -11,6 +12,7 @@ Array1D, BaseExplainableEnsemble, BaseExplainer, + NonNegativeArray1D, NonNegativeInt, PositiveInt, ) @@ -25,13 +27,13 @@ def __init__( ensemble: BaseExplainableEnsemble, *, mapper: Mapper[Feature], - weights: Array1D | None = None, + weights: NonNegativeArray1D | None = None, epsilon: int = Model.DEFAULT_EPSILON, model_type: Model.Type = Model.Type.CP, ) -> None: ensembles = (ensemble,) trees = parse_ensembles(*ensembles, mapper=mapper) - if trees[0].adaboost: + if isinstance(ensemble, AdaBoostClassifier): weights = ensemble.estimator_weights_ Model.__init__( self, diff --git a/ocean/cp/_managers/_tree.py b/ocean/cp/_managers/_tree.py index dbfef77..ed22ac3 100644 --- a/ocean/cp/_managers/_tree.py +++ b/ocean/cp/_managers/_tree.py @@ -140,10 +140,7 @@ def weighted_function( if self._adaboost: # no need to scale since values are 0/1 # scaling is done later for tree weights - if np.argmax(leaf.value[op, :]) == c: - coeff = 1 - else: - coeff = 0 + coeff = int(np.argmax(leaf.value[op, :]) == c) else: coeff = int(leaf.value[op, c] * scale) coefs.append(coeff) @@ -151,7 +148,7 @@ def weighted_function( tree_expr = cp.LinearExpr.WeightedSum(variables, coefs) tree_exprs.append(tree_expr) if self._adaboost: - tree_weights.append(int(weight * scale)) + tree_weights.append(int(weight * scale)) else: tree_weights.append(int(weight)) expr = cp.LinearExpr.WeightedSum(tree_exprs, tree_weights) diff --git a/ocean/mip/_explainer.py b/ocean/mip/_explainer.py index b420b56..690d62b 100644 --- a/ocean/mip/_explainer.py +++ b/ocean/mip/_explainer.py @@ -2,7 +2,7 @@ import warnings import gurobipy as gp -from sklearn.ensemble import IsolationForest +from sklearn.ensemble import AdaBoostClassifier, IsolationForest from ..abc import Mapper from ..feature import Feature @@ -37,7 +37,7 @@ def __init__( ensembles = (ensemble,) if isolation is None else (ensemble, isolation) n_isolators, max_samples = self._get_isolation_params(isolation) trees = parse_ensembles(*ensembles, mapper=mapper) - if trees[0].adaboost: + if isinstance(ensemble, AdaBoostClassifier): weights = ensemble.estimator_weights_ Model.__init__( self, diff --git a/ocean/mip/_managers/_tree.py b/ocean/mip/_managers/_tree.py index e70d26f..ed80eb5 100644 --- a/ocean/mip/_managers/_tree.py +++ b/ocean/mip/_managers/_tree.py @@ -160,7 +160,12 @@ def create(item: tuple[int, Tree]) -> TreeVar: if tree.adaboost: self._adaboost = tree.adaboost name = self.TREE_VAR_FMT.format(t=t) - return TreeVar(tree, name=name, flow_type=flow_type, _adaboost=self._adaboost) + return TreeVar( + tree, + name=name, + flow_type=flow_type, + _adaboost=self._adaboost, + ) tree_vars = tuple(map(create, enumerate(trees))) if len(tree_vars) == 0: diff --git a/ocean/mip/_variables/_tree.py b/ocean/mip/_variables/_tree.py index 4bbc2c2..fee7908 100644 --- a/ocean/mip/_variables/_tree.py +++ b/ocean/mip/_variables/_tree.py @@ -1,8 +1,8 @@ from collections.abc import Iterator, Mapping from enum import Enum -import numpy as np import gurobipy as gp +import numpy as np from pydantic import validate_call from ...tree._keeper import TreeKeeper, TreeLike @@ -36,6 +36,7 @@ def __init__( TreeKeeper.__init__(self, tree=tree) self._set_builder(flow_type=flow_type) self._adaboost = _adaboost + @property def value(self) -> gp.MLinExpr: return self._value @@ -91,12 +92,14 @@ def _get_value(self) -> gp.MLinExpr: for leaf in self.leaves: if self._adaboost: # one-hot encode the confidence vector - val = np.zeros(leaf.value.shape[1]) + val: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = ( + np.zeros(leaf.value.shape[1]) + ) val[np.argmax(leaf.value)] = 1 else: val = leaf.value value += self._flow[leaf.node_id] * val - + return value def _get_length(self) -> gp.LinExpr: diff --git a/ocean/tree/_parse.py b/ocean/tree/_parse.py index eed4542..ef36693 100644 --- a/ocean/tree/_parse.py +++ b/ocean/tree/_parse.py @@ -4,8 +4,8 @@ from itertools import chain import xgboost as xgb -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.ensemble import AdaBoostClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from ..abc import Mapper from ..feature import Feature @@ -70,22 +70,30 @@ def _parse_node( return _build_node(tree, node_id, mapper=mapper) -def _parse_tree(sklearn_tree: SKLearnTree, *, - mapper: Mapper[Feature], - is_adaboost: bool = False) -> Tree: - tree = SKLearnTreeProtocol(sklearn_tree) - root = _parse_node(tree, 0, mapper=mapper) +def _parse_tree( + sklearn_tree: SKLearnTree, + *, + mapper: Mapper[Feature], + is_adaboost: bool = False, +) -> Tree: + sk_tree = SKLearnTreeProtocol(sklearn_tree) + root = _parse_node(sk_tree, 0, mapper=mapper) tree = Tree(root=root) if is_adaboost: tree.adaboost = True return tree -def parse_tree(tree: SKLearnDecisionTree, *, - mapper: Mapper[Feature], - is_adaboost: bool = False) -> Tree: + +def parse_tree( + tree: SKLearnDecisionTree, + *, + mapper: Mapper[Feature], + is_adaboost: bool = False, +) -> Tree: getter = operator.attrgetter("tree_") return _parse_tree(getter(tree), mapper=mapper, is_adaboost=is_adaboost) + def parse_trees( trees: Iterable[SKLearnDecisionTree], *, diff --git a/ocean/typing/__init__.py b/ocean/typing/__init__.py index 04fe37d..3c2e20b 100644 --- a/ocean/typing/__init__.py +++ b/ocean/typing/__init__.py @@ -5,9 +5,15 @@ import pandas as pd import xgboost as xgb from pydantic import Field -from sklearn.ensemble import IsolationForest, RandomForestClassifier, AdaBoostClassifier - -type BaseExplainableEnsemble = RandomForestClassifier | xgb.XGBClassifier | AdaBoostClassifier +from sklearn.ensemble import ( + AdaBoostClassifier, + IsolationForest, + RandomForestClassifier, +) + +type BaseExplainableEnsemble = ( + RandomForestClassifier | xgb.XGBClassifier | AdaBoostClassifier +) type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest | xgb.Booster type Number = float From 2ec0aa9feff0a16ec2bca3d14e08e458d0adad0f Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Tue, 25 Nov 2025 12:15:23 -0500 Subject: [PATCH 07/18] fix: Refactor MaxSATSolver integration and enhance explanation handling in the environment --- ocean/maxsat/_base.py | 4 +-- ocean/maxsat/_env.py | 64 ++++++++++++++++++++++++++++++++++++ ocean/maxsat/_explainer.py | 4 +-- ocean/maxsat/_explanation.py | 29 +++++++++++----- ocean/maxsat/_solver.py | 33 ------------------- 5 files changed, 88 insertions(+), 46 deletions(-) create mode 100644 ocean/maxsat/_env.py delete mode 100644 ocean/maxsat/_solver.py diff --git a/ocean/maxsat/_base.py b/ocean/maxsat/_base.py index 7a07668..66cc9a7 100644 --- a/ocean/maxsat/_base.py +++ b/ocean/maxsat/_base.py @@ -21,13 +21,13 @@ def add_var(self, name: str) -> int: if name in self.vpool.obj2id: # var has been already created msg = f"Variable with name '{name}' already exists." raise ValueError(msg) - return self.vpool.id(f"{name}") + return self.vpool.id(f"{name}") # type: ignore[no-any-return] def get_var(self, name: str) -> int: if name not in self.vpool.obj2id: # var has not been created msg = f"Variable with name '{name}' does not exist." raise ValueError(msg) - return self.vpool.obj2id[name] + return self.vpool.obj2id[name] # type: ignore[no-any-return] def add_hard(self, lits: list[int]) -> None: """Add a hard clause (must be satisfied).""" diff --git a/ocean/maxsat/_env.py b/ocean/maxsat/_env.py new file mode 100644 index 0000000..5311c07 --- /dev/null +++ b/ocean/maxsat/_env.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from pysat.examples.rc2 import RC2Stratified + +if TYPE_CHECKING: + from pysat.formula import WCNF + + +class Env: + _solver: MaxSATSolver + + def __init__(self) -> None: + self._solver = MaxSATSolver() + + @property + def solver(self) -> MaxSATSolver: + return self._solver + + @solver.setter + def solver(self, solver: MaxSATSolver) -> None: + self._solver = solver + + +class MaxSATSolver: + """Thin RC2 wrapper to keep a stable interface.""" + + _model: list[int] | None = None + + def __init__( + self, + solver_name: str = "cadical195", + TimeLimit: int = 60, + n_threads: int = 1, + ) -> None: + self.solver_name = solver_name + self.TimeLimit = TimeLimit + self.n_threads = n_threads + + def solve(self, w: WCNF) -> list[int]: + with RC2Stratified( + w, + solver=self.solver_name, + adapt=True, + blo="full", + exhaust=False, + minz=True, + ) as rc2: + model = cast("list[int] | None", rc2.compute()) + if model is None: + msg = "UNSAT: no counterfactual found." + raise RuntimeError(msg) + self._model = model + return model + + def model(self, v: int) -> float: + if self._model is None: + msg = "No model found, please run 'solve' first." + raise ValueError(msg) + return self._model[v] + + +ENV = Env() diff --git a/ocean/maxsat/_explainer.py b/ocean/maxsat/_explainer.py index fcec0b8..f2d16f3 100644 --- a/ocean/maxsat/_explainer.py +++ b/ocean/maxsat/_explainer.py @@ -10,8 +10,8 @@ NonNegativeInt, PositiveInt, ) +from ._env import ENV from ._model import Model -from ._solver import MaxSATSolver if TYPE_CHECKING: from ..abc import Mapper @@ -40,7 +40,7 @@ def __init__( model_type=model_type, ) self.build() - self.solver = MaxSATSolver + self.solver = ENV.solver def get_objective_value(self) -> float: raise NotImplementedError diff --git a/ocean/maxsat/_explanation.py b/ocean/maxsat/_explanation.py index 886aa05..a6fe402 100644 --- a/ocean/maxsat/_explanation.py +++ b/ocean/maxsat/_explanation.py @@ -1,15 +1,13 @@ from collections.abc import Mapping -from typing import TYPE_CHECKING import numpy as np +import pandas as pd from ..abc import Mapper from ..typing import Array1D, BaseExplanation, Key, Number +from ._env import ENV from ._variables import FeatureVar -if TYPE_CHECKING: - import pandas as pd - class Explanation(Mapper[FeatureVar], BaseExplanation): _epsilon: float = float(np.finfo(np.float32).eps) @@ -21,15 +19,28 @@ def vget(self, i: int) -> int: code = self.codes[i] return self[name].xget(code=code) if self[name].is_numeric: - j = int( - np.searchsorted(self[name].levels, self._x[i], side="left") # type: ignore[arg-type] + j: int = int( + np.searchsorted(self[name].levels, self._x[i], side="left") # pyright: ignore[reportUnknownArgumentType] ) return self[name].xget(mu=j) return self[name].xget() def to_series(self) -> "pd.Series[float]": - msg = "Not implemented." - raise NotImplementedError(msg) + values: list[float] = [ + ENV.solver.model(v) for v in map(self.vget, range(self.n_columns)) + ] + for f in range(self.n_columns): + name = self.names[f] + value = ENV.solver.model(self.vget(f)) + if self[name].is_continuous: + values[f] = self.format_continuous_value( + f, int(value), list(self[name].levels) + ) + elif self[name].is_discrete: + values[f] = self.format_discrete_value( + f, int(value), self[name].thresholds + ) + return pd.Series(values, index=self.columns) def to_numpy(self) -> Array1D: return ( @@ -50,7 +61,7 @@ def value(self) -> Mapping[Key, Key | Number]: msg = "Not implemented." raise NotImplementedError(msg) - def format_value( + def format_continuous_value( self, f: int, idx: int, diff --git a/ocean/maxsat/_solver.py b/ocean/maxsat/_solver.py deleted file mode 100644 index e86b7f4..0000000 --- a/ocean/maxsat/_solver.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from typing import cast - -from pysat.examples.rc2 import RC2 -from pysat.formula import WCNF - - -class MaxSATSolver: - """Thin RC2 wrapper to keep a stable interface.""" - - def __init__(self, solver_name: str = "glucose3") -> None: - self.solver_name = solver_name - - @staticmethod - def new_wcnf() -> WCNF: - return WCNF() - - @staticmethod - def add_hard(w: WCNF, clause: list[int]) -> None: - w.append(clause, weight=-1) - - @staticmethod - def add_soft(w: WCNF, clause: list[int], weight: int) -> None: - w.append(clause, weight=weight) - - def solve(self, w: WCNF) -> list[int]: - with RC2(w, solver=self.solver_name, adapt=True) as rc2: - model = cast("list[int] | None", rc2.compute()) - if model is None: - msg = "UNSAT: no counterfactual found." - raise RuntimeError(msg) - return model From 00f0f2727486c36f35c655b3030bc42e9759da63 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Thu, 27 Nov 2025 12:06:36 -0500 Subject: [PATCH 08/18] fix: Remove the GarbageManager class --- ocean/maxsat/_managers/_garbage.py | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 ocean/maxsat/_managers/_garbage.py diff --git a/ocean/maxsat/_managers/_garbage.py b/ocean/maxsat/_managers/_garbage.py deleted file mode 100644 index b8b21c0..0000000 --- a/ocean/maxsat/_managers/_garbage.py +++ /dev/null @@ -1,16 +0,0 @@ -class GarbageManager: - type GarbageObject = object - - # Garbage collector for the model. - # - Used to keep track of the variables and constraints created, - # and to remove them when the model is cleared. - _garbage: list[GarbageObject] - - def __init__(self) -> None: - self._garbage = [] - - def add_garbage(self, *args: GarbageObject) -> None: - self._garbage.extend(args) - - def remove_garbage(self) -> None: - raise NotImplementedError From 2ef684f64c8a2ad77c8f6e2aaf226dc192de4e6e Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Thu, 27 Nov 2025 12:06:56 -0500 Subject: [PATCH 09/18] feat: Enhance MaxSAT model and explainer with new methods and improved structure --- ocean/maxsat/__init__.py | 18 ++- ocean/maxsat/_base.py | 9 +- ocean/maxsat/_env.py | 30 +++- ocean/maxsat/_explainer.py | 50 ++++++- ocean/maxsat/_explanation.py | 58 ++++++-- ocean/maxsat/_model.py | 278 ++++++++++++++++++++++++++++++++++- 6 files changed, 414 insertions(+), 29 deletions(-) diff --git a/ocean/maxsat/__init__.py b/ocean/maxsat/__init__.py index 6915f01..c311836 100644 --- a/ocean/maxsat/__init__.py +++ b/ocean/maxsat/__init__.py @@ -1,3 +1,19 @@ +from ._base import BaseModel +from ._env import ENV from ._explainer import Explainer +from ._explanation import Explanation +from ._managers import FeatureManager, TreeManager +from ._model import Model +from ._variables import FeatureVar, TreeVar -__all__ = ["Explainer"] +__all__ = [ + "ENV", + "BaseModel", + "Explainer", + "Explanation", + "FeatureManager", + "FeatureVar", + "Model", + "TreeManager", + "TreeVar", +] diff --git a/ocean/maxsat/_base.py b/ocean/maxsat/_base.py index 66cc9a7..b49d71d 100644 --- a/ocean/maxsat/_base.py +++ b/ocean/maxsat/_base.py @@ -5,10 +5,11 @@ class BaseModel(ABC, WCNF): - vpool: IDPool = IDPool() + vpool: IDPool def __init__(self) -> None: WCNF.__init__(self) + self.vpool = IDPool() # Create new pool for each instance def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401 object.__setattr__(self, name, value) @@ -45,6 +46,12 @@ def add_exactly_one(self, lits: list[int]) -> None: for j in range(i + 1, len(lits)): self.add_hard([-lits[i], -lits[j]]) # at most one + def _clean_soft(self) -> None: + """Reset the model to only contain hard constraints.""" + self.soft.clear() + self.wght.clear() + self.topw = 1 + class Var(Protocol): _name: str diff --git a/ocean/maxsat/_env.py b/ocean/maxsat/_env.py index 5311c07..f3567ac 100644 --- a/ocean/maxsat/_env.py +++ b/ocean/maxsat/_env.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, cast -from pysat.examples.rc2 import RC2Stratified +from pysat.examples.rc2 import RC2 if TYPE_CHECKING: from pysat.formula import WCNF @@ -27,6 +27,7 @@ class MaxSATSolver: """Thin RC2 wrapper to keep a stable interface.""" _model: list[int] | None = None + _cost: float = float("inf") def __init__( self, @@ -39,11 +40,10 @@ def __init__( self.n_threads = n_threads def solve(self, w: WCNF) -> list[int]: - with RC2Stratified( + with RC2( w, solver=self.solver_name, adapt=True, - blo="full", exhaust=False, minz=True, ) as rc2: @@ -52,13 +52,35 @@ def solve(self, w: WCNF) -> list[int]: msg = "UNSAT: no counterfactual found." raise RuntimeError(msg) self._model = model + self._cost = rc2.cost return model def model(self, v: int) -> float: + """ + Return 1.0 if variable v is true in the model, 0.0 otherwise. + + Args: + v: Variable to check in the model. + + Returns: + 1.0 if variable v is true in the model, 0.0 otherwise. + + Raises: + ValueError: If no model has been found solve() must be called first. + + """ if self._model is None: msg = "No model found, please run 'solve' first." raise ValueError(msg) - return self._model[v] + # The model is a list of signed literals. + # Variable v is true if v is in the model, false if -v is in the model. + if v in self._model: + return 1.0 + return 0.0 + + @property + def cost(self) -> float: + return self._cost ENV = Env() diff --git a/ocean/maxsat/_explainer.py b/ocean/maxsat/_explainer.py index f2d16f3..e395fec 100644 --- a/ocean/maxsat/_explainer.py +++ b/ocean/maxsat/_explainer.py @@ -1,7 +1,10 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING +from sklearn.ensemble import AdaBoostClassifier + from ..tree import parse_ensembles from ..typing import ( Array1D, @@ -20,6 +23,10 @@ class Explainer(Model, BaseExplainer): + """MaxSAT-based explainer for tree ensemble classifiers.""" + + Status: str = "UNKNOWN" + def __init__( self, ensemble: BaseExplainableEnsemble, @@ -31,6 +38,8 @@ def __init__( ) -> None: ensembles = (ensemble,) trees = parse_ensembles(*ensembles, mapper=mapper) + if isinstance(ensemble, AdaBoostClassifier): + weights = ensemble.estimator_weights_ Model.__init__( self, trees, @@ -43,10 +52,10 @@ def __init__( self.solver = ENV.solver def get_objective_value(self) -> float: - raise NotImplementedError + return self.solver.cost / self._obj_scale def get_solving_status(self) -> str: - raise NotImplementedError + return self.Status def get_anytime_solutions(self) -> list[dict[str, float]] | None: raise NotImplementedError @@ -57,10 +66,35 @@ def explain( *, y: NonNegativeInt, norm: PositiveInt, - return_callback: bool = False, - verbose: bool = False, - max_time: int = 60, - num_workers: int | None = None, - random_seed: int = 42, + verbose: bool = False, # noqa: ARG002 + max_time: int = 60, # noqa: ARG002 + random_seed: int = 42, # noqa: ARG002 ) -> Explanation | None: - raise NotImplementedError + # Add objective soft clauses + self.add_objective(x, norm=norm) + + # Add hard constraints for target class + self.set_majority_class(y=y) + + try: + # Solve the MaxSAT problem + self.solver.solve(self) + self.Status = "OPTIMAL" + except RuntimeError as e: + if "UNSAT" in str(e): + self.Status = "INFEASIBLE" + msg = "There are no feasible counterfactuals for this query." + msg += " If there should be one, please check the model " + msg += "constraints or report this issue to the developers." + warnings.warn(msg, category=UserWarning, stacklevel=2) + self.cleanup() + return None + raise + else: + # Store the query in the explanation + self.explanation.query = x + + # Clean up for next solve + self.cleanup() + + return self.explanation diff --git a/ocean/maxsat/_explanation.py b/ocean/maxsat/_explanation.py index a6fe402..519caf2 100644 --- a/ocean/maxsat/_explanation.py +++ b/ocean/maxsat/_explanation.py @@ -25,21 +25,61 @@ def vget(self, i: int) -> int: return self[name].xget(mu=j) return self[name].xget() + def _get_active_mu_index( + self, + name: Key, + for_discrete: bool = False, # noqa: FBT001, FBT002 + ) -> int: + """ + Find which mu variable is set to true for a numeric feature. + + Returns: + Index of the active mu variable, or 0 if none found. + + """ + if for_discrete: + # For discrete: one mu per level + n_vars = len(self[name].levels) + else: + # For continuous: one mu per interval + n_vars = len(self[name].levels) - 1 + for mu_idx in range(n_vars): + var = self[name].xget(mu=mu_idx) + if ENV.solver.model(var) > 0: + return mu_idx + return 0 # Default to first if none found + def to_series(self) -> "pd.Series[float]": - values: list[float] = [ - ENV.solver.model(v) for v in map(self.vget, range(self.n_columns)) - ] + values: list[float] = [] for f in range(self.n_columns): name = self.names[f] - value = ENV.solver.model(self.vget(f)) - if self[name].is_continuous: - values[f] = self.format_continuous_value( - f, int(value), list(self[name].levels) + if self[name].is_one_hot_encoded: + code = self.codes[f] + var = self[name].xget(code=code) + values.append(ENV.solver.model(var)) + elif self[name].is_continuous: + mu_idx = self._get_active_mu_index(name, for_discrete=False) + values.append( + self.format_continuous_value( + f, mu_idx, list(self[name].levels) + ) ) elif self[name].is_discrete: - values[f] = self.format_discrete_value( - f, int(value), self[name].thresholds + # For discrete features, mu[i] means value == levels[i] + mu_idx = self._get_active_mu_index(name, for_discrete=True) + levels = list(self[name].levels) + discrete_val = int(levels[mu_idx]) + values.append( + self.format_discrete_value( + f, discrete_val, self[name].levels + ) ) + elif self[name].is_binary: + var = self[name].xget() + values.append(ENV.solver.model(var)) + else: + var = self[name].xget() + values.append(ENV.solver.model(var)) return pd.Series(values, index=self.columns) def to_numpy(self) -> Array1D: diff --git a/ocean/maxsat/_model.py b/ocean/maxsat/_model.py index 0aeab76..1d1fb6d 100644 --- a/ocean/maxsat/_model.py +++ b/ocean/maxsat/_model.py @@ -1,12 +1,16 @@ from __future__ import annotations -from dataclasses import dataclass +import itertools from enum import Enum from typing import TYPE_CHECKING +import numpy as np +from pydantic import validate_call + +from ..typing import NonNegativeInt from ._base import BaseModel from ._builder.model import ModelBuilder, ModelBuilderFactory -from ._managers import FeatureManager, GarbageManager, TreeManager +from ._managers import FeatureManager, TreeManager if TYPE_CHECKING: from collections.abc import Iterable @@ -14,14 +18,15 @@ from ..abc import Mapper from ..feature import Feature from ..tree import Tree - from ..typing import NonNegativeArray1D, NonNegativeInt + from ..typing import Array1D, Key, NonNegativeArray1D, NonNegativeInt + from ._variables import FeatureVar -@dataclass -class Model(BaseModel, FeatureManager, TreeManager, GarbageManager): +class Model(BaseModel, FeatureManager, TreeManager): # Model builder for the ensemble. _builder: ModelBuilder DEFAULT_EPSILON: int = 1 + _obj_scale: int = int(1e8) class Type(Enum): MAXSAT = "MAXSAT" @@ -43,7 +48,6 @@ def __init__( weights=weights, ) FeatureManager.__init__(self, mapper=mapper) - GarbageManager.__init__(self) self._set_weights(weights=weights) self._max_samples = max_samples @@ -55,6 +59,268 @@ def build(self) -> None: self.build_trees(self) self._builder.build(self, trees=self.trees, mapper=self.mapper) + def add_objective( + self, + x: Array1D, + *, + norm: int = 1, + ) -> None: + if x.size != self.mapper.n_columns: + msg = f"Expected {self.mapper.n_columns} values, got {x.size}" + raise ValueError(msg) + if norm != 1: + msg = f"Unsupported norm: {norm}" + raise ValueError(msg) + + x_arr = np.asarray(x, dtype=float).ravel() + variables = self.mapper.values() + names = [n for n, _ in self.mapper.items()] + k = 0 + indexer = self.mapper.idx + + for v, name in zip(variables, names, strict=True): + if v.is_one_hot_encoded: + for code in v.codes: + idx = indexer.get(name, code) + self._add_soft_l1_ohe(x_arr[idx], v, code=code) + k += 1 + elif v.is_continuous: + self._add_soft_l1_continuous(x_arr[k], v) + k += 1 + elif v.is_discrete: + self._add_soft_l1_discrete(x_arr[k], v) + k += 1 + elif v.is_binary: + self._add_soft_l1_binary(x_arr[k], v) + k += 1 + else: + k += 1 + + def _add_soft_l1_binary(self, x_val: float, v: FeatureVar) -> None: + """Add soft clause for binary feature.""" + weight = int(self._obj_scale) + x_var = v.xget() + binary_threshold = 0.5 + if x_val > binary_threshold: + # If x=1, penalize flipping to 0 + self.add_soft([x_var], weight=weight) + else: + # If x=0, penalize flipping to 1 + self.add_soft([-x_var], weight=weight) + + def _add_soft_l1_ohe( + self, + x_val: float, + v: FeatureVar, + code: Key, + ) -> None: + """Add soft clause for one-hot encoded feature.""" + weight = int(self._obj_scale / 2) # OHE uses half weight + x_var = v.xget(code=code) + binary_threshold = 0.5 + if x_val > binary_threshold: + self.add_soft([x_var], weight=weight) + else: + self.add_soft([-x_var], weight=weight) + + def _add_soft_l1_continuous(self, x_val: float, v: FeatureVar) -> None: + """Add soft clauses for continuous feature intervals.""" + levels = v.levels + intervals_cost = self._get_intervals_cost(levels, x_val) + + for i in range(len(levels) - 1): + cost = intervals_cost[i] + if cost > 0: + mu_var = v.xget(mu=i) + self.add_soft([-mu_var], weight=cost) + + def _add_soft_l1_discrete(self, x_val: float, v: FeatureVar) -> None: + """ + Add soft clauses for discrete feature. + + For discrete features, mu[i] means value == levels[i]. + Penalize each level based on distance from x_val. + """ + levels = v.levels + + for i in range(len(levels)): + level_val = levels[i] + if level_val == x_val: + # No cost if this is the same value + continue + cost = int(abs(x_val - level_val) * self._obj_scale) + if cost > 0: + mu_var = v.xget(mu=i) + self.add_soft([-mu_var], weight=cost) + + def _get_intervals_cost(self, levels: Array1D, x: float) -> list[int]: + """ + Compute cost for each interval based on distance from x. + + Returns: + List of integer costs for each interval based on distance from x. + + """ + intervals_cost = np.zeros(len(levels) - 1, dtype=int) + for i in range(len(intervals_cost)): + if levels[i] < x <= levels[i + 1]: + continue + if levels[i] > x: + intervals_cost[i] = int(abs(x - levels[i]) * self._obj_scale) + elif levels[i + 1] < x: + intervals_cost[i] = int( + abs(x - levels[i + 1]) * self._obj_scale + ) + return intervals_cost.tolist() + + @validate_call + def set_majority_class( + self, + y: NonNegativeInt, + *, + op: NonNegativeInt = 0, + ) -> None: + """ + Set hard constraints to enforce majority vote for class y. + + Raises: + ValueError: If y is greater than or equal to the number of classes. + + """ + if y >= self.n_classes: + msg = f"Expected class < {self.n_classes}, got {y}" + raise ValueError(msg) + + self._set_majority_class(y, op=op) + + def _set_majority_class( + self, + y: NonNegativeInt, + *, + op: NonNegativeInt = 0, + ) -> None: + """ + Add hard constraints to enforce class y gets majority vote. + + For sklearn's RandomForestClassifier, the predicted class is the one + with the highest mean probability across all trees (soft voting). + + We encode this as: for each class c != y, + sum(prob_y - prob_c) >= epsilon + where epsilon > 0 if c < y (for tie-breaking), else epsilon >= 0. + + Since MaxSAT doesn't directly support weighted sums, we use a + discretized approach with auxiliary variables. + """ + scale = 10000 # Scale factor for probabilities + + for class_ in range(self.n_classes): + if class_ == y: + continue + + # Compute the score difference for each leaf in each tree + # We need: sum over trees of (prob_y - prob_c) >= epsilon + + # For each tree, compute min and max possible contributions + tree_contributions: list[list[tuple[int, int]]] = [] + for tree in self.trees: + contribs: list[tuple[int, int]] = [] + for leaf in tree.leaves: + prob_y = leaf.value[op, y] + prob_c = leaf.value[op, class_] + diff = int((prob_y - prob_c) * scale) + leaf_var = tree[leaf.node_id] + contribs.append((leaf_var, diff)) + tree_contributions.append(contribs) + + # Threshold for comparison + epsilon = self._epsilon if class_ < y else 0 + + # Use iterative bounds propagation to encode the constraint + # For each tree, we track the range of possible partial sums + self._encode_weighted_sum_constraint(tree_contributions, epsilon) + + def _encode_weighted_sum_constraint( + self, + tree_contributions: list[list[tuple[int, int]]], + threshold: int, + ) -> None: + """ + Encode constraint: sum of selected contributions >= threshold. + + Uses dynamic programming to compute reachable sums and encodes + constraints to ensure the sum meets the threshold. + """ + n_trees = len(tree_contributions) + + # For each tree, compute min and max contribution + tree_bounds: list[tuple[int, int]] = [] + for contribs in tree_contributions: + min_c = min(c[1] for c in contribs) + max_c = max(c[1] for c in contribs) + tree_bounds.append((min_c, max_c)) + + # Compute global bounds + global_min = sum(b[0] for b in tree_bounds) + global_max = sum(b[1] for b in tree_bounds) + + # If even the maximum sum is below threshold, UNSAT + if global_max < threshold: + self.add_hard([]) + return + + # If the minimum sum meets threshold, constraint is always satisfied + if global_min >= threshold: + return + + # Compute prefix max and suffix max for bounds checking + prefix_max = [0] * (n_trees + 1) + for i in range(n_trees): + prefix_max[i + 1] = prefix_max[i] + tree_bounds[i][1] + + suffix_max = [0] * (n_trees + 1) + for i in range(n_trees - 1, -1, -1): + suffix_max[i] = suffix_max[i + 1] + tree_bounds[i][1] + + # For each leaf, check if it can possibly be part of a valid solution + # A leaf with contribution c at tree t is forbidden if: + # prefix_max[t] + c + suffix_max[t+1] < threshold (i.e., even with + # best choices for all other trees, can't reach threshold) + for t_idx, contribs in enumerate(tree_contributions): + for leaf_var, contrib in contribs: + best_case = prefix_max[t_idx] + contrib + suffix_max[t_idx + 1] + if best_case < threshold: + self.add_hard([-leaf_var]) + + # Enumerate and forbid all bad combinations + # A combination is bad if sum(contributions) < threshold + max_trees = 10 + if n_trees <= max_trees: # Only for small forests + self._enumerate_bad_combinations(tree_contributions, threshold) + + def _enumerate_bad_combinations( + self, + tree_contributions: list[list[tuple[int, int]]], + threshold: int, + ) -> None: + """Enumerate and forbid all combinations with sum < threshold.""" + # Get leaf indices for each tree + tree_leaves = [ + [(lv, c) for lv, c in contribs] for contribs in tree_contributions + ] + + # Enumerate all combinations + for combo in itertools.product(*tree_leaves): + total = sum(c for _, c in combo) + if total < threshold: + # This combination is bad, TODO find efficient way + # At least one of the leaves must NOT be selected + clause = [-lv for lv, _ in combo] + self.add_hard(clause) + + def cleanup(self) -> None: + self._clean_soft() + def _set_builder(self, model_type: Type) -> None: match model_type: case Model.Type.MAXSAT: From ae1d3ecee380b6745f5a5007dee78d1b420e987b Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Thu, 27 Nov 2025 12:07:17 -0500 Subject: [PATCH 10/18] fix: Improve handling of sigma conditions in MaxSATBuilder methods --- ocean/maxsat/_builder/model.py | 54 +++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/ocean/maxsat/_builder/model.py b/ocean/maxsat/_builder/model.py index c716e44..12a66ff 100644 --- a/ocean/maxsat/_builder/model.py +++ b/ocean/maxsat/_builder/model.py @@ -106,10 +106,12 @@ def _bset( v: FeatureVar, sigma: bool, ) -> None: + # sigma=True => left child (x <= 0.5, i.e., x=0) + # sigma=False => right child (x > 0.5, i.e., x=1) if sigma: - model.add_hard([-y, v.xget()]) - else: model.add_hard([-y, -v.xget()]) + else: + model.add_hard([-y, v.xget()]) @staticmethod def _cset( @@ -120,13 +122,26 @@ def _cset( v: FeatureVar, sigma: bool, ) -> None: + # For continuous features: + # j = searchsorted(levels, threshold) gives index where threshold fits + # sigma=True => left child (x <= threshold) + # sigma=False => right child (x > threshold) threshold = node.threshold j = int(np.searchsorted(v.levels, threshold, side="left")) - mu = v.xget(mu=j - 1) + n_intervals = len(v.levels) - 1 + if sigma: - model.add_hard([-y, mu]) + # Left branch: x <= threshold, so x is in interval 0, 1, ..., j-1 + # Forbid intervals j, j+1, ..., n-2 + for i in range(j, n_intervals): + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) else: - model.add_hard([-y, -mu]) + # Right branch: x > threshold, so x is in interval j, j+1, ..., n-2 + # Forbid intervals 0, 1, ..., j-1 + for i in range(j): + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) @staticmethod def _dset( @@ -137,13 +152,28 @@ def _dset( v: FeatureVar, sigma: bool, ) -> None: + # For discrete features: + # sigma=True => left child (x <= threshold) + # sigma=False => right child (x > threshold) + # + # mu[i] => value == levels[i] threshold = node.threshold - j = int(np.searchsorted(v.levels, threshold, side="left")) - x = v.xget(mu=j) + n_values = len(v.levels) + if sigma: - model.add_hard([-y, x]) + # Left branch: x <= threshold + # Forbid values where levels[i] > threshold + for i in range(n_values): + if v.levels[i] > threshold: + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) else: - model.add_hard([-y, x]) + # Right branch: x > threshold + # Forbid values where levels[i] <= threshold + for i in range(n_values): + if v.levels[i] <= threshold: + mu = v.xget(mu=i) + model.add_hard([-y, -mu]) @staticmethod def _eset( @@ -154,11 +184,13 @@ def _eset( v: FeatureVar, sigma: bool, ) -> None: + # sigma=True (left child): category != code, so u[code] = False + # sigma=False (right child): category == code, so u[code] = True x = v.xget(code=node.code) if sigma: - model.add_hard([-y, x]) - else: model.add_hard([-y, -x]) + else: + model.add_hard([-y, x]) class ModelBuilderFactory: From 98fc563485cc655cb1132b24c2f0954fdd1b4c36 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Thu, 27 Nov 2025 12:07:44 -0500 Subject: [PATCH 11/18] fix: Enhance FeatureVar with exact constraints and improved mu variable handling --- ocean/maxsat/_managers/__init__.py | 2 -- ocean/maxsat/_managers/_tree.py | 27 ++++++++++++++++++----- ocean/maxsat/_variables/_feature.py | 33 +++++++++++++++++++++++------ 3 files changed, 49 insertions(+), 13 deletions(-) diff --git a/ocean/maxsat/_managers/__init__.py b/ocean/maxsat/_managers/__init__.py index 450d21d..056784d 100644 --- a/ocean/maxsat/_managers/__init__.py +++ b/ocean/maxsat/_managers/__init__.py @@ -1,9 +1,7 @@ from ._feature import FeatureManager -from ._garbage import GarbageManager from ._tree import TreeManager __all__ = [ "FeatureManager", - "GarbageManager", "TreeManager", ] diff --git a/ocean/maxsat/_managers/_tree.py b/ocean/maxsat/_managers/_tree.py index e3c1d74..7bd4940 100644 --- a/ocean/maxsat/_managers/_tree.py +++ b/ocean/maxsat/_managers/_tree.py @@ -90,11 +90,28 @@ def _set_weights(self, weights: NonNegativeArray1D | None = None) -> None: def weighted_function( self, - weights: NonNegativeArray1D, - ) -> dict[tuple[NonNegativeInt, NonNegativeInt], object]: - raise NotImplementedError + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], list[int]]: + func: dict[tuple[NonNegativeInt, NonNegativeInt], list[int]] = {} + n_classes = self.n_classes + n_outputs = self.shape[-2] + for op in range(n_outputs): + for c in range(n_classes): + leaf_vars: list[int] = [] + for tree in self.estimators: + for leaf in tree.leaves: + leaf_class = int(np.argmax(leaf.value[op, :])) + if leaf_class == c: + leaf_vars.append(tree[leaf.node_id]) + func[op, c] = leaf_vars + return func def _get_function( self, - ) -> dict[tuple[NonNegativeInt, NonNegativeInt], object]: - return self.weighted_function(weights=self.weights) + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], list[int]]: + return self.weighted_function() + + @property + def function( + self, + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], list[int]]: + return self._function diff --git a/ocean/maxsat/_variables/_feature.py b/ocean/maxsat/_variables/_feature.py index 329fd25..0986a15 100644 --- a/ocean/maxsat/_variables/_feature.py +++ b/ocean/maxsat/_variables/_feature.py @@ -22,6 +22,8 @@ def build(self, model: BaseModel) -> None: self._x = self._add_x(model) if self.is_numeric: self._mu = self._add_mu(model) + # Add exactly-one constraint for mu variables (exactly one interval) + model.add_exactly_one(list(self._mu.values())) if self.is_one_hot_encoded: self._u = self._add_u(model) @@ -65,9 +67,19 @@ def _add_one_hot_encoded( def _add_mu(self, model: BaseModel) -> Mapping[Key, int]: name = self._name.format(name=self._name) + if self.is_discrete: + # For discrete features: one mu variable per level (value) + # mu[i] means value == levels[i] + n_values = len(self.levels) + return { + lv: model.add_var(name=f"{name}[{lv}]") + for lv in range(n_values) + } + # For continuous features: n-1 mu variables for n levels (intervals) + # mu[i] means value in interval (levels[i], levels[i+1]] + n_intervals = len(self.levels) - 1 return { - lv: model.add_var(name=f"{name}[{lv}]") - for lv in range(len(self.levels)) + lv: model.add_var(name=f"{name}[{lv}]") for lv in range(n_intervals) } @staticmethod @@ -85,9 +97,18 @@ def _xget_one_hot_encoded(self, code: Key | None) -> int: def _xget_numeric(self, mu: Key | None) -> int: if mu is None: - msg = "Mu is required for numeric features get" - raise ValueError(msg) - if mu not in range(len(self.levels)): - msg = f"Mu '{mu}' not found in the feature levels" + msg = "mu is required to get numeric features" raise ValueError(msg) + if self.is_discrete: + # For discrete: mu[i] represents value levels[i] + n_values = len(self.levels) + if mu not in range(n_values): + msg = f"mu '{mu}' not in values (0 to {n_values - 1})" + raise ValueError(msg) + else: + # For continuous: mu[i] represents interval (levels[i], levels[i+1]] + n_intervals = len(self.levels) - 1 + if mu not in range(n_intervals): + msg = f"mu '{mu}' not in intervals (0 to {n_intervals - 1})" + raise ValueError(msg) return self._mu[mu] From 1216a3de0834b61d327a851d2102dda5cf244d58 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Thu, 27 Nov 2025 12:08:02 -0500 Subject: [PATCH 12/18] feat: Add MaxSAT explainer to query example and update explainer options --- examples/maxsat_example.py | 78 ++++++++++++++++++++++++++++++++++++++ examples/query.py | 14 +++++-- 2 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 examples/maxsat_example.py diff --git a/examples/maxsat_example.py b/examples/maxsat_example.py new file mode 100644 index 0000000..4078275 --- /dev/null +++ b/examples/maxsat_example.py @@ -0,0 +1,78 @@ +from sklearn.ensemble import RandomForestClassifier + +from ocean import ( + ConstraintProgrammingExplainer, + MaxSATExplainer, + MixedIntegerProgramExplainer, +) +from ocean.datasets import load_adult + +# Load the adult dataset +(data, target), mapper = load_adult(scale=True) + +# Train a random forest classifier +rf = RandomForestClassifier(n_estimators=5, max_depth=3, random_state=42) +rf.fit(data, target) + +# Select an instance to explain from the dataset +x = data.iloc[19].to_frame().T +x_np = x.to_numpy().flatten() + +# Predict the class of the instance +y_pred = int(rf.predict(x).item()) +target_class = 1 - y_pred # Binary classification - choose opposite class + +print(f"Instance shape: {x_np.shape}") +print(f"Original prediction: {y_pred}") +print(f"Target counterfactual class: {target_class}") + +# Explain the prediction using MaxSATExplainer +print("\n--- MaxSAT Explainer ---") +try: + maxsat_model = MaxSATExplainer(rf, mapper=mapper) + maxsat_explanation = maxsat_model.explain(x_np, y=target_class, norm=1) + if maxsat_explanation is not None: + cf_np = maxsat_explanation.to_numpy() + print("MaxSAT CF:", cf_np) + print("MaxSAT CF prediction:", rf.predict([cf_np])[0]) + print("Objective value:", maxsat_model.get_objective_value()) + print("Status:", maxsat_model.get_solving_status()) + else: + print("MaxSAT: No counterfactual found.") +except (ValueError, RuntimeError, ImportError) as e: + import traceback + + print(f"MaxSAT Error: {e}") + traceback.print_exc() + +# Explain the prediction using MIPExplainer +print("\n--- MIP Explainer ---") +try: + mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper) + mip_explanation = mip_model.explain(x_np, y=target_class, norm=1) + if mip_explanation is not None: + cf_np = mip_explanation.to_numpy() + print("MIP CF:", cf_np) + print("MIP CF prediction:", rf.predict([cf_np])[0]) + print("Objective value:", mip_model.get_objective_value()) + print("Status:", mip_model.get_solving_status()) + else: + print("MIP: No counterfactual found.") +except (ValueError, RuntimeError, ImportError) as e: + print(f"MIP Error: {e}") + +# Explain the prediction using CPExplainer +print("\n--- CP Explainer ---") +try: + cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) + cp_explanation = cp_model.explain(x_np, y=target_class, norm=1) + if cp_explanation is not None: + cf_np = cp_explanation.to_numpy() + print("CP CF:", cf_np) + print("CP CF prediction:", rf.predict([cf_np])[0]) + print("Objective value:", cp_model.get_objective_value()) + print("Status:", cp_model.get_solving_status()) + else: + print("CP: No counterfactual found.") +except (ValueError, RuntimeError, ImportError) as e: + print(f"CP Error: {e}") diff --git a/examples/query.py b/examples/query.py index 53284b5..5cb6531 100644 --- a/examples/query.py +++ b/examples/query.py @@ -13,6 +13,7 @@ from ocean import ( ConstraintProgrammingExplainer, + MaxSATExplainer, MixedIntegerProgramExplainer, ) from ocean.abc import Mapper @@ -28,6 +29,7 @@ EXPLAINERS = { "mip": MixedIntegerProgramExplainer, "cp": ConstraintProgrammingExplainer, + "maxsat": MaxSATExplainer, } MODELS = { "rf": RandomForestClassifier, @@ -75,8 +77,8 @@ def create_argument_parser() -> ArgumentParser: help="List of explainers to use", type=str, nargs="+", - choices=["mip", "cp"], - default=["mip", "cp"], + choices=["mip", "cp", "maxsat"], + default=["mip", "cp", "maxsat"], ) parser.add_argument( "-m", @@ -153,7 +155,8 @@ def fit_model_with_console( def build_explainer( explainer_name: str, explainer_class: type[MixedIntegerProgramExplainer] - | type[ConstraintProgrammingExplainer], + | type[ConstraintProgrammingExplainer] + | type[MaxSATExplainer], args: Args, model: RandomForestClassifier | XGBClassifier, mapper: Mapper[Feature], @@ -163,7 +166,10 @@ def build_explainer( if explainer_class is MixedIntegerProgramExplainer: ENV.setParam("Seed", args.seed) exp = explainer_class(model, mapper=mapper, env=ENV) - elif explainer_class is ConstraintProgrammingExplainer: + elif ( + explainer_class is ConstraintProgrammingExplainer + or explainer_class is MaxSATExplainer + ): exp = explainer_class(model, mapper=mapper) else: msg = f"Unknown explainer type: {explainer_class}" From 6351dda1ee461178a38b2078f0a544b10e6dc536 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Thu, 27 Nov 2025 12:08:16 -0500 Subject: [PATCH 13/18] =?UTF-8?q?feat:=20Ajouter=20des=20tests=20pour=20le?= =?UTF-8?q?=20mod=C3=A8le=20MaxSAT=20et=20l'explainer=20MaxSAT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/maxsat/__init__.py | 1 + tests/maxsat/model/__init__.py | 1 + tests/maxsat/model/test_feasible.py | 99 ++++++++++++++++ tests/maxsat/model/test_init.py | 111 ++++++++++++++++++ tests/maxsat/model/test_objective.py | 118 +++++++++++++++++++ tests/maxsat/utils.py | 169 +++++++++++++++++++++++++++ tests/test_explainer.py | 44 ++++++- 7 files changed, 542 insertions(+), 1 deletion(-) create mode 100644 tests/maxsat/__init__.py create mode 100644 tests/maxsat/model/__init__.py create mode 100644 tests/maxsat/model/test_feasible.py create mode 100644 tests/maxsat/model/test_init.py create mode 100644 tests/maxsat/model/test_objective.py create mode 100644 tests/maxsat/utils.py diff --git a/tests/maxsat/__init__.py b/tests/maxsat/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/maxsat/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/maxsat/model/__init__.py b/tests/maxsat/model/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/maxsat/model/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/maxsat/model/test_feasible.py b/tests/maxsat/model/test_feasible.py new file mode 100644 index 0000000..489c88a --- /dev/null +++ b/tests/maxsat/model/test_feasible.py @@ -0,0 +1,99 @@ +import numpy as np +import pytest + +from ocean.maxsat import ENV, Model +from ocean.tree import parse_trees + +from ..utils import ( + MAX_DEPTH, + N_CLASSES, + N_ESTIMATORS, + N_SAMPLES, + SEEDS, + train_rf, + validate_paths, + validate_sklearn_pred, + validate_solution, +) + + +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("n_estimators", N_ESTIMATORS) +@pytest.mark.parametrize("max_depth", MAX_DEPTH) +@pytest.mark.parametrize("n_samples", N_SAMPLES) +@pytest.mark.parametrize("n_classes", N_CLASSES) +class TestNoIsolation: + @staticmethod + def test_build( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + # Get predictions and pick a class to target + predictions = np.array(clf.predict(data), dtype=np.int64) + target_class = int(predictions[0]) + + model = Model(trees=trees, mapper=mapper) + model.build() + model.set_majority_class(y=target_class) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + validate_sklearn_pred(clf, explanation, m_class=target_class) + + @staticmethod + def test_set_majority_class( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + predictions = np.array(clf.predict(data), dtype=np.int64) + classes = set(map(int, predictions.flatten())) + + for class_ in classes: + model = Model(trees=trees, mapper=mapper) + model.build() + model.set_majority_class(y=class_) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + validate_sklearn_pred(clf, explanation, m_class=class_) + model.cleanup() diff --git a/tests/maxsat/model/test_init.py b/tests/maxsat/model/test_init.py new file mode 100644 index 0000000..9cd5692 --- /dev/null +++ b/tests/maxsat/model/test_init.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest + +from ocean.abc import Mapper +from ocean.maxsat import Model +from ocean.tree import parse_trees + +from ..utils import ( + MAX_DEPTH, + N_CLASSES, + N_ESTIMATORS, + N_SAMPLES, + SEEDS, + train_rf, +) + + +def test_no_trees() -> None: + msg = r"At least one tree is required." + with pytest.raises(ValueError, match=msg): + Model(trees=[], mapper=Mapper()) + + +def test_no_features() -> None: + msg = r"At least one feature is required." + rf, mapper = train_rf(42, 2, 2, 100, 2) + trees = tuple(parse_trees(rf, mapper=mapper)) + with pytest.raises(ValueError, match=msg): + Model(trees=trees, mapper=Mapper()) + + +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("n_estimators", N_ESTIMATORS) +@pytest.mark.parametrize("max_depth", MAX_DEPTH) +@pytest.mark.parametrize("n_samples", N_SAMPLES) +@pytest.mark.parametrize("n_classes", N_CLASSES) +class TestNoIsolation: + @staticmethod + def test_no_weights( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + ) + trees = parse_trees(clf, mapper=mapper) + model = Model(trees=trees, mapper=mapper) + expected_weights = np.ones(n_estimators, dtype=float) + assert model is not None + assert model.n_estimators == n_estimators + assert model.n_classes == n_classes + assert model.weights.shape == expected_weights.shape + assert np.isclose(model.weights, expected_weights).all() + + @staticmethod + def test_weights( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + ) + trees = parse_trees(clf, mapper=mapper) + generator = np.random.default_rng(seed) + weights = generator.random(n_estimators).flatten() + model = Model(trees=trees, mapper=mapper, weights=weights) + assert model is not None + assert model.n_estimators == n_estimators + assert model.n_classes == n_classes + assert model.weights.shape == weights.shape + assert np.isclose(model.weights, weights).all() + + @staticmethod + def test_invalid_weights( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + generator = np.random.default_rng(seed) + shapes = [generator.integers(n_estimators + 1, 2 * n_estimators + 1)] + if n_estimators > 2: + shapes += [generator.integers(1, n_estimators - 1)] + for shape in shapes: + weights = generator.random(shape).flatten() + msg = r"The number of weights must match the number of trees." + with pytest.raises(ValueError, match=msg): + Model(trees=trees, mapper=mapper, weights=weights) diff --git a/tests/maxsat/model/test_objective.py b/tests/maxsat/model/test_objective.py new file mode 100644 index 0000000..e1daeb6 --- /dev/null +++ b/tests/maxsat/model/test_objective.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from ocean.maxsat import ENV, Model +from ocean.tree import parse_trees + +from ..utils import ( + MAX_DEPTH, + N_CLASSES, + N_ESTIMATORS, + N_SAMPLES, + SEEDS, + check_solution, + train_rf, + validate_paths, + validate_sklearn_pred, + validate_solution, +) + +P_QUERIES = 0.2 + + +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("n_estimators", N_ESTIMATORS) +@pytest.mark.parametrize("max_depth", MAX_DEPTH) +@pytest.mark.parametrize("n_samples", N_SAMPLES) +@pytest.mark.parametrize("n_classes", N_CLASSES) +class TestNoIsolation: + @staticmethod + def test_build( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + n_queries = int(data.shape[0] * P_QUERIES) + generator = np.random.default_rng(seed) + queries = generator.choice( + range(len(data)), size=n_queries, replace=False + ) + + for query in queries: + x = np.array(data.to_numpy()[query], dtype=np.float64).flatten() + + model = Model(trees=trees, mapper=mapper) + model.build() + model.add_objective(x=x) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + check_solution(x, explanation) + model.cleanup() + + @staticmethod + def test_set_majority_class( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + ) -> None: + clf, mapper, data = train_rf( + seed, + n_estimators, + max_depth, + n_samples, + n_classes, + return_data=True, + ) + trees = tuple(parse_trees(clf, mapper=mapper)) + + predictions = np.array(clf.predict(data.to_numpy()), dtype=np.int64) + classes = set(map(int, predictions.flatten())) + + generator = np.random.default_rng(seed) + query = generator.integers(len(data)) + + x = np.array(data.to_numpy()[query], dtype=np.float64).flatten() + y = int(predictions[query]) + + for class_ in classes: + model = Model(trees=trees, mapper=mapper) + model.build() + model.set_majority_class(y=class_) + model.add_objective(x=x) + + solver = ENV.solver + solver_model = solver.solve(model) + + explanation = model.explanation + + validate_solution(explanation) + validate_paths( + *model.trees, explanation=explanation, solver_model=solver_model + ) + validate_sklearn_pred(clf, explanation, m_class=class_) + if class_ == y: + check_solution(x, explanation) + + model.cleanup() diff --git a/tests/maxsat/utils.py b/tests/maxsat/utils.py new file mode 100644 index 0000000..0b8463d --- /dev/null +++ b/tests/maxsat/utils.py @@ -0,0 +1,169 @@ +from collections import defaultdict +from typing import TYPE_CHECKING, Literal, overload + +import numpy as np +import pandas as pd +from sklearn.ensemble import RandomForestClassifier + +from ocean.abc import Mapper +from ocean.feature import Feature +from ocean.maxsat import Explanation, TreeVar +from ocean.typing import Array1D, NonNegativeInt + +from ..utils import generate_data + +if TYPE_CHECKING: + from ocean.typing import Key + + +def check_solution(x: Array1D, explanation: Explanation) -> None: + n = explanation.n_columns + x_sol = explanation.x + for i in range(n): + name = explanation.names[i] + # For now we only check the non continuous features + # as the continuous features are epsilon away from + # the explanation + if not explanation[name].is_continuous: + assert np.isclose(x[i], x_sol[i]) + + +def validate_solution(explanation: Explanation) -> None: + x = explanation.x + n = explanation.n_columns + codes: dict[Key, float] = defaultdict(float) + for i in range(n): + name = explanation.names[i] + feature = explanation[name] + value = x[i] + if feature.is_one_hot_encoded: + assert np.any(np.isclose(value, [0.0, 1.0])) + codes[name] += value + + if feature.is_binary: + assert np.any(np.isclose(value, [0.0, 1.0])) + elif feature.is_numeric: + assert feature.levels[0] <= value <= feature.levels[-1] + + for value in codes.values(): + assert np.isclose(value, 1.0) + + +def find_leaf(tree: TreeVar, explanation: Explanation) -> NonNegativeInt: + node = tree.root + x = explanation.x + while not node.is_leaf: + name = node.feature + if explanation[name].is_one_hot_encoded: + code = node.code + i = explanation.idx.get(name, code) + value = x[i] + else: + i = explanation.idx.get(name) + value = x[i] + + if explanation[name].is_numeric: + threshold = node.threshold + node = node.left if value <= threshold else node.right + elif np.isclose(value, 0.0): + node = node.left + else: + node = node.right + return node.node_id + + +def validate_path( + tree: TreeVar, explanation: Explanation, solver_model: list[int] +) -> None: + """Check that exactly one leaf is selected and matches the solution.""" + n_active = 0 + active_leaf = tree.root.node_id + for node in tree.leaves: + assert node.is_leaf + leaf_var = tree[node.node_id] + is_active = leaf_var in solver_model + if is_active: + n_active += 1 + active_leaf = node.node_id + + assert n_active == 1, ( + f"Expected one leaf to be active, but {n_active} were found." + ) + x_id_leaf = find_leaf(tree, explanation) + assert active_leaf == x_id_leaf, ( + f"Expected leaf {active_leaf}, but found {x_id_leaf}." + ) + + +def validate_paths( + *trees: TreeVar, explanation: Explanation, solver_model: list[int] +) -> None: + for tree in trees: + validate_path(tree, explanation, solver_model) + + +def validate_sklearn_pred( + clf: RandomForestClassifier, + explanation: Explanation, + m_class: NonNegativeInt, +) -> None: + x = explanation.x.reshape(1, -1) + prediction = np.asarray(clf.predict(x), dtype=np.int64) + assert (prediction == m_class).all(), ( + f"Expected class {m_class}, got {prediction[0]}" + ) + + +@overload +def train_rf( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + *, + return_data: Literal[False] = False, +) -> tuple[RandomForestClassifier, Mapper[Feature]]: ... + + +@overload +def train_rf( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + *, + return_data: Literal[True], +) -> tuple[RandomForestClassifier, Mapper[Feature], pd.DataFrame]: ... + + +def train_rf( + seed: int, + n_estimators: int, + max_depth: int, + n_samples: int, + n_classes: int, + *, + return_data: bool = False, +) -> ( + tuple[RandomForestClassifier, Mapper[Feature]] + | tuple[RandomForestClassifier, Mapper[Feature], pd.DataFrame] +): + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = RandomForestClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + if return_data: + return clf, mapper, data + return clf, mapper + + +SEEDS = [43, 44, 45] +N_ESTIMATORS = [1, 4] +MAX_DEPTH = [3] +N_CLASSES = [2, 4] +N_SAMPLES = [100, 200] diff --git a/tests/test_explainer.py b/tests/test_explainer.py index 0362c0d..d167144 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -4,7 +4,11 @@ from sklearn.ensemble import RandomForestClassifier from xgboost import XGBClassifier -from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MaxSATExplainer, + MixedIntegerProgramExplainer, +) from .utils import ENV, generate_data @@ -194,3 +198,41 @@ def test_cp_explain_xgb( except gp.GurobiError as e: pytest.skip(f"Skipping test due to {e}") + + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_estimators", [5]) +@pytest.mark.parametrize("max_depth", [2, 3]) +@pytest.mark.parametrize("n_classes", [2]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +def test_maxsat_explain( + seed: int, + n_estimators: int, + max_depth: int, + n_classes: int, + n_samples: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = RandomForestClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + model = MaxSATExplainer(clf, mapper=mapper) + + x = data.iloc[0, :].to_numpy().astype(float).flatten() + # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] + for target in classes[classes != y]: + exp = model.explain( + x, + y=target, + norm=1, + random_seed=seed, + ) + assert model.Status == "OPTIMAL" + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() From 4d618fb6a9f8164e659458a08b27a85b95a1c962 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Thu, 27 Nov 2025 16:00:18 -0500 Subject: [PATCH 14/18] feat: Add the garbage manager and improve the MaxSAT explainer --- README.md | 40 +++++++++-- examples/maxsat_example.py | 4 +- examples/readme.py | 14 +++- ocean/maxsat/_base.py | 13 +++- ocean/maxsat/_explanation.py | 27 ++++++- ocean/maxsat/_managers/__init__.py | 2 + ocean/maxsat/_managers/_garbage.py | 19 +++++ ocean/maxsat/_model.py | 111 ++++++++++++----------------- pyproject.toml | 2 +- 9 files changed, 149 insertions(+), 83 deletions(-) create mode 100644 ocean/maxsat/_managers/_garbage.py diff --git a/README.md b/README.md index 733906e..7c927ef 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,11 @@ The package provides multiple classes and functions to wrap the tree ensemble mo ```python from sklearn.ensemble import RandomForestClassifier -from ocean import MixedIntegerProgramExplainer, ConstraintProgrammingExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MaxSATExplainer, + MixedIntegerProgramExplainer, +) from ocean.datasets import load_adult # Load the adult dataset @@ -47,20 +51,28 @@ rf.fit(data, target) # Predict the class of the random instance y = int(rf.predict(x).item()) +x = x.to_numpy().flatten() # Explain the prediction using MIPEXplainer mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper) -x = x.to_numpy().flatten() mip_explanation = mip_model.explain(x, y=1 - y, norm=1) # Explain the prediction using CPEExplainer cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) -x = x.to_numpy().flatten() cp_explanation = cp_model.explain(x, y=1 - y, norm=1) -# Show the explanation -print("MIP: ",mip_explanation, "\n") -print("CP : ",cp_explanation) +maxsat_model = MaxSATExplainer(rf, mapper=mapper) +maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1) + +# Show the explanations and their objective values +print("MIP objective value:", mip_model.get_objective_value()) +print("MIP", mip_explanation, "\n") + +print("CP objective value:", cp_model.get_objective_value()) +print("CP", cp_explanation, "\n") + +print("MaxSAT objective value:", maxsat_model.get_objective_value()) +print("MaxSAT", maxsat_explanation, "\n") ``` @@ -94,6 +106,20 @@ Occupation : 1 Relationship : 0 Sex : 0 WorkClass : 4 + +MaxSAT objective value: 3.0 +MaxSAT Explanation: +Age : 39.0 +CapitalGain : 2174.0 +CapitalLoss : 0.0 +EducationNumber : 13.0 +HoursPerWeek : 40.0 +MaritalStatus : 3 +NativeCountry : 0 +Occupation : 1 +Relationship : 0 +Sex : 0 +WorkClass : 4 ``` @@ -106,7 +132,7 @@ See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) fo | ------------------------------- | ---------- | ------------------------------------------ | | **MIP formulation** | ✅ Done | Based on Parmentier & Vidal (2020/2021). | | **Constraint Programming (CP)** | ✅ Done | Based on an upcoming paper. | -| **MaxSAT formulation** | ⏳ Upcoming | Planned addition to the toolbox. | +| **MaxSAT formulation** | ✅ Done | Planned addition to the toolbox. | | **Heuristics** | ⏳ Upcoming | Fast approximate methods. | | **Other methods** | ⏳ Upcoming | Additional formulations under exploration. | | **AdaBoost support** | ✅ Ready | Fully supported in ocean. | diff --git a/examples/maxsat_example.py b/examples/maxsat_example.py index 4078275..23b53a8 100644 --- a/examples/maxsat_example.py +++ b/examples/maxsat_example.py @@ -11,11 +11,11 @@ (data, target), mapper = load_adult(scale=True) # Train a random forest classifier -rf = RandomForestClassifier(n_estimators=5, max_depth=3, random_state=42) +rf = RandomForestClassifier(n_estimators=20, max_depth=3, random_state=42) rf.fit(data, target) # Select an instance to explain from the dataset -x = data.iloc[19].to_frame().T +x = data.iloc[0].to_frame().T x_np = x.to_numpy().flatten() # Predict the class of the instance diff --git a/examples/readme.py b/examples/readme.py index 410b8b3..00ad19e 100644 --- a/examples/readme.py +++ b/examples/readme.py @@ -1,6 +1,10 @@ from sklearn.ensemble import RandomForestClassifier -from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer +from ocean import ( + ConstraintProgrammingExplainer, + MaxSATExplainer, + MixedIntegerProgramExplainer, +) from ocean.datasets import load_adult # Load the adult dataset @@ -15,19 +19,25 @@ # Predict the class of the random instance y = int(rf.predict(x).item()) +x = x.to_numpy().flatten() # Explain the prediction using MIPEXplainer mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper) -x = x.to_numpy().flatten() mip_explanation = mip_model.explain(x, y=1 - y, norm=1) # Explain the prediction using CPEExplainer cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) cp_explanation = cp_model.explain(x, y=1 - y, norm=1) +maxsat_model = MaxSATExplainer(rf, mapper=mapper) +maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1) + # Show the explanations and their objective values print("MIP objective value:", mip_model.get_objective_value()) print("MIP", mip_explanation, "\n") print("CP objective value:", cp_model.get_objective_value()) print("CP", cp_explanation, "\n") + +print("MaxSAT objective value:", maxsat_model.get_objective_value()) +print("MaxSAT", maxsat_explanation, "\n") diff --git a/ocean/maxsat/_base.py b/ocean/maxsat/_base.py index b49d71d..d79d4f5 100644 --- a/ocean/maxsat/_base.py +++ b/ocean/maxsat/_base.py @@ -30,10 +30,19 @@ def get_var(self, name: str) -> int: raise ValueError(msg) return self.vpool.obj2id[name] # type: ignore[no-any-return] - def add_hard(self, lits: list[int]) -> None: - """Add a hard clause (must be satisfied).""" + def add_hard(self, lits: list[int], return_id: bool = False) -> int: # noqa: FBT001, FBT002 + """ + Add a hard clause (must be satisfied). + + Returns: + The clause ID if return_id is True, otherwise -1. + + """ # weight=None => hard clause in WCNF self.append(lits) + if return_id: + return len(self.hard) - 1 # pyright: ignore[reportUnknownArgumentType] + return -1 def add_soft(self, lits: list[int], weight: int = 1) -> None: """Add a soft clause with a given weight.""" diff --git a/ocean/maxsat/_explanation.py b/ocean/maxsat/_explanation.py index 519caf2..48c508f 100644 --- a/ocean/maxsat/_explanation.py +++ b/ocean/maxsat/_explanation.py @@ -98,8 +98,31 @@ def x(self) -> Array1D: @property def value(self) -> Mapping[Key, Key | Number]: - msg = "Not implemented." - raise NotImplementedError(msg) + def get(v: FeatureVar) -> Key | Number: + if v.is_one_hot_encoded: + for code in v.codes: + if ENV.solver.model(v.xget(code)) > 0: + return code + if v.is_numeric: + f = [val for _, val in self.items()].index(v) + if v.is_discrete: + idx = self._get_active_mu_index( + self.names[f], for_discrete=True + ) + val = int(v.levels[idx]) + return self.format_discrete_value(f, val, v.levels) + idx = self._get_active_mu_index( + self.names[f], for_discrete=False + ) + return self.format_continuous_value( + f, + idx, + list(v.levels), + ) + x = v.xget() + return int(ENV.solver.model(x)) + + return self.reduce(get) def format_continuous_value( self, diff --git a/ocean/maxsat/_managers/__init__.py b/ocean/maxsat/_managers/__init__.py index 056784d..450d21d 100644 --- a/ocean/maxsat/_managers/__init__.py +++ b/ocean/maxsat/_managers/__init__.py @@ -1,7 +1,9 @@ from ._feature import FeatureManager +from ._garbage import GarbageManager from ._tree import TreeManager __all__ = [ "FeatureManager", + "GarbageManager", "TreeManager", ] diff --git a/ocean/maxsat/_managers/_garbage.py b/ocean/maxsat/_managers/_garbage.py new file mode 100644 index 0000000..5d3e21c --- /dev/null +++ b/ocean/maxsat/_managers/_garbage.py @@ -0,0 +1,19 @@ +class GarbageManager: + type GarbageObject = int + + # Garbage collector for the model. + # - Used to keep track of the variables and constraints created, + # and to remove them when the model is cleared. + _garbage: list[GarbageObject] + + def __init__(self) -> None: + self._garbage = [] + + def add_garbage(self, *args: GarbageObject) -> None: + self._garbage.extend(args) + + def remove_garbage(self) -> None: + self._garbage.clear() + + def garbage_list(self) -> list[GarbageObject]: + return self._garbage diff --git a/ocean/maxsat/_model.py b/ocean/maxsat/_model.py index 1d1fb6d..c595b8d 100644 --- a/ocean/maxsat/_model.py +++ b/ocean/maxsat/_model.py @@ -1,16 +1,16 @@ from __future__ import annotations -import itertools from enum import Enum from typing import TYPE_CHECKING import numpy as np from pydantic import validate_call +from pysat.pb import PBEnc from ..typing import NonNegativeInt from ._base import BaseModel from ._builder.model import ModelBuilder, ModelBuilderFactory -from ._managers import FeatureManager, TreeManager +from ._managers import FeatureManager, GarbageManager, TreeManager if TYPE_CHECKING: from collections.abc import Iterable @@ -22,7 +22,7 @@ from ._variables import FeatureVar -class Model(BaseModel, FeatureManager, TreeManager): +class Model(BaseModel, FeatureManager, GarbageManager, TreeManager): # Model builder for the ensemble. _builder: ModelBuilder DEFAULT_EPSILON: int = 1 @@ -48,6 +48,7 @@ def __init__( weights=weights, ) FeatureManager.__init__(self, mapper=mapper) + GarbageManager.__init__(self) self._set_weights(weights=weights) self._max_samples = max_samples @@ -246,80 +247,56 @@ def _encode_weighted_sum_constraint( threshold: int, ) -> None: """ - Encode constraint: sum of selected contributions >= threshold. + Encode: sum of contributions >= threshold using pseudo-Boolean encoding. - Uses dynamic programming to compute reachable sums and encodes - constraints to ensure the sum meets the threshold. + This approach avoids exponential enumeration. """ - n_trees = len(tree_contributions) + lits: list[int] = [] + weights: list[int] = [] + shift = 0 # sum over |negative weights| - # For each tree, compute min and max contribution - tree_bounds: list[tuple[int, int]] = [] for contribs in tree_contributions: - min_c = min(c[1] for c in contribs) - max_c = max(c[1] for c in contribs) - tree_bounds.append((min_c, max_c)) - - # Compute global bounds - global_min = sum(b[0] for b in tree_bounds) - global_max = sum(b[1] for b in tree_bounds) - - # If even the maximum sum is below threshold, UNSAT - if global_max < threshold: - self.add_hard([]) + for leaf_var, diff in contribs: + if diff == 0: + continue # contributes nothing, can be ignored + + if diff > 0: + # positive coefficient: weight * x + lits.append(leaf_var) # x + weights.append(diff) + else: + # negative coefficient: -a * x + a = -diff + # transform -a*x into a*(-x) and shift the bound by +a + lits.append(-leaf_var) # -x + weights.append(a) + shift += a + + effective_bound = threshold + shift + + if not lits: # degenerate case + if effective_bound > 0: + self.add_hard([]) # UNSAT return - # If the minimum sum meets threshold, constraint is always satisfied - if global_min >= threshold: - return + # Encode sum(weights_i * lits_i) >= effective_bound + pb = PBEnc.atleast( + lits=lits, + weights=weights, + bound=effective_bound, + vpool=self.vpool, + ) - # Compute prefix max and suffix max for bounds checking - prefix_max = [0] * (n_trees + 1) - for i in range(n_trees): - prefix_max[i + 1] = prefix_max[i] + tree_bounds[i][1] - - suffix_max = [0] * (n_trees + 1) - for i in range(n_trees - 1, -1, -1): - suffix_max[i] = suffix_max[i + 1] + tree_bounds[i][1] - - # For each leaf, check if it can possibly be part of a valid solution - # A leaf with contribution c at tree t is forbidden if: - # prefix_max[t] + c + suffix_max[t+1] < threshold (i.e., even with - # best choices for all other trees, can't reach threshold) - for t_idx, contribs in enumerate(tree_contributions): - for leaf_var, contrib in contribs: - best_case = prefix_max[t_idx] + contrib + suffix_max[t_idx + 1] - if best_case < threshold: - self.add_hard([-leaf_var]) - - # Enumerate and forbid all bad combinations - # A combination is bad if sum(contributions) < threshold - max_trees = 10 - if n_trees <= max_trees: # Only for small forests - self._enumerate_bad_combinations(tree_contributions, threshold) - - def _enumerate_bad_combinations( - self, - tree_contributions: list[list[tuple[int, int]]], - threshold: int, - ) -> None: - """Enumerate and forbid all combinations with sum < threshold.""" - # Get leaf indices for each tree - tree_leaves = [ - [(lv, c) for lv, c in contribs] for contribs in tree_contributions - ] - - # Enumerate all combinations - for combo in itertools.product(*tree_leaves): - total = sum(c for _, c in combo) - if total < threshold: - # This combination is bad, TODO find efficient way - # At least one of the leaves must NOT be selected - clause = [-lv for lv, _ in combo] - self.add_hard(clause) + for clause in pb.clauses: + self.add_garbage( + self.add_hard(clause, return_id=True) # pyright: ignore[reportUnknownArgumentType] + ) def cleanup(self) -> None: self._clean_soft() + for idx in sorted(self.garbage_list(), reverse=True): + self.hard.pop(idx) + self.remove_garbage() def _set_builder(self, model_type: Type) -> None: match model_type: diff --git a/pyproject.toml b/pyproject.toml index cb2defe..12c31e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "pydantic", "scikit-learn", "xgboost", - "python-sat", + "python-sat[pblib,aiger]", ] optional-dependencies.dev = [ From d1f4bab62b29736ce299024b3dbba8e511ef5bd7 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 1 Dec 2025 12:19:04 -0500 Subject: [PATCH 15/18] fix: Update python-sat dependency to include optional features in dev, test, and example sections --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 12c31e8..350cc77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "pydantic", "scikit-learn", "xgboost", - "python-sat[pblib,aiger]", + "python-sat", ] optional-dependencies.dev = [ @@ -51,6 +51,7 @@ optional-dependencies.dev = [ "ruff", "scipy-stubs", "tox", + "python-sat[pblib,aiger]", ] optional-dependencies.test = [ @@ -63,11 +64,13 @@ optional-dependencies.test = [ "ruff", "scipy-stubs", "tox", + "python-sat[pblib,aiger]", ] optional-dependencies.example = [ "rich", "matplotlib", + "python-sat[pblib,aiger]", ] urls.Homepage = "https://github.com/vidalt/OCEAN" From e7456ae8372e031f0a11125b6b882e17ec3085d3 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 1 Dec 2025 12:24:08 -0500 Subject: [PATCH 16/18] fix: Fix optional dependencies for python-sat in the test and example sections. --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 350cc77..a4e0989 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,13 +64,13 @@ optional-dependencies.test = [ "ruff", "scipy-stubs", "tox", - "python-sat[pblib,aiger]", + "python-sat", ] optional-dependencies.example = [ "rich", "matplotlib", - "python-sat[pblib,aiger]", + "python-sat[aiger,pblib]", ] urls.Homepage = "https://github.com/vidalt/OCEAN" From c63a966217caf0e5d786e84c7324f0ff0796b3a2 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 1 Dec 2025 13:00:35 -0500 Subject: [PATCH 17/18] fix: ignore the pysat pblib and maxsat tests on windows platforms --- pyproject.toml | 8 +++--- tests/maxsat/conftest.py | 8 ++++++ tests/maxsat/test_explainer.py | 45 ++++++++++++++++++++++++++++++++++ tests/test_explainer.py | 39 ----------------------------- tox.ini | 4 ++- 5 files changed, 60 insertions(+), 44 deletions(-) create mode 100644 tests/maxsat/conftest.py create mode 100644 tests/maxsat/test_explainer.py diff --git a/pyproject.toml b/pyproject.toml index a4e0989..0c699d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ dependencies = [ "pydantic", "scikit-learn", "xgboost", - "python-sat", ] optional-dependencies.dev = [ @@ -51,7 +50,6 @@ optional-dependencies.dev = [ "ruff", "scipy-stubs", "tox", - "python-sat[pblib,aiger]", ] optional-dependencies.test = [ @@ -64,13 +62,15 @@ optional-dependencies.test = [ "ruff", "scipy-stubs", "tox", - "python-sat", ] optional-dependencies.example = [ "rich", "matplotlib", - "python-sat[aiger,pblib]", +] + +optional-dependencies.maxsat = [ + "python-sat[pblib,aiger]; platform_system != 'Windows'", ] urls.Homepage = "https://github.com/vidalt/OCEAN" diff --git a/tests/maxsat/conftest.py b/tests/maxsat/conftest.py new file mode 100644 index 0000000..7d4d8b6 --- /dev/null +++ b/tests/maxsat/conftest.py @@ -0,0 +1,8 @@ +import sys + +import pytest + +pytestmark = pytest.mark.skipif( + sys.platform.startswith("win"), + reason="MaxSAT tests are disabled on Windows.", +) diff --git a/tests/maxsat/test_explainer.py b/tests/maxsat/test_explainer.py new file mode 100644 index 0000000..ae9ffb7 --- /dev/null +++ b/tests/maxsat/test_explainer.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest +from sklearn.ensemble import RandomForestClassifier + +from ocean import MaxSATExplainer + +from .utils import generate_data + + +@pytest.mark.parametrize("seed", [42, 43, 44]) +@pytest.mark.parametrize("n_estimators", [5]) +@pytest.mark.parametrize("max_depth", [2, 3]) +@pytest.mark.parametrize("n_classes", [2]) +@pytest.mark.parametrize("n_samples", [100, 200, 500]) +def test_maxsat_explain( + seed: int, + n_estimators: int, + max_depth: int, + n_classes: int, + n_samples: int, +) -> None: + data, y, mapper = generate_data(seed, n_samples, n_classes) + clf = RandomForestClassifier( + random_state=seed, + n_estimators=n_estimators, + max_depth=max_depth, + ) + clf.fit(data, y) + model = MaxSATExplainer(clf, mapper=mapper) + + x = data.iloc[0, :].to_numpy().astype(float).flatten() + # pyright: ignore[reportUnknownVariableType] + y = clf.predict([x])[0] + classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] + for target in classes[classes != y]: + exp = model.explain( + x, + y=target, + norm=1, + random_seed=seed, + ) + assert model.Status == "OPTIMAL" + assert exp is not None + assert clf.predict([exp.to_numpy()])[0] == target + model.cleanup() diff --git a/tests/test_explainer.py b/tests/test_explainer.py index d167144..54d0daf 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -6,7 +6,6 @@ from ocean import ( ConstraintProgrammingExplainer, - MaxSATExplainer, MixedIntegerProgramExplainer, ) @@ -198,41 +197,3 @@ def test_cp_explain_xgb( except gp.GurobiError as e: pytest.skip(f"Skipping test due to {e}") - - -@pytest.mark.parametrize("seed", [42, 43, 44]) -@pytest.mark.parametrize("n_estimators", [5]) -@pytest.mark.parametrize("max_depth", [2, 3]) -@pytest.mark.parametrize("n_classes", [2]) -@pytest.mark.parametrize("n_samples", [100, 200, 500]) -def test_maxsat_explain( - seed: int, - n_estimators: int, - max_depth: int, - n_classes: int, - n_samples: int, -) -> None: - data, y, mapper = generate_data(seed, n_samples, n_classes) - clf = RandomForestClassifier( - random_state=seed, - n_estimators=n_estimators, - max_depth=max_depth, - ) - clf.fit(data, y) - model = MaxSATExplainer(clf, mapper=mapper) - - x = data.iloc[0, :].to_numpy().astype(float).flatten() - # pyright: ignore[reportUnknownVariableType] - y = clf.predict([x])[0] - classes = np.unique(clf.predict(data.to_numpy())).astype(np.int64) # pyright: ignore[reportUnknownArgumentType] - for target in classes[classes != y]: - exp = model.explain( - x, - y=target, - norm=1, - random_seed=seed, - ) - assert model.Status == "OPTIMAL" - assert exp is not None - assert clf.predict([exp.to_numpy()])[0] == target - model.cleanup() diff --git a/tox.ini b/tox.ini index 0119c38..4bf587b 100644 --- a/tox.ini +++ b/tox.ini @@ -9,8 +9,9 @@ basepython = python3 [testenv] deps = pytest +extras = maxsat commands = - pip install .[test] + pip install .[test,maxsat] pytest -vv [testenv:py312] @@ -37,6 +38,7 @@ deps = pandas-stubs pydantic pytest + python-sat scikit-learn scipy scipy-stubs From c3c5b5132f9d51c31d5c4a10b32cb34ce2648c74 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 1 Dec 2025 13:06:44 -0500 Subject: [PATCH 18/18] fix: add python-sat to test dependencies without the extras libraries --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0c699d8..d450c12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ optional-dependencies.test = [ "pyright", "pytest", "pytest-cov", + "python-sat", "ruff", "scipy-stubs", "tox",