diff --git a/mypy.ini b/mypy.ini index afd5a9c..29004c1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,7 +10,7 @@ warn_return_any = True warn_unreachable = True warn_unused_configs = True warn_unused_ignores = True -plugins = pydantic.mypy, numpy.typing.mypy_plugin +plugins = pydantic.mypy [mypy-anytree] ignore_missing_imports = True diff --git a/ocean/mip/_builders/model.py b/ocean/mip/_builders/model.py index eb2aa60..0ad8a4b 100644 --- a/ocean/mip/_builders/model.py +++ b/ocean/mip/_builders/model.py @@ -140,7 +140,7 @@ def _cset( # :: mu[j] <= 1 - flow[node.left], # :: mu[j] >= epsilon * flow[node.right]. - epsilon = self._find_best_epsilon(model, var) + epsilon = self._find_best_epsilon(model, var, self._epsilon) threshold = node.threshold j = int(np.searchsorted(var.levels, threshold)) @@ -227,9 +227,14 @@ def _eset( model.addConstr(x >= tree[node.right.node_id]) @staticmethod - def _find_best_epsilon(model: BaseModel, var: FeatureVar) -> float: + def _find_best_epsilon( + model: BaseModel, + var: FeatureVar, + epsilon: float, + ) -> float: # Find the best epsilon value for the given feature variable. - # This + # This is done by finding the minimum difference between + # the split levels and the tolerance of the solver. tol: float = model.getParamInfo("FeasibilityTol")[2] min_tol: float = 1e-9 delta: float = min(*np.diff(var.levels)) @@ -242,6 +247,8 @@ def _find_best_epsilon(model: BaseModel, var: FeatureVar) -> float: msg += " Consider not scaling the data or using bigger intervals." warnings.warn(msg, category=UserWarning, stacklevel=2) return eps + if delta * epsilon > tol: + return epsilon while 2 * tol / delta >= 1.0: tol /= 2 feas_tol = model.getParamInfo("FeasibilityTol")[2] diff --git a/tests/mip/model/test_objective.py b/tests/mip/model/test_objective.py index 2b5eade..6aab3aa 100644 --- a/tests/mip/model/test_objective.py +++ b/tests/mip/model/test_objective.py @@ -130,7 +130,6 @@ def test_set_majority_class( assert model.Status == gp.GRB.OPTIMAL explanation = model.explanation - validate_solution(explanation) validate_paths(*model.trees, explanation=explanation) validate_sklearn_paths(clf, explanation, model.trees)