Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions ocean/cp/_managers/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
from ortools.sat.python import cp_model as cp

from ...tree import Tree
from ...typing import (
NonNegativeArray1D,
NonNegativeInt,
PositiveInt,
)
from ...typing import Array1D, NonNegativeArray1D, NonNegativeInt, PositiveInt
from .._base import BaseModel
from .._variables import TreeVar


class TreeManager:
TREE_VAR_FMT: str = "tree[{t}]"
DEFAULT_SCORE_SCALE: int = int(1e10)
XGBOOST_DEFAULT_CLASS: int = 1
NUM_BINARY_CLASS: int = 2

# Tree variables in the ensemble.
_trees: tuple[TreeVar, *tuple[TreeVar, ...]]
Expand All @@ -29,6 +27,12 @@ class TreeManager:
# Scale for the scores.
_score_scale: int = DEFAULT_SCORE_SCALE

# Base score for the ensemble.
_logit: Array1D

# Flag to indicate if the model is using XGBoost trees.
_xgboost: bool = False

def __init__(
self,
trees: Iterable[Tree],
Expand Down Expand Up @@ -89,6 +93,9 @@ def _set_trees(
) -> None:
def create(item: tuple[int, Tree]) -> TreeVar:
t, tree = item
if tree.xgboost:
self._logit = tree.logit
self._xgboost = tree.xgboost
name = self.TREE_VAR_FMT.format(t=t)
return TreeVar(tree, name=name)

Expand Down Expand Up @@ -131,6 +138,14 @@ def weighted_function(
tree_exprs.append(tree_expr)
tree_weights.append(int(weight))
expr = cp.LinearExpr.WeightedSum(tree_exprs, tree_weights)
if self._xgboost:
if (
n_classes == self.NUM_BINARY_CLASS
and c == self.XGBOOST_DEFAULT_CLASS
):
expr += int(self._logit[0] * scale)
elif n_classes > self.NUM_BINARY_CLASS:
expr += int(self._logit[c] * scale)
exprs[op, c] = expr
return exprs

Expand Down
28 changes: 28 additions & 0 deletions ocean/mip/_managers/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...tree import Tree
from ...tree._utils import average_length
from ...typing import (
Array1D,
NonNegativeArray1D,
NonNegativeInt,
NonNegativeNumber,
Expand All @@ -17,6 +18,7 @@

class TreeManager:
TREE_VAR_FMT: str = "tree[{t}]"
NUM_BINARY_CLASS: int = 2

# Tree variables in the ensemble.
_trees: tuple[TreeVar, *tuple[TreeVar, ...]]
Expand All @@ -36,6 +38,12 @@ class TreeManager:
# Function of the ensemble.
_function: gp.MLinExpr

# Base score for the ensemble.
_logit: Array1D

# Flag to indicate if the model is using XGBoost trees.
_xgboost: bool = False

def __init__(
self,
trees: Iterable[Tree],
Expand Down Expand Up @@ -120,6 +128,21 @@ def weighted(tree: TreeVar, weight: float) -> gp.MLinExpr:
zeros = gp.MLinExpr.zeros(self.shape)
return sum(map(weighted, self.estimators, weights), zeros)

def xgb_margin_function(
self,
weights: NonNegativeArray1D,
) -> gp.MLinExpr:
margin_values = gp.MLinExpr.zeros(self.shape) + self._logit
if self.n_classes == self.NUM_BINARY_CLASS:
margin_values += (
weights[0] * self._logit[0] * np.array([[0.0, 1.0]])
)

for tree, weight in zip(self.estimators, weights, strict=True):
margin_values += weight * tree.value

return margin_values

def _set_trees(
self,
trees: Iterable[Tree],
Expand All @@ -128,6 +151,9 @@ def _set_trees(
) -> None:
def create(item: tuple[int, Tree]) -> TreeVar:
t, tree = item
if tree.xgboost:
self._logit = tree.logit
self._xgboost = tree.xgboost
name = self.TREE_VAR_FMT.format(t=t)
return TreeVar(tree, name=name, flow_type=flow_type)

Expand All @@ -152,4 +178,6 @@ def _get_length(self) -> gp.LinExpr:
return sum((tree.length for tree in self.isolators), gp.LinExpr())

def _get_function(self) -> gp.MLinExpr:
if self._xgboost:
return self.xgb_margin_function(weights=self.weights)
return self.weighted_function(weights=self.weights)
47 changes: 39 additions & 8 deletions ocean/tree/_parse_xgb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
import json
from collections.abc import Iterable
from typing import Any

import numpy as np
import xgboost as xgb

from ..abc import Mapper
from ..feature import Feature
from ..typing import NonNegativeInt, XGBTree
from ..typing import Array1D, NonNegativeInt, XGBTree
from ._node import Node
from ._tree import Tree


def _parse_base_score(cfg: dict[str, Any]) -> Array1D:
base_score_values = json.loads(
cfg["learner"]["learner_model_param"]["base_score"]
)
if isinstance(base_score_values, float):
return np.array([float(base_score_values)])
return np.array([float(s) for s in base_score_values])


def _logit(p: Array1D) -> Array1D:
p = np.clip(p, 1e-12, 1 - 1e-12)
return np.log(p / (1 - p))


def _get_column_value(
xgb_tree: XGBTree, node_id: NonNegativeInt, column: str
) -> str | float | int:
Expand All @@ -27,12 +43,11 @@ def _build_xgb_leaf(
weight = float(_get_column_value(xgb_tree, node_id, "Gain"))

if num_trees_per_round == 1:
value = np.array([[weight, -weight]])
value = np.array([[0.0, weight]])
else:
k = int(tree_id % num_trees_per_round)
value = np.zeros((1, int(num_trees_per_round)), dtype=float)
value[0, k] = weight

value[0, k] += weight
return Node(node_id, n_samples=0, value=value)


Expand Down Expand Up @@ -90,7 +105,7 @@ def _build_xgb_node(

threshold = None
if mapper[name].is_numeric:
threshold = float(_get_column_value(xgb_tree, node_id, "Split"))
threshold = float(_get_column_value(xgb_tree, node_id, "Split")) - 1e-8
mapper[name].add(threshold)

left_id = _get_child_id(xgb_tree, node_id, "Yes")
Expand Down Expand Up @@ -150,6 +165,7 @@ def _parse_xgb_tree(
tree_id: NonNegativeInt,
num_trees_per_round: NonNegativeInt,
mapper: Mapper[Feature],
base_score_margin: Array1D,
) -> Tree:
root = _parse_xgb_node(
xgb_tree,
Expand All @@ -158,7 +174,14 @@ def _parse_xgb_tree(
num_trees_per_round=num_trees_per_round,
mapper=mapper,
)
return Tree(root=root)
tree = Tree(root=root)
tree.logit = (
base_score_margin
if len(base_score_margin) > 1
else _logit(base_score_margin)
)
tree.xgboost = True
return tree


def parse_xgb_tree(
Expand All @@ -167,12 +190,14 @@ def parse_xgb_tree(
tree_id: NonNegativeInt,
num_trees_per_round: NonNegativeInt,
mapper: Mapper[Feature],
base_score_margin: Array1D,
) -> Tree:
return _parse_xgb_tree(
xgb_tree,
tree_id=tree_id,
num_trees_per_round=num_trees_per_round,
mapper=mapper,
base_score_margin=base_score_margin,
)


Expand All @@ -181,13 +206,15 @@ def parse_xgb_trees(
*,
num_trees_per_round: NonNegativeInt,
mapper: Mapper[Feature],
base_score_margin: Array1D,
) -> tuple[Tree, ...]:
return tuple(
parse_xgb_tree(
tree,
tree_id=tree_id,
num_trees_per_round=num_trees_per_round,
mapper=mapper,
base_score_margin=base_score_margin,
)
for tree_id, tree in enumerate(trees)
)
Expand All @@ -197,6 +224,7 @@ def parse_xgb_ensemble(
ensemble: xgb.Booster, *, mapper: Mapper[Feature]
) -> tuple[Tree, ...]:
df = ensemble.trees_to_dataframe()
cfg = json.loads(ensemble.save_config())
groups = df.groupby("Tree")
trees = tuple(
groups.get_group(tree_id).reset_index(drop=True)
Expand All @@ -205,7 +233,10 @@ def parse_xgb_ensemble(

num_rounds = ensemble.num_boosted_rounds() or 1
num_trees_per_round = max(1, len(trees) // num_rounds)

base_score_margin = _parse_base_score(cfg)
return parse_xgb_trees(
trees, num_trees_per_round=num_trees_per_round, mapper=mapper
trees,
num_trees_per_round=num_trees_per_round,
mapper=mapper,
base_score_margin=base_score_margin,
)
19 changes: 18 additions & 1 deletion ocean/tree/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pydantic import validate_call

from ..typing import NonNegativeInt, PositiveInt
from ..typing import Array1D, NonNegativeInt, PositiveInt
from ._node import Node


class Tree:
root: Node
_shape: tuple[NonNegativeInt, ...]
_xgboost: bool = False

def __init__(self, root: Node) -> None:
self.root = root
Expand All @@ -30,6 +31,22 @@ def leaves(self) -> tuple[Node, *tuple[Node, ...]]:
def shape(self) -> tuple[NonNegativeInt, ...]:
return self._shape

@property
def logit(self) -> Array1D:
return self._base_score_prob

@logit.setter
def logit(self, value: Array1D) -> None:
self._base_score_prob = value

@property
def xgboost(self) -> bool:
return self._xgboost

@xgboost.setter
def xgboost(self, value: bool) -> None:
self._xgboost = value

@validate_call
def nodes_at(self, depth: NonNegativeInt) -> Iterator[Node]:
return self._nodes_at(self.root, depth=depth)
Expand Down
Loading