Skip to content

Conversation

@acalejos
Copy link
Owner

@acalejos acalejos commented Jul 5, 2023

Currently, all strategies work with JIT compilation with the caveat that Perfect Tree Traversal is only working with calls to print_value that seemingly transfer tensors to the correct backend. When those calls are absent, EXLA.jit does not work on PTT, so more works needs to be done to fix that

gather_indices =
nodes |> print_value() |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]})

features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto})
features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.flatten()

acalejos added 6 commits July 14, 2023 23:26
Merging in the other protocol implementations since in that branch I also
redid testing suites and brough in exgboost as a dependency
…:acalejos/mockingjay into make_tree_travs_jit_compilable
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants