Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9fc2d0c
fix: Add variable management methods and clause handling in BaseModel
AwaKhouna Nov 23, 2025
57fa4c0
fix: Implement vget method to handle one-hot encoded and numeric valu…
AwaKhouna Nov 23, 2025
42bd50f
fix: Refactor Model class to ensure _builder is initialized and build…
AwaKhouna Nov 23, 2025
3c2b8eb
fix: Update MaxSATBuilder methods to enforce integer type for 'y' par…
AwaKhouna Nov 23, 2025
fa28247
fix: Implement vget method in FeatureManager and enhance build method…
AwaKhouna Nov 23, 2025
205707c
Merge branch 'main' into maxsat-implementation
AwaKhouna Nov 23, 2025
5422b8b
fix: Refactor tree and explainer modules to improve AdaBoost handling…
AwaKhouna Nov 25, 2025
2ec0aa9
fix: Refactor MaxSATSolver integration and enhance explanation handli…
AwaKhouna Nov 25, 2025
00f0f27
fix: Remove the GarbageManager class
AwaKhouna Nov 27, 2025
2ef684f
feat: Enhance MaxSAT model and explainer with new methods and improve…
AwaKhouna Nov 27, 2025
ae1d3ec
fix: Improve handling of sigma conditions in MaxSATBuilder methods
AwaKhouna Nov 27, 2025
98fc563
fix: Enhance FeatureVar with exact constraints and improved mu variab…
AwaKhouna Nov 27, 2025
1216a3d
feat: Add MaxSAT explainer to query example and update explainer options
AwaKhouna Nov 27, 2025
6351dda
feat: Ajouter des tests pour le modèle MaxSAT et l'explainer MaxSAT
AwaKhouna Nov 27, 2025
4d618fb
feat: Add the garbage manager and improve the MaxSAT explainer
AwaKhouna Nov 27, 2025
d1f4bab
fix: Update python-sat dependency to include optional features in dev…
AwaKhouna Dec 1, 2025
e7456ae
fix: Fix optional dependencies for python-sat in the test and example…
AwaKhouna Dec 1, 2025
c63a966
fix: ignore the pysat pblib and maxsat tests on windows platforms
AwaKhouna Dec 1, 2025
c3c5b51
fix: add python-sat to test dependencies without the extras libraries
AwaKhouna Dec 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ The package provides multiple classes and functions to wrap the tree ensemble mo
```python
from sklearn.ensemble import RandomForestClassifier

from ocean import MixedIntegerProgramExplainer, ConstraintProgrammingExplainer
from ocean import (
ConstraintProgrammingExplainer,
MaxSATExplainer,
MixedIntegerProgramExplainer,
)
from ocean.datasets import load_adult

# Load the adult dataset
Expand All @@ -47,20 +51,28 @@ rf.fit(data, target)

# Predict the class of the random instance
y = int(rf.predict(x).item())
x = x.to_numpy().flatten()

# Explain the prediction using MIPEXplainer
mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper)
x = x.to_numpy().flatten()
mip_explanation = mip_model.explain(x, y=1 - y, norm=1)

# Explain the prediction using CPEExplainer
cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper)
x = x.to_numpy().flatten()
cp_explanation = cp_model.explain(x, y=1 - y, norm=1)

# Show the explanation
print("MIP: ",mip_explanation, "\n")
print("CP : ",cp_explanation)
maxsat_model = MaxSATExplainer(rf, mapper=mapper)
maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1)

# Show the explanations and their objective values
print("MIP objective value:", mip_model.get_objective_value())
print("MIP", mip_explanation, "\n")

print("CP objective value:", cp_model.get_objective_value())
print("CP", cp_explanation, "\n")

print("MaxSAT objective value:", maxsat_model.get_objective_value())
print("MaxSAT", maxsat_explanation, "\n")

```

Expand Down Expand Up @@ -94,6 +106,20 @@ Occupation : 1
Relationship : 0
Sex : 0
WorkClass : 4

MaxSAT objective value: 3.0
MaxSAT Explanation:
Age : 39.0
CapitalGain : 2174.0
CapitalLoss : 0.0
EducationNumber : 13.0
HoursPerWeek : 40.0
MaritalStatus : 3
NativeCountry : 0
Occupation : 1
Relationship : 0
Sex : 0
WorkClass : 4
```


Expand All @@ -106,7 +132,7 @@ See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) fo
| ------------------------------- | ---------- | ------------------------------------------ |
| **MIP formulation** | ✅ Done | Based on Parmentier & Vidal (2020/2021). |
| **Constraint Programming (CP)** | ✅ Done | Based on an upcoming paper. |
| **MaxSAT formulation** | ⏳ Upcoming | Planned addition to the toolbox. |
| **MaxSAT formulation** | ✅ Done | Planned addition to the toolbox. |
| **Heuristics** | ⏳ Upcoming | Fast approximate methods. |
| **Other methods** | ⏳ Upcoming | Additional formulations under exploration. |
| **AdaBoost support** | ✅ Ready | Fully supported in ocean. |
Expand Down
78 changes: 78 additions & 0 deletions examples/maxsat_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from sklearn.ensemble import RandomForestClassifier

from ocean import (
ConstraintProgrammingExplainer,
MaxSATExplainer,
MixedIntegerProgramExplainer,
)
from ocean.datasets import load_adult

# Load the adult dataset
(data, target), mapper = load_adult(scale=True)

# Train a random forest classifier
rf = RandomForestClassifier(n_estimators=20, max_depth=3, random_state=42)
rf.fit(data, target)

# Select an instance to explain from the dataset
x = data.iloc[0].to_frame().T
x_np = x.to_numpy().flatten()

# Predict the class of the instance
y_pred = int(rf.predict(x).item())
target_class = 1 - y_pred # Binary classification - choose opposite class

print(f"Instance shape: {x_np.shape}")
print(f"Original prediction: {y_pred}")
print(f"Target counterfactual class: {target_class}")

# Explain the prediction using MaxSATExplainer
print("\n--- MaxSAT Explainer ---")
try:
maxsat_model = MaxSATExplainer(rf, mapper=mapper)
maxsat_explanation = maxsat_model.explain(x_np, y=target_class, norm=1)
if maxsat_explanation is not None:
cf_np = maxsat_explanation.to_numpy()
print("MaxSAT CF:", cf_np)
print("MaxSAT CF prediction:", rf.predict([cf_np])[0])
print("Objective value:", maxsat_model.get_objective_value())
print("Status:", maxsat_model.get_solving_status())
else:
print("MaxSAT: No counterfactual found.")
except (ValueError, RuntimeError, ImportError) as e:
import traceback

print(f"MaxSAT Error: {e}")
traceback.print_exc()

# Explain the prediction using MIPExplainer
print("\n--- MIP Explainer ---")
try:
mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper)
mip_explanation = mip_model.explain(x_np, y=target_class, norm=1)
if mip_explanation is not None:
cf_np = mip_explanation.to_numpy()
print("MIP CF:", cf_np)
print("MIP CF prediction:", rf.predict([cf_np])[0])
print("Objective value:", mip_model.get_objective_value())
print("Status:", mip_model.get_solving_status())
else:
print("MIP: No counterfactual found.")
except (ValueError, RuntimeError, ImportError) as e:
print(f"MIP Error: {e}")

# Explain the prediction using CPExplainer
print("\n--- CP Explainer ---")
try:
cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper)
cp_explanation = cp_model.explain(x_np, y=target_class, norm=1)
if cp_explanation is not None:
cf_np = cp_explanation.to_numpy()
print("CP CF:", cf_np)
print("CP CF prediction:", rf.predict([cf_np])[0])
print("Objective value:", cp_model.get_objective_value())
print("Status:", cp_model.get_solving_status())
else:
print("CP: No counterfactual found.")
except (ValueError, RuntimeError, ImportError) as e:
print(f"CP Error: {e}")
14 changes: 10 additions & 4 deletions examples/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ocean import (
ConstraintProgrammingExplainer,
MaxSATExplainer,
MixedIntegerProgramExplainer,
)
from ocean.abc import Mapper
Expand All @@ -28,6 +29,7 @@
EXPLAINERS = {
"mip": MixedIntegerProgramExplainer,
"cp": ConstraintProgrammingExplainer,
"maxsat": MaxSATExplainer,
}
MODELS = {
"rf": RandomForestClassifier,
Expand Down Expand Up @@ -75,8 +77,8 @@ def create_argument_parser() -> ArgumentParser:
help="List of explainers to use",
type=str,
nargs="+",
choices=["mip", "cp"],
default=["mip", "cp"],
choices=["mip", "cp", "maxsat"],
default=["mip", "cp", "maxsat"],
)
parser.add_argument(
"-m",
Expand Down Expand Up @@ -153,7 +155,8 @@ def fit_model_with_console(
def build_explainer(
explainer_name: str,
explainer_class: type[MixedIntegerProgramExplainer]
| type[ConstraintProgrammingExplainer],
| type[ConstraintProgrammingExplainer]
| type[MaxSATExplainer],
args: Args,
model: RandomForestClassifier | XGBClassifier,
mapper: Mapper[Feature],
Expand All @@ -163,7 +166,10 @@ def build_explainer(
if explainer_class is MixedIntegerProgramExplainer:
ENV.setParam("Seed", args.seed)
exp = explainer_class(model, mapper=mapper, env=ENV)
elif explainer_class is ConstraintProgrammingExplainer:
elif (
explainer_class is ConstraintProgrammingExplainer
or explainer_class is MaxSATExplainer
):
exp = explainer_class(model, mapper=mapper)
else:
msg = f"Unknown explainer type: {explainer_class}"
Expand Down
14 changes: 12 additions & 2 deletions examples/readme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from sklearn.ensemble import RandomForestClassifier

from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer
from ocean import (
ConstraintProgrammingExplainer,
MaxSATExplainer,
MixedIntegerProgramExplainer,
)
from ocean.datasets import load_adult

# Load the adult dataset
Expand All @@ -15,19 +19,25 @@

# Predict the class of the random instance
y = int(rf.predict(x).item())
x = x.to_numpy().flatten()

# Explain the prediction using MIPEXplainer
mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper)
x = x.to_numpy().flatten()
mip_explanation = mip_model.explain(x, y=1 - y, norm=1)

# Explain the prediction using CPEExplainer
cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper)
cp_explanation = cp_model.explain(x, y=1 - y, norm=1)

maxsat_model = MaxSATExplainer(rf, mapper=mapper)
maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1)

# Show the explanations and their objective values
print("MIP objective value:", mip_model.get_objective_value())
print("MIP", mip_explanation, "\n")

print("CP objective value:", cp_model.get_objective_value())
print("CP", cp_explanation, "\n")

print("MaxSAT objective value:", maxsat_model.get_objective_value())
print("MaxSAT", maxsat_explanation, "\n")
6 changes: 4 additions & 2 deletions ocean/cp/_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

from ortools.sat.python import cp_model as cp
from sklearn.ensemble import AdaBoostClassifier

from ..abc import Mapper
from ..feature import Feature
Expand All @@ -11,6 +12,7 @@
Array1D,
BaseExplainableEnsemble,
BaseExplainer,
NonNegativeArray1D,
NonNegativeInt,
PositiveInt,
)
Expand All @@ -25,13 +27,13 @@ def __init__(
ensemble: BaseExplainableEnsemble,
*,
mapper: Mapper[Feature],
weights: Array1D | None = None,
weights: NonNegativeArray1D | None = None,
epsilon: int = Model.DEFAULT_EPSILON,
model_type: Model.Type = Model.Type.CP,
) -> None:
ensembles = (ensemble,)
trees = parse_ensembles(*ensembles, mapper=mapper)
if trees[0].adaboost:
if isinstance(ensemble, AdaBoostClassifier):
weights = ensemble.estimator_weights_
Model.__init__(
self,
Expand Down
7 changes: 2 additions & 5 deletions ocean/cp/_managers/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,15 @@ def weighted_function(
if self._adaboost:
# no need to scale since values are 0/1
# scaling is done later for tree weights
if np.argmax(leaf.value[op, :]) == c:
coeff = 1
else:
coeff = 0
coeff = int(np.argmax(leaf.value[op, :]) == c)
else:
coeff = int(leaf.value[op, c] * scale)
coefs.append(coeff)
variables.append(tree[leaf.node_id])
tree_expr = cp.LinearExpr.WeightedSum(variables, coefs)
tree_exprs.append(tree_expr)
if self._adaboost:
tree_weights.append(int(weight * scale))
tree_weights.append(int(weight * scale))
else:
tree_weights.append(int(weight))
expr = cp.LinearExpr.WeightedSum(tree_exprs, tree_weights)
Expand Down
18 changes: 17 additions & 1 deletion ocean/maxsat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
from ._base import BaseModel
from ._env import ENV
from ._explainer import Explainer
from ._explanation import Explanation
from ._managers import FeatureManager, TreeManager
from ._model import Model
from ._variables import FeatureVar, TreeVar

__all__ = ["Explainer"]
__all__ = [
"ENV",
"BaseModel",
"Explainer",
"Explanation",
"FeatureManager",
"FeatureVar",
"Model",
"TreeManager",
"TreeVar",
]
48 changes: 47 additions & 1 deletion ocean/maxsat/_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import ABC
from typing import Any, Protocol

from pysat.formula import WCNF
from pysat.formula import WCNF, IDPool


class BaseModel(ABC, WCNF):
vpool: IDPool

def __init__(self) -> None:
WCNF.__init__(self)
self.vpool = IDPool() # Create new pool for each instance

def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401
object.__setattr__(self, name, value)
Expand All @@ -15,6 +18,49 @@ def build_vars(self, *variables: "Var") -> None:
for variable in variables:
variable.build(model=self)

def add_var(self, name: str) -> int:
if name in self.vpool.obj2id: # var has been already created
msg = f"Variable with name '{name}' already exists."
raise ValueError(msg)
return self.vpool.id(f"{name}") # type: ignore[no-any-return]

def get_var(self, name: str) -> int:
if name not in self.vpool.obj2id: # var has not been created
msg = f"Variable with name '{name}' does not exist."
raise ValueError(msg)
return self.vpool.obj2id[name] # type: ignore[no-any-return]

def add_hard(self, lits: list[int], return_id: bool = False) -> int: # noqa: FBT001, FBT002
"""
Add a hard clause (must be satisfied).

Returns:
The clause ID if return_id is True, otherwise -1.

"""
# weight=None => hard clause in WCNF
self.append(lits)
if return_id:
return len(self.hard) - 1 # pyright: ignore[reportUnknownArgumentType]
return -1

def add_soft(self, lits: list[int], weight: int = 1) -> None:
"""Add a soft clause with a given weight."""
self.append(lits, weight=weight)

def add_exactly_one(self, lits: list[int]) -> None:
"""Add constraint that exactly one path is selected."""
self.add_hard(lits) # at least one
for i in range(len(lits)):
for j in range(i + 1, len(lits)):
self.add_hard([-lits[i], -lits[j]]) # at most one

def _clean_soft(self) -> None:
"""Reset the model to only contain hard constraints."""
self.soft.clear()
self.wght.clear()
self.topw = 1


class Var(Protocol):
_name: str
Expand Down
Loading
Loading