From 8d3c552c9831113a6827ba0cbbbb33d57d832511 Mon Sep 17 00:00:00 2001 From: acalejos Date: Tue, 4 Jul 2023 23:22:22 -0400 Subject: [PATCH 1/5] working w/ caveat for PTT --- lib/mockingjay/strategies/gemm.ex | 3 +- .../strategies/perfect_tree_traversal.ex | 80 +++++++++---------- lib/mockingjay/strategies/tree_traversal.ex | 56 ++++++------- 3 files changed, 67 insertions(+), 72 deletions(-) diff --git a/lib/mockingjay/strategies/gemm.ex b/lib/mockingjay/strategies/gemm.ex index 74c2e03..278d153 100644 --- a/lib/mockingjay/strategies/gemm.ex +++ b/lib/mockingjay/strategies/gemm.ex @@ -83,8 +83,7 @@ defmodule Mockingjay.Strategies.GEMM do :n_classes, :max_decision_nodes, :max_leaf_nodes, - :n_weak_learner_classes, - :custom_forward + :n_weak_learner_classes ]) _forward(x, arg, opts) diff --git a/lib/mockingjay/strategies/perfect_tree_traversal.ex b/lib/mockingjay/strategies/perfect_tree_traversal.ex index ccbbbb7..4108ae1 100644 --- a/lib/mockingjay/strategies/perfect_tree_traversal.ex +++ b/lib/mockingjay/strategies/perfect_tree_traversal.ex @@ -117,59 +117,48 @@ defmodule Mockingjay.Strategies.PerfectTreeTraversal do indices = 0..(nt - 1)//2 |> Enum.into([]) |> Nx.tensor(type: :s64) - [ - indices: indices, - num_trees: num_trees, - max_tree_depth: max_tree_depth, - features: features, - thresholds: thresholds, + arg = %{ root_features: root_features, root_thresholds: root_thresholds, + features: features, + thresholds: thresholds, values: values, + indices: indices + } + + opts = [ + num_trees: num_trees, + max_tree_depth: max_tree_depth, condition: Mockingjay.Strategy.cond_to_fun(condition), n_classes: n_classes ] + + {arg, opts} end @impl true - def forward(x, opts \\ []) do + deftransform forward(x, {arg, opts}) do opts = Keyword.validate!(opts, [ - :custom_forward, - :root_features, - :root_thresholds, :condition, - :indices, :num_trees, :n_classes, - :thresholds, - :features, - :max_tree_depth, - :values + :max_tree_depth ]) - _forward( - x, - opts[:root_features], - opts[:root_thresholds], - opts[:features], - opts[:thresholds], - opts[:values], - opts[:indices], - Keyword.take(opts, [:num_trees, :condition, :n_classes]) - ) + _forward(x, arg, opts) end - defnp _forward( - x, - root_features, - root_thresholds, - features, - thresholds, - values, - indices, - opts \\ [] - ) do + defnp _forward(x, arg, opts) do + %{ + root_features: root_features, + root_thresholds: root_thresholds, + features: features, + thresholds: thresholds, + values: values, + indices: indices + } = arg + prev_indices = x |> Nx.take(root_features, axis: 1) @@ -178,7 +167,7 @@ defmodule Mockingjay.Strategies.PerfectTreeTraversal do |> Nx.reshape({:auto}) |> forward_reduce_features(x, features, thresholds, opts) - Nx.take(values, prev_indices) + Nx.take(values |> print_value(), prev_indices) |> Nx.reshape({:auto, opts[:num_trees], opts[:n_classes]}) end @@ -188,16 +177,23 @@ defmodule Mockingjay.Strategies.PerfectTreeTraversal do Tuple.to_list(thresholds), prev_indices, fn nodes, biases, acc -> - gather_indices = nodes |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]}) - features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto}) - - acc - |> Nx.multiply(@factor) - |> Nx.add(opts[:condition].(features, Nx.take(biases, acc))) + _inner_reduce(x, nodes, biases, acc, opts) end ) end + defnp _inner_reduce(x, nodes, biases, acc, opts \\ []) do + gather_indices = + nodes |> print_value() |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]}) + + features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto}) + + acc + |> print_value() + |> Nx.multiply(@factor) + |> Nx.add(opts[:condition].(features, Nx.take(biases |> print_value(), acc))) + end + defp make_tree_perfect(tree, current_depth, max_depth) do case tree do %Tree{left: nil, right: nil} -> diff --git a/lib/mockingjay/strategies/tree_traversal.ex b/lib/mockingjay/strategies/tree_traversal.ex index 8dd5c9c..6815f1c 100644 --- a/lib/mockingjay/strategies/tree_traversal.ex +++ b/lib/mockingjay/strategies/tree_traversal.ex @@ -115,51 +115,49 @@ defmodule Mockingjay.Strategies.TreeTraversal do Nx.iota({1, num_trees}, type: :s64) |> Nx.multiply(num_nodes) - [ - nodes_offset: nodes_offset, - num_trees: num_trees, - max_tree_depth: max_tree_depth, + arg = %{ + features: features, lefts: lefts, rights: rights, - features: features, thresholds: thresholds, - values: values, + nodes_offset: nodes_offset, + values: values + } + + opts = [ + num_trees: num_trees, + max_tree_depth: max_tree_depth, condition: Mockingjay.Strategy.cond_to_fun(condition), n_classes: n_classes ] + + {arg, opts} end @impl true - deftransform forward(x, opts \\ []) do + deftransform forward(x, {arg, opts}) do opts = Keyword.validate!(opts, [ - :custom_forward, :max_tree_depth, :num_trees, :n_classes, - :nodes_offset, - :lefts, - :rights, - :features, - :thresholds, - :values, :condition, unroll: false ]) - _forward( - x, - opts[:features], - opts[:lefts], - opts[:rights], - opts[:thresholds], - opts[:nodes_offset], - opts[:values], - opts - ) + _forward(x, arg, opts) end - defn _forward(x, features, lefts, rights, thresholds, nodes_offset, values, opts \\ []) do + defnp _forward(x, arg, opts \\ []) do + %{ + features: features, + lefts: lefts, + rights: rights, + thresholds: thresholds, + nodes_offset: nodes_offset, + values: values + } = arg + max_tree_depth = opts[:max_tree_depth] num_trees = opts[:num_trees] n_classes = opts[:n_classes] @@ -173,8 +171,10 @@ defmodule Mockingjay.Strategies.TreeTraversal do |> Nx.broadcast({batch_size, num_trees}) |> Nx.reshape({:auto}) - {indices, _} = - while {tree_nodes = indices, {features, lefts, rights, thresholds, nodes_offset, x}}, + # Values isn't used in the loop but this transfers it to the correct backend + {indices, {_, _, _, _, _, _, values}} = + while {tree_nodes = indices, + {features, lefts, rights, thresholds, nodes_offset, x, values}}, _ <- 1..max_tree_depth, unroll: unroll do feature_nodes = Nx.take(features, tree_nodes) |> Nx.reshape({:auto, num_trees}) @@ -192,7 +192,7 @@ defmodule Mockingjay.Strategies.TreeTraversal do |> Nx.add(nodes_offset) |> Nx.reshape({:auto}) - {result, {features, lefts, rights, thresholds, nodes_offset, x}} + {result, {features, lefts, rights, thresholds, nodes_offset, x, values}} end values From 51bdaacc1efc6ba4a9c8d20c2688ce5ff72a24b5 Mon Sep 17 00:00:00 2001 From: acalejos Date: Fri, 14 Jul 2023 23:26:43 -0400 Subject: [PATCH 2/5] Implement protocol for Catboost --- lib/mockingjay.ex | 4 + lib/mockingjay/adapters.ex | 15 + lib/mockingjay/adapters/catboost.ex | 113 ++ lib/mockingjay/adapters/exgboost.ex | 94 + lib/mockingjay/adapters/lightgbm.ex | 28 + lib/mockingjay/tree.ex | 60 +- mix.exs | 8 +- mix.lock | 9 + test/mockingjay/catboost_test.exs | 24 + test/mockingjay/exgboost_test.exs | 92 + test/support/catboost_classifier.json | 2279 +++++++++++++++++++++++++ test/support/catboost_regressor.json | 2076 ++++++++++++++++++++++ 12 files changed, 4787 insertions(+), 15 deletions(-) create mode 100644 lib/mockingjay/adapters.ex create mode 100644 lib/mockingjay/adapters/catboost.ex create mode 100644 lib/mockingjay/adapters/exgboost.ex create mode 100644 lib/mockingjay/adapters/lightgbm.ex create mode 100644 test/mockingjay/catboost_test.exs create mode 100644 test/mockingjay/exgboost_test.exs create mode 100644 test/support/catboost_classifier.json create mode 100644 test/support/catboost_regressor.json diff --git a/lib/mockingjay.ex b/lib/mockingjay.ex index 2faabbb..ec91cce 100644 --- a/lib/mockingjay.ex +++ b/lib/mockingjay.ex @@ -12,6 +12,10 @@ defmodule Mockingjay do this protocol is implemented by `EXGBoost` in its `EXGBoost.Compile` module. This protocol is used to extract the trees from the model and to get the number of classes and features in the model. + ## Adapters + Mockingjay also provides adapters for `EXGBoost` and `Catboost` models. These adapters are used to implement the `Mockingjay.DecisionTree` + protocol for these models. See the `Mockingjay.Adapters` module for more information. + ## Strategies Mockingjay supports three strategies for compiling decision trees: `:gemm`, `:tree_traversal`, and `:perfect_tree_traversal`, diff --git a/lib/mockingjay/adapters.ex b/lib/mockingjay/adapters.ex new file mode 100644 index 0000000..e45f21f --- /dev/null +++ b/lib/mockingjay/adapters.ex @@ -0,0 +1,15 @@ +defmodule Mockingjay.Adapters do + @moduledoc """ + The Adapter module provides adapters for `EXGBoost`,`Catboost` and 'LightGBM' models. + These adapters are used to implement the `Mockingjay.DecisionTree` protocol for these models. + + The 'EXGBoost' adapter works with 'EXGBoost' 'Booster' structs, and thus can be used directly with 'EXGBoost' models. + + The 'Catboost' and 'LightGBM' adapter work by creating mock modules for these libraries that implement the 'Mockingjay.DecisionTree' protocol. + These adapters can be used with models from these libraries by passing the model to the 'Mockingjay.convert' function. + + Refer to each adapter module for more information on how to load models from each library. Please note that as these are mock modules, + they only serve to load the JSON model files and implement the 'Mockingjay.DecisionTree' protocol. They do not provide any other functionality + from the original libraries. + """ +end diff --git a/lib/mockingjay/adapters/catboost.ex b/lib/mockingjay/adapters/catboost.ex new file mode 100644 index 0000000..322f1ab --- /dev/null +++ b/lib/mockingjay/adapters/catboost.ex @@ -0,0 +1,113 @@ +defmodule Mockingjay.Adapters.Catboost do + @enforce_keys [:booster] + defstruct [:booster] + # This is a "mocked" function for Catboost while there is no Elixir Catboost library + # Here we simply make a mock module that can load a catboost json model file and implement the DecisionTree protocol + def load_model(model_path) do + unless File.exists?(model_path) and File.regular?(model_path) do + raise "Could not find model file at #{model_path}" + end + + json = File.read!(model_path) |> Jason.decode!() + %__MODULE__{booster: json} + end + + defp _to_tree(splits, leafs) do + case splits do + [] -> + unless length(leafs) == 1 do + raise "Bad model: leafs must have length 1" + end + + %{value: hd(leafs)} + + [split | rest] -> + # This should always be even since we checked before that + # its length is a power of 2 + half = (length(leafs) / 2) |> round() + left_leaves = Enum.take(leafs, half) + right_leaves = Enum.drop(leafs, half) + + %{ + left: _to_tree(rest, left_leaves), + right: _to_tree(rest, right_leaves), + value: %{threshold: split["border"], feature: split["float_feature_index"]} + } + end + end + + def to_tree(%{} = booster_json, n_classes) do + leaf_values = Map.get(booster_json, "leaf_values") + splits = Map.get(booster_json, "splits") |> Enum.reverse() + + cond do + length(leaf_values) == 2 ** length(splits) * n_classes -> + # Classifier model + # Will need to argmax to get the class label from the leaf values + + leaf_values = Enum.chunk_every(leaf_values, n_classes) + _to_tree(splits, leaf_values) + + length(leaf_values) == 2 ** length(splits) -> + # Regression model + _to_tree(splits, leaf_values) + + true -> + raise "Bad model: leaf_values must have length 2 ** length(splits) * n_classes (#{2 ** length(splits) * n_classes}): got #{length(leaf_values)}" + end + end + + defimpl Mockingjay.DecisionTree do + def trees(booster) do + trees = Map.get(booster.booster, "oblivious_trees") + n_classes = num_classes(booster) + + if is_nil(trees) do + raise "Could not find trees in model, found keys #{inspect(Map.keys(booster.booster))}" + end + + trees + |> Enum.map(fn tree -> + Mockingjay.Adapters.Catboost.to_tree(tree, n_classes) |> Mockingjay.Tree.from_map() + end) + end + + def num_classes(booster) do + model_info = Map.get(booster.booster, "model_info") + keys = Map.keys(model_info) + + cond do + # Regression models don't have 'class_params' but classifier models still have 'params' key + "class_params" in keys -> + class_names = get_in(model_info, ["class_params", "class_names"]) + + class_to_label = get_in(model_info, ["class_params", "class_to_label"]) + + unless length(class_names) == length(class_to_label) do + raise "Bad model: class_names and class_to_label must have the same length, got #{length(class_names)} and #{length(class_to_label)}" + end + + length(class_names) + + "params" in keys -> + 1 + + true -> + raise "Bad model: model must have either 'class_params' or 'params' key -- could not determine number of classes, got #{keys}" + end + end + + def num_features(booster) do + float_features = get_in(booster.booster, ["features_info", "float_features"]) || [] + + categorical_features = + get_in(booster.booster, ["features_info", "categorical_features"]) || [] + + length(float_features ++ categorical_features) + end + + def condition(booster) do + :greater + end + end +end diff --git a/lib/mockingjay/adapters/exgboost.ex b/lib/mockingjay/adapters/exgboost.ex new file mode 100644 index 0000000..886b418 --- /dev/null +++ b/lib/mockingjay/adapters/exgboost.ex @@ -0,0 +1,94 @@ +defmodule Mockingjay.Adapters.Exgboost do + alias Exgboost.Booster + + def to_tree(%{} = tree_map) do + nodes = + Enum.zip([ + tree_map["left_children"], + tree_map["right_children"], + tree_map["split_conditions"], + tree_map["split_indices"] + ]) + + case nodes do + [{_left, _right, _threshold, value}] -> + %{value: value} + + [_root | _rest] -> + nodes = Enum.with_index(nodes) + [current | rest] = nodes + _to_tree(current, rest) + + [] -> + %{} + end + end + + def _to_tree(current, rest) do + {current, _} = current + + case current do + {-1, -1, value, _feature_id} -> + %{ + value: value + } + + {left_id, right_id, threshold, feature_id} -> + %{true: [left_next], false: left_rest} = + Enum.group_by(rest, fn {_elem, index} -> index == left_id end) + + %{true: [right_next], false: right_rest} = + Enum.group_by(rest, fn {_elem, index} -> index == right_id end) + + %{ + left: _to_tree(left_next, left_rest), + right: _to_tree(right_next, right_rest), + value: %{threshold: threshold, feature: feature_id} + } + end + end + + defimpl Mockingjay.DecisionTree, for: EXGBoost.Booster do + def trees(booster) do + model = EXGBoost.dump_weights(booster) |> Jason.decode!() + trees = get_in(model, ["learner", "gradient_booster", "model", "trees"]) + + if is_nil(trees) do + raise "Could not find trees in model" + end + + trees + |> Enum.map(fn tree -> + Mockingjay.Adapters.Exgboost.to_tree(tree) |> Mockingjay.Tree.from_map() + end) + end + + def num_classes(booster) do + num_classes = + EXGBoost.dump_weights(booster) + |> Jason.decode!() + |> get_in(["learner", "learner_model_param", "num_class"]) + + if is_nil(num_classes) do + raise "Could not find num_classes in model" + end + + String.to_integer(num_classes) + end + + def num_features(booster) do + model = EXGBoost.dump_weights(booster) |> Jason.decode!() + num_features = get_in(model, ["learner", "learner_model_param", "num_feature"]) + + if is_nil(num_features) do + raise "Could not find num_features in model" + end + + String.to_integer(num_features) + end + + def condition(_booster) do + :less + end + end +end diff --git a/lib/mockingjay/adapters/lightgbm.ex b/lib/mockingjay/adapters/lightgbm.ex new file mode 100644 index 0000000..423ecc5 --- /dev/null +++ b/lib/mockingjay/adapters/lightgbm.ex @@ -0,0 +1,28 @@ +defmodule Mockingjay.Adapters.Lightgbm do + @enforce_keys [:model] + defstruct [:model] + # This is a "mocked" function for Catboost while there is no Elixir Catboost library + # Here we simply make a mock module that can load a catboost json model file and implement the DecisionTree protocol + def load_model(model_path) do + unless File.exists?(model_path) and File.regular?(model_path) do + raise "Could not find model file at #{model_path}" + end + + json = File.read!(model_path) |> Jason.decode!() + %__MODULE__{model: json} + end + + defimpl Mockingjay.DecisionTree do + def trees(booster) do + end + + def n_classes(booster) do + end + + def num_features(booster) do + end + + def condition(booster) do + end + end +end diff --git a/lib/mockingjay/tree.ex b/lib/mockingjay/tree.ex index e91a0ab..bdc7aec 100644 --- a/lib/mockingjay/tree.ex +++ b/lib/mockingjay/tree.ex @@ -62,16 +62,48 @@ defmodule Mockingjay.Tree do def from_map(%__MODULE__{} = t), do: t def from_map(%{} = map) do - case map do - %{left: nil, right: nil, value: value} when is_number(value) -> - %__MODULE__{ - id: make_ref(), - left: nil, - right: nil, - value: value - } + %{ + left: %{ + left: %{ + value: [ + 1.0996503496504055, + 0.0900349650349933, + -0.2421328671328476, + -0.4737762237762082, + -0.47377622377620826 + ] + }, + right: %{ + value: [ + 0.18779904306214146, + 0.4449760765549639, + 0.26555023923438986, + -0.44019138755987713, + -0.45813397129193495 + ] + }, + value: %{feature: 78, threshold: 4.637722969055176} + }, + right: %{ + left: %{ + value: [1.1875, -0.3125000000000001, -0.2500000000000001, -0.3125, -0.3125] + }, + right: %{ + value: [ + -0.08333333333333333, + 0.4999999999999997, + 0.0833333333333333, + -0.24999999999999994, + -0.25 + ] + }, + value: %{feature: 78, threshold: 4.637722969055176} + }, + value: %{feature: 122, threshold: -1.053920030593872} + } - %{value: value} when is_number(value) -> + case map do + %{left: nil, right: nil, value: value} -> %__MODULE__{ id: make_ref(), left: nil, @@ -79,9 +111,6 @@ defmodule Mockingjay.Tree do value: value } - %{left: nil, right: nil, value: value} -> - raise ArgumentError, "Leaf nodes must have a numeric value. Got: #{inspect(value)}" - %{left: left, right: right, value: %{threshold: threshold, feature: feature}} when is_number(threshold) and is_number(feature) -> %__MODULE__{ @@ -96,7 +125,12 @@ defmodule Mockingjay.Tree do "Non-leaf nodes must have a numeric threshold and feature. Got: #{inspect(map)}" %{value: value} -> - raise ArgumentError, "Leaf nodes must have a numeric value. Got: #{inspect(value)}" + %__MODULE__{ + id: make_ref(), + left: nil, + right: nil, + value: value + } _ -> raise ArgumentError, "Invalid tree map: #{inspect(map)}" diff --git a/mix.exs b/mix.exs index 36f1bb6..c50f188 100644 --- a/mix.exs +++ b/mix.exs @@ -32,9 +32,13 @@ defmodule Mockingjay.MixProject do defp deps do [ - {:nx, "~> 0.5"}, + {:nx, "~> 0.5", override: true}, {:axon, "~> 0.5"}, - {:ex_doc, "~> 0.29", only: :docs} + {:ex_doc, "~> 0.29", only: :docs}, + # {:exgboost, github: "acalejos/exgboost", optional: true}, + {:exgboost, path: "/Users/andres/Documents/exgboost", optional: true}, + {:scidata, "~> 0.1.9", only: :test}, + {:scholar, "~> 0.1", only: :test} ] end diff --git a/mix.lock b/mix.lock index 4101754..449ed08 100644 --- a/mix.lock +++ b/mix.lock @@ -1,12 +1,21 @@ %{ "axon": {:hex, :axon, "0.5.1", "1ae3a2193df45e51fca912158320b2ca87cb7fba4df242bd3ebe245504d0ea1a", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "d36f2a11c34c6c2b458f54df5c71ffdb7ed91c6a9ccd908faba909c84cc6a38e"}, + "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, + "cc_precompiler": {:hex, :cc_precompiler, "0.1.7", "77de20ac77f0e53f20ca82c563520af0237c301a1ec3ab3bc598e8a96c7ee5d9", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "2768b28bf3c2b4f788c995576b39b8cb5d47eb788526d93bd52206c1d8bf4b75"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, + "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, + "exgboost": {:git, "https://github.com/acalejos/exgboost.git", "9b8458bca8dc45c82afe4d74f521a42e9c1bfb53", []}, + "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, + "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, + "nimble_options": {:hex, :nimble_options, "1.0.2", "92098a74df0072ff37d0c12ace58574d26880e522c22801437151a159392270e", [:mix], [], "hexpm", "fd12a8db2021036ce12a309f26f564ec367373265b53e25403f0ee697380f1b8"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, + "scholar": {:hex, :scholar, "0.1.0", "ec6ef480663765193c3ef1b70d4178dbbab21b10aee01a916df3c6f074465375", [:mix], [{:nimble_options, "~> 0.5.2 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "1d4f7d85af3868619252efe8190b837eaa2569f9ca4d109b83e428512cb653cf"}, + "scidata": {:hex, :scidata, "0.1.10", "94ca1bb71e37d3d74df12288bf1bf677e246d5c49ac65b1616a701061d6ce74f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "c7db0cc00a7cdbaa70784c2e4ef84a1cd4572303458ac45a427796bddc12af02"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, } diff --git a/test/mockingjay/catboost_test.exs b/test/mockingjay/catboost_test.exs new file mode 100644 index 0000000..3387121 --- /dev/null +++ b/test/mockingjay/catboost_test.exs @@ -0,0 +1,24 @@ +defmodule CatboostTest do + use ExUnit.Case, async: true + alias Mockingjay.DecisionTree + alias Mockingjay.Adapters.Catboost + + test "load json" do + clf_booster = Catboost.load_model("test/support/catboost_classifier.json") + reg_booster = Catboost.load_model("test/support/catboost_regressor.json") + end + + test "protocol implementation" do + clf_booster = Catboost.load_model("test/support/catboost_classifier.json") + reg_booster = Catboost.load_model("test/support/catboost_regressor.json") + + for {booster, expected_num_class} <- [{clf_booster, 5}, {reg_booster, 1}] do + trees = DecisionTree.trees(booster) + + assert is_list(trees) + assert is_struct(hd(trees) |> IO.inspect(label: "Tree"), Mockingjay.Tree) + assert DecisionTree.num_classes(booster) == expected_num_class + assert DecisionTree.num_features(booster) == 137 + end + end +end diff --git a/test/mockingjay/exgboost_test.exs b/test/mockingjay/exgboost_test.exs new file mode 100644 index 0000000..83587e3 --- /dev/null +++ b/test/mockingjay/exgboost_test.exs @@ -0,0 +1,92 @@ +defmodule EXGBoostTest do + use ExUnit.Case, async: true + alias Mockingjay.DecisionTree + + setup do + {x, y} = Scidata.Iris.download() + data = Enum.zip(x, y) |> Enum.shuffle() + {train, test} = Enum.split(data, ceil(length(data) * 0.8)) + {x_train, y_train} = Enum.unzip(train) + {x_test, y_test} = Enum.unzip(test) + + x_train = Nx.tensor(x_train) + y_train = Nx.tensor(y_train) + + x_test = Nx.tensor(x_test) + y_test = Nx.tensor(y_test) + + %{ + x_train: x_train, + y_train: y_train, + x_test: x_test, + y_test: y_test + } + end + + test "protocol implementation", context do + booster = + EXGBoost.train(context.x_train, context.y_train, num_class: 3, objective: :multi_softprob) + + trees = DecisionTree.trees(booster) + + trees_params = + EXGBoost.dump_weights(booster) + |> Jason.decode!() + |> get_in(["learner", "gradient_booster", "model", "trees"]) + + Enum.each(Enum.zip(trees, trees_params), fn {tree, tree_param} -> + assert length(Mockingjay.Tree.bfs(tree)) == + String.to_integer(get_in(tree_param, ["tree_param", "num_nodes"])) + end) + + assert is_list(trees) + assert is_struct(hd(trees), Mockingjay.Tree) + assert DecisionTree.num_classes(booster) == 3 + assert DecisionTree.num_features(booster) == 4 + end + + test "compiles", context do + booster = + EXGBoost.train(context.x_train, context.y_train, num_class: 3, objective: :multi_softprob) + + gemm_predict = EXGBoost.compile(booster, strategy: :gemm) + tt_predict = EXGBoost.compile(booster, strategy: :tree_traversal) + ptt_predict = EXGBoost.compile(booster, strategy: :perfect_tree_traversal) + auto_predict = EXGBoost.compile(booster, strategy: :auto) + # host_jit = EXLA.jit(compiled_predict) + + preds1 = + EXGBoost.predict(booster, context.x_test) + |> Nx.argmax(axis: -1) + + preds2 = gemm_predict.(context.x_test) |> Nx.argmax(axis: -1) + preds3 = tt_predict.(context.x_test) |> Nx.argmax(axis: -1) + preds4 = ptt_predict.(context.x_test) |> Nx.argmax(axis: -1) + preds5 = auto_predict.(context.x_test) |> Nx.argmax(axis: -1) + + base_acc = + Scholar.Metrics.accuracy(context.y_test, preds1) + |> Nx.to_number() + + gmm_accuracy = + Scholar.Metrics.accuracy(context.y_test, preds2) + |> Nx.to_number() + + tt_accuracy = + Scholar.Metrics.accuracy(context.y_test, preds3) + |> Nx.to_number() + + ptt_accuracy = + Scholar.Metrics.accuracy(context.y_test, preds4) + |> Nx.to_number() + + auto_accuracy = + Scholar.Metrics.accuracy(context.y_test, preds5) + |> Nx.to_number() + + assert gmm_accuracy >= base_acc + assert tt_accuracy >= base_acc + assert ptt_accuracy >= base_acc + assert auto_accuracy >= base_acc + end +end diff --git a/test/support/catboost_classifier.json b/test/support/catboost_classifier.json new file mode 100644 index 0000000..ed60bf1 --- /dev/null +++ b/test/support/catboost_classifier.json @@ -0,0 +1,2279 @@ +{ + "features_info": + { + "float_features": + [ + { + "borders": + [ + ], + "feature_id":"1", + "feature_index":0, + "flat_feature_index":0, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"2", + "feature_index":1, + "flat_feature_index":1, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"3", + "feature_index":2, + "flat_feature_index":2, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"4", + "feature_index":3, + "flat_feature_index":3, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"5", + "feature_index":4, + "flat_feature_index":4, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"6", + "feature_index":5, + "flat_feature_index":5, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"7", + "feature_index":6, + "flat_feature_index":6, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"8", + "feature_index":7, + "flat_feature_index":7, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"9", + "feature_index":8, + "flat_feature_index":8, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"10", + "feature_index":9, + "flat_feature_index":9, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"11", + "feature_index":10, + "flat_feature_index":10, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 124 + ], + "feature_id":"12", + "feature_index":11, + "flat_feature_index":11, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"13", + "feature_index":12, + "flat_feature_index":12, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"14", + "feature_index":13, + "flat_feature_index":13, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"15", + "feature_index":14, + "flat_feature_index":14, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"16", + "feature_index":15, + "flat_feature_index":15, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 10.407758712768555 + ], + "feature_id":"17", + "feature_index":16, + "flat_feature_index":16, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"18", + "feature_index":17, + "flat_feature_index":17, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"19", + "feature_index":18, + "flat_feature_index":18, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"20", + "feature_index":19, + "flat_feature_index":19, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"21", + "feature_index":20, + "flat_feature_index":20, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"22", + "feature_index":21, + "flat_feature_index":21, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"23", + "feature_index":22, + "flat_feature_index":22, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"24", + "feature_index":23, + "flat_feature_index":23, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"25", + "feature_index":24, + "flat_feature_index":24, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 41.5 + ], + "feature_id":"26", + "feature_index":25, + "flat_feature_index":25, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 2.5 + ], + "feature_id":"27", + "feature_index":26, + "flat_feature_index":26, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"28", + "feature_index":27, + "flat_feature_index":27, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 2.5 + ], + "feature_id":"29", + "feature_index":28, + "flat_feature_index":28, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"30", + "feature_index":29, + "flat_feature_index":29, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"31", + "feature_index":30, + "flat_feature_index":30, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"32", + "feature_index":31, + "flat_feature_index":31, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"33", + "feature_index":32, + "flat_feature_index":32, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"34", + "feature_index":33, + "flat_feature_index":33, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 0.5 + ], + "feature_id":"35", + "feature_index":34, + "flat_feature_index":34, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"36", + "feature_index":35, + "flat_feature_index":35, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"37", + "feature_index":36, + "flat_feature_index":36, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"38", + "feature_index":37, + "flat_feature_index":37, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"39", + "feature_index":38, + "flat_feature_index":38, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"40", + "feature_index":39, + "flat_feature_index":39, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"41", + "feature_index":40, + "flat_feature_index":40, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"42", + "feature_index":41, + "flat_feature_index":41, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"43", + "feature_index":42, + "flat_feature_index":42, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"44", + "feature_index":43, + "flat_feature_index":43, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"45", + "feature_index":44, + "flat_feature_index":44, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"46", + "feature_index":45, + "flat_feature_index":45, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"47", + "feature_index":46, + "flat_feature_index":46, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"48", + "feature_index":47, + "flat_feature_index":47, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"49", + "feature_index":48, + "flat_feature_index":48, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 0.10263150185346603 + ], + "feature_id":"50", + "feature_index":49, + "flat_feature_index":49, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"51", + "feature_index":50, + "flat_feature_index":50, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"52", + "feature_index":51, + "flat_feature_index":51, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"53", + "feature_index":52, + "flat_feature_index":52, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"54", + "feature_index":53, + "flat_feature_index":53, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"55", + "feature_index":54, + "flat_feature_index":54, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"56", + "feature_index":55, + "flat_feature_index":55, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"57", + "feature_index":56, + "flat_feature_index":56, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"58", + "feature_index":57, + "flat_feature_index":57, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"59", + "feature_index":58, + "flat_feature_index":58, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"60", + "feature_index":59, + "flat_feature_index":59, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"61", + "feature_index":60, + "flat_feature_index":60, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"62", + "feature_index":61, + "flat_feature_index":61, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"63", + "feature_index":62, + "flat_feature_index":62, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"64", + "feature_index":63, + "flat_feature_index":63, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"65", + "feature_index":64, + "flat_feature_index":64, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"66", + "feature_index":65, + "flat_feature_index":65, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"67", + "feature_index":66, + "flat_feature_index":66, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"68", + "feature_index":67, + "flat_feature_index":67, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 0.0079079996794462204 + ], + "feature_id":"69", + "feature_index":68, + "flat_feature_index":68, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"70", + "feature_index":69, + "flat_feature_index":69, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"71", + "feature_index":70, + "flat_feature_index":70, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"72", + "feature_index":71, + "flat_feature_index":71, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 24.802555084228516 + ], + "feature_id":"73", + "feature_index":72, + "flat_feature_index":72, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"74", + "feature_index":73, + "flat_feature_index":73, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"75", + "feature_index":74, + "flat_feature_index":74, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"76", + "feature_index":75, + "flat_feature_index":75, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"77", + "feature_index":76, + "flat_feature_index":76, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"78", + "feature_index":77, + "flat_feature_index":77, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 4.6377229690551758 + ], + "feature_id":"79", + "feature_index":78, + "flat_feature_index":78, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"80", + "feature_index":79, + "flat_feature_index":79, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"81", + "feature_index":80, + "flat_feature_index":80, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 223.0445556640625 + ], + "feature_id":"82", + "feature_index":81, + "flat_feature_index":81, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"83", + "feature_index":82, + "flat_feature_index":82, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"84", + "feature_index":83, + "flat_feature_index":83, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"85", + "feature_index":84, + "flat_feature_index":84, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"86", + "feature_index":85, + "flat_feature_index":85, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"87", + "feature_index":86, + "flat_feature_index":86, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"88", + "feature_index":87, + "flat_feature_index":87, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 7.703493595123291 + ], + "feature_id":"89", + "feature_index":88, + "flat_feature_index":88, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"90", + "feature_index":89, + "flat_feature_index":89, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"91", + "feature_index":90, + "flat_feature_index":90, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"92", + "feature_index":91, + "flat_feature_index":91, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"93", + "feature_index":92, + "flat_feature_index":92, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 8.8898735046386719 + ], + "feature_id":"94", + "feature_index":93, + "flat_feature_index":93, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"95", + "feature_index":94, + "flat_feature_index":94, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"96", + "feature_index":95, + "flat_feature_index":95, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"97", + "feature_index":96, + "flat_feature_index":96, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"98", + "feature_index":97, + "flat_feature_index":97, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"99", + "feature_index":98, + "flat_feature_index":98, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"100", + "feature_index":99, + "flat_feature_index":99, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"101", + "feature_index":100, + "flat_feature_index":100, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"102", + "feature_index":101, + "flat_feature_index":101, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"103", + "feature_index":102, + "flat_feature_index":102, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"104", + "feature_index":103, + "flat_feature_index":103, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"105", + "feature_index":104, + "flat_feature_index":104, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"106", + "feature_index":105, + "flat_feature_index":105, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"107", + "feature_index":106, + "flat_feature_index":106, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 20.003637313842773 + ], + "feature_id":"108", + "feature_index":107, + "flat_feature_index":107, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"109", + "feature_index":108, + "flat_feature_index":108, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"110", + "feature_index":109, + "flat_feature_index":109, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 18.918785095214844 + ], + "feature_id":"111", + "feature_index":110, + "flat_feature_index":110, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"112", + "feature_index":111, + "flat_feature_index":111, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"113", + "feature_index":112, + "flat_feature_index":112, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + -10.170950889587402 + ], + "feature_id":"114", + "feature_index":113, + "flat_feature_index":113, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"115", + "feature_index":114, + "flat_feature_index":114, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"116", + "feature_index":115, + "flat_feature_index":115, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"117", + "feature_index":116, + "flat_feature_index":116, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"118", + "feature_index":117, + "flat_feature_index":117, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"119", + "feature_index":118, + "flat_feature_index":118, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"120", + "feature_index":119, + "flat_feature_index":119, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"121", + "feature_index":120, + "flat_feature_index":120, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + -19.004238128662109 + ], + "feature_id":"122", + "feature_index":121, + "flat_feature_index":121, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + -1.0539200305938721 + ], + "feature_id":"123", + "feature_index":122, + "flat_feature_index":122, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"124", + "feature_index":123, + "flat_feature_index":123, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"125", + "feature_index":124, + "flat_feature_index":124, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"126", + "feature_index":125, + "flat_feature_index":125, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"127", + "feature_index":126, + "flat_feature_index":126, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"128", + "feature_index":127, + "flat_feature_index":127, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"129", + "feature_index":128, + "flat_feature_index":128, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"130", + "feature_index":129, + "flat_feature_index":129, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"131", + "feature_index":130, + "flat_feature_index":130, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 8.5 + ], + "feature_id":"132", + "feature_index":131, + "flat_feature_index":131, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"133", + "feature_index":132, + "flat_feature_index":132, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"134", + "feature_index":133, + "flat_feature_index":133, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"135", + "feature_index":134, + "flat_feature_index":134, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 31 + ], + "feature_id":"136", + "feature_index":135, + "flat_feature_index":135, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"137", + "feature_index":136, + "flat_feature_index":136, + "has_nans":false, + "nan_value_treatment":"AsIs" + } + ] + }, + "model_info": + { + "catboost_version_info":"Arc info:\n Branch: unknown-vcs-branch\n Commit: 0000000000000000000000000000000000000000\n Author: \n Summary: No VCS\n\n", + "class_params": + { + "class_label_type":"Float", + "class_names": + [ + 0, + 1, + 2, + 3, + 4 + ], + "class_to_label": + [ + 0, + 1, + 2, + 3, + 4 + ], + "classes_count":0 + }, + "model_guid":"a66298e7-38a0d1b5-ede14a2-c2cc9143", + "output_options":"{\"name\":\"experiment\",\"verbose\":0,\"test_error_log\":\"test_error.tsv\",\"json_log\":\"catboost_training.json\",\"result_model_file\":\"model\",\"roc_file\":\"\",\"eval_file_name\":\"\",\"use_best_model\":false,\"allow_writing_files\":true,\"prediction_type\":[\"RawFormulaVal\"],\"fstr_internal_file\":\"\",\"output_columns\":[\"SampleId\",\"RawFormulaVal\",\"Label\"],\"snapshot_interval\":600,\"time_left_log\":\"time_left.tsv\",\"fstr_type\":\"FeatureImportance\",\"profile_log\":\"catboost_profile.log\",\"train_dir\":\"catboost_info\",\"learn_error_log\":\"learn_error.tsv\",\"training_options_file\":\"\",\"snapshot_file\":\"experiment.cbsnapshot\",\"save_snapshot\":false,\"model_format\":[\"CatboostBinary\"],\"final_feature_calcer_computation_mode\":\"Default\",\"metric_period\":1,\"output_borders\":\"\",\"final_ctr_computation_mode\":\"Default\",\"best_model_min_trees\":1,\"fstr_regular_file\":\"\"}", + "params": + { + "boosting_options": + { + "approx_on_full_history":false, + "boost_from_average":false, + "boosting_type":"Plain", + "fold_len_multiplier":2, + "fold_permutation_block":0, + "iterations":10, + "learning_rate":0.5, + "model_shrink_mode":"Constant", + "model_shrink_rate":0, + "od_config": + { + "stop_pvalue":0, + "type":"None", + "wait_iterations":20 + }, + "permutation_count":4, + "posterior_sampling":false + }, + "cat_feature_params": + { + "combinations_ctrs": + [ + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Borders", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ], + [ + 0.5, + 1 + ], + [ + 1, + 1 + ] + ], + "target_binarization": + { + "border_count":4, + "border_type":"MinEntropy" + } + }, + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Counter", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ] + ] + } + ], + "counter_calc_method":"SkipTest", + "ctr_leaf_count_limit":18446744073709551615, + "max_ctr_complexity":1, + "one_hot_max_size":2, + "per_feature_ctrs": + { + }, + "simple_ctrs": + [ + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Borders", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ], + [ + 0.5, + 1 + ], + [ + 1, + 1 + ] + ], + "target_binarization": + { + "border_count":4, + "border_type":"MinEntropy" + } + }, + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Counter", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ] + ] + } + ], + "store_all_simple_ctr":false, + "target_binarization": + { + "border_count":1, + "border_type":"MinEntropy" + } + }, + "data_processing_options": + { + "allow_const_label":false, + "auto_class_weights":"None", + "class_names": + [ + 0, + 1, + 2, + 3, + 4 + ], + "class_weights": + [ + ], + "classes_count":0, + "dev_default_value_fraction_for_sparse":0.82999998331069946, + "dev_group_features":false, + "dev_leafwise_scoring":false, + "dev_sparse_array_indexing":"Indices", + "embedding_processing_options": + { + "embedding_processing": + { + "default": + [ + "LDA", + "KNN" + ] + } + }, + "eval_fraction":0, + "float_features_binarization": + { + "border_count":254, + "border_type":"GreedyLogSum", + "dev_max_subset_size_for_build_borders":200000, + "nan_mode":"Min" + }, + "force_unit_auto_pair_weights":false, + "has_time":false, + "ignored_features": + [ + ], + "per_float_feature_quantization": + { + }, + "target_border":null, + "text_processing_options": + { + "dictionaries": + [ + { + "dictionary_id":"BiGram", + "end_of_sentence_token_policy":"Skip", + "end_of_word_token_policy":"Insert", + "gram_order":"2", + "max_dictionary_size":"50000", + "occurrence_lower_bound":"5", + "skip_step":"0", + "start_token_id":"0", + "token_level_type":"Word" + }, + { + "dictionary_id":"Word", + "end_of_sentence_token_policy":"Skip", + "end_of_word_token_policy":"Insert", + "gram_order":"1", + "max_dictionary_size":"50000", + "occurrence_lower_bound":"5", + "skip_step":"0", + "start_token_id":"0", + "token_level_type":"Word" + } + ], + "feature_processing": + { + "default": + [ + { + "dictionaries_names": + [ + "BiGram", + "Word" + ], + "feature_calcers": + [ + "BoW" + ], + "tokenizers_names": + [ + "Space" + ] + }, + { + "dictionaries_names": + [ + "Word" + ], + "feature_calcers": + [ + "NaiveBayes" + ], + "tokenizers_names": + [ + "Space" + ] + } + ] + }, + "tokenizers": + [ + { + "delimiter":" ", + "languages": + [ + ], + "lemmatizing":"0", + "lowercasing":"0", + "number_process_policy":"LeaveAsIs", + "number_token":"🔢", + "separator_type":"ByDelimiter", + "skip_empty":"1", + "split_by_set":"0", + "subtokens_policy":"SingleToken", + "token_types": + [ + "Number", + "Unknown", + "Word" + ], + "tokenizer_id":"Space" + } + ] + } + }, + "detailed_profile":false, + "flat_params": + { + "depth":2, + "iterations":10, + "loss_function":"MultiClass", + "random_seed":0, + "verbose":0 + }, + "logging_level":"Silent", + "loss_function": + { + "params": + { + }, + "type":"MultiClass" + }, + "metadata": + { + }, + "metrics": + { + "custom_metrics": + [ + ], + "eval_metric": + { + "params": + { + }, + "type":"MultiClass" + }, + "objective_metric": + { + "params": + { + }, + "type":"MultiClass" + } + }, + "pool_metainfo_options": + { + "tags": + { + } + }, + "random_seed":0, + "system_options": + { + "file_with_hosts":"hosts.txt", + "node_port":0, + "node_type":"SingleHost", + "thread_count":10, + "used_ram_limit":"" + }, + "task_type":"CPU", + "tree_learner_options": + { + "bayesian_matrix_reg":0.10000000149011612, + "bootstrap": + { + "bagging_temperature":1, + "type":"Bayesian" + }, + "depth":2, + "dev_efb_max_buckets":1024, + "dev_leafwise_approxes":false, + "dev_score_calc_obj_block_size":5000000, + "grow_policy":"SymmetricTree", + "l2_leaf_reg":3, + "leaf_estimation_backtracking":"AnyImprovement", + "leaf_estimation_iterations":1, + "leaf_estimation_method":"Newton", + "max_leaves":4, + "min_data_in_leaf":1, + "model_size_reg":0.5, + "monotone_constraints": + { + }, + "penalties": + { + "feature_weights": + { + }, + "first_feature_use_penalties": + { + }, + "penalties_coefficient":1, + "per_object_feature_penalties": + { + } + }, + "random_strength":1, + "rsm":1, + "sampling_frequency":"PerTree", + "score_function":"Cosine", + "sparse_features_conflict_fraction":0 + } + }, + "train_finish_time":"2023-07-11T03:00:51Z" + }, + "oblivious_trees": + [ + { + "leaf_values": + [ + 1.0996503496504055, + 0.090034965034993295, + -0.2421328671328476, + -0.47377622377620821, + -0.47377622377620826, + 0.18779904306214146, + 0.44497607655496391, + 0.26555023923438986, + -0.44019138755987713, + -0.45813397129193495, + 1.1875, + -0.31250000000000011, + -0.25000000000000011, + -0.3125, + -0.3125, + -0.083333333333333329, + 0.49999999999999972, + 0.083333333333333301, + -0.24999999999999994, + -0.25 + ], + "leaf_weights": + [ + 557, + 403, + 25, + 15 + ], + "splits": + [ + { + "border":4.6377229690551758, + "float_feature_index":78, + "split_index":9, + "split_type":"FloatFeature" + }, + { + "border":-1.0539200305938721, + "float_feature_index":122, + "split_index":17, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.2260605751440336, + 0.23531086132866974, + 0.039108635124860416, + -0.22514108456448823, + -0.27533898703307047, + 0.49245128665243182, + 0.1596331205374483, + -0.13518239751286185, + -0.24889623737316394, + -0.26800577230382722, + 0.12768441931606134, + 0.35838414532948393, + 0.20125910789241883, + -0.32925288304643041, + -0.35807478949165161, + 0.3733861869210664, + -0.044851127831861369, + 0.21105595011057066, + -0.31186030949704024, + -0.22773069970274556 + ], + "leaf_weights": + [ + 68, + 223, + 567, + 142 + ], + "splits": + [ + { + "border":10.407758712768555, + "float_feature_index":16, + "split_index":1, + "split_type":"FloatFeature" + }, + { + "border":-19.004238128662109, + "float_feature_index":121, + "split_index":16, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.30322987474001128, + 0.15067325984719901, + 0.060315532657711161, + -0.24750733016657547, + -0.26671133707826661, + -0.21500021120323076, + 0.24982582725176083, + 0.20101607474521779, + -0.1179184800898066, + -0.11792321070394124, + 0.10706110243015572, + 0.3186864261367417, + 0.18761112991802159, + -0.2976473020787182, + -0.31571135640621972, + -0.21401333980735943, + -0.24861635318208319, + 0.85687723722944753, + -0.19770860906890458, + -0.19653893517110063 + ], + "leaf_weights": + [ + 626, + 13, + 335, + 26 + ], + "splits": + [ + { + "border":20.003637313842773, + "float_feature_index":107, + "split_index":13, + "split_type":"FloatFeature" + }, + { + "border":0.5, + "float_feature_index":34, + "split_index":5, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.33314554234648891, + 0.16825987316305704, + -0.11648537023942601, + -0.18089925096965179, + -0.20402079430057724, + 0.092662651235827051, + 0.16710393081258634, + 0.27168804499396798, + -0.24467150204426674, + -0.28678312499811748, + -0.00046690460750491337, + 0.17027482720336781, + 0.24548444905155281, + -0.185529610201145, + -0.22976276144627739, + 0.08849801009090412, + 0.14585428715004656, + 0.19117387644274145, + -0.24652302609888357, + -0.17900314758481195 + ], + "leaf_weights": + [ + 389, + 190, + 317, + 104 + ], + "splits": + [ + { + "border":0.10263150185346603, + "float_feature_index":49, + "split_index":6, + "split_type":"FloatFeature" + }, + { + "border":-10.170950889587402, + "float_feature_index":113, + "split_index":15, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.22777436995758954, + 0.10632364744248361, + 0.081720972755231464, + -0.19653248194052131, + -0.21928650821472842, + 0.0036970123625032149, + 0.17657285933507696, + 0.087227481731722345, + -0.10722520325212968, + -0.16027215017718727, + 0.23598151159516953, + 0.24520659808289103, + -0.13054120949147338, + -0.17740974438146837, + -0.17323715580511964, + -0.04016158940998666, + 0.0079865643613920186, + 0.44505804231747287, + -0.2088899806051997, + -0.20399303666367644 + ], + "leaf_weights": + [ + 434, + 339, + 135, + 92 + ], + "splits": + [ + { + "border":2.5, + "float_feature_index":26, + "split_index":3, + "split_type":"FloatFeature" + }, + { + "border":41.5, + "float_feature_index":25, + "split_index":2, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.14217743347355216, + 0.10799764945032921, + 0.097938891102720457, + -0.1530888383996391, + -0.19502513562693596, + -0.24625490544600348, + 0.08839381880208938, + 0.34351993494912964, + -0.094155090262857413, + -0.091503758042358177, + 0.062783511620311885, + 0.31013506940427321, + -0.17786716498820154, + -0.099421081323112143, + -0.09563033471327094, + -0.28925955907610168, + 0.1181047737946504, + 0.20418129856841655, + -0.017076373916100843, + -0.015950139370864459 + ], + "leaf_weights": + [ + 921, + 32, + 43, + 4 + ], + "splits": + [ + { + "border":24.802555084228516, + "float_feature_index":72, + "split_index":8, + "split_type":"FloatFeature" + }, + { + "border":223.0445556640625, + "float_feature_index":81, + "split_index":10, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.22007066451161272, + 0.27901422364451733, + 0.053726441413690061, + -0.060789540489797837, + -0.051880460056796592, + 0.14084355050530242, + 0.08175494902799188, + 0.10222863603685897, + -0.10853955456906017, + -0.21628758100117962, + -0.20427842365298121, + 0.29496597630070487, + -0.063858057100830964, + -0.013838577501709278, + -0.012990918045183299, + -0.13652972878791608, + 0.11993613258479305, + -0.001042875714270501, + -0.10740890322441982, + 0.12504537514181302 + ], + "leaf_weights": + [ + 21, + 865, + 4, + 110 + ], + "splits": + [ + { + "border":8.5, + "float_feature_index":131, + "split_index":18, + "split_type":"FloatFeature" + }, + { + "border":0.0079079996794462204, + "float_feature_index":68, + "split_index":7, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.10370927493075215, + 0.030222911806032191, + -0.054209237674861861, + 0.054027190072161196, + -0.13375013913408346, + 0.0035910182954583237, + 0.1518987313984575, + 0.028040321232255382, + -0.17607326492494235, + -0.007456806001230094, + 0.04968366244568987, + 0.0562433437643374, + 0.22672662673800298, + -0.14078169890850042, + -0.19187193403953257, + -0.087060584816570613, + 0.085946140681893518, + 0.099339797233289673, + -0.11769732595930756, + 0.019471972860694835 + ], + "leaf_weights": + [ + 502, + 261, + 181, + 56 + ], + "splits": + [ + { + "border":8.8898735046386719, + "float_feature_index":93, + "split_index":12, + "split_type":"FloatFeature" + }, + { + "border":7.703493595123291, + "float_feature_index":88, + "split_index":11, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.076502991422632186, + 0.067091580590373442, + 0.056139802333490134, + -0.081801675683306777, + -0.11793269866317857, + -0.018256523746251842, + -0.21297481308205177, + 0.25348131177678396, + -0.012976026207467344, + -0.0092739487410130206, + -0.22129594836325564, + 0.088363514050145045, + 0.038673259650776295, + 0.034359683260323563, + 0.059899491402010806, + 0, + 0, + 0, + 0, + 0 + ], + "leaf_weights": + [ + 949, + 5, + 46, + 0 + ], + "splits": + [ + { + "border":2.5, + "float_feature_index":28, + "split_index":4, + "split_type":"FloatFeature" + }, + { + "border":31, + "float_feature_index":135, + "split_index":19, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.37386395194911243, + 0.05208760418731518, + 0.37109746685447348, + 0.02884935538112238, + -0.078170474473797602, + 0.032600605121987293, + -0.11157786097459373, + 0.050269632346439913, + -0.13115591087796583, + 0.15986353438413281, + 0.27103592266631699, + 0.0023455339404997583, + -0.080358958432657585, + -0.067083448895804407, + -0.12593904927834632, + -0.016232911305038306, + 0.05848896571047691, + 0.027001374466226028, + 0.00041657296033062533, + -0.069674001831997276 + ], + "leaf_weights": + [ + 49, + 83, + 204, + 664 + ], + "splits": + [ + { + "border":18.918785095214844, + "float_feature_index":110, + "split_index":14, + "split_type":"FloatFeature" + }, + { + "border":124, + "float_feature_index":11, + "split_index":0, + "split_type":"FloatFeature" + } + ] + } + ], + "scale_and_bias": + [ + 1, + [ + 0, + 0, + 0, + 0, + 0 + ] + ] +} \ No newline at end of file diff --git a/test/support/catboost_regressor.json b/test/support/catboost_regressor.json new file mode 100644 index 0000000..4327a55 --- /dev/null +++ b/test/support/catboost_regressor.json @@ -0,0 +1,2076 @@ +{ + "features_info": + { + "float_features": + [ + { + "borders": + [ + ], + "feature_id":"1", + "feature_index":0, + "flat_feature_index":0, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"2", + "feature_index":1, + "flat_feature_index":1, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"3", + "feature_index":2, + "flat_feature_index":2, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"4", + "feature_index":3, + "flat_feature_index":3, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"5", + "feature_index":4, + "flat_feature_index":4, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"6", + "feature_index":5, + "flat_feature_index":5, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"7", + "feature_index":6, + "flat_feature_index":6, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"8", + "feature_index":7, + "flat_feature_index":7, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"9", + "feature_index":8, + "flat_feature_index":8, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"10", + "feature_index":9, + "flat_feature_index":9, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"11", + "feature_index":10, + "flat_feature_index":10, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 10.5 + ], + "feature_id":"12", + "feature_index":11, + "flat_feature_index":11, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"13", + "feature_index":12, + "flat_feature_index":12, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 2.5 + ], + "feature_id":"14", + "feature_index":13, + "flat_feature_index":13, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 3.5 + ], + "feature_id":"15", + "feature_index":14, + "flat_feature_index":14, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"16", + "feature_index":15, + "flat_feature_index":15, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"17", + "feature_index":16, + "flat_feature_index":16, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"18", + "feature_index":17, + "flat_feature_index":17, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 19.769384384155273 + ], + "feature_id":"19", + "feature_index":18, + "flat_feature_index":18, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"20", + "feature_index":19, + "flat_feature_index":19, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"21", + "feature_index":20, + "flat_feature_index":20, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"22", + "feature_index":21, + "flat_feature_index":21, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"23", + "feature_index":22, + "flat_feature_index":22, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"24", + "feature_index":23, + "flat_feature_index":23, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"25", + "feature_index":24, + "flat_feature_index":24, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"26", + "feature_index":25, + "flat_feature_index":25, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"27", + "feature_index":26, + "flat_feature_index":26, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"28", + "feature_index":27, + "flat_feature_index":27, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"29", + "feature_index":28, + "flat_feature_index":28, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 1.5 + ], + "feature_id":"30", + "feature_index":29, + "flat_feature_index":29, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 3.5 + ], + "feature_id":"31", + "feature_index":30, + "flat_feature_index":30, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"32", + "feature_index":31, + "flat_feature_index":31, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"33", + "feature_index":32, + "flat_feature_index":32, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"34", + "feature_index":33, + "flat_feature_index":33, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"35", + "feature_index":34, + "flat_feature_index":34, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"36", + "feature_index":35, + "flat_feature_index":35, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"37", + "feature_index":36, + "flat_feature_index":36, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"38", + "feature_index":37, + "flat_feature_index":37, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"39", + "feature_index":38, + "flat_feature_index":38, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"40", + "feature_index":39, + "flat_feature_index":39, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"41", + "feature_index":40, + "flat_feature_index":40, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"42", + "feature_index":41, + "flat_feature_index":41, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"43", + "feature_index":42, + "flat_feature_index":42, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"44", + "feature_index":43, + "flat_feature_index":43, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"45", + "feature_index":44, + "flat_feature_index":44, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"46", + "feature_index":45, + "flat_feature_index":45, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"47", + "feature_index":46, + "flat_feature_index":46, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"48", + "feature_index":47, + "flat_feature_index":47, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"49", + "feature_index":48, + "flat_feature_index":48, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"50", + "feature_index":49, + "flat_feature_index":49, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"51", + "feature_index":50, + "flat_feature_index":50, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"52", + "feature_index":51, + "flat_feature_index":51, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"53", + "feature_index":52, + "flat_feature_index":52, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"54", + "feature_index":53, + "flat_feature_index":53, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"55", + "feature_index":54, + "flat_feature_index":54, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"56", + "feature_index":55, + "flat_feature_index":55, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"57", + "feature_index":56, + "flat_feature_index":56, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"58", + "feature_index":57, + "flat_feature_index":57, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"59", + "feature_index":58, + "flat_feature_index":58, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"60", + "feature_index":59, + "flat_feature_index":59, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"61", + "feature_index":60, + "flat_feature_index":60, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"62", + "feature_index":61, + "flat_feature_index":61, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"63", + "feature_index":62, + "flat_feature_index":62, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"64", + "feature_index":63, + "flat_feature_index":63, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"65", + "feature_index":64, + "flat_feature_index":64, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"66", + "feature_index":65, + "flat_feature_index":65, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"67", + "feature_index":66, + "flat_feature_index":66, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"68", + "feature_index":67, + "flat_feature_index":67, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"69", + "feature_index":68, + "flat_feature_index":68, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"70", + "feature_index":69, + "flat_feature_index":69, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"71", + "feature_index":70, + "flat_feature_index":70, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"72", + "feature_index":71, + "flat_feature_index":71, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"73", + "feature_index":72, + "flat_feature_index":72, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"74", + "feature_index":73, + "flat_feature_index":73, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 14.753328323364258 + ], + "feature_id":"75", + "feature_index":74, + "flat_feature_index":74, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 33.714653015136719 + ], + "feature_id":"76", + "feature_index":75, + "flat_feature_index":75, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"77", + "feature_index":76, + "flat_feature_index":76, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 5.98291015625 + ], + "feature_id":"78", + "feature_index":77, + "flat_feature_index":77, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 4.6377229690551758 + ], + "feature_id":"79", + "feature_index":78, + "flat_feature_index":78, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"80", + "feature_index":79, + "flat_feature_index":79, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"81", + "feature_index":80, + "flat_feature_index":80, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"82", + "feature_index":81, + "flat_feature_index":81, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 8.7581996917724609 + ], + "feature_id":"83", + "feature_index":82, + "flat_feature_index":82, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"84", + "feature_index":83, + "flat_feature_index":83, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"85", + "feature_index":84, + "flat_feature_index":84, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"86", + "feature_index":85, + "flat_feature_index":85, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 5.6147050857543945 + ], + "feature_id":"87", + "feature_index":86, + "flat_feature_index":86, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"88", + "feature_index":87, + "flat_feature_index":87, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"89", + "feature_index":88, + "flat_feature_index":88, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"90", + "feature_index":89, + "flat_feature_index":89, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"91", + "feature_index":90, + "flat_feature_index":90, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"92", + "feature_index":91, + "flat_feature_index":91, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"93", + "feature_index":92, + "flat_feature_index":92, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"94", + "feature_index":93, + "flat_feature_index":93, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"95", + "feature_index":94, + "flat_feature_index":94, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"96", + "feature_index":95, + "flat_feature_index":95, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"97", + "feature_index":96, + "flat_feature_index":96, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"98", + "feature_index":97, + "flat_feature_index":97, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"99", + "feature_index":98, + "flat_feature_index":98, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"100", + "feature_index":99, + "flat_feature_index":99, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"101", + "feature_index":100, + "flat_feature_index":100, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"102", + "feature_index":101, + "flat_feature_index":101, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"103", + "feature_index":102, + "flat_feature_index":102, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"104", + "feature_index":103, + "flat_feature_index":103, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"105", + "feature_index":104, + "flat_feature_index":104, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"106", + "feature_index":105, + "flat_feature_index":105, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"107", + "feature_index":106, + "flat_feature_index":106, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"108", + "feature_index":107, + "flat_feature_index":107, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 10.594841957092285, + 20.248912811279297 + ], + "feature_id":"109", + "feature_index":108, + "flat_feature_index":108, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"110", + "feature_index":109, + "flat_feature_index":109, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"111", + "feature_index":110, + "flat_feature_index":110, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"112", + "feature_index":111, + "flat_feature_index":111, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"113", + "feature_index":112, + "flat_feature_index":112, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"114", + "feature_index":113, + "flat_feature_index":113, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + -25.435811996459961 + ], + "feature_id":"115", + "feature_index":114, + "flat_feature_index":114, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"116", + "feature_index":115, + "flat_feature_index":115, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"117", + "feature_index":116, + "flat_feature_index":116, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"118", + "feature_index":117, + "flat_feature_index":117, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"119", + "feature_index":118, + "flat_feature_index":118, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"120", + "feature_index":119, + "flat_feature_index":119, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"121", + "feature_index":120, + "flat_feature_index":120, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + -12.996198654174805 + ], + "feature_id":"122", + "feature_index":121, + "flat_feature_index":121, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + -26.762472152709961 + ], + "feature_id":"123", + "feature_index":122, + "flat_feature_index":122, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"124", + "feature_index":123, + "flat_feature_index":123, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"125", + "feature_index":124, + "flat_feature_index":124, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"126", + "feature_index":125, + "flat_feature_index":125, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 1.5 + ], + "feature_id":"127", + "feature_index":126, + "flat_feature_index":126, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"128", + "feature_index":127, + "flat_feature_index":127, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 235 + ], + "feature_id":"129", + "feature_index":128, + "flat_feature_index":128, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"130", + "feature_index":129, + "flat_feature_index":129, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"131", + "feature_index":130, + "flat_feature_index":130, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"132", + "feature_index":131, + "flat_feature_index":131, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"133", + "feature_index":132, + "flat_feature_index":132, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"134", + "feature_index":133, + "flat_feature_index":133, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 4.5 + ], + "feature_id":"135", + "feature_index":134, + "flat_feature_index":134, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"136", + "feature_index":135, + "flat_feature_index":135, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + ], + "feature_id":"137", + "feature_index":136, + "flat_feature_index":136, + "has_nans":false, + "nan_value_treatment":"AsIs" + } + ] + }, + "model_info": + { + "catboost_version_info":"Arc info:\n Branch: unknown-vcs-branch\n Commit: 0000000000000000000000000000000000000000\n Author: \n Summary: No VCS\n\n", + "model_guid":"1e7e6196-aac12f39-f85c172f-a327afbd", + "output_options":"{\"name\":\"experiment\",\"verbose\":0,\"test_error_log\":\"test_error.tsv\",\"json_log\":\"catboost_training.json\",\"result_model_file\":\"model\",\"roc_file\":\"\",\"eval_file_name\":\"\",\"use_best_model\":false,\"allow_writing_files\":true,\"prediction_type\":[\"RawFormulaVal\"],\"fstr_internal_file\":\"\",\"output_columns\":[\"SampleId\",\"RawFormulaVal\",\"Label\"],\"snapshot_interval\":600,\"time_left_log\":\"time_left.tsv\",\"fstr_type\":\"FeatureImportance\",\"profile_log\":\"catboost_profile.log\",\"train_dir\":\"catboost_info\",\"learn_error_log\":\"learn_error.tsv\",\"training_options_file\":\"\",\"snapshot_file\":\"experiment.cbsnapshot\",\"save_snapshot\":false,\"model_format\":[\"CatboostBinary\"],\"final_feature_calcer_computation_mode\":\"Default\",\"metric_period\":1,\"output_borders\":\"\",\"final_ctr_computation_mode\":\"Default\",\"best_model_min_trees\":1,\"fstr_regular_file\":\"\"}", + "params": + { + "boosting_options": + { + "approx_on_full_history":false, + "boost_from_average":true, + "boosting_type":"Plain", + "fold_len_multiplier":2, + "fold_permutation_block":0, + "iterations":10, + "learning_rate":0.5, + "model_shrink_mode":"Constant", + "model_shrink_rate":0, + "od_config": + { + "stop_pvalue":0, + "type":"None", + "wait_iterations":20 + }, + "permutation_count":4, + "posterior_sampling":false + }, + "cat_feature_params": + { + "combinations_ctrs": + [ + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Borders", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ], + [ + 0.5, + 1 + ], + [ + 1, + 1 + ] + ], + "target_binarization": + { + "border_count":1, + "border_type":"MinEntropy" + } + }, + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Counter", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ] + ] + } + ], + "counter_calc_method":"SkipTest", + "ctr_leaf_count_limit":18446744073709551615, + "max_ctr_complexity":1, + "one_hot_max_size":2, + "per_feature_ctrs": + { + }, + "simple_ctrs": + [ + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Borders", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ], + [ + 0.5, + 1 + ], + [ + 1, + 1 + ] + ], + "target_binarization": + { + "border_count":1, + "border_type":"MinEntropy" + } + }, + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Counter", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ] + ] + } + ], + "store_all_simple_ctr":false, + "target_binarization": + { + "border_count":1, + "border_type":"MinEntropy" + } + }, + "data_processing_options": + { + "allow_const_label":false, + "auto_class_weights":"None", + "class_names": + [ + ], + "class_weights": + [ + ], + "classes_count":0, + "dev_default_value_fraction_for_sparse":0.82999998331069946, + "dev_group_features":false, + "dev_leafwise_scoring":false, + "dev_sparse_array_indexing":"Indices", + "embedding_processing_options": + { + "embedding_processing": + { + "default": + [ + "LDA", + "KNN" + ] + } + }, + "eval_fraction":0, + "float_features_binarization": + { + "border_count":254, + "border_type":"GreedyLogSum", + "dev_max_subset_size_for_build_borders":200000, + "nan_mode":"Min" + }, + "force_unit_auto_pair_weights":false, + "has_time":false, + "ignored_features": + [ + ], + "per_float_feature_quantization": + { + }, + "target_border":null, + "text_processing_options": + { + "dictionaries": + [ + { + "dictionary_id":"BiGram", + "end_of_sentence_token_policy":"Skip", + "end_of_word_token_policy":"Insert", + "gram_order":"2", + "max_dictionary_size":"50000", + "occurrence_lower_bound":"5", + "skip_step":"0", + "start_token_id":"0", + "token_level_type":"Word" + }, + { + "dictionary_id":"Word", + "end_of_sentence_token_policy":"Skip", + "end_of_word_token_policy":"Insert", + "gram_order":"1", + "max_dictionary_size":"50000", + "occurrence_lower_bound":"5", + "skip_step":"0", + "start_token_id":"0", + "token_level_type":"Word" + } + ], + "feature_processing": + { + "default": + [ + { + "dictionaries_names": + [ + "BiGram", + "Word" + ], + "feature_calcers": + [ + "BoW" + ], + "tokenizers_names": + [ + "Space" + ] + } + ] + }, + "tokenizers": + [ + { + "delimiter":" ", + "languages": + [ + ], + "lemmatizing":"0", + "lowercasing":"0", + "number_process_policy":"LeaveAsIs", + "number_token":"🔢", + "separator_type":"ByDelimiter", + "skip_empty":"1", + "split_by_set":"0", + "subtokens_policy":"SingleToken", + "token_types": + [ + "Number", + "Unknown", + "Word" + ], + "tokenizer_id":"Space" + } + ] + } + }, + "detailed_profile":false, + "flat_params": + { + "depth":2, + "iterations":10, + "loss_function":"RMSE", + "random_seed":0, + "verbose":0 + }, + "logging_level":"Silent", + "loss_function": + { + "params": + { + }, + "type":"RMSE" + }, + "metadata": + { + }, + "metrics": + { + "custom_metrics": + [ + ], + "eval_metric": + { + "params": + { + }, + "type":"RMSE" + }, + "objective_metric": + { + "params": + { + }, + "type":"RMSE" + } + }, + "pool_metainfo_options": + { + "tags": + { + } + }, + "random_seed":0, + "system_options": + { + "file_with_hosts":"hosts.txt", + "node_port":0, + "node_type":"SingleHost", + "thread_count":10, + "used_ram_limit":"" + }, + "task_type":"CPU", + "tree_learner_options": + { + "bayesian_matrix_reg":0.10000000149011612, + "bootstrap": + { + "mvs_reg":null, + "subsample":0.80000001192092896, + "type":"MVS" + }, + "depth":2, + "dev_efb_max_buckets":1024, + "dev_leafwise_approxes":false, + "dev_score_calc_obj_block_size":5000000, + "grow_policy":"SymmetricTree", + "l2_leaf_reg":3, + "leaf_estimation_backtracking":"AnyImprovement", + "leaf_estimation_iterations":1, + "leaf_estimation_method":"Newton", + "max_leaves":4, + "min_data_in_leaf":1, + "model_size_reg":0.5, + "monotone_constraints": + { + }, + "penalties": + { + "feature_weights": + { + }, + "first_feature_use_penalties": + { + }, + "penalties_coefficient":1, + "per_object_feature_penalties": + { + } + }, + "random_strength":1, + "rsm":1, + "sampling_frequency":"PerTree", + "score_function":"Cosine", + "sparse_features_conflict_fraction":0 + } + }, + "train_finish_time":"2023-07-11T03:12:33Z" + }, + "oblivious_trees": + [ + { + "leaf_values": + [ + -0.22147451522303563, + 0.27130434305771539, + -0.063933939368159207, + 0.17836907186413048 + ], + "leaf_weights": + [ + 252, + 20, + 330, + 398 + ], + "splits": + [ + { + "border":4.6377229690551758, + "float_feature_index":78, + "split_index":9, + "split_type":"FloatFeature" + }, + { + "border":-26.762472152709961, + "float_feature_index":122, + "split_index":16, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.1027057318217771, + 0.14109578800784825, + 0, + 0.1429880835006494 + ], + "leaf_weights": + [ + 580, + 414, + 0, + 6 + ], + "splits": + [ + { + "border":3.5, + "float_feature_index":30, + "split_index":5, + "split_type":"FloatFeature" + }, + { + "border":1.5, + "float_feature_index":29, + "split_index":4, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.020148719957822723, + 0.78355534315845921, + -0.05006291928409741, + 0.093270262012351418 + ], + "leaf_weights": + [ + 102, + 9, + 657, + 232 + ], + "splits": + [ + { + "border":20.248912811279297, + "float_feature_index":108, + "split_index":13, + "split_type":"FloatFeature" + }, + { + "border":1.5, + "float_feature_index":126, + "split_index":17, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.061480519876815291, + 0.080732473582315514, + 0.33705637488502921, + 0.065161926801400594 + ], + "leaf_weights": + [ + 587, + 358, + 14, + 41 + ], + "splits": + [ + { + "border":-12.996198654174805, + "float_feature_index":121, + "split_index":15, + "split_type":"FloatFeature" + }, + { + "border":5.98291015625, + "float_feature_index":77, + "split_index":8, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.081261917752663435, + 0.024952196665882127, + 0.092116996387399827, + 0.071687282428878257 + ], + "leaf_weights": + [ + 316, + 515, + 86, + 83 + ], + "splits": + [ + { + "border":-25.435811996459961, + "float_feature_index":114, + "split_index":14, + "split_type":"FloatFeature" + }, + { + "border":235, + "float_feature_index":128, + "split_index":18, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.045860069338519834, + 0.16221633934911087, + 0.022053619224936132, + 0.37127731022198202 + ], + "leaf_weights": + [ + 411, + 12, + 567, + 10 + ], + "splits": + [ + { + "border":4.5, + "float_feature_index":134, + "split_index":19, + "split_type":"FloatFeature" + }, + { + "border":19.769384384155273, + "float_feature_index":18, + "split_index":3, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.065004395849513938, + -0.11991603989566227, + -0.025988777276491074, + 0.12592267878180907 + ], + "leaf_weights": + [ + 88, + 6, + 777, + 129 + ], + "splits": + [ + { + "border":14.753328323364258, + "float_feature_index":74, + "split_index":6, + "split_type":"FloatFeature" + }, + { + "border":2.5, + "float_feature_index":13, + "split_index":1, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.25015765656540179, + -0.011085223248840677, + 0, + 0.049484059242209277 + ], + "leaf_weights": + [ + 18, + 880, + 0, + 102 + ], + "splits": + [ + { + "border":3.5, + "float_feature_index":14, + "split_index":2, + "split_type":"FloatFeature" + }, + { + "border":8.7581996917724609, + "float_feature_index":82, + "split_index":10, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.016934519010446836, + -0.049914426296902924, + -0.063644694466597068, + 0.045150290779615992 + ], + "leaf_weights": + [ + 161, + 97, + 229, + 513 + ], + "splits": + [ + { + "border":10.594841957092285, + "float_feature_index":108, + "split_index":12, + "split_type":"FloatFeature" + }, + { + "border":33.714653015136719, + "float_feature_index":75, + "split_index":7, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.27357486662880259, + -0.10586060075148507, + -0.10261505509561733, + 0.011008999759450447 + ], + "leaf_weights": + [ + 28, + 153, + 4, + 815 + ], + "splits": + [ + { + "border":10.5, + "float_feature_index":11, + "split_index":0, + "split_type":"FloatFeature" + }, + { + "border":5.6147050857543945, + "float_feature_index":86, + "split_index":11, + "split_type":"FloatFeature" + } + ] + } + ], + "scale_and_bias": + [ + 1, + [ + 0.72600001096725464 + ] + ] +} \ No newline at end of file From 194144f488262c29d50d56605fe1ec91f9d54b01 Mon Sep 17 00:00:00 2001 From: acalejos Date: Wed, 19 Jul 2023 15:35:10 -0400 Subject: [PATCH 3/5] Add catboost tests and WIP --- lib/mockingjay/strategies/gemm.ex | 18 +- test/mockingjay/catboost_test.exs | 43 +- test/mockingjay/exgboost_test.exs | 13 +- test/support/catboost_iris.json | 1081 +++++++++++++++++++++++++++++ 4 files changed, 1140 insertions(+), 15 deletions(-) create mode 100644 test/support/catboost_iris.json diff --git a/lib/mockingjay/strategies/gemm.ex b/lib/mockingjay/strategies/gemm.ex index 74c2e03..72e19a1 100644 --- a/lib/mockingjay/strategies/gemm.ex +++ b/lib/mockingjay/strategies/gemm.ex @@ -32,11 +32,11 @@ defmodule Mockingjay.Strategies.GEMM do n_weak_learner_classes = trees |> hd() - |> Tree.get_decision_values() + |> Tree.get_leaf_nodes() |> hd() |> case do - value when is_list(value) -> - length(value) + value when is_list(value.value) -> + length(value.value) _value -> 1 @@ -48,7 +48,7 @@ defmodule Mockingjay.Strategies.GEMM do max(h2, length(Tree.get_leaf_nodes(tree)))} end) - n_trees = length(trees) + n_trees = length(trees) |> IO.inspect(label: "n_trees") {mat_A, mat_B} = generate_matrices_AB(trees, num_features, max_decision_nodes) mat_C = generate_matrix_C(trees, max_decision_nodes, max_leaf_nodes) @@ -205,7 +205,13 @@ defmodule Mockingjay.Strategies.GEMM do if n_weak_learner_classes == 1 do {[index, 0, node_index], node.value} else - {[index, trunc(node.value), node_index], 1} + cat_value = node.value |> Nx.tensor() |> Nx.argmax() |> Nx.to_number() + + {[ + index, + cat_value, + node_index + ], 1} end end) end) @@ -221,7 +227,7 @@ defmodule Mockingjay.Strategies.GEMM do d = Nx.indexed_put(d_zero, d_indices, d_updates) e_updates = Nx.tensor(updates_list) - + IO.inspect(n_weak_learner_classes, label: "n_weak_learner_classes") e_zero = Nx.broadcast(0, {n_trees, n_weak_learner_classes, max_leaf_nodes}) e = Nx.indexed_put(e_zero, e_indices, e_updates) diff --git a/test/mockingjay/catboost_test.exs b/test/mockingjay/catboost_test.exs index 3387121..96632f2 100644 --- a/test/mockingjay/catboost_test.exs +++ b/test/mockingjay/catboost_test.exs @@ -4,21 +4,56 @@ defmodule CatboostTest do alias Mockingjay.Adapters.Catboost test "load json" do - clf_booster = Catboost.load_model("test/support/catboost_classifier.json") + clf_booster = Catboost.load_model("test/support/catboost_iris.json") reg_booster = Catboost.load_model("test/support/catboost_regressor.json") end test "protocol implementation" do - clf_booster = Catboost.load_model("test/support/catboost_classifier.json") + clf_booster = Catboost.load_model("test/support/catboost_iris.json") reg_booster = Catboost.load_model("test/support/catboost_regressor.json") - for {booster, expected_num_class} <- [{clf_booster, 5}, {reg_booster, 1}] do + for {booster, expected_num_class, expected_num_feature} <- [ + {clf_booster, 3, 4}, + {reg_booster, 1, 137} + ] do trees = DecisionTree.trees(booster) assert is_list(trees) assert is_struct(hd(trees) |> IO.inspect(label: "Tree"), Mockingjay.Tree) assert DecisionTree.num_classes(booster) == expected_num_class - assert DecisionTree.num_features(booster) == 137 + assert DecisionTree.num_features(booster) == expected_num_feature end end + + test "iris performance" do + {x, y} = Scidata.Iris.download() + # data = Enum.zip(x, y) |> Enum.shuffle() + # {train, test} = Enum.split(data, ceil(length(data) * 0.8)) + # {x_train, y_train} = Enum.unzip(train) + # {x_test, y_test} = Enum.unzip(test) + x = Nx.tensor(x) + y = Nx.tensor(y) + + # x_train = Nx.tensor(x_train) + # y_train = Nx.tensor(y_train) + + # x_test = Nx.tensor(x_test) + # y_test = Nx.tensor(y_test) + + booster = Catboost.load_model("test/support/catboost_iris.json") + + gemm_predict = Mockingjay.convert(booster, strategy: :gemm, post_transform: :linear) + # tt_predict = Mockingjay.convert(booster, strategy: :tree_traversal) + # ptt_predict = Mockingjay.convert(booster, strategy: :perfect_tree_traversal) + # auto_predict = Mockingjay.convert(booster, strategy: :auto) + + gemm_preds = gemm_predict.(x) |> IO.inspect() |> Nx.argmax(axis: -1) + + gemm_accuracy = + Scholar.Metrics.accuracy(y, gemm_preds) + |> Nx.to_number() + |> IO.inspect(label: "gemm_accuracy") + + assert false + end end diff --git a/test/mockingjay/exgboost_test.exs b/test/mockingjay/exgboost_test.exs index 83587e3..95522b9 100644 --- a/test/mockingjay/exgboost_test.exs +++ b/test/mockingjay/exgboost_test.exs @@ -45,20 +45,22 @@ defmodule EXGBoostTest do assert DecisionTree.num_features(booster) == 4 end - test "compiles", context do + test "converts", context do booster = EXGBoost.train(context.x_train, context.y_train, num_class: 3, objective: :multi_softprob) - gemm_predict = EXGBoost.compile(booster, strategy: :gemm) - tt_predict = EXGBoost.compile(booster, strategy: :tree_traversal) - ptt_predict = EXGBoost.compile(booster, strategy: :perfect_tree_traversal) - auto_predict = EXGBoost.compile(booster, strategy: :auto) + gemm_predict = Mockingjay.convert(booster, strategy: :gemm) + tt_predict = Mockingjay.convert(booster, strategy: :tree_traversal) + ptt_predict = Mockingjay.convert(booster, strategy: :perfect_tree_traversal) + auto_predict = Mockingjay.convert(booster, strategy: :auto) # host_jit = EXLA.jit(compiled_predict) preds1 = EXGBoost.predict(booster, context.x_test) |> Nx.argmax(axis: -1) + context.x_test |> IO.inspect(label: "context.x_test") + preds2 = gemm_predict.(context.x_test) |> Nx.argmax(axis: -1) preds3 = tt_predict.(context.x_test) |> Nx.argmax(axis: -1) preds4 = ptt_predict.(context.x_test) |> Nx.argmax(axis: -1) @@ -88,5 +90,6 @@ defmodule EXGBoostTest do assert tt_accuracy >= base_acc assert ptt_accuracy >= base_acc assert auto_accuracy >= base_acc + assert false end end diff --git a/test/support/catboost_iris.json b/test/support/catboost_iris.json new file mode 100644 index 0000000..1d76ef6 --- /dev/null +++ b/test/support/catboost_iris.json @@ -0,0 +1,1081 @@ +{ + "features_info": + { + "float_features": + [ + { + "borders": + [ + 6.6499996185302734 + ], + "feature_id":"", + "feature_index":0, + "flat_feature_index":0, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 2.25, + 2.3499999046325684, + 2.4500000476837158, + 2.8499999046325684, + 3.4500000476837158 + ], + "feature_id":"", + "feature_index":1, + "flat_feature_index":1, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 3.6500000953674316, + 4.8500003814697266, + 5.0500001907348633, + 5.1499996185302734, + 5.3500003814697266 + ], + "feature_id":"", + "feature_index":2, + "flat_feature_index":2, + "has_nans":false, + "nan_value_treatment":"AsIs" + }, + { + "borders": + [ + 0.44999998807907104, + 0.55000001192092896, + 0.80000001192092896, + 1.25, + 1.3499999046325684, + 1.4500000476837158, + 1.6500000953674316, + 1.75, + 2.3499999046325684 + ], + "feature_id":"", + "feature_index":3, + "flat_feature_index":3, + "has_nans":false, + "nan_value_treatment":"AsIs" + } + ] + }, + "model_info": + { + "catboost_version_info":"Arc info:\n Branch: unknown-vcs-branch\n Commit: 0000000000000000000000000000000000000000\n Author: \n Summary: No VCS\n\n", + "class_params": + { + "class_label_type":"Integer", + "class_names": + [ + 0, + 1, + 2 + ], + "class_to_label": + [ + 0, + 1, + 2 + ], + "classes_count":0 + }, + "model_guid":"5d4a4e59-b7c39d71-8763a050-14fc97f4", + "output_options":"{\"name\":\"experiment\",\"verbose\":0,\"test_error_log\":\"test_error.tsv\",\"json_log\":\"catboost_training.json\",\"result_model_file\":\"model\",\"roc_file\":\"\",\"eval_file_name\":\"\",\"use_best_model\":false,\"allow_writing_files\":true,\"prediction_type\":[\"RawFormulaVal\"],\"fstr_internal_file\":\"\",\"output_columns\":[\"SampleId\",\"RawFormulaVal\",\"Label\"],\"snapshot_interval\":600,\"time_left_log\":\"time_left.tsv\",\"fstr_type\":\"FeatureImportance\",\"profile_log\":\"catboost_profile.log\",\"train_dir\":\"catboost_info\",\"learn_error_log\":\"learn_error.tsv\",\"training_options_file\":\"\",\"snapshot_file\":\"experiment.cbsnapshot\",\"save_snapshot\":false,\"model_format\":[\"CatboostBinary\"],\"final_feature_calcer_computation_mode\":\"Default\",\"metric_period\":1,\"output_borders\":\"\",\"final_ctr_computation_mode\":\"Default\",\"best_model_min_trees\":1,\"fstr_regular_file\":\"\"}", + "params": + { + "boosting_options": + { + "approx_on_full_history":false, + "boost_from_average":false, + "boosting_type":"Plain", + "fold_len_multiplier":2, + "fold_permutation_block":0, + "iterations":10, + "learning_rate":0.5, + "model_shrink_mode":"Constant", + "model_shrink_rate":0, + "od_config": + { + "stop_pvalue":0, + "type":"None", + "wait_iterations":20 + }, + "permutation_count":4, + "posterior_sampling":false + }, + "cat_feature_params": + { + "combinations_ctrs": + [ + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Borders", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ], + [ + 0.5, + 1 + ], + [ + 1, + 1 + ] + ], + "target_binarization": + { + "border_count":2, + "border_type":"MinEntropy" + } + }, + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Counter", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ] + ] + } + ], + "counter_calc_method":"SkipTest", + "ctr_leaf_count_limit":18446744073709551615, + "max_ctr_complexity":1, + "one_hot_max_size":2, + "per_feature_ctrs": + { + }, + "simple_ctrs": + [ + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Borders", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ], + [ + 0.5, + 1 + ], + [ + 1, + 1 + ] + ], + "target_binarization": + { + "border_count":2, + "border_type":"MinEntropy" + } + }, + { + "ctr_binarization": + { + "border_count":15, + "border_type":"Uniform" + }, + "ctr_type":"Counter", + "prior_estimation":"No", + "priors": + [ + [ + 0, + 1 + ] + ] + } + ], + "store_all_simple_ctr":false, + "target_binarization": + { + "border_count":1, + "border_type":"MinEntropy" + } + }, + "data_processing_options": + { + "allow_const_label":false, + "auto_class_weights":"None", + "class_names": + [ + 0, + 1, + 2 + ], + "class_weights": + [ + ], + "classes_count":0, + "dev_default_value_fraction_for_sparse":0.82999998331069946, + "dev_group_features":false, + "dev_leafwise_scoring":false, + "dev_sparse_array_indexing":"Indices", + "embedding_processing_options": + { + "embedding_processing": + { + "default": + [ + "LDA", + "KNN" + ] + } + }, + "eval_fraction":0, + "float_features_binarization": + { + "border_count":254, + "border_type":"GreedyLogSum", + "dev_max_subset_size_for_build_borders":200000, + "nan_mode":"Min" + }, + "force_unit_auto_pair_weights":false, + "has_time":false, + "ignored_features": + [ + ], + "per_float_feature_quantization": + { + }, + "target_border":null, + "text_processing_options": + { + "dictionaries": + [ + { + "dictionary_id":"BiGram", + "end_of_sentence_token_policy":"Skip", + "end_of_word_token_policy":"Insert", + "gram_order":"2", + "max_dictionary_size":"50000", + "occurrence_lower_bound":"1", + "skip_step":"0", + "start_token_id":"0", + "token_level_type":"Word" + }, + { + "dictionary_id":"Word", + "end_of_sentence_token_policy":"Skip", + "end_of_word_token_policy":"Insert", + "gram_order":"1", + "max_dictionary_size":"50000", + "occurrence_lower_bound":"1", + "skip_step":"0", + "start_token_id":"0", + "token_level_type":"Word" + } + ], + "feature_processing": + { + "default": + [ + { + "dictionaries_names": + [ + "BiGram", + "Word" + ], + "feature_calcers": + [ + "BoW" + ], + "tokenizers_names": + [ + "Space" + ] + }, + { + "dictionaries_names": + [ + "Word" + ], + "feature_calcers": + [ + "NaiveBayes" + ], + "tokenizers_names": + [ + "Space" + ] + } + ] + }, + "tokenizers": + [ + { + "delimiter":" ", + "languages": + [ + ], + "lemmatizing":"0", + "lowercasing":"0", + "number_process_policy":"LeaveAsIs", + "number_token":"🔢", + "separator_type":"ByDelimiter", + "skip_empty":"1", + "split_by_set":"0", + "subtokens_policy":"SingleToken", + "token_types": + [ + "Number", + "Unknown", + "Word" + ], + "tokenizer_id":"Space" + } + ] + } + }, + "detailed_profile":false, + "flat_params": + { + "depth":3, + "iterations":10, + "loss_function":"MultiClass", + "random_seed":0, + "verbose":0 + }, + "logging_level":"Silent", + "loss_function": + { + "params": + { + }, + "type":"MultiClass" + }, + "metadata": + { + }, + "metrics": + { + "custom_metrics": + [ + ], + "eval_metric": + { + "params": + { + }, + "type":"MultiClass" + }, + "objective_metric": + { + "params": + { + }, + "type":"MultiClass" + } + }, + "pool_metainfo_options": + { + "tags": + { + } + }, + "random_seed":0, + "system_options": + { + "file_with_hosts":"hosts.txt", + "node_port":0, + "node_type":"SingleHost", + "thread_count":10, + "used_ram_limit":"" + }, + "task_type":"CPU", + "tree_learner_options": + { + "bayesian_matrix_reg":0.10000000149011612, + "bootstrap": + { + "bagging_temperature":1, + "type":"Bayesian" + }, + "depth":3, + "dev_efb_max_buckets":1024, + "dev_leafwise_approxes":false, + "dev_score_calc_obj_block_size":5000000, + "grow_policy":"SymmetricTree", + "l2_leaf_reg":3, + "leaf_estimation_backtracking":"AnyImprovement", + "leaf_estimation_iterations":1, + "leaf_estimation_method":"Newton", + "max_leaves":8, + "min_data_in_leaf":1, + "model_size_reg":0.5, + "monotone_constraints": + { + }, + "penalties": + { + "feature_weights": + { + }, + "first_feature_use_penalties": + { + }, + "penalties_coefficient":1, + "per_object_feature_penalties": + { + } + }, + "random_strength":1, + "rsm":1, + "sampling_frequency":"PerTree", + "score_function":"Cosine", + "sparse_features_conflict_fraction":0 + } + }, + "train_finish_time":"2023-07-19T00:13:50Z" + }, + "oblivious_trees": + [ + { + "leaf_values": + [ + 0.84745762711864536, + -0.42372881355932229, + -0.42372881355932229, + -0.41666666666666696, + 0.83333333333333437, + -0.41666666666666696, + 0, + 0, + 0, + -0.15384615384615383, + -0.038461538461538429, + 0.19230769230769237, + 0, + 0, + 0, + -0.21874999999999994, + 0.062500000000000042, + 0.15625000000000003, + 0, + 0, + 0, + -0.41509433962264175, + -0.38679245283018904, + 0.80188679245283112 + ], + "leaf_weights": + [ + 50, + 45, + 0, + 4, + 0, + 7, + 0, + 44 + ], + "splits": + [ + { + "border":0.80000001192092896, + "float_feature_index":3, + "split_index":13, + "split_type":"FloatFeature" + }, + { + "border":1.6500000953674316, + "float_feature_index":3, + "split_index":17, + "split_type":"FloatFeature" + }, + { + "border":4.8500003814697266, + "float_feature_index":2, + "split_index":7, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.05371409384946401, + -0.026857046924732009, + -0.026857046924732009, + -0.13378690974898891, + 0.26757381949797776, + -0.13378690974898894, + 0, + 0, + 0, + -0.039986660898481866, + -0.051615151110821995, + 0.09160181200930391, + 0.43916158306858172, + -0.21958079153429225, + -0.21958079153429222, + -0.21214867350977784, + 0.36190573163947565, + -0.1497570581296977, + 0, + 0, + 0, + -0.26711032255601891, + -0.12725538161472449, + 0.39436570417074179 + ], + "leaf_weights": + [ + 1, + 9, + 0, + 1, + 47, + 42, + 0, + 50 + ], + "splits": + [ + { + "border":0.44999998807907104, + "float_feature_index":3, + "split_index":11, + "split_type":"FloatFeature" + }, + { + "border":4.8500003814697266, + "float_feature_index":2, + "split_index":7, + "split_type":"FloatFeature" + }, + { + "border":2.4500000476837158, + "float_feature_index":1, + "split_index":3, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.34918772181349389, + -0.17459386090674711, + -0.17459386090674706, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.025199632212335312, + 0.098820682482484198, + -0.12402031469481938, + 0, + 0, + 0, + -0.22767126947440264, + 0.33747314653009375, + -0.10980187705569329, + -0.20715844237747794, + -0.14375006806909443, + 0.35090851044657534 + ], + "leaf_weights": + [ + 48, + 0, + 0, + 0, + 8, + 0, + 46, + 48 + ], + "splits": + [ + { + "border":1.6500000953674316, + "float_feature_index":3, + "split_index":17, + "split_type":"FloatFeature" + }, + { + "border":3.6500000953674316, + "float_feature_index":2, + "split_index":6, + "split_type":"FloatFeature" + }, + { + "border":0.44999998807907104, + "float_feature_index":3, + "split_index":11, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.31215375600815054, + -0.16139705047810729, + -0.15075670553004436, + -0.24716825933127606, + 0.25992227998765965, + -0.012754020656382923, + 0, + 0, + 0, + -0.12902852588563279, + -0.166165946434645, + 0.29519447232027851, + 0, + 0, + 0, + -0.011422312095993073, + -0.014310032381433689, + 0.025732344477426752, + 0, + 0, + 0, + -0.045918473735984879, + -0.056291979412333541, + 0.10221045314831836 + ], + "leaf_weights": + [ + 50, + 65, + 0, + 29, + 0, + 1, + 0, + 5 + ], + "splits": + [ + { + "border":0.80000001192092896, + "float_feature_index":3, + "split_index":13, + "split_type":"FloatFeature" + }, + { + "border":5.1499996185302734, + "float_feature_index":2, + "split_index":9, + "split_type":"FloatFeature" + }, + { + "border":2.3499999046325684, + "float_feature_index":3, + "split_index":19, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.26027472637771182, + -0.13228665759616232, + -0.12798806878154978, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.14757045143532338, + 0.18622704398398404, + -0.038656592548659864, + -0.12169566443189685, + -0.11783616320542475, + 0.2395318276373217, + -0.061497588450030574, + 0.17444404164777497, + -0.11294645319774443, + -0.091286548785742591, + -0.10796412182370456, + 0.19925067060944743 + ], + "leaf_weights": + [ + 49, + 0, + 0, + 0, + 48, + 25, + 7, + 21 + ], + "splits": + [ + { + "border":1.75, + "float_feature_index":3, + "split_index":18, + "split_type":"FloatFeature" + }, + { + "border":6.6499996185302734, + "float_feature_index":0, + "split_index":0, + "split_type":"FloatFeature" + }, + { + "border":0.55000001192092896, + "float_feature_index":3, + "split_index":12, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.016512677613825251, + 0.071664007875940947, + -0.055151330262115647, + -0.086972291831742252, + 0.13088099125894129, + -0.04390869942719882, + -0.025596303379558448, + -0.067231954321729376, + 0.092828257701287831, + -0.10811412203852651, + -0.068707463383110151, + 0.17682158542163673, + 0, + 0, + 0, + 0.19752437109286178, + -0.021257256824621505, + -0.17626711426823929, + 0, + 0, + 0, + -0.16613167360946043, + 0.061782767037950208, + 0.10434890657151058 + ], + "leaf_weights": + [ + 6, + 19, + 2, + 20, + 0, + 61, + 0, + 42 + ], + "splits": + [ + { + "border":2.3499999046325684, + "float_feature_index":1, + "split_index":2, + "split_type":"FloatFeature" + }, + { + "border":1.4500000476837158, + "float_feature_index":3, + "split_index":16, + "split_type":"FloatFeature" + }, + { + "border":2.8499999046325684, + "float_feature_index":1, + "split_index":4, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.094599568809820414, + 0.098181768476517536, + -0.19278133728634017, + -0.035827429530033955, + -0.037712165420130256, + 0.073539594950164211, + -0.031644374130690781, + 0.10543743927593076, + -0.073793065145240047, + -0.014071313610320308, + -0.053020407250554027, + 0.067091720860874365, + -0.064099445159599519, + -0.073877379367392609, + 0.13797682452699223, + -0.05598952825199642, + -0.086266436069089333, + 0.14225596432108578, + -0.0085307173927249638, + 0.13132104978435397, + -0.12279033239162902, + -0.065628710714512414, + -0.091180699723714093, + 0.15680941043822638 + ], + "leaf_weights": + [ + 93, + 3, + 5, + 1, + 9, + 17, + 1, + 21 + ], + "splits": + [ + { + "border":5.0500001907348633, + "float_feature_index":2, + "split_index":8, + "split_type":"FloatFeature" + }, + { + "border":6.6499996185302734, + "float_feature_index":0, + "split_index":0, + "split_type":"FloatFeature" + }, + { + "border":1.6500000953674316, + "float_feature_index":3, + "split_index":17, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + 0.20229917333244413, + -0.12217460382349277, + -0.080124569508951179, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.1410956880307826, + 0.20722569018112538, + -0.06613000215034201, + -0.087836772890955428, + -0.0085643925957112864, + 0.096401165486666704, + -0.025510955048121609, + -0.091542707401980089, + 0.11705366245010165, + -0.068380551454905195, + -0.090236808610626343, + 0.15861736006553151 + ], + "leaf_weights": + [ + 50, + 0, + 0, + 0, + 50, + 20, + 2, + 28 + ], + "splits": + [ + { + "border":1.6500000953674316, + "float_feature_index":3, + "split_index":17, + "split_type":"FloatFeature" + }, + { + "border":5.3500003814697266, + "float_feature_index":2, + "split_index":10, + "split_type":"FloatFeature" + }, + { + "border":0.80000001192092896, + "float_feature_index":3, + "split_index":13, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.011460547387877912, + 0.023547718772386563, + -0.012087171384508648, + 0.14048401154483015, + -0.01892527553470048, + -0.1215587360101297, + -0.021922120351086694, + -0.086098734203301525, + 0.10802085455438826, + -0.12243371205915753, + 0.17527834035526058, + -0.052844628296102862, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.091587788906639211, + -0.11290738963767658, + 0.20449517854431595 + ], + "leaf_weights": + [ + 2, + 63, + 2, + 41, + 0, + 0, + 0, + 42 + ], + "splits": + [ + { + "border":2.25, + "float_feature_index":1, + "split_index":1, + "split_type":"FloatFeature" + }, + { + "border":1.25, + "float_feature_index":3, + "split_index":14, + "split_type":"FloatFeature" + }, + { + "border":5.0500001907348633, + "float_feature_index":2, + "split_index":8, + "split_type":"FloatFeature" + } + ] + }, + { + "leaf_values": + [ + -0.01111746897746046, + 0.022835579212510018, + -0.011718110235049589, + 0.036231275585282982, + 0.078349059124290812, + -0.11458033470957384, + 0, + 0, + 0, + 0.097772443355430491, + -0.064230761527239208, + -0.033541681828191283, + -0.022015436190169182, + -0.07838860676619526, + 0.10040404295636443, + -0.14613920719576645, + 0.042379351518289816, + 0.10375985567747636, + 0, + 0, + 0, + -0.0078034395089186491, + -0.010973070590466861, + 0.018776510099385486 + ], + "leaf_weights": + [ + 2, + 54, + 0, + 22, + 2, + 67, + 0, + 3 + ], + "splits": + [ + { + "border":2.25, + "float_feature_index":1, + "split_index":1, + "split_type":"FloatFeature" + }, + { + "border":3.4500000476837158, + "float_feature_index":1, + "split_index":5, + "split_type":"FloatFeature" + }, + { + "border":1.3499999046325684, + "float_feature_index":3, + "split_index":15, + "split_type":"FloatFeature" + } + ] + } + ], + "scale_and_bias": + [ + 1, + [ + 0, + 0, + 0 + ] + ] +} \ No newline at end of file From 178cbafee257ce61b7da8a3fbe58ba34d87baaac Mon Sep 17 00:00:00 2001 From: acalejos Date: Wed, 19 Jul 2023 15:41:21 -0400 Subject: [PATCH 4/5] Make tests fail --- lib/mockingjay/strategies/perfect_tree_traversal.ex | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/mockingjay/strategies/perfect_tree_traversal.ex b/lib/mockingjay/strategies/perfect_tree_traversal.ex index 4108ae1..c3d1cb2 100644 --- a/lib/mockingjay/strategies/perfect_tree_traversal.ex +++ b/lib/mockingjay/strategies/perfect_tree_traversal.ex @@ -167,7 +167,7 @@ defmodule Mockingjay.Strategies.PerfectTreeTraversal do |> Nx.reshape({:auto}) |> forward_reduce_features(x, features, thresholds, opts) - Nx.take(values |> print_value(), prev_indices) + Nx.take(values, prev_indices) |> Nx.reshape({:auto, opts[:num_trees], opts[:n_classes]}) end @@ -183,15 +183,13 @@ defmodule Mockingjay.Strategies.PerfectTreeTraversal do end defnp _inner_reduce(x, nodes, biases, acc, opts \\ []) do - gather_indices = - nodes |> print_value() |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]}) + gather_indices = nodes |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]}) features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto}) acc - |> print_value() |> Nx.multiply(@factor) - |> Nx.add(opts[:condition].(features, Nx.take(biases |> print_value(), acc))) + |> Nx.add(opts[:condition].(features, Nx.take(biases, acc))) end defp make_tree_perfect(tree, current_depth, max_depth) do From 40fae5be622d232a8a78a5cb518708e20ab01e36 Mon Sep 17 00:00:00 2001 From: acalejos Date: Sat, 2 Sep 2023 15:29:35 -0400 Subject: [PATCH 5/5] Change output from n_trees -> n_trees_per_class --- lib/mockingjay.ex | 5 +---- lib/mockingjay/strategies/gemm.ex | 6 +++--- lib/mockingjay/strategies/perfect_tree_traversal.ex | 4 +++- lib/mockingjay/strategies/tree_traversal.ex | 3 ++- test/mockingjay/catboost_test.exs | 2 +- test/mockingjay/exgboost_test.exs | 3 +-- 6 files changed, 11 insertions(+), 12 deletions(-) diff --git a/lib/mockingjay.ex b/lib/mockingjay.ex index ec91cce..0e43b23 100644 --- a/lib/mockingjay.ex +++ b/lib/mockingjay.ex @@ -80,11 +80,8 @@ defmodule Mockingjay do defp aggregate(x, n_trees, n_classes) do cond do n_classes > 1 and n_trees > 1 -> - n_gbdt_classes = if n_classes > 2, do: n_classes, else: 1 - n_trees_per_class = trunc(n_trees / n_gbdt_classes) - x - |> Nx.reshape({:auto, n_gbdt_classes, n_trees_per_class}) + |> Nx.reshape({:auto, n_classes, n_trees}) |> Nx.sum(axes: [2]) n_classes > 1 and n_trees == 1 -> diff --git a/lib/mockingjay/strategies/gemm.ex b/lib/mockingjay/strategies/gemm.ex index 72e19a1..c35f9ae 100644 --- a/lib/mockingjay/strategies/gemm.ex +++ b/lib/mockingjay/strategies/gemm.ex @@ -48,7 +48,7 @@ defmodule Mockingjay.Strategies.GEMM do max(h2, length(Tree.get_leaf_nodes(tree)))} end) - n_trees = length(trees) |> IO.inspect(label: "n_trees") + n_trees = length(trees) {mat_A, mat_B} = generate_matrices_AB(trees, num_features, max_decision_nodes) mat_C = generate_matrix_C(trees, max_decision_nodes, max_leaf_nodes) @@ -99,6 +99,7 @@ defmodule Mockingjay.Strategies.GEMM do max_decision_nodes = opts[:max_decision_nodes] max_leaf_nodes = opts[:max_leaf_nodes] n_weak_learner_classes = opts[:n_weak_learner_classes] + n_trees_per_class = div(n_trees, n_classes) mat_A |> Nx.dot([1], x, [1]) @@ -111,7 +112,7 @@ defmodule Mockingjay.Strategies.GEMM do |> then(&Nx.dot(mat_E, [2], [0], &1, [1], [0])) |> Nx.reshape({n_trees, n_weak_learner_classes, :auto}) |> Nx.transpose() - |> Nx.reshape({:auto, n_trees, n_classes}) + |> Nx.reshape({:auto, n_trees_per_class, n_classes}) end # Leaves are ordered as DFS rather than BFS that internal nodes are @@ -227,7 +228,6 @@ defmodule Mockingjay.Strategies.GEMM do d = Nx.indexed_put(d_zero, d_indices, d_updates) e_updates = Nx.tensor(updates_list) - IO.inspect(n_weak_learner_classes, label: "n_weak_learner_classes") e_zero = Nx.broadcast(0, {n_trees, n_weak_learner_classes, max_leaf_nodes}) e = Nx.indexed_put(e_zero, e_indices, e_updates) diff --git a/lib/mockingjay/strategies/perfect_tree_traversal.ex b/lib/mockingjay/strategies/perfect_tree_traversal.ex index ccbbbb7..7352a73 100644 --- a/lib/mockingjay/strategies/perfect_tree_traversal.ex +++ b/lib/mockingjay/strategies/perfect_tree_traversal.ex @@ -170,6 +170,8 @@ defmodule Mockingjay.Strategies.PerfectTreeTraversal do indices, opts \\ [] ) do + n_trees_per_class = div(opts[:num_trees], opts[:n_classes]) + prev_indices = x |> Nx.take(root_features, axis: 1) @@ -179,7 +181,7 @@ defmodule Mockingjay.Strategies.PerfectTreeTraversal do |> forward_reduce_features(x, features, thresholds, opts) Nx.take(values, prev_indices) - |> Nx.reshape({:auto, opts[:num_trees], opts[:n_classes]}) + |> Nx.reshape({:auto, n_trees_per_class, opts[:n_classes]}) end deftransformp forward_reduce_features(prev_indices, x, features, thresholds, opts \\ []) do diff --git a/lib/mockingjay/strategies/tree_traversal.ex b/lib/mockingjay/strategies/tree_traversal.ex index 8dd5c9c..494d112 100644 --- a/lib/mockingjay/strategies/tree_traversal.ex +++ b/lib/mockingjay/strategies/tree_traversal.ex @@ -165,6 +165,7 @@ defmodule Mockingjay.Strategies.TreeTraversal do n_classes = opts[:n_classes] condition = opts[:condition] unroll = opts[:unroll] + n_trees_per_class = div(num_trees, n_classes) batch_size = Nx.axis_size(x, 0) @@ -197,6 +198,6 @@ defmodule Mockingjay.Strategies.TreeTraversal do values |> Nx.take(indices) - |> Nx.reshape({:auto, num_trees, n_classes}) + |> Nx.reshape({:auto, n_trees_per_class, n_classes}) end end diff --git a/test/mockingjay/catboost_test.exs b/test/mockingjay/catboost_test.exs index 96632f2..73bfd92 100644 --- a/test/mockingjay/catboost_test.exs +++ b/test/mockingjay/catboost_test.exs @@ -19,7 +19,7 @@ defmodule CatboostTest do trees = DecisionTree.trees(booster) assert is_list(trees) - assert is_struct(hd(trees) |> IO.inspect(label: "Tree"), Mockingjay.Tree) + assert is_struct(hd(trees) assert DecisionTree.num_classes(booster) == expected_num_class assert DecisionTree.num_features(booster) == expected_num_feature end diff --git a/test/mockingjay/exgboost_test.exs b/test/mockingjay/exgboost_test.exs index 95522b9..e6ece8e 100644 --- a/test/mockingjay/exgboost_test.exs +++ b/test/mockingjay/exgboost_test.exs @@ -59,7 +59,7 @@ defmodule EXGBoostTest do EXGBoost.predict(booster, context.x_test) |> Nx.argmax(axis: -1) - context.x_test |> IO.inspect(label: "context.x_test") + context.x_test preds2 = gemm_predict.(context.x_test) |> Nx.argmax(axis: -1) preds3 = tt_predict.(context.x_test) |> Nx.argmax(axis: -1) @@ -90,6 +90,5 @@ defmodule EXGBoostTest do assert tt_accuracy >= base_acc assert ptt_accuracy >= base_acc assert auto_accuracy >= base_acc - assert false end end