diff --git a/mypy.ini b/mypy.ini index 29004c1..4b4ed49 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,3 +17,6 @@ ignore_missing_imports = True [mypy-sklearn.*] ignore_missing_imports = True + +[mypy-pysat.*] +ignore_missing_imports = True diff --git a/ocean/__init__.py b/ocean/__init__.py index d859e99..ac7e2eb 100644 --- a/ocean/__init__.py +++ b/ocean/__init__.py @@ -1,15 +1,18 @@ -from . import abc, cp, datasets, feature, mip, tree +from . import abc, cp, datasets, feature, maxsat, mip, tree MixedIntegerProgramExplainer = mip.Explainer ConstraintProgrammingExplainer = cp.Explainer +MaxSATExplainer = maxsat.Explainer __all__ = [ "ConstraintProgrammingExplainer", + "MaxSATExplainer", "MixedIntegerProgramExplainer", "abc", "cp", "datasets", "feature", + "maxsat", "mip", "tree", ] diff --git a/ocean/maxsat/__init__.py b/ocean/maxsat/__init__.py new file mode 100644 index 0000000..6915f01 --- /dev/null +++ b/ocean/maxsat/__init__.py @@ -0,0 +1,3 @@ +from ._explainer import Explainer + +__all__ = ["Explainer"] diff --git a/ocean/maxsat/_base.py b/ocean/maxsat/_base.py new file mode 100644 index 0000000..4a17993 --- /dev/null +++ b/ocean/maxsat/_base.py @@ -0,0 +1,25 @@ +from abc import ABC +from typing import Any, Protocol + +from pysat.formula import WCNF + + +class BaseModel(ABC, WCNF): + def __init__(self) -> None: + WCNF.__init__(self) + + def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401 + object.__setattr__(self, name, value) + + def build_vars(self, *variables: "Var") -> None: + for variable in variables: + variable.build(model=self) + + +class Var(Protocol): + _name: str + + def __init__(self, name: str) -> None: + self._name = name + + def build(self, model: BaseModel) -> None: ... diff --git a/ocean/maxsat/_builder/__init__.py b/ocean/maxsat/_builder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ocean/maxsat/_builder/model.py b/ocean/maxsat/_builder/model.py new file mode 100644 index 0000000..35100fe --- /dev/null +++ b/ocean/maxsat/_builder/model.py @@ -0,0 +1,145 @@ +from collections.abc import Iterable +from typing import Protocol + +from ...abc import Mapper +from ...tree._node import Node +from .._base import BaseModel +from .._variables import FeatureVar, TreeVar + + +class ModelBuilder(Protocol): + def build( + self, + model: BaseModel, + *, + trees: Iterable[TreeVar], + mapper: Mapper[FeatureVar], + ) -> None: + """ + Build the model constraints for the given trees and features. + + Parameters + ---------- + model : BaseModel + The model to which the constraints will be added. + trees : tuple[TreeVar, ...] + The tree variables for which the constraints will be built. + mapper : Mapper[FeatureVar] + The feature variables for which the constraints will be built. + + """ + raise NotImplementedError + + +class MaxSATBuilder(ModelBuilder): + def build( + self, + model: BaseModel, + *, + trees: Iterable[TreeVar], + mapper: Mapper[FeatureVar], + ) -> None: + for tree in trees: + self._build(model, tree=tree, mapper=mapper) + + def _build( + self, + model: BaseModel, + *, + tree: TreeVar, + mapper: Mapper[FeatureVar], + ) -> None: + for leaf in tree.leaves: + self._build_path(model, tree=tree, leaf=leaf, mapper=mapper) + + def _build_path( + self, + model: BaseModel, + *, + tree: TreeVar, + leaf: Node, + mapper: Mapper[FeatureVar], + ) -> None: + y = tree[leaf.node_id] + self._propagate(model, node=leaf, mapper=mapper, y=y) + + def _propagate( + self, + model: BaseModel, + *, + node: Node, + mapper: Mapper[FeatureVar], + y: object, + ) -> None: + parent = node.parent + if parent is None: + return + v = mapper[parent.feature] + self._expand(model, node=parent, y=y, v=v, sigma=node.sigma) + self._propagate(model, node=parent, mapper=mapper, y=y) + + def _expand( + self, + model: BaseModel, + *, + node: Node, + y: object, + v: FeatureVar, + sigma: bool, + ) -> None: + if v.is_binary: + self._bset(model, y=y, v=v, sigma=sigma) + elif v.is_continuous: + self._cset(model, node=node, y=y, v=v, sigma=sigma) + elif v.is_discrete: + self._dset(model, node=node, y=y, v=v, sigma=sigma) + elif v.is_one_hot_encoded: + self._eset(model, node=node, y=y, v=v, sigma=sigma) + + @staticmethod + def _bset( + model: BaseModel, + *, + y: object, + v: FeatureVar, + sigma: bool, + ) -> None: + msg = "Raise NotImplementedError" + raise NotImplementedError(msg) + + @staticmethod + def _cset( + model: BaseModel, + *, + node: Node, + y: object, + v: FeatureVar, + sigma: bool, + ) -> None: + raise NotImplementedError + + @staticmethod + def _dset( + model: BaseModel, + *, + node: Node, + y: object, + v: FeatureVar, + sigma: bool, + ) -> None: + raise NotImplementedError + + @staticmethod + def _eset( + model: BaseModel, + *, + node: Node, + y: object, + v: FeatureVar, + sigma: bool, + ) -> None: + raise NotImplementedError + + +class ModelBuilderFactory: + MAXSAT: type[MaxSATBuilder] = MaxSATBuilder diff --git a/ocean/maxsat/_explainer.py b/ocean/maxsat/_explainer.py new file mode 100644 index 0000000..fcec0b8 --- /dev/null +++ b/ocean/maxsat/_explainer.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..tree import parse_ensembles +from ..typing import ( + Array1D, + BaseExplainableEnsemble, + BaseExplainer, + NonNegativeInt, + PositiveInt, +) +from ._model import Model +from ._solver import MaxSATSolver + +if TYPE_CHECKING: + from ..abc import Mapper + from ..feature import Feature + from ._explanation import Explanation + + +class Explainer(Model, BaseExplainer): + def __init__( + self, + ensemble: BaseExplainableEnsemble, + *, + mapper: Mapper[Feature], + weights: Array1D | None = None, + epsilon: int = Model.DEFAULT_EPSILON, + model_type: Model.Type = Model.Type.MAXSAT, + ) -> None: + ensembles = (ensemble,) + trees = parse_ensembles(*ensembles, mapper=mapper) + Model.__init__( + self, + trees, + mapper=mapper, + weights=weights, + epsilon=epsilon, + model_type=model_type, + ) + self.build() + self.solver = MaxSATSolver + + def get_objective_value(self) -> float: + raise NotImplementedError + + def get_solving_status(self) -> str: + raise NotImplementedError + + def get_anytime_solutions(self) -> list[dict[str, float]] | None: + raise NotImplementedError + + def explain( + self, + x: Array1D, + *, + y: NonNegativeInt, + norm: PositiveInt, + return_callback: bool = False, + verbose: bool = False, + max_time: int = 60, + num_workers: int | None = None, + random_seed: int = 42, + ) -> Explanation | None: + raise NotImplementedError diff --git a/ocean/maxsat/_explanation.py b/ocean/maxsat/_explanation.py new file mode 100644 index 0000000..b3da47d --- /dev/null +++ b/ocean/maxsat/_explanation.py @@ -0,0 +1,97 @@ +from collections.abc import Mapping +from typing import TYPE_CHECKING + +import numpy as np + +from ..abc import Mapper +from ..typing import Array1D, BaseExplanation, Key, Number +from ._variables import FeatureVar + +if TYPE_CHECKING: + import pandas as pd + + +class Explanation(Mapper[FeatureVar], BaseExplanation): + _epsilon: float = float(np.finfo(np.float32).eps) + _x: Array1D = np.zeros((0,), dtype=int) + + def vget(self, i: int) -> int: + msg = "Not implemented." + raise NotImplementedError(msg) + + def to_series(self) -> "pd.Series[float]": + msg = "Not implemented." + raise NotImplementedError(msg) + + def to_numpy(self) -> Array1D: + return ( + self.to_series() + .to_frame() + .T[self.columns] + .to_numpy() + .flatten() + .astype(np.float64) + ) + + @property + def x(self) -> Array1D: + return self.to_numpy() + + @property + def value(self) -> Mapping[Key, Key | Number]: + msg = "Not implemented." + raise NotImplementedError(msg) + + def format_value( + self, + f: int, + idx: int, + levels: list[float], + ) -> float: + if self.query.shape[0] == 0: + return float(levels[idx] + levels[idx + 1]) / 2 + j = 0 + query_arr = np.asarray(self.query, dtype=float).ravel() + while query_arr[f] > levels[j + 1]: + j += 1 + if j == idx: + value = float(query_arr[f]) + elif j < idx: + value = float(levels[idx]) + self._epsilon + else: + value = float(levels[idx + 1]) - self._epsilon + return value + + def format_discrete_value( + self, + f: int, + val: int, + thresholds: Array1D, + ) -> float: + if self.query.shape[0] == 0: + return val + query_arr = np.asarray(self.query, dtype=float).ravel() + j_x = np.searchsorted(thresholds, query_arr[f], side="left") + j_val = np.searchsorted(thresholds, val, side="left") + if j_x != j_val: + return float(val) + return float(query_arr[f]) + + @property + def query(self) -> Array1D: + return self._x + + @query.setter + def query(self, value: Array1D) -> None: + self._x = value + + def __repr__(self) -> str: + mapping = self.value + prefix = f"{self.__class__.__name__}:\n" + root = self._repr(mapping) + suffix = "" + + return prefix + root + suffix + + +__all__ = ["Explanation"] diff --git a/ocean/maxsat/_managers/__init__.py b/ocean/maxsat/_managers/__init__.py new file mode 100644 index 0000000..450d21d --- /dev/null +++ b/ocean/maxsat/_managers/__init__.py @@ -0,0 +1,9 @@ +from ._feature import FeatureManager +from ._garbage import GarbageManager +from ._tree import TreeManager + +__all__ = [ + "FeatureManager", + "GarbageManager", + "TreeManager", +] diff --git a/ocean/maxsat/_managers/_feature.py b/ocean/maxsat/_managers/_feature.py new file mode 100644 index 0000000..9d57ec1 --- /dev/null +++ b/ocean/maxsat/_managers/_feature.py @@ -0,0 +1,51 @@ +from ...abc import Mapper +from ...feature import Feature +from ...typing import ( + Key, + PositiveInt, +) +from .._base import BaseModel +from .._explanation import Explanation +from .._variables import FeatureVar + + +class FeatureManager: + FEATURE_VAR_FMT: str = "feature[{key}]" + + _mapper: Explanation + + def __init__(self, mapper: Mapper[Feature]) -> None: + self._set_mapper(mapper) + + def build_features(self, model: BaseModel) -> None: + model.build_vars(*self.mapper.values()) + + @property + def n_columns(self) -> PositiveInt: + return self.mapper.n_columns + + @property + def n_features(self) -> PositiveInt: + return len(self.mapper) + + @property + def mapper(self) -> Explanation: + return self._mapper + + @property + def explanation(self) -> Explanation: + return self.mapper + + def vget(self, i: int) -> int: + raise NotImplementedError + + def _set_mapper(self, mapper: Mapper[Feature]) -> None: + def create(key: Key, feature: Feature) -> FeatureVar: + name = self.FEATURE_VAR_FMT.format(key=key) + return FeatureVar(feature, name=name) + + if len(mapper) == 0: + msg = "At least one feature is required." + raise ValueError(msg) + + self._mapper = Explanation(mapper.apply(create)) diff --git a/ocean/maxsat/_managers/_garbage.py b/ocean/maxsat/_managers/_garbage.py new file mode 100644 index 0000000..b8b21c0 --- /dev/null +++ b/ocean/maxsat/_managers/_garbage.py @@ -0,0 +1,16 @@ +class GarbageManager: + type GarbageObject = object + + # Garbage collector for the model. + # - Used to keep track of the variables and constraints created, + # and to remove them when the model is cleared. + _garbage: list[GarbageObject] + + def __init__(self) -> None: + self._garbage = [] + + def add_garbage(self, *args: GarbageObject) -> None: + self._garbage.extend(args) + + def remove_garbage(self) -> None: + raise NotImplementedError diff --git a/ocean/maxsat/_managers/_tree.py b/ocean/maxsat/_managers/_tree.py new file mode 100644 index 0000000..e3c1d74 --- /dev/null +++ b/ocean/maxsat/_managers/_tree.py @@ -0,0 +1,100 @@ +from collections.abc import Iterable + +import numpy as np + +from ...tree import Tree +from ...typing import ( + NonNegativeArray1D, + NonNegativeInt, + PositiveInt, +) +from .._base import BaseModel +from .._variables import TreeVar + + +class TreeManager: + TREE_VAR_FMT: str = "tree[{t}]" + + # Tree variables in the ensemble. + _trees: tuple[TreeVar, *tuple[TreeVar, ...]] + + # Weights for the estimators in the ensemble. + _weights: NonNegativeArray1D + + def __init__( + self, + trees: Iterable[Tree], + *, + weights: NonNegativeArray1D | None = None, + ) -> None: + self._set_trees(trees=trees) + self._set_weights(weights=weights) + + def build_trees(self, model: BaseModel) -> None: + model.build_vars(*self.trees) + self._function = self._get_function() + + @property + def n_trees(self) -> PositiveInt: + return len(self.trees) + + @property + def n_estimators(self) -> PositiveInt: + return self.n_trees + + @property + def trees(self) -> tuple[TreeVar, *tuple[TreeVar, ...]]: + return self._trees + + @property + def estimators(self) -> tuple[TreeVar, *tuple[TreeVar, ...]]: + return self._trees[0], *self._trees[1 : self.n_estimators] + + @property + def shape(self) -> tuple[NonNegativeInt, ...]: + return self._trees[0].shape + + @property + def n_classes(self) -> NonNegativeInt: + return self.shape[-1] + + @property + def weights(self) -> NonNegativeArray1D: + return self._weights + + def _set_trees( + self, + trees: Iterable[Tree], + ) -> None: + def create(item: tuple[int, Tree]) -> TreeVar: + t, tree = item + name = self.TREE_VAR_FMT.format(t=t) + return TreeVar(tree, name=name) + + tree_vars = tuple(map(create, enumerate(trees))) + if len(tree_vars) == 0: + msg = "At least one tree is required." + raise ValueError(msg) + + self._trees = tree_vars[0], *tree_vars[1:] + + def _set_weights(self, weights: NonNegativeArray1D | None = None) -> None: + if weights is None: + weights = np.ones(self.n_estimators, dtype=np.float64) + + if len(weights) != self.n_estimators: + msg = "The number of weights must match the number of trees." + raise ValueError(msg) + + self._weights = weights + + def weighted_function( + self, + weights: NonNegativeArray1D, + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], object]: + raise NotImplementedError + + def _get_function( + self, + ) -> dict[tuple[NonNegativeInt, NonNegativeInt], object]: + return self.weighted_function(weights=self.weights) diff --git a/ocean/maxsat/_model.py b/ocean/maxsat/_model.py new file mode 100644 index 0000000..9af8be6 --- /dev/null +++ b/ocean/maxsat/_model.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING + +from ._base import BaseModel +from ._builder.model import ModelBuilder, ModelBuilderFactory +from ._managers import FeatureManager, GarbageManager, TreeManager + +if TYPE_CHECKING: + from collections.abc import Iterable + + from ..abc import Mapper + from ..feature import Feature + from ..tree import Tree + from ..typing import NonNegativeArray1D, NonNegativeInt + + +@dataclass +class Model(BaseModel, FeatureManager, TreeManager, GarbageManager): + DEFAULT_EPSILON: int = 1 + + # Model builder for the ensemble. + _builder: ModelBuilder | None = None + + class Type(Enum): + MAXSAT = "MAXSAT" + + def __init__( + self, + trees: Iterable[Tree], + mapper: Mapper[Feature], + *, + model_type: Type = Type.MAXSAT, + weights: NonNegativeArray1D | None = None, + max_samples: NonNegativeInt = 0, + epsilon: int = DEFAULT_EPSILON, + ) -> None: + BaseModel.__init__(self) + TreeManager.__init__( + self, + trees=trees, + weights=weights, + ) + FeatureManager.__init__(self, mapper=mapper) + GarbageManager.__init__(self) + + self._set_weights(weights=weights) + self._max_samples = max_samples + self._epsilon = epsilon + self._set_builder(model_type=model_type) + + def build(self) -> None: + raise NotImplementedError + + def _set_builder(self, model_type: Type) -> None: + match model_type: + case Model.Type.MAXSAT: + self._builder = ModelBuilderFactory.MAXSAT() diff --git a/ocean/maxsat/_solver.py b/ocean/maxsat/_solver.py new file mode 100644 index 0000000..e86b7f4 --- /dev/null +++ b/ocean/maxsat/_solver.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import cast + +from pysat.examples.rc2 import RC2 +from pysat.formula import WCNF + + +class MaxSATSolver: + """Thin RC2 wrapper to keep a stable interface.""" + + def __init__(self, solver_name: str = "glucose3") -> None: + self.solver_name = solver_name + + @staticmethod + def new_wcnf() -> WCNF: + return WCNF() + + @staticmethod + def add_hard(w: WCNF, clause: list[int]) -> None: + w.append(clause, weight=-1) + + @staticmethod + def add_soft(w: WCNF, clause: list[int], weight: int) -> None: + w.append(clause, weight=weight) + + def solve(self, w: WCNF) -> list[int]: + with RC2(w, solver=self.solver_name, adapt=True) as rc2: + model = cast("list[int] | None", rc2.compute()) + if model is None: + msg = "UNSAT: no counterfactual found." + raise RuntimeError(msg) + return model diff --git a/ocean/maxsat/_variables/__init__.py b/ocean/maxsat/_variables/__init__.py new file mode 100644 index 0000000..3ab204d --- /dev/null +++ b/ocean/maxsat/_variables/__init__.py @@ -0,0 +1,4 @@ +from ._feature import FeatureVar +from ._tree import TreeVar + +__all__ = ["FeatureVar", "TreeVar"] diff --git a/ocean/maxsat/_variables/_feature.py b/ocean/maxsat/_variables/_feature.py new file mode 100644 index 0000000..3d4dc7d --- /dev/null +++ b/ocean/maxsat/_variables/_feature.py @@ -0,0 +1,56 @@ +from ...feature import Feature +from ...feature._keeper import FeatureKeeper +from ...typing import Key +from .._base import BaseModel, Var + + +class FeatureVar(Var, FeatureKeeper): + X_VAR_NAME_FMT: str = "x[{name}]" + + def __init__(self, feature: Feature, name: str) -> None: + Var.__init__(self, name=name) + FeatureKeeper.__init__(self, feature=feature) + + def build(self, model: BaseModel) -> None: + raise NotImplementedError + + def xget(self, code: Key | None = None) -> None: + raise NotImplementedError + + def mget(self, key: int) -> None: + raise NotImplementedError + + def objvarget(self) -> None: + raise NotImplementedError + + def _add_x(self, model: BaseModel) -> None: + raise NotImplementedError + + def _add_u(self, model: BaseModel) -> None: + raise NotImplementedError + + def _set_mu(self, model: BaseModel, m: int) -> None: + raise NotImplementedError + + def _add_mu(self, model: BaseModel) -> None: + raise NotImplementedError + + def _add_one_hot_encoded( + self, + model: BaseModel, + name: str, + ) -> None: + raise NotImplementedError + + @staticmethod + def _add_binary(model: BaseModel, name: str) -> None: + raise NotImplementedError + + def _add_continuous(self, model: BaseModel, name: str) -> None: + raise NotImplementedError + + def _add_discrete(self, model: BaseModel, name: str) -> None: + raise NotImplementedError + + def _xget_one_hot_encoded(self, code: Key | None) -> None: + raise NotImplementedError diff --git a/ocean/maxsat/_variables/_tree.py b/ocean/maxsat/_variables/_tree.py new file mode 100644 index 0000000..2ee57c8 --- /dev/null +++ b/ocean/maxsat/_variables/_tree.py @@ -0,0 +1,39 @@ +from collections.abc import Iterator, Mapping + +from pydantic import validate_call + +from ...tree._keeper import TreeKeeper, TreeLike +from ...typing import NonNegativeInt +from .._base import BaseModel, Var + + +class TreeVar(Var, TreeKeeper, Mapping[NonNegativeInt, object]): + PATH_VAR_NAME_FMT: str = "{name}_path" + + def __init__( + self, + tree: TreeLike, + name: str, + ) -> None: + Var.__init__(self, name=name) + TreeKeeper.__init__(self, tree=tree) + + def build(self, model: BaseModel) -> None: + raise NotImplementedError + + def __len__(self) -> int: + return self.n_nodes + + def __iter__(self) -> Iterator[NonNegativeInt]: + return iter(range(self.n_nodes)) + + @validate_call + def __getitem__(self, node_id: NonNegativeInt) -> None: + raise NotImplementedError + + def _add_path( + self, + model: BaseModel, + name: str, + ) -> None: + raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index a930ea3..cb2defe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "pydantic", "scikit-learn", "xgboost", + "python-sat", ] optional-dependencies.dev = [