Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ ignore_missing_imports = True

[mypy-sklearn.*]
ignore_missing_imports = True

[mypy-pysat.*]
ignore_missing_imports = True
5 changes: 4 additions & 1 deletion ocean/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from . import abc, cp, datasets, feature, mip, tree
from . import abc, cp, datasets, feature, maxsat, mip, tree

MixedIntegerProgramExplainer = mip.Explainer
ConstraintProgrammingExplainer = cp.Explainer
MaxSATExplainer = maxsat.Explainer

__all__ = [
"ConstraintProgrammingExplainer",
"MaxSATExplainer",
"MixedIntegerProgramExplainer",
"abc",
"cp",
"datasets",
"feature",
"maxsat",
"mip",
"tree",
]
3 changes: 3 additions & 0 deletions ocean/maxsat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._explainer import Explainer

__all__ = ["Explainer"]
25 changes: 25 additions & 0 deletions ocean/maxsat/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC
from typing import Any, Protocol

from pysat.formula import WCNF


class BaseModel(ABC, WCNF):
def __init__(self) -> None:
WCNF.__init__(self)

def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401
object.__setattr__(self, name, value)

def build_vars(self, *variables: "Var") -> None:
for variable in variables:
variable.build(model=self)


class Var(Protocol):
_name: str

def __init__(self, name: str) -> None:
self._name = name

def build(self, model: BaseModel) -> None: ...
Empty file.
145 changes: 145 additions & 0 deletions ocean/maxsat/_builder/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from collections.abc import Iterable
from typing import Protocol

from ...abc import Mapper
from ...tree._node import Node
from .._base import BaseModel
from .._variables import FeatureVar, TreeVar


class ModelBuilder(Protocol):
def build(
self,
model: BaseModel,
*,
trees: Iterable[TreeVar],
mapper: Mapper[FeatureVar],
) -> None:
"""
Build the model constraints for the given trees and features.

Parameters
----------
model : BaseModel
The model to which the constraints will be added.
trees : tuple[TreeVar, ...]
The tree variables for which the constraints will be built.
mapper : Mapper[FeatureVar]
The feature variables for which the constraints will be built.

"""
raise NotImplementedError


class MaxSATBuilder(ModelBuilder):
def build(
self,
model: BaseModel,
*,
trees: Iterable[TreeVar],
mapper: Mapper[FeatureVar],
) -> None:
for tree in trees:
self._build(model, tree=tree, mapper=mapper)

def _build(
self,
model: BaseModel,
*,
tree: TreeVar,
mapper: Mapper[FeatureVar],
) -> None:
for leaf in tree.leaves:
self._build_path(model, tree=tree, leaf=leaf, mapper=mapper)

def _build_path(
self,
model: BaseModel,
*,
tree: TreeVar,
leaf: Node,
mapper: Mapper[FeatureVar],
) -> None:
y = tree[leaf.node_id]
self._propagate(model, node=leaf, mapper=mapper, y=y)

def _propagate(
self,
model: BaseModel,
*,
node: Node,
mapper: Mapper[FeatureVar],
y: object,
) -> None:
parent = node.parent
if parent is None:
return
v = mapper[parent.feature]
self._expand(model, node=parent, y=y, v=v, sigma=node.sigma)
self._propagate(model, node=parent, mapper=mapper, y=y)

def _expand(
self,
model: BaseModel,
*,
node: Node,
y: object,
v: FeatureVar,
sigma: bool,
) -> None:
if v.is_binary:
self._bset(model, y=y, v=v, sigma=sigma)
elif v.is_continuous:
self._cset(model, node=node, y=y, v=v, sigma=sigma)
elif v.is_discrete:
self._dset(model, node=node, y=y, v=v, sigma=sigma)
elif v.is_one_hot_encoded:
self._eset(model, node=node, y=y, v=v, sigma=sigma)

@staticmethod
def _bset(
model: BaseModel,
*,
y: object,
v: FeatureVar,
sigma: bool,
) -> None:
msg = "Raise NotImplementedError"
raise NotImplementedError(msg)

@staticmethod
def _cset(
model: BaseModel,
*,
node: Node,
y: object,
v: FeatureVar,
sigma: bool,
) -> None:
raise NotImplementedError

@staticmethod
def _dset(
model: BaseModel,
*,
node: Node,
y: object,
v: FeatureVar,
sigma: bool,
) -> None:
raise NotImplementedError

@staticmethod
def _eset(
model: BaseModel,
*,
node: Node,
y: object,
v: FeatureVar,
sigma: bool,
) -> None:
raise NotImplementedError


class ModelBuilderFactory:
MAXSAT: type[MaxSATBuilder] = MaxSATBuilder
66 changes: 66 additions & 0 deletions ocean/maxsat/_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from ..tree import parse_ensembles
from ..typing import (
Array1D,
BaseExplainableEnsemble,
BaseExplainer,
NonNegativeInt,
PositiveInt,
)
from ._model import Model
from ._solver import MaxSATSolver

if TYPE_CHECKING:
from ..abc import Mapper
from ..feature import Feature
from ._explanation import Explanation


class Explainer(Model, BaseExplainer):
def __init__(
self,
ensemble: BaseExplainableEnsemble,
*,
mapper: Mapper[Feature],
weights: Array1D | None = None,
epsilon: int = Model.DEFAULT_EPSILON,
model_type: Model.Type = Model.Type.MAXSAT,
) -> None:
ensembles = (ensemble,)
trees = parse_ensembles(*ensembles, mapper=mapper)
Model.__init__(
self,
trees,
mapper=mapper,
weights=weights,
epsilon=epsilon,
model_type=model_type,
)
self.build()
self.solver = MaxSATSolver

def get_objective_value(self) -> float:
raise NotImplementedError

def get_solving_status(self) -> str:
raise NotImplementedError

def get_anytime_solutions(self) -> list[dict[str, float]] | None:
raise NotImplementedError

def explain(
self,
x: Array1D,
*,
y: NonNegativeInt,
norm: PositiveInt,
return_callback: bool = False,
verbose: bool = False,
max_time: int = 60,
num_workers: int | None = None,
random_seed: int = 42,
) -> Explanation | None:
raise NotImplementedError
97 changes: 97 additions & 0 deletions ocean/maxsat/_explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING

import numpy as np

from ..abc import Mapper
from ..typing import Array1D, BaseExplanation, Key, Number
from ._variables import FeatureVar

if TYPE_CHECKING:
import pandas as pd


class Explanation(Mapper[FeatureVar], BaseExplanation):
_epsilon: float = float(np.finfo(np.float32).eps)
_x: Array1D = np.zeros((0,), dtype=int)

def vget(self, i: int) -> int:
msg = "Not implemented."
raise NotImplementedError(msg)

def to_series(self) -> "pd.Series[float]":
msg = "Not implemented."
raise NotImplementedError(msg)

def to_numpy(self) -> Array1D:
return (
self.to_series()
.to_frame()
.T[self.columns]
.to_numpy()
.flatten()
.astype(np.float64)
)

@property
def x(self) -> Array1D:
return self.to_numpy()

@property
def value(self) -> Mapping[Key, Key | Number]:
msg = "Not implemented."
raise NotImplementedError(msg)

def format_value(
self,
f: int,
idx: int,
levels: list[float],
) -> float:
if self.query.shape[0] == 0:
return float(levels[idx] + levels[idx + 1]) / 2
j = 0
query_arr = np.asarray(self.query, dtype=float).ravel()
while query_arr[f] > levels[j + 1]:
j += 1
if j == idx:
value = float(query_arr[f])
elif j < idx:
value = float(levels[idx]) + self._epsilon
else:
value = float(levels[idx + 1]) - self._epsilon
return value

def format_discrete_value(
self,
f: int,
val: int,
thresholds: Array1D,
) -> float:
if self.query.shape[0] == 0:
return val
query_arr = np.asarray(self.query, dtype=float).ravel()
j_x = np.searchsorted(thresholds, query_arr[f], side="left")
j_val = np.searchsorted(thresholds, val, side="left")
if j_x != j_val:
return float(val)
return float(query_arr[f])

@property
def query(self) -> Array1D:
return self._x

@query.setter
def query(self, value: Array1D) -> None:
self._x = value

def __repr__(self) -> str:
mapping = self.value
prefix = f"{self.__class__.__name__}:\n"
root = self._repr(mapping)
suffix = ""

return prefix + root + suffix


__all__ = ["Explanation"]
9 changes: 9 additions & 0 deletions ocean/maxsat/_managers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._feature import FeatureManager
from ._garbage import GarbageManager
from ._tree import TreeManager

__all__ = [
"FeatureManager",
"GarbageManager",
"TreeManager",
]
Loading