From b64ab0f14db3fe67cecd3f100273def2ccc1adc1 Mon Sep 17 00:00:00 2001 From: Awa Khouna Date: Mon, 22 Sep 2025 14:12:53 -0400 Subject: [PATCH] Optimizes status management and improves epsilon accuracy in the Explanation class --- ocean/cp/_explainer.py | 9 ++++----- ocean/cp/_explanation.py | 7 +++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/ocean/cp/_explainer.py b/ocean/cp/_explainer.py index 3b350e0..3db26c7 100644 --- a/ocean/cp/_explainer.py +++ b/ocean/cp/_explainer.py @@ -112,14 +112,13 @@ def explain( case _: msg = "Unexpected solver status: " + status raise RuntimeError(msg) - + if not cf_status_ok: self.cleanup() return None - else: - self.explanation.query = x - self.cleanup() - return self.explanation + self.explanation.query = x + self.cleanup() + return self.explanation class MySolCallback(cp.CpSolverSolutionCallback): diff --git a/ocean/cp/_explanation.py b/ocean/cp/_explanation.py index 6495d88..e030b17 100644 --- a/ocean/cp/_explanation.py +++ b/ocean/cp/_explanation.py @@ -11,7 +11,7 @@ class Explanation(Mapper[FeatureVar], BaseExplanation): - _epsilon: float = 1e-6 + _epsilon: float = float(np.finfo(np.float32).eps) _x: Array1D = np.zeros((0,), dtype=int) def vget(self, i: int) -> cp.IntVar: @@ -73,7 +73,6 @@ def get(v: FeatureVar) -> Key | Number: return self.reduce(get) def format_value(self, f: int, idx: int, levels: list[float]) -> float: - eps = min(self._epsilon, 0.5 * min(np.diff(levels))) if self.query.shape[0] == 0: return float(levels[idx] + levels[idx + 1]) / 2 j = 0 @@ -83,9 +82,9 @@ def format_value(self, f: int, idx: int, levels: list[float]) -> float: if j == idx: value = float(query_arr[f]) elif j < idx: - value = float(levels[idx]) + eps + value = float(levels[idx]) + self._epsilon else: - value = float(levels[idx + 1]) - eps + value = float(levels[idx + 1]) - self._epsilon return value @property