Skip to content
This repository was archived by the owner on Sep 24, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/simple_example_both.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import time

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer
from ocean.datasets import load_adult


plot_anytime_distances = True
num_workers = 8 # Both CP and MILP solving support multithreading
random_state = 42
Expand Down Expand Up @@ -87,9 +87,7 @@
rf.predict([explanation_oceancp.to_numpy()])[0],
")",
)
print(
"CP Sollist = ", cp_model.get_anytime_solutions()
)
print("CP Sollist = ", cp_model.get_anytime_solutions())
else:
print("CP: No CF found.")

Expand Down
60 changes: 32 additions & 28 deletions ocean/cp/_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,34 +80,38 @@ def explain(
_ = self.solver.Solve(self, solution_callback=self.callback)
status = self.solver.status_name()
self.Status = status
if status == "INFEASIBLE":
msg = "There are no feasible counterfactuals for this query."
msg += " If there should be one, please check the model "
msg += "constraints or report this issue to the developers."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
elif status == "MODEL_INVALID":
msg = "The constraint programming model is invalid. "
msg += "Please check the model constraints or report"
msg += " this issue to the developers."
raise RuntimeError(msg)
elif status == "UNKNOWN":
msg = "The constraint programming solver could "
msg += "not find any valid CF within the given time frame."
msg += " Try increasing the time limit."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
elif status == "FEASIBLE":
msg = "A valid CF was found, but it might be "
msg += "suboptimal as the constraint programming "
msg += "solver could not prove optimality within "
msg += "the given time frame. \n It can however certify"
msg += " that no counterfactual can be closer than"
msg += f" {self.solver.BestObjectiveBound()}."
warnings.warn(msg, category=UserWarning, stacklevel=2)
elif status != "OPTIMAL":
msg = "Unexpected solver status: " + status
raise RuntimeError(msg)

match status:
case "OPTIMAL":
pass
case "FEASIBLE":
msg = "A valid CF was found, but it might be "
msg += "suboptimal as the constraint programming "
msg += "solver could not prove optimality within "
msg += "the given time frame. \n It can however certify"
msg += " that no counterfactual can be closer than"
msg += f" {self.solver.BestObjectiveBound()}."
warnings.warn(msg, category=UserWarning, stacklevel=2)
case "INFEASIBLE":
msg = "There are no feasible counterfactuals for this query."
msg += " If there should be one, please check the model "
msg += "constraints or report this issue to the developers."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
case "MODEL_INVALID":
msg = "The constraint programming model is invalid. "
msg += "Please check the model constraints or report"
msg += " this issue to the developers."
raise RuntimeError(msg)
case "UNKNOWN":
msg = "The constraint programming solver could "
msg += "not find any valid CF within the given time frame."
msg += " Try increasing the time limit."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
case _:
msg = "Unexpected solver status: " + status
raise RuntimeError(msg)
self.explanation.query = x
return self.explanation

Expand Down
5 changes: 3 additions & 2 deletions ocean/cp/_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def format_value(self, f: int, idx: int, levels: list[float]) -> float:
if self.query.shape[0] == 0:
return float(levels[idx] + levels[idx + 1]) / 2
j = 0
while self.query[f] > levels[j + 1]:
query_arr = np.asarray(self.query, dtype=float).ravel()
while query_arr[f] > levels[j + 1]:
j += 1
if j == idx:
value = float(self.query[f])
value = float(query_arr[f])
elif j < idx:
value = float(levels[idx]) + self._epsilon
else:
Expand Down
5 changes: 3 additions & 2 deletions ocean/cp/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,18 @@ def _add_objective(self, x: Array1D, norm: int) -> cp.ObjLinearExprT:
if norm != 1:
msg = f"Unsupported norm: {norm}"
raise ValueError(msg)
x_arr = np.asarray(x, dtype=float).ravel()

variables = self.mapper.values()
objective: cp.LinearExpr = 0 # type: ignore[assignment]
k = 0
for v in variables:
if v.is_one_hot_encoded:
for code in v.codes:
objective += self.L1(x[k], v, code=code)
objective += self.L1(x_arr[k], v, code=code)
k += 1
else:
objective += self.L1(x[k], v)
objective += self.L1(x_arr[k], v)
k += 1
return objective

Expand Down
9 changes: 4 additions & 5 deletions ocean/mip/_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,17 @@ def explain(
else:
self.optimize()
status = self.get_solving_status()

if status == "INFEASIBLE":
msg = "There are no feasible counterfactuals for this query."
msg += " If there should be one, please check the model "
msg += "constraints or report this issue to the developers."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
elif status != "OPTIMAL":
if status != "OPTIMAL":
if self.SolCount > 0:
msg = "A valid CF was found, but it might be "
msg += "suboptimal as the MILP "
msg = "A valid CF was found, but it might be "
msg += "suboptimal as the MILP "
msg += "solver could not prove optimality within "
msg += "the given time frame. \n It can however certify"
msg += " that no counterfactual can be closer than"
Expand All @@ -138,7 +138,6 @@ def explain(
msg += " valid CF for an un-handled reason."
msg += "Unexpected solver status: " + status
raise RuntimeError(msg)

return self.explanation

@staticmethod
Expand Down
26 changes: 13 additions & 13 deletions tests/test_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_mip_explain(
max_depth: int,
n_classes: int,
n_samples: int,
num_workers:int,
num_workers: int,
) -> None:
data, y, mapper = generate_data(seed, n_samples, n_classes)
clf = RandomForestClassifier(
Expand All @@ -30,18 +30,18 @@ def test_mip_explain(
clf.fit(data, y)
model = MixedIntegerProgramExplainer(clf, mapper=mapper, env=ENV)

x = data.iloc[0, :].to_numpy().astype(float).flatten()
x = data.iloc[0, :].to_numpy().astype(float).flatten()
# pyright: ignore[reportUnknownVariableType]

try:
model.explain(x, y=0, norm=1,
num_workers=num_workers,
model.explain(x, y=0, norm=1,
num_workers=num_workers,
random_seed=seed)
assert model.Status == gp.GRB.OPTIMAL
model.cleanup()
model.explain(x, y=0, norm=1,
return_callback=True,
num_workers=num_workers,
model.explain(x, y=0, norm=1,
return_callback=True,
num_workers=num_workers,
random_seed=seed)
assert len(model.callback.sollist) != 0

Expand Down Expand Up @@ -72,19 +72,19 @@ def test_cp_explain(
clf.fit(data, y)
model = ConstraintProgrammingExplainer(clf, mapper=mapper)

x = data.iloc[0, :].to_numpy().astype(float).flatten()
x = data.iloc[0, :].to_numpy().astype(float).flatten()
# pyright: ignore[reportUnknownVariableType]

try:
_ = model.explain(x, y=0, norm=1,
_ = model.explain(x, y=0, norm=1,
return_callback=False,
num_workers=num_workers,
num_workers=num_workers,
random_seed=seed)
assert model.callback is None or len(model.callback.sollist) == 0
model.cleanup()
_ = model.explain(x, y=0, norm=1,
return_callback=True,
num_workers=num_workers,
_ = model.explain(x, y=0, norm=1,
return_callback=True,
num_workers=num_workers,
random_seed=seed)
assert model.callback is None or len(model.callback.sollist) != 0
except gp.GurobiError as e:
Expand Down