Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions lib/mockingjay.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ defmodule Mockingjay do
this protocol is implemented by `EXGBoost` in its `EXGBoost.Compile` module. This protocol is used to extract the trees from the model
and to get the number of classes and features in the model.

## Adapters
Mockingjay also provides adapters for `EXGBoost` and `Catboost` models. These adapters are used to implement the `Mockingjay.DecisionTree`
protocol for these models. See the `Mockingjay.Adapters` module for more information.

## Strategies

Mockingjay supports three strategies for compiling decision trees: `:gemm`, `:tree_traversal`, and `:perfect_tree_traversal`,
Expand Down Expand Up @@ -76,11 +80,8 @@ defmodule Mockingjay do
defp aggregate(x, n_trees, n_classes) do
cond do
n_classes > 1 and n_trees > 1 ->
n_gbdt_classes = if n_classes > 2, do: n_classes, else: 1
n_trees_per_class = trunc(n_trees / n_gbdt_classes)

x
|> Nx.reshape({:auto, n_gbdt_classes, n_trees_per_class})
|> Nx.reshape({:auto, n_classes, n_trees})
|> Nx.sum(axes: [2])

n_classes > 1 and n_trees == 1 ->
Expand Down
15 changes: 15 additions & 0 deletions lib/mockingjay/adapters.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defmodule Mockingjay.Adapters do
@moduledoc """
The Adapter module provides adapters for `EXGBoost`,`Catboost` and 'LightGBM' models.
These adapters are used to implement the `Mockingjay.DecisionTree` protocol for these models.

The 'EXGBoost' adapter works with 'EXGBoost' 'Booster' structs, and thus can be used directly with 'EXGBoost' models.

The 'Catboost' and 'LightGBM' adapter work by creating mock modules for these libraries that implement the 'Mockingjay.DecisionTree' protocol.
These adapters can be used with models from these libraries by passing the model to the 'Mockingjay.convert' function.

Refer to each adapter module for more information on how to load models from each library. Please note that as these are mock modules,
they only serve to load the JSON model files and implement the 'Mockingjay.DecisionTree' protocol. They do not provide any other functionality
from the original libraries.
"""
end
113 changes: 113 additions & 0 deletions lib/mockingjay/adapters/catboost.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
defmodule Mockingjay.Adapters.Catboost do
@enforce_keys [:booster]
defstruct [:booster]
# This is a "mocked" function for Catboost while there is no Elixir Catboost library
# Here we simply make a mock module that can load a catboost json model file and implement the DecisionTree protocol
def load_model(model_path) do
unless File.exists?(model_path) and File.regular?(model_path) do
raise "Could not find model file at #{model_path}"
end

json = File.read!(model_path) |> Jason.decode!()
%__MODULE__{booster: json}
end

defp _to_tree(splits, leafs) do
case splits do
[] ->
unless length(leafs) == 1 do
raise "Bad model: leafs must have length 1"
end

%{value: hd(leafs)}

[split | rest] ->
# This should always be even since we checked before that
# its length is a power of 2
half = (length(leafs) / 2) |> round()
left_leaves = Enum.take(leafs, half)
right_leaves = Enum.drop(leafs, half)

%{
left: _to_tree(rest, left_leaves),
right: _to_tree(rest, right_leaves),
value: %{threshold: split["border"], feature: split["float_feature_index"]}
}
end
end

def to_tree(%{} = booster_json, n_classes) do
leaf_values = Map.get(booster_json, "leaf_values")
splits = Map.get(booster_json, "splits") |> Enum.reverse()

cond do
length(leaf_values) == 2 ** length(splits) * n_classes ->
# Classifier model
# Will need to argmax to get the class label from the leaf values

leaf_values = Enum.chunk_every(leaf_values, n_classes)
_to_tree(splits, leaf_values)

length(leaf_values) == 2 ** length(splits) ->
# Regression model
_to_tree(splits, leaf_values)

true ->
raise "Bad model: leaf_values must have length 2 ** length(splits) * n_classes (#{2 ** length(splits) * n_classes}): got #{length(leaf_values)}"
end
end

defimpl Mockingjay.DecisionTree do
def trees(booster) do
trees = Map.get(booster.booster, "oblivious_trees")
n_classes = num_classes(booster)

if is_nil(trees) do
raise "Could not find trees in model, found keys #{inspect(Map.keys(booster.booster))}"
end

trees
|> Enum.map(fn tree ->
Mockingjay.Adapters.Catboost.to_tree(tree, n_classes) |> Mockingjay.Tree.from_map()
end)
end

def num_classes(booster) do
model_info = Map.get(booster.booster, "model_info")
keys = Map.keys(model_info)

cond do
# Regression models don't have 'class_params' but classifier models still have 'params' key
"class_params" in keys ->
class_names = get_in(model_info, ["class_params", "class_names"])

class_to_label = get_in(model_info, ["class_params", "class_to_label"])

unless length(class_names) == length(class_to_label) do
raise "Bad model: class_names and class_to_label must have the same length, got #{length(class_names)} and #{length(class_to_label)}"
end

length(class_names)

"params" in keys ->
1

true ->
raise "Bad model: model must have either 'class_params' or 'params' key -- could not determine number of classes, got #{keys}"
end
end

def num_features(booster) do
float_features = get_in(booster.booster, ["features_info", "float_features"]) || []

categorical_features =
get_in(booster.booster, ["features_info", "categorical_features"]) || []

length(float_features ++ categorical_features)
end

def condition(booster) do
:greater
end
end
end
94 changes: 94 additions & 0 deletions lib/mockingjay/adapters/exgboost.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
defmodule Mockingjay.Adapters.Exgboost do
alias Exgboost.Booster

def to_tree(%{} = tree_map) do
nodes =
Enum.zip([
tree_map["left_children"],
tree_map["right_children"],
tree_map["split_conditions"],
tree_map["split_indices"]
])

case nodes do
[{_left, _right, _threshold, value}] ->
%{value: value}

[_root | _rest] ->
nodes = Enum.with_index(nodes)
[current | rest] = nodes
_to_tree(current, rest)

[] ->
%{}
end
end

def _to_tree(current, rest) do
{current, _} = current

case current do
{-1, -1, value, _feature_id} ->
%{
value: value
}

{left_id, right_id, threshold, feature_id} ->
%{true: [left_next], false: left_rest} =
Enum.group_by(rest, fn {_elem, index} -> index == left_id end)

%{true: [right_next], false: right_rest} =
Enum.group_by(rest, fn {_elem, index} -> index == right_id end)

%{
left: _to_tree(left_next, left_rest),
right: _to_tree(right_next, right_rest),
value: %{threshold: threshold, feature: feature_id}
}
end
end

defimpl Mockingjay.DecisionTree, for: EXGBoost.Booster do
def trees(booster) do
model = EXGBoost.dump_weights(booster) |> Jason.decode!()
trees = get_in(model, ["learner", "gradient_booster", "model", "trees"])

if is_nil(trees) do
raise "Could not find trees in model"
end

trees
|> Enum.map(fn tree ->
Mockingjay.Adapters.Exgboost.to_tree(tree) |> Mockingjay.Tree.from_map()
end)
end

def num_classes(booster) do
num_classes =
EXGBoost.dump_weights(booster)
|> Jason.decode!()
|> get_in(["learner", "learner_model_param", "num_class"])

if is_nil(num_classes) do
raise "Could not find num_classes in model"
end

String.to_integer(num_classes)
end

def num_features(booster) do
model = EXGBoost.dump_weights(booster) |> Jason.decode!()
num_features = get_in(model, ["learner", "learner_model_param", "num_feature"])

if is_nil(num_features) do
raise "Could not find num_features in model"
end

String.to_integer(num_features)
end

def condition(_booster) do
:less
end
end
end
28 changes: 28 additions & 0 deletions lib/mockingjay/adapters/lightgbm.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
defmodule Mockingjay.Adapters.Lightgbm do
@enforce_keys [:model]
defstruct [:model]
# This is a "mocked" function for Catboost while there is no Elixir Catboost library
# Here we simply make a mock module that can load a catboost json model file and implement the DecisionTree protocol
def load_model(model_path) do
unless File.exists?(model_path) and File.regular?(model_path) do
raise "Could not find model file at #{model_path}"
end

json = File.read!(model_path) |> Jason.decode!()
%__MODULE__{model: json}
end

defimpl Mockingjay.DecisionTree do
def trees(booster) do
end

def n_classes(booster) do
end

def num_features(booster) do
end

def condition(booster) do
end
end
end
21 changes: 13 additions & 8 deletions lib/mockingjay/strategies/gemm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ defmodule Mockingjay.Strategies.GEMM do
n_weak_learner_classes =
trees
|> hd()
|> Tree.get_decision_values()
|> Tree.get_leaf_nodes()
|> hd()
|> case do
value when is_list(value) ->
length(value)
value when is_list(value.value) ->
length(value.value)

_value ->
1
Expand Down Expand Up @@ -83,8 +83,7 @@ defmodule Mockingjay.Strategies.GEMM do
:n_classes,
:max_decision_nodes,
:max_leaf_nodes,
:n_weak_learner_classes,
:custom_forward
:n_weak_learner_classes
])

_forward(x, arg, opts)
Expand All @@ -99,6 +98,7 @@ defmodule Mockingjay.Strategies.GEMM do
max_decision_nodes = opts[:max_decision_nodes]
max_leaf_nodes = opts[:max_leaf_nodes]
n_weak_learner_classes = opts[:n_weak_learner_classes]
n_trees_per_class = div(n_trees, n_classes)

mat_A
|> Nx.dot([1], x, [1])
Expand All @@ -111,7 +111,7 @@ defmodule Mockingjay.Strategies.GEMM do
|> then(&Nx.dot(mat_E, [2], [0], &1, [1], [0]))
|> Nx.reshape({n_trees, n_weak_learner_classes, :auto})
|> Nx.transpose()
|> Nx.reshape({:auto, n_trees, n_classes})
|> Nx.reshape({:auto, n_trees_per_class, n_classes})
end

# Leaves are ordered as DFS rather than BFS that internal nodes are
Expand Down Expand Up @@ -205,7 +205,13 @@ defmodule Mockingjay.Strategies.GEMM do
if n_weak_learner_classes == 1 do
{[index, 0, node_index], node.value}
else
{[index, trunc(node.value), node_index], 1}
cat_value = node.value |> Nx.tensor() |> Nx.argmax() |> Nx.to_number()

{[
index,
cat_value,
node_index
], 1}
end
end)
end)
Expand All @@ -221,7 +227,6 @@ defmodule Mockingjay.Strategies.GEMM do
d = Nx.indexed_put(d_zero, d_indices, d_updates)

e_updates = Nx.tensor(updates_list)

e_zero = Nx.broadcast(0, {n_trees, n_weak_learner_classes, max_leaf_nodes})

e = Nx.indexed_put(e_zero, e_indices, e_updates)
Expand Down
Loading