diff --git a/inst/include/TreeTools/renumber_tree.h b/inst/include/TreeTools/renumber_tree.h index 418f5909b..fe1c0f92b 100644 --- a/inst/include/TreeTools/renumber_tree.h +++ b/inst/include/TreeTools/renumber_tree.h @@ -14,41 +14,24 @@ namespace TreeTools { -inline void swap(int32 *a, int32 *b) { - const int32 temp = *a; - *a = *b; - *b = temp; -} - -inline void insertion_sort_by_smallest(int32* arr, const int32 arr_len, - const int32* sort_by) { - ASSERT(arr_len > 0); - switch (arr_len) { - case 1: return; - case 2: - if (sort_by[arr[0]] > sort_by[arr[1]]) { - swap(&arr[0], &arr[1]); - } - return; - } - std::sort(arr, arr + arr_len, [&](int32 a, int32 b) { - return sort_by[a] < sort_by[b]; - }); -} +// We'll use this a sentinel type to handle the unweighted case +struct NoWeights {}; +// Used to conditionally create a type +struct DummyDoubleVector {}; struct TreeData { - int32 n_edge; - int32 node_limit; + int32_t n_edge; + int32_t node_limit; - std::vector memory_block; + std::vector memory_block; - int32* parent_of; - int32* n_children; - int32* smallest_desc; - int32* children_start_idx; - int32* children_data; + int32_t* parent_of; + int32_t* n_children; + int32_t* smallest_desc; + int32_t* children_start_idx; + int32_t* children_data; - TreeData(int32 num_edges) + TreeData(int32_t num_edges) : n_edge(num_edges), node_limit(num_edges + 2), memory_block( @@ -69,137 +52,55 @@ struct TreeData { }; struct Frame { - int32 node; - int32 parent_label; - int32 child_index; - int32 child_count; - const int32* node_children; -}; - -struct PreorderState { - TreeData& data; - int32 next_edge; - int32 next_label; - int32 n_tip; - int32 root_node; - Rcpp::IntegerMatrix& ret_edges; - - PreorderState(TreeData& d, int32 nt, int32 rn, Rcpp::IntegerMatrix& edges) - : data(d), next_edge(0), next_label(nt + 2), n_tip(nt), - root_node(rn), ret_edges(edges) {} + int32_t node; + int32_t parent_label; + int32_t child_index; + int32_t child_count; + const int32_t* node_children; }; -template -inline void traverse_preorder(PreorderState& state, - const double* wt_above = nullptr, - Rcpp::NumericVector* ret_weights = nullptr) { - // Use a fixed-size stack for most trees to avoid heap allocation - constexpr size_t STACK_SIZE = 128; - std::array fast_stack; - std::vector heap_stack; - - size_t stack_pos = 0; - bool use_heap = false; - - auto push_frame = [&](const Frame& f) { - if (stack_pos < STACK_SIZE && !use_heap) { - fast_stack[stack_pos++] = f; - } else { - if (!use_heap) { - // Migrate to heap - heap_stack.reserve(STACK_SIZE * 2); - for (size_t i = 0; i < stack_pos; ++i) { - heap_stack.push_back(fast_stack[i]); - } - use_heap = true; - } - heap_stack.push_back(f); - } - }; - - auto pop_frame = [&]() -> Frame& { - return use_heap ? heap_stack.back() : fast_stack[stack_pos - 1]; - }; - - auto pop = [&]() { - if (use_heap) { - heap_stack.pop_back(); - } else { - --stack_pos; - } - }; - - auto empty = [&]() { - return use_heap ? heap_stack.empty() : stack_pos == 0; - }; - - // Initialize with root - int32 child_count = state.data.n_children[state.root_node]; - if (child_count > 0) { - push_frame({state.root_node, state.n_tip + 1, 0, child_count, - state.data.children_data + - state.data.children_start_idx[state.root_node]}); - } - - while (!empty()) { - Frame& top = pop_frame(); - - if (top.child_index == top.child_count) { - pop(); - continue; - } - - int32 child_node = top.node_children[top.child_index]; - - state.ret_edges(state.next_edge, 0) = top.parent_label; - - if constexpr (HasWeights) { - (*ret_weights)[state.next_edge] = wt_above[child_node]; - } - - if (state.data.n_children[child_node] == 0) { - // Leaf node - state.ret_edges(state.next_edge, 1) = child_node; - ++state.next_edge; - ++top.child_index; - } else { - // Internal node - int32 child_label = state.next_label++; - state.ret_edges(state.next_edge, 1) = child_label; - ++state.next_edge; - ++top.child_index; - - push_frame({child_node, child_label, 0, state.data.n_children[child_node], - state.data.children_data + state.data.children_start_idx[child_node]}); - } - } -} - -// Separate functions to avoid template bloat in the setup code -// There's scope for less repetition, but not, it seems, without incurring -// a slowdown -- see https://github.com/ms609/TreeTools/pull/205 -inline Rcpp::IntegerMatrix preorder_unweighted_impl( +template +inline RetType preorder_core( const Rcpp::IntegerVector& parent, - const Rcpp::IntegerVector& child) { + const Rcpp::IntegerVector& child, + const W& weights) +{ + const int32_t n_edge = parent.length(); + if (R_xlen_t(2LL + child.length() + 2LL + child.length()) > R_xlen_t(INT_FAST32_MAX)) { + Rcpp::stop("Too many edges in tree: Contact 'TreeTools' maintainer for support."); + } - const int32 n_edge = parent.length(); if (child.length() != n_edge) { Rcpp::stop("Length of parent and child must match"); } TreeData data(n_edge); - int32 root_node = 0; - int32 n_tip = 0; + int32_t root_node = n_edge * 2; + int32_t n_tip = 0; - for (int32 i = n_edge; i--; ) { - const int32 child_i = child[i]; - const int32 parent_i = parent[i]; + std::conditional_t, DummyDoubleVector, std::vector> wt_above_storage; + const std::vector* wt_above_ptr = nullptr; + + if constexpr (!std::is_same_v) { + if (weights.length() != n_edge) { + Rcpp::stop("weights must match number of edges"); + } + wt_above_storage.resize(data.node_limit); + wt_above_ptr = &wt_above_storage; + } + + for (int32_t i = 0; i < n_edge; ++i) { + const int32_t child_i = child[i]; + const int32_t parent_i = parent[i]; data.parent_of[child_i] = parent_i; ++data.n_children[parent_i]; + if constexpr (!std::is_same_v) { + wt_above_storage[child_i] = weights[i]; + } } - int32 current_idx = 0; - for (int32 i = 1; i < data.node_limit; i++) { + int32_t current_idx = 0; + for (int32_t i = 1; i < data.node_limit; i++) { if (!data.parent_of[i]) { root_node = i; } @@ -211,9 +112,9 @@ inline Rcpp::IntegerMatrix preorder_unweighted_impl( } } - for (int32 tip = 1; tip < n_tip + 1; ++tip) { + for (int32_t tip = 1; tip < n_tip + 1; ++tip) { data.smallest_desc[tip] = tip; - int32 parent_node = data.parent_of[tip]; + int32_t parent_node = data.parent_of[tip]; while (parent_node && !data.smallest_desc[parent_node]) { data.smallest_desc[parent_node] = tip; parent_node = data.parent_of[parent_node]; @@ -221,126 +122,136 @@ inline Rcpp::IntegerMatrix preorder_unweighted_impl( } std::fill(data.n_children, data.n_children + data.node_limit, 0); - for (int32 i = 0; i < n_edge; ++i) { - int32 p = parent[i]; - int32 insert_pos = data.children_start_idx[p] + data.n_children[p]; + for (int32_t i = 0; i < n_edge; ++i) { + int32_t p = parent[i]; + int32_t insert_pos = data.children_start_idx[p] + data.n_children[p]; data.children_data[insert_pos] = child[i]; ++data.n_children[p]; } - - for (int32 node = n_tip + 1; node < data.node_limit; ++node) { - int32* node_children = data.children_data + data.children_start_idx[node]; - insertion_sort_by_smallest(node_children, data.n_children[node], - data.smallest_desc); + + for (int32_t node = n_tip + 1; node < data.node_limit; ++node) { + int32_t* node_children = data.children_data + data.children_start_idx[node]; + std::sort(node_children, node_children + data.n_children[node], + [&](int32_t a, int32_t b) { + return data.smallest_desc[a] < data.smallest_desc[b]; + }); } - Rcpp::IntegerMatrix ret_edges(n_edge, 2); - PreorderState state(data, n_tip, root_node, ret_edges); + int32_t next_edge = 0; + int32_t next_label = n_tip + 2; - traverse_preorder(state); + Rcpp::IntegerMatrix ret_edges(n_edge, 2); - return ret_edges; -} - -inline std::pair preorder_weighted_impl( - const Rcpp::IntegerVector& parent, - const Rcpp::IntegerVector& child, - const Rcpp::DoubleVector& weights) { + std::conditional_t, DummyDoubleVector, Rcpp::NumericVector> ret_weights; - const int32 n_edge = parent.length(); - if (child.length() != n_edge || weights.length() != n_edge) { - Rcpp::stop("Length mismatch"); + if constexpr (!std::is_same_v) { + ret_weights = Rcpp::NumericVector(n_edge); } - TreeData data(n_edge); - std::vector wt_above_storage(data.node_limit); - int32 root_node = 0; - int32 n_tip = 0; + std::stack stack; + int32_t root_label = n_tip + 1; - // Setup with weights - for (int32 i = n_edge; i--; ) { - const int32 child_i = child[i]; - const int32 parent_i = parent[i]; - data.parent_of[child_i] = parent_i; - ++data.n_children[parent_i]; - wt_above_storage[child_i] = weights[i]; + // Initialize with root node children + { + int32_t child_count = data.n_children[root_node]; + if (child_count > 0) { + stack.push(Frame{root_node, root_label, 0, child_count, data.children_data + data.children_start_idx[root_node]}); + } } - // ... rest of setup same as unweighted ... - int32 current_idx = 0; - for (int32 i = 1; i < data.node_limit; i++) { - if (!data.parent_of[i]) { - root_node = i; + while (!stack.empty()) { + Frame& top = stack.top(); + + if (top.child_index == top.child_count) { + stack.pop(); + continue; } - if (!data.n_children[i]) { - ++n_tip; - } else { - data.children_start_idx[i] = current_idx; - current_idx += data.n_children[i]; + + int32_t child_node = top.node_children[top.child_index]; + + ret_edges(next_edge, 0) = top.parent_label; + if constexpr (!std::is_same_v) { + ret_weights[next_edge] = (*wt_above_ptr)[child_node]; } - } - - for (int32 tip = 1; tip < n_tip + 1; ++tip) { - data.smallest_desc[tip] = tip; - int32 parent_node = data.parent_of[tip]; - while (parent_node && !data.smallest_desc[parent_node]) { - data.smallest_desc[parent_node] = tip; - parent_node = data.parent_of[parent_node]; + + if (data.n_children[child_node] == 0) { + ret_edges(next_edge, 1) = child_node; + ++next_edge; + ++top.child_index; + } else { + int32_t child_label = next_label++; + ret_edges(next_edge, 1) = child_label; + ++next_edge; + ++top.child_index; + + int32_t child_count = data.n_children[child_node]; + const int32_t* child_children = data.children_data + data.children_start_idx[child_node]; + stack.push(Frame{child_node, child_label, 0, child_count, child_children}); } } - std::fill(data.n_children, data.n_children + data.node_limit, 0); - for (int32 i = 0; i < n_edge; ++i) { - int32 p = parent[i]; - int32 insert_pos = data.children_start_idx[p] + data.n_children[p]; - data.children_data[insert_pos] = child[i]; - ++data.n_children[p]; + if constexpr (std::is_same_v) { + return ret_edges; + } else { + return std::make_pair(ret_edges, ret_weights); } - - for (int32 node = n_tip + 1; node < data.node_limit; ++node) { - int32* node_children = data.children_data + data.children_start_idx[node]; - insertion_sort_by_smallest(node_children, data.n_children[node], - data.smallest_desc); - } - - Rcpp::IntegerMatrix ret_edges(n_edge, 2); - Rcpp::NumericVector ret_weights(n_edge); - PreorderState state(data, n_tip, root_node, ret_edges); - - traverse_preorder(state, wt_above_storage.data(), &ret_weights); - - return std::make_pair(ret_edges, ret_weights); } +// === PUBLIC EXPORTED FUNCTIONS === // [[Rcpp::export]] inline Rcpp::IntegerMatrix preorder_edges_and_nodes( const Rcpp::IntegerVector parent, - const Rcpp::IntegerVector child) { - return preorder_unweighted_impl(parent, child); + const Rcpp::IntegerVector child) +{ + return preorder_core(parent, child, NoWeights{}); } // [[Rcpp::export]] inline Rcpp::List preorder_weighted( const Rcpp::IntegerVector& parent, const Rcpp::IntegerVector& child, - const Rcpp::DoubleVector& weight) { + const Rcpp::DoubleVector& weight) +{ + std::pair result = + preorder_core>(parent, child, weight); - auto result = preorder_weighted_impl(parent, child, weight); - return Rcpp::List::create( + Rcpp::List ret = Rcpp::List::create( Rcpp::Named("edge") = result.first, Rcpp::Named("edge.length") = result.second ); -} - -inline std::pair preorder_weighted_pair( - const Rcpp::IntegerVector& parent, - const Rcpp::IntegerVector& child, - const Rcpp::DoubleVector& weights) { - return preorder_weighted_impl(parent, child, weights); + return ret; } + + + + + + + inline int32 get_subtree_size(int32 node, int32 *subtree_size, + int32 *n_children, int32 **children_of, + int32 n_edge) { + if (!subtree_size[node]) { + for (int32 i = n_children[node]; i--; ) { + subtree_size[node] += get_subtree_size(children_of[node][i], + subtree_size, n_children, children_of, n_edge); + } + } + return subtree_size[node]; + } + + + + + + + + + + + template struct SmallBuffer { diff --git a/inst/include/TreeTools/root_tree.h b/inst/include/TreeTools/root_tree.h index 7abc5f1b6..2dd7a27ec 100644 --- a/inst/include/TreeTools/root_tree.h +++ b/inst/include/TreeTools/root_tree.h @@ -13,10 +13,11 @@ namespace TreeTools { const Rcpp::IntegerVector parent, const Rcpp::IntegerVector child); - extern inline std::pair preorder_weighted_pair( - const Rcpp::IntegerVector& parent, - const Rcpp::IntegerVector& child, - const Rcpp::DoubleVector& weights); + template + extern inline RetType preorder_core( + const Rcpp::IntegerVector& parent, + const Rcpp::IntegerVector& child, + const W& weights); // edge must be BINARY // edge must be in preorder @@ -112,7 +113,7 @@ namespace TreeTools { Rcpp::IntegerVector parent_col(edge(Rcpp::_, 0)); Rcpp::IntegerVector child_col(edge(Rcpp::_, 1)); Rcpp::NumericVector edge_len = phy["edge.length"]; - std::tie(edge, weight) = preorder_weighted_pair( + std::tie(edge, weight) = preorder_core>( parent_col, child_col, edge_len @@ -180,7 +181,7 @@ namespace TreeTools { new_child[root_edges[spare_edge]] = outgroup; if (weighted) { - std::tie(edge, weight) = preorder_weighted_pair(new_parent, new_child, weight); + std::tie(edge, weight) = preorder_core>(new_parent, new_child, weight); ret["edge"] = edge; ret["edge.length"] = weight; } else { @@ -210,7 +211,7 @@ namespace TreeTools { Rcpp::NumericVector new_wt(n_edge + 1); std::copy(weight.begin(), weight.end(), new_wt.begin()); new_wt[n_edge] = 0; - std::tie(edge, weight) = preorder_weighted_pair(new_parent, new_child, new_wt); + std::tie(edge, weight) = preorder_core>(new_parent, new_child, new_wt); ret["edge"] = edge; ret["edge.length"] = weight; } else {