From 5d6ee4f2d7b278f1d535c2e9bd0920c11f2ca1c6 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Wed, 20 Aug 2025 14:49:08 -0400 Subject: [PATCH] refactor: fix the warnings, improves status management in CP and MILP explanations, optimizes table processing, and cleans up the code --- examples/simple_example_both.py | 6 ++-- ocean/cp/_explainer.py | 60 ++++++++++++++++++--------------- ocean/cp/_explanation.py | 5 +-- ocean/cp/_model.py | 5 +-- ocean/mip/_explainer.py | 9 +++-- tests/test_explainer.py | 26 +++++++------- 6 files changed, 57 insertions(+), 54 deletions(-) diff --git a/examples/simple_example_both.py b/examples/simple_example_both.py index a07263c..266c6a8 100644 --- a/examples/simple_example_both.py +++ b/examples/simple_example_both.py @@ -1,11 +1,11 @@ import time + from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer from ocean.datasets import load_adult - plot_anytime_distances = True num_workers = 8 # Both CP and MILP solving support multithreading random_state = 42 @@ -87,9 +87,7 @@ rf.predict([explanation_oceancp.to_numpy()])[0], ")", ) - print( - "CP Sollist = ", cp_model.get_anytime_solutions() - ) + print("CP Sollist = ", cp_model.get_anytime_solutions()) else: print("CP: No CF found.") diff --git a/ocean/cp/_explainer.py b/ocean/cp/_explainer.py index 9bcb8ad..e49c43c 100644 --- a/ocean/cp/_explainer.py +++ b/ocean/cp/_explainer.py @@ -80,34 +80,38 @@ def explain( _ = self.solver.Solve(self, solution_callback=self.callback) status = self.solver.status_name() self.Status = status - if status == "INFEASIBLE": - msg = "There are no feasible counterfactuals for this query." - msg += " If there should be one, please check the model " - msg += "constraints or report this issue to the developers." - warnings.warn(msg, category=UserWarning, stacklevel=2) - return None - elif status == "MODEL_INVALID": - msg = "The constraint programming model is invalid. " - msg += "Please check the model constraints or report" - msg += " this issue to the developers." - raise RuntimeError(msg) - elif status == "UNKNOWN": - msg = "The constraint programming solver could " - msg += "not find any valid CF within the given time frame." - msg += " Try increasing the time limit." - warnings.warn(msg, category=UserWarning, stacklevel=2) - return None - elif status == "FEASIBLE": - msg = "A valid CF was found, but it might be " - msg += "suboptimal as the constraint programming " - msg += "solver could not prove optimality within " - msg += "the given time frame. \n It can however certify" - msg += " that no counterfactual can be closer than" - msg += f" {self.solver.BestObjectiveBound()}." - warnings.warn(msg, category=UserWarning, stacklevel=2) - elif status != "OPTIMAL": - msg = "Unexpected solver status: " + status - raise RuntimeError(msg) + + match status: + case "OPTIMAL": + pass + case "FEASIBLE": + msg = "A valid CF was found, but it might be " + msg += "suboptimal as the constraint programming " + msg += "solver could not prove optimality within " + msg += "the given time frame. \n It can however certify" + msg += " that no counterfactual can be closer than" + msg += f" {self.solver.BestObjectiveBound()}." + warnings.warn(msg, category=UserWarning, stacklevel=2) + case "INFEASIBLE": + msg = "There are no feasible counterfactuals for this query." + msg += " If there should be one, please check the model " + msg += "constraints or report this issue to the developers." + warnings.warn(msg, category=UserWarning, stacklevel=2) + return None + case "MODEL_INVALID": + msg = "The constraint programming model is invalid. " + msg += "Please check the model constraints or report" + msg += " this issue to the developers." + raise RuntimeError(msg) + case "UNKNOWN": + msg = "The constraint programming solver could " + msg += "not find any valid CF within the given time frame." + msg += " Try increasing the time limit." + warnings.warn(msg, category=UserWarning, stacklevel=2) + return None + case _: + msg = "Unexpected solver status: " + status + raise RuntimeError(msg) self.explanation.query = x return self.explanation diff --git a/ocean/cp/_explanation.py b/ocean/cp/_explanation.py index 9d5a634..4bb175d 100644 --- a/ocean/cp/_explanation.py +++ b/ocean/cp/_explanation.py @@ -76,10 +76,11 @@ def format_value(self, f: int, idx: int, levels: list[float]) -> float: if self.query.shape[0] == 0: return float(levels[idx] + levels[idx + 1]) / 2 j = 0 - while self.query[f] > levels[j + 1]: + query_arr = np.asarray(self.query, dtype=float).ravel() + while query_arr[f] > levels[j + 1]: j += 1 if j == idx: - value = float(self.query[f]) + value = float(query_arr[f]) elif j < idx: value = float(levels[idx]) + self._epsilon else: diff --git a/ocean/cp/_model.py b/ocean/cp/_model.py index 46d8015..a7870b3 100644 --- a/ocean/cp/_model.py +++ b/ocean/cp/_model.py @@ -119,6 +119,7 @@ def _add_objective(self, x: Array1D, norm: int) -> cp.ObjLinearExprT: if norm != 1: msg = f"Unsupported norm: {norm}" raise ValueError(msg) + x_arr = np.asarray(x, dtype=float).ravel() variables = self.mapper.values() objective: cp.LinearExpr = 0 # type: ignore[assignment] @@ -126,10 +127,10 @@ def _add_objective(self, x: Array1D, norm: int) -> cp.ObjLinearExprT: for v in variables: if v.is_one_hot_encoded: for code in v.codes: - objective += self.L1(x[k], v, code=code) + objective += self.L1(x_arr[k], v, code=code) k += 1 else: - objective += self.L1(x[k], v) + objective += self.L1(x_arr[k], v) k += 1 return objective diff --git a/ocean/mip/_explainer.py b/ocean/mip/_explainer.py index 653c040..3373d8a 100644 --- a/ocean/mip/_explainer.py +++ b/ocean/mip/_explainer.py @@ -105,17 +105,17 @@ def explain( else: self.optimize() status = self.get_solving_status() - + if status == "INFEASIBLE": msg = "There are no feasible counterfactuals for this query." msg += " If there should be one, please check the model " msg += "constraints or report this issue to the developers." warnings.warn(msg, category=UserWarning, stacklevel=2) return None - elif status != "OPTIMAL": + if status != "OPTIMAL": if self.SolCount > 0: - msg = "A valid CF was found, but it might be " - msg += "suboptimal as the MILP " + msg = "A valid CF was found, but it might be " + msg += "suboptimal as the MILP " msg += "solver could not prove optimality within " msg += "the given time frame. \n It can however certify" msg += " that no counterfactual can be closer than" @@ -138,7 +138,6 @@ def explain( msg += " valid CF for an un-handled reason." msg += "Unexpected solver status: " + status raise RuntimeError(msg) - return self.explanation @staticmethod diff --git a/tests/test_explainer.py b/tests/test_explainer.py index c08462e..a34bd19 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -19,7 +19,7 @@ def test_mip_explain( max_depth: int, n_classes: int, n_samples: int, - num_workers:int, + num_workers: int, ) -> None: data, y, mapper = generate_data(seed, n_samples, n_classes) clf = RandomForestClassifier( @@ -30,18 +30,18 @@ def test_mip_explain( clf.fit(data, y) model = MixedIntegerProgramExplainer(clf, mapper=mapper, env=ENV) - x = data.iloc[0, :].to_numpy().astype(float).flatten() + x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] try: - model.explain(x, y=0, norm=1, - num_workers=num_workers, + model.explain(x, y=0, norm=1, + num_workers=num_workers, random_seed=seed) assert model.Status == gp.GRB.OPTIMAL model.cleanup() - model.explain(x, y=0, norm=1, - return_callback=True, - num_workers=num_workers, + model.explain(x, y=0, norm=1, + return_callback=True, + num_workers=num_workers, random_seed=seed) assert len(model.callback.sollist) != 0 @@ -72,19 +72,19 @@ def test_cp_explain( clf.fit(data, y) model = ConstraintProgrammingExplainer(clf, mapper=mapper) - x = data.iloc[0, :].to_numpy().astype(float).flatten() + x = data.iloc[0, :].to_numpy().astype(float).flatten() # pyright: ignore[reportUnknownVariableType] try: - _ = model.explain(x, y=0, norm=1, + _ = model.explain(x, y=0, norm=1, return_callback=False, - num_workers=num_workers, + num_workers=num_workers, random_seed=seed) assert model.callback is None or len(model.callback.sollist) == 0 model.cleanup() - _ = model.explain(x, y=0, norm=1, - return_callback=True, - num_workers=num_workers, + _ = model.explain(x, y=0, norm=1, + return_callback=True, + num_workers=num_workers, random_seed=seed) assert model.callback is None or len(model.callback.sollist) != 0 except gp.GurobiError as e: