From 263b72eebdadf1d72d88290bdb4b0cee97ea74f5 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 9 May 2025 10:40:09 -0600 Subject: [PATCH 01/18] added runtime function to inrange searches --- src/ball_tree.jl | 16 ++++++++++------ src/brute_tree.jl | 11 ++++++++--- src/inrange.jl | 24 ++++++++++++++++++++++++ src/kd_tree.jl | 16 ++++++++++------ src/tree_ops.jl | 5 ++++- 5 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 1be8cc3..6d2245d 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -177,16 +177,20 @@ end function _inrange(tree::BallTree{V}, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V} + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + point_index::Int = 1, + runtime_function::Union{Nothing, Function} = nothing) where {V} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder + return inrange_kernel!(tree, 1, point, ball, idx_in_ball, runtime_function, point_index) # Call the recursive range finder end function inrange_kernel!(tree::BallTree, index::Int, point::AbstractVector, query_ball::HyperSphere, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + runtime_function::Union{Nothing, Function}, + point_index::Int) if index > length(tree.hyper_spheres) return 0 @@ -204,7 +208,7 @@ function inrange_kernel!(tree::BallTree, # At a leaf node, check all points in the leaf node if isleaf(tree.tree_data.n_internal_nodes, index) r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(idx_in_ball, tree, index, point, r, runtime_function, point_index) end count = 0 @@ -215,8 +219,8 @@ function inrange_kernel!(tree::BallTree, count += addall(tree, index, idx_in_ball) else # Recursively call the left and right sub tree. - count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball) - count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball) + count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, runtime_function, point_index) + count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, runtime_function, point_index) end return count end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index ed5ce9a..313b1a5 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -61,21 +61,26 @@ end function _inrange(tree::BruteTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) - return inrange_kernel!(tree, point, radius, idx_in_ball) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + point_index::Int = 1, + runtime_function::Union{Nothing, Function} = nothing) + return inrange_kernel!(tree, point, radius, idx_in_ball, runtime_function, point_index) end function inrange_kernel!(tree::BruteTree, point::AbstractVector, r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + runtime_function::Union{Nothing, Function}, + point_index::Int) count = 0 for i in 1:length(tree.data) d = evaluate(tree.metric, tree.data[i], point) if d <= r count += 1 idx_in_ball !== nothing && push!(idx_in_ball, i) + !isnothing(runtime_function) && runtime_function(point_index, idx, point) end end return count diff --git a/src/inrange.jl b/src/inrange.jl index d271639..017ffbe 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -104,3 +104,27 @@ function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) end return inrangecount(tree, new_data, radius) end + +""" + inrange_runtime(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::Function) where {V, T <: Number} + + Compute a runtime function for all in range queries. +""" +function inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::F) where {V, T <: AbstractVector, F} + check_input(tree, points) + check_radius(radius) + + for i in eachindex(points) + _inrange(tree, points[i], radius, nothing, i, runtime_function) + end + return nothing +end + +function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F) where {V, T <: Number, F} + check_input(tree, points) + check_radius(radius) + for i in axes(points,2) + _inrange(tree, view(points,:,i), radius, nothing, i, runtime_function) + end + return nothing +end \ No newline at end of file diff --git a/src/kd_tree.jl b/src/kd_tree.jl index 5518d7d..8327cdb 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -207,10 +207,12 @@ end function _inrange(tree::KDTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[]) + idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[], + point_index::Int = 1, + runtime_function::Union{Nothing, Function} = nothing) init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) return inrange_kernel!(tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, - tree.hyper_rec, init_min) + tree.hyper_rec, init_min, runtime_function, point_index) end # Explicitly check the distance between leaf node and point while traversing @@ -220,7 +222,9 @@ function inrange_kernel!(tree::KDTree, r::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}}, hyper_rec::HyperRectangle, - min_dist) + min_dist, + runtime_function::Union{Nothing, Function}, + point_index::Int) # Point is outside hyper rectangle, skip the whole sub tree if min_dist > r return 0 @@ -228,7 +232,7 @@ function inrange_kernel!(tree::KDTree, # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(idx_in_ball, tree, index, point, r, runtime_function, point_index) end split_val = tree.split_vals[index] @@ -255,7 +259,7 @@ function inrange_kernel!(tree::KDTree, ddiff = max(zero(lo - p_dim), lo - p_dim) end # Call closer sub tree - count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist) + count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, runtime_function, point_index) # TODO: We could potentially also keep track of the max distance # between the point and the hyper rectangle and add the whole sub tree @@ -267,6 +271,6 @@ function inrange_kernel!(tree::KDTree, ddiff_pow = eval_pow(M, ddiff) diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) new_min = eval_reduce(M, min_dist, diff_tot) - count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min) + count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, runtime_function, point_index) return count end diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 39338cf..d4aa99b 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -115,13 +115,16 @@ end # This will probably prevent SIMD and other optimizations so some care is needed # to evaluate if it is worth it. @inline function add_points_inrange!(idx_in_ball::Union{Nothing, AbstractVector{<:Integer}}, tree::NNTree, - index::Int, point::AbstractVector, r::Number) + index::Int, point::AbstractVector, r::Number, + runtime_function::Union{Nothing, Function}, + point_index::Int) count = 0 for z in get_leaf_range(tree.tree_data, index) idx = tree.reordered ? z : tree.indices[z] if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) + !isnothing(runtime_function) && runtime_function(point_index, idx, point) end end return count From 0196134ce37bd250d9e58e1953104e577209bb83 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 9 May 2025 11:54:40 -0600 Subject: [PATCH 02/18] bug fix --- src/tree_ops.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tree_ops.jl b/src/tree_ops.jl index d4aa99b..6fd90ea 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -124,7 +124,8 @@ end if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) - !isnothing(runtime_function) && runtime_function(point_index, idx, point) + !isnothing(runtime_function) && runtime_function(point_index, tree.reordered ? tree.indices[idx] : idx, point) + end end return count From 89ba4c653dda4791c4d6ca05ef6e2289dda21870 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 9 May 2025 12:48:44 -0600 Subject: [PATCH 03/18] accelerated --- src/inrange.jl | 11 +++++++++-- src/tree_ops.jl | 4 ++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/inrange.jl b/src/inrange.jl index 017ffbe..83f11eb 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -121,10 +121,17 @@ function inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Nu end function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F) where {V, T <: Number, F} + dim = size(points, 1) + return inrange_runtime!(tree, points, radius, runtime_function, Val(dim)) +end + +function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F, ::Val{dim}) where {V, T <: Number, F, dim} check_input(tree, points) check_radius(radius) - for i in axes(points,2) - _inrange(tree, view(points,:,i), radius, nothing, i, runtime_function) + n_points = size(points, 2) + for i in 1:n_points + point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) + _inrange(tree, point, radius, nothing, i, runtime_function) end return nothing end \ No newline at end of file diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 6fd90ea..5dccd6e 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -119,12 +119,12 @@ end runtime_function::Union{Nothing, Function}, point_index::Int) count = 0 - for z in get_leaf_range(tree.tree_data, index) + @inbounds for z in get_leaf_range(tree.tree_data, index) idx = tree.reordered ? z : tree.indices[z] if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) - !isnothing(runtime_function) && runtime_function(point_index, tree.reordered ? tree.indices[idx] : idx, point) + @inbounds !isnothing(runtime_function) && runtime_function(point_index, tree.reordered ? tree.indices[idx] : idx, point) end end From 41907a2cb03fc6a3b7ae12ba2dd6956014ddce21 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 9 May 2025 21:16:34 -0600 Subject: [PATCH 04/18] small reformat --- src/inrange.jl | 3 +-- src/tree_ops.jl | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/inrange.jl b/src/inrange.jl index 83f11eb..20748e8 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -121,8 +121,7 @@ function inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Nu end function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F) where {V, T <: Number, F} - dim = size(points, 1) - return inrange_runtime!(tree, points, radius, runtime_function, Val(dim)) + return inrange_runtime!(tree, points, radius, runtime_function, Val(size(points, 1))) end function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F, ::Val{dim}) where {V, T <: Number, F, dim} diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 5dccd6e..c937b1b 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -125,7 +125,6 @@ end count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) @inbounds !isnothing(runtime_function) && runtime_function(point_index, tree.reordered ? tree.indices[idx] : idx, point) - end end return count From 200344b9dcc058b3f9b40f3d43fd1fe7f7675c40 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Sat, 10 May 2025 14:52:32 -0600 Subject: [PATCH 05/18] added test and docstring --- src/inrange.jl | 24 +++++++++++++++++++++++- test/test_inrange.jl | 22 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/inrange.jl b/src/inrange.jl index 20748e8..7aedf3f 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -108,7 +108,29 @@ end """ inrange_runtime(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::Function) where {V, T <: Number} - Compute a runtime function for all in range queries. +Compute a runtime function for all in range queries. +Instead of returning the indicies, the `runtime_function` is called for each point in points +and is given the points, the index of the point, and the index of the neighbor. +The `runtime_function` should return nothing. +The `runtime_function` should be of the form: +runtime_function(point_index::Int, neighbor_index::Int, point::AbstractVector{T}) +where `point_index` is the index of the point in `points`, `neighbor_index` is the index of the neighbor in the tree, +and `point` is the point in points. + +The `runtime_function` should not modify the tree or the points. + +Ananymous functions can be used as well. +For example: +```julia +function runtime_function(point_index, neighbor_index, point, random_storage_of_results, neightbors_data) + # do something with the points + return nothing +end + +random_storage_of_results = rand(3, 100) +neightbors_data = rand(3, 100) +f(a, b, c) = runtime_function(a, b, c, random_storage_of_results, neightbors_data) +``` """ function inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::F) where {V, T <: AbstractVector, F} check_input(tree, points) diff --git a/test/test_inrange.jl b/test/test_inrange.jl index 83fc56d..b48aa81 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -93,3 +93,25 @@ end @test idxs == idxs2 end end + +@testset "inrange_runtime function" begin + function runtime_test(point_index, neighbor_index, point, sum_of_random_data, neightbor_points) + sum_of_random_data += sum(neightbor_points[4:6,neighbor_index]) + return nothing + end + + for T in (KDTree, BallTree, BruteTree) + sum_runtime = 0.0 + data = rand(6, 100) # first 3 rows are l"ocations", last 3 rows are random data + f(a, b, c) = runtime_test(a, b, c, sum_runtime, data) + + tree = T(data[1:3, :]) + idxs = inrange(tree, [0.5, 0.5, 0.5], 1.0) + sum_idxs = 0.0 + for i in eachindex(idxs) + sum_idxs == sum(data[4:6, idxs[i]]) + end + + @test sum_idxs == sum_runtime + end +end \ No newline at end of file From e8cdbaa90fbfbec93660c66d80ae97682cd5db61 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Sat, 10 May 2025 14:54:29 -0600 Subject: [PATCH 06/18] fix doc string --- src/inrange.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inrange.jl b/src/inrange.jl index 7aedf3f..a1dfd59 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -106,7 +106,7 @@ function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) end """ - inrange_runtime(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::Function) where {V, T <: Number} + inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::Function) where {V, T <: Number} Compute a runtime function for all in range queries. Instead of returning the indicies, the `runtime_function` is called for each point in points From 8f632aa13350e3356ecb9826f4afbcdbb3107bb2 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Sat, 10 May 2025 14:55:05 -0600 Subject: [PATCH 07/18] fix docstring --- src/inrange.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inrange.jl b/src/inrange.jl index a1dfd59..9e5ef07 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -106,7 +106,7 @@ function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) end """ - inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::Function) where {V, T <: Number} + inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::Function) Compute a runtime function for all in range queries. Instead of returning the indicies, the `runtime_function` is called for each point in points From 91e96b5955d28699df7f5ea1f67d785ad46c7fde Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Sat, 10 May 2025 15:21:23 -0600 Subject: [PATCH 08/18] update readme --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index 12814af..f7b4b50 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,41 @@ inrange!(idxs, balltree, point, r) neighborscount = inrangecount(balltree, point, r) ``` +### Passing a runtime function into the range search +```julia +inrange_runtime!(tree, points, radius, runtime_function) +``` + +Example: +```julia +using NearestNeighbors +data = rand(3,10^4) +data_values = rand(10^4) +r = 0.05 +points = rand(3,10) +results = zeros(10) + +# this function will sum the `data_values` corresponding to the `data` that is in range of `points` +# `p_idx` is the index of `points` i.e. 1-10 +# `data_idx` is is the index of the data in the tree that is in range +# `p` is `points[:,p_idx]` +# `values` is data needed for the operation +# `results` is a storage space for the results +function sum_values!(p_idx, data_idx, p, values, results) + results[p_idx] += values[data_idx] +end + +# `runtime_function` must be of the form f(p_idx, data_idx, p) +runtime_function(p_idx, data_idx, p) = sum_values!(p_idx, data_idx, p, values, results) + +kdtree = KDTree(data) + +# runs the runtime_function with all tree data points in range of points. In this case sums the `data_values` corresponding to the `data` that is in range of `points` +inrange_runtime!(tree, points, radius, runtime_function) +``` + + + ## Using On-Disk Data Sets By default, trees store a copy of the `data` provided during construction. For data sets larger than available memory, `DataFreeTree` can be used to strip a tree of its data field and re-link it later. From 2a7cea5fc1fef92d34c42e146e9027bc63eb1140 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Mon, 12 May 2025 10:20:39 -0600 Subject: [PATCH 09/18] fixed brute tree implimentation for runtime functions --- src/brute_tree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brute_tree.jl b/src/brute_tree.jl index 313b1a5..1b8c038 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -80,7 +80,7 @@ function inrange_kernel!(tree::BruteTree, if d <= r count += 1 idx_in_ball !== nothing && push!(idx_in_ball, i) - !isnothing(runtime_function) && runtime_function(point_index, idx, point) + !isnothing(runtime_function) && runtime_function(point_index, i, point) end end return count From 12ffe5e86ca0ca2c51578751d22b1328a64476e0 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Mon, 12 May 2025 10:52:11 -0600 Subject: [PATCH 10/18] fix test --- src/NearestNeighbors.jl | 2 +- src/inrange.jl | 5 +++++ test/test_inrange.jl | 13 +++++++------ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 6edbc1c..b56057c 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -7,7 +7,7 @@ using StaticArrays import Base.show export NNTree, BruteTree, KDTree, BallTree, DataFreeTree -export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs +export knn, knn!, nn, inrange, inrange!,inrangecount, inrange_runtime! # TODOs? , allpairs, distmat, npairs export injectdata export Euclidean, diff --git a/src/inrange.jl b/src/inrange.jl index 9e5ef07..6f26534 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -146,6 +146,11 @@ function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Nu return inrange_runtime!(tree, points, radius, runtime_function, Val(size(points, 1))) end +function inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::F) where {V, T <: Number, F} + points = reshape(points, size(points, 1), 1) + return inrange_runtime!(tree, points, radius, runtime_function, Val(size(points, 1))) +end + function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F, ::Val{dim}) where {V, T <: Number, F, dim} check_input(tree, points) check_radius(radius) diff --git a/test/test_inrange.jl b/test/test_inrange.jl index b48aa81..0a5f5c9 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -96,22 +96,23 @@ end @testset "inrange_runtime function" begin function runtime_test(point_index, neighbor_index, point, sum_of_random_data, neightbor_points) - sum_of_random_data += sum(neightbor_points[4:6,neighbor_index]) + sum_of_random_data[1] += sum(neightbor_points[4:6,neighbor_index]) return nothing end for T in (KDTree, BallTree, BruteTree) - sum_runtime = 0.0 - data = rand(6, 100) # first 3 rows are l"ocations", last 3 rows are random data + sum_runtime = fill(0.0, 1) + data = rand(6, 100) # first 3 rows are "locations", last 3 rows are random data f(a, b, c) = runtime_test(a, b, c, sum_runtime, data) - tree = T(data[1:3, :]) + tree = KDTree(data[1:3, :]) + inrange_runtime!(tree, [0.5, 0.5, 0.5], 1.0, f) idxs = inrange(tree, [0.5, 0.5, 0.5], 1.0) sum_idxs = 0.0 for i in eachindex(idxs) - sum_idxs == sum(data[4:6, idxs[i]]) + sum_idxs += sum(data[4:6, idxs[i]]) end - @test sum_idxs == sum_runtime + @test sum_idxs == sum_runtime[1] end end \ No newline at end of file From 3df422ae3271305d960f920deb39989fdece5365 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 11 Jul 2025 08:06:40 -0600 Subject: [PATCH 11/18] Removed unnecessary docstring --- src/inrange.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/inrange.jl b/src/inrange.jl index 6f26534..575c562 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -117,9 +117,6 @@ runtime_function(point_index::Int, neighbor_index::Int, point::AbstractVector{T} where `point_index` is the index of the point in `points`, `neighbor_index` is the index of the neighbor in the tree, and `point` is the point in points. -The `runtime_function` should not modify the tree or the points. - -Ananymous functions can be used as well. For example: ```julia function runtime_function(point_index, neighbor_index, point, random_storage_of_results, neightbors_data) From d562574b2ff88d882078912aa78a1e2666364117 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 11 Jul 2025 08:33:01 -0600 Subject: [PATCH 12/18] change runtime to callback and inrange_runtime! to inrange_callback! --- README.md | 10 +++++----- src/NearestNeighbors.jl | 2 +- src/ball_tree.jl | 12 ++++++------ src/brute_tree.jl | 8 ++++---- src/inrange.jl | 30 +++++++++++++++--------------- src/kd_tree.jl | 12 ++++++------ src/tree_ops.jl | 4 ++-- test/test_inrange.jl | 2 +- 8 files changed, 40 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index f7b4b50..4c6b816 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ neighborscount = inrangecount(balltree, point, r) ### Passing a runtime function into the range search ```julia -inrange_runtime!(tree, points, radius, runtime_function) +inrange_callback!(tree, points, radius, callback) ``` Example: @@ -188,13 +188,13 @@ function sum_values!(p_idx, data_idx, p, values, results) results[p_idx] += values[data_idx] end -# `runtime_function` must be of the form f(p_idx, data_idx, p) -runtime_function(p_idx, data_idx, p) = sum_values!(p_idx, data_idx, p, values, results) +# `callback` must be of the form f(p_idx, data_idx, p) +callback(p_idx, data_idx, p) = sum_values!(p_idx, data_idx, p, values, results) kdtree = KDTree(data) -# runs the runtime_function with all tree data points in range of points. In this case sums the `data_values` corresponding to the `data` that is in range of `points` -inrange_runtime!(tree, points, radius, runtime_function) +# runs the callback with all tree data points in range of points. In this case sums the `data_values` corresponding to the `data` that is in range of `points` +inrange_callback!(tree, points, radius, callback) ``` diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index b56057c..f5b07bb 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -7,7 +7,7 @@ using StaticArrays import Base.show export NNTree, BruteTree, KDTree, BallTree, DataFreeTree -export knn, knn!, nn, inrange, inrange!,inrangecount, inrange_runtime! # TODOs? , allpairs, distmat, npairs +export knn, knn!, nn, inrange, inrange!,inrangecount, inrange_callback! # TODOs? , allpairs, distmat, npairs export injectdata export Euclidean, diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 6d2245d..96f68aa 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -179,9 +179,9 @@ function _inrange(tree::BallTree{V}, radius::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}}, point_index::Int = 1, - runtime_function::Union{Nothing, Function} = nothing) where {V} + callback::Union{Nothing, Function} = nothing) where {V} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - return inrange_kernel!(tree, 1, point, ball, idx_in_ball, runtime_function, point_index) # Call the recursive range finder + return inrange_kernel!(tree, 1, point, ball, idx_in_ball, callback, point_index) # Call the recursive range finder end function inrange_kernel!(tree::BallTree, @@ -189,7 +189,7 @@ function inrange_kernel!(tree::BallTree, point::AbstractVector, query_ball::HyperSphere, idx_in_ball::Union{Nothing, Vector{<:Integer}}, - runtime_function::Union{Nothing, Function}, + callback::Union{Nothing, Function}, point_index::Int) if index > length(tree.hyper_spheres) @@ -208,7 +208,7 @@ function inrange_kernel!(tree::BallTree, # At a leaf node, check all points in the leaf node if isleaf(tree.tree_data.n_internal_nodes, index) r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r - return add_points_inrange!(idx_in_ball, tree, index, point, r, runtime_function, point_index) + return add_points_inrange!(idx_in_ball, tree, index, point, r, callback, point_index) end count = 0 @@ -219,8 +219,8 @@ function inrange_kernel!(tree::BallTree, count += addall(tree, index, idx_in_ball) else # Recursively call the left and right sub tree. - count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, runtime_function, point_index) - count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, runtime_function, point_index) + count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, callback, point_index) + count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, callback, point_index) end return count end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index 1b8c038..a113e32 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -63,8 +63,8 @@ function _inrange(tree::BruteTree, radius::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}}, point_index::Int = 1, - runtime_function::Union{Nothing, Function} = nothing) - return inrange_kernel!(tree, point, radius, idx_in_ball, runtime_function, point_index) + callback::Union{Nothing, Function} = nothing) + return inrange_kernel!(tree, point, radius, idx_in_ball, callback, point_index) end @@ -72,7 +72,7 @@ function inrange_kernel!(tree::BruteTree, point::AbstractVector, r::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}}, - runtime_function::Union{Nothing, Function}, + callback::Union{Nothing, Function}, point_index::Int) count = 0 for i in 1:length(tree.data) @@ -80,7 +80,7 @@ function inrange_kernel!(tree::BruteTree, if d <= r count += 1 idx_in_ball !== nothing && push!(idx_in_ball, i) - !isnothing(runtime_function) && runtime_function(point_index, i, point) + !isnothing(callback) && callback(point_index, i, point) end end return count diff --git a/src/inrange.jl b/src/inrange.jl index 575c562..175df34 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -106,55 +106,55 @@ function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) end """ - inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::Function) + inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::Function) Compute a runtime function for all in range queries. -Instead of returning the indicies, the `runtime_function` is called for each point in points +Instead of returning the indicies, the `callback` is called for each point in points and is given the points, the index of the point, and the index of the neighbor. -The `runtime_function` should return nothing. -The `runtime_function` should be of the form: -runtime_function(point_index::Int, neighbor_index::Int, point::AbstractVector{T}) +The `callback` should return nothing. +The `callback` should be of the form: +callback(point_index::Int, neighbor_index::Int, point::AbstractVector{T}) where `point_index` is the index of the point in `points`, `neighbor_index` is the index of the neighbor in the tree, and `point` is the point in points. For example: ```julia -function runtime_function(point_index, neighbor_index, point, random_storage_of_results, neightbors_data) +function callback(point_index, neighbor_index, point, random_storage_of_results, neightbors_data) # do something with the points return nothing end random_storage_of_results = rand(3, 100) neightbors_data = rand(3, 100) -f(a, b, c) = runtime_function(a, b, c, random_storage_of_results, neightbors_data) +f(a, b, c) = callback(a, b, c, random_storage_of_results, neightbors_data) ``` """ -function inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::F) where {V, T <: AbstractVector, F} +function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F) where {V, T <: AbstractVector, F} check_input(tree, points) check_radius(radius) for i in eachindex(points) - _inrange(tree, points[i], radius, nothing, i, runtime_function) + _inrange(tree, points[i], radius, nothing, i, callback) end return nothing end -function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F) where {V, T <: Number, F} - return inrange_runtime!(tree, points, radius, runtime_function, Val(size(points, 1))) +function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, callback::F) where {V, T <: Number, F} + return inrange_callback!(tree, points, radius, callback, Val(size(points, 1))) end -function inrange_runtime!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, runtime_function::F) where {V, T <: Number, F} +function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F) where {V, T <: Number, F} points = reshape(points, size(points, 1), 1) - return inrange_runtime!(tree, points, radius, runtime_function, Val(size(points, 1))) + return inrange_callback!(tree, points, radius, callback, Val(size(points, 1))) end -function inrange_runtime!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, runtime_function::F, ::Val{dim}) where {V, T <: Number, F, dim} +function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, callback::F, ::Val{dim}) where {V, T <: Number, F, dim} check_input(tree, points) check_radius(radius) n_points = size(points, 2) for i in 1:n_points point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - _inrange(tree, point, radius, nothing, i, runtime_function) + _inrange(tree, point, radius, nothing, i, callback) end return nothing end \ No newline at end of file diff --git a/src/kd_tree.jl b/src/kd_tree.jl index 8327cdb..c9a5dcf 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -209,10 +209,10 @@ function _inrange(tree::KDTree, radius::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[], point_index::Int = 1, - runtime_function::Union{Nothing, Function} = nothing) + callback::Union{Nothing, Function} = nothing) init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) return inrange_kernel!(tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, - tree.hyper_rec, init_min, runtime_function, point_index) + tree.hyper_rec, init_min, callback, point_index) end # Explicitly check the distance between leaf node and point while traversing @@ -223,7 +223,7 @@ function inrange_kernel!(tree::KDTree, idx_in_ball::Union{Nothing, Vector{<:Integer}}, hyper_rec::HyperRectangle, min_dist, - runtime_function::Union{Nothing, Function}, + callback::Union{Nothing, Function}, point_index::Int) # Point is outside hyper rectangle, skip the whole sub tree if min_dist > r @@ -232,7 +232,7 @@ function inrange_kernel!(tree::KDTree, # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - return add_points_inrange!(idx_in_ball, tree, index, point, r, runtime_function, point_index) + return add_points_inrange!(idx_in_ball, tree, index, point, r, callback, point_index) end split_val = tree.split_vals[index] @@ -259,7 +259,7 @@ function inrange_kernel!(tree::KDTree, ddiff = max(zero(lo - p_dim), lo - p_dim) end # Call closer sub tree - count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, runtime_function, point_index) + count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, callback, point_index) # TODO: We could potentially also keep track of the max distance # between the point and the hyper rectangle and add the whole sub tree @@ -271,6 +271,6 @@ function inrange_kernel!(tree::KDTree, ddiff_pow = eval_pow(M, ddiff) diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) new_min = eval_reduce(M, min_dist, diff_tot) - count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, runtime_function, point_index) + count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, callback, point_index) return count end diff --git a/src/tree_ops.jl b/src/tree_ops.jl index c937b1b..ecde7e7 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -116,7 +116,7 @@ end # to evaluate if it is worth it. @inline function add_points_inrange!(idx_in_ball::Union{Nothing, AbstractVector{<:Integer}}, tree::NNTree, index::Int, point::AbstractVector, r::Number, - runtime_function::Union{Nothing, Function}, + callback::Union{Nothing, Function}, point_index::Int) count = 0 @inbounds for z in get_leaf_range(tree.tree_data, index) @@ -124,7 +124,7 @@ end if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) - @inbounds !isnothing(runtime_function) && runtime_function(point_index, tree.reordered ? tree.indices[idx] : idx, point) + @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx, point) end end return count diff --git a/test/test_inrange.jl b/test/test_inrange.jl index 0a5f5c9..9694e4b 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -106,7 +106,7 @@ end f(a, b, c) = runtime_test(a, b, c, sum_runtime, data) tree = KDTree(data[1:3, :]) - inrange_runtime!(tree, [0.5, 0.5, 0.5], 1.0, f) + inrange_callback!(tree, [0.5, 0.5, 0.5], 1.0, f) idxs = inrange(tree, [0.5, 0.5, 0.5], 1.0) sum_idxs = 0.0 for i in eachindex(idxs) From 58dfb2d759313d3519633dfd6bbf4b7459b04914 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 11 Jul 2025 09:12:32 -0600 Subject: [PATCH 13/18] remove point from callback function --- README.md | 7 +++---- src/inrange.jl | 9 ++++----- src/tree_ops.jl | 2 +- test/test_inrange.jl | 4 ++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 4c6b816..7efd7ac 100644 --- a/README.md +++ b/README.md @@ -181,15 +181,14 @@ results = zeros(10) # this function will sum the `data_values` corresponding to the `data` that is in range of `points` # `p_idx` is the index of `points` i.e. 1-10 # `data_idx` is is the index of the data in the tree that is in range -# `p` is `points[:,p_idx]` # `values` is data needed for the operation # `results` is a storage space for the results -function sum_values!(p_idx, data_idx, p, values, results) +function sum_values!(p_idx, data_idx, values, results) results[p_idx] += values[data_idx] end -# `callback` must be of the form f(p_idx, data_idx, p) -callback(p_idx, data_idx, p) = sum_values!(p_idx, data_idx, p, values, results) +# `callback` must be of the form f(p_idx, data_idx) +callback(p_idx, data_idx, p) = sum_values!(p_idx, data_idx, values, results) kdtree = KDTree(data) diff --git a/src/inrange.jl b/src/inrange.jl index 175df34..7050f81 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -113,20 +113,19 @@ Instead of returning the indicies, the `callback` is called for each point in po and is given the points, the index of the point, and the index of the neighbor. The `callback` should return nothing. The `callback` should be of the form: -callback(point_index::Int, neighbor_index::Int, point::AbstractVector{T}) -where `point_index` is the index of the point in `points`, `neighbor_index` is the index of the neighbor in the tree, -and `point` is the point in points. +callback(point_index::Int, neighbor_index::Int) +where `point_index` is the index of the point in `points`, `neighbor_index` is the index of the neighbor in the tree. For example: ```julia -function callback(point_index, neighbor_index, point, random_storage_of_results, neightbors_data) +function callback(point_index, neighbor_index, random_storage_of_results, neightbors_data) # do something with the points return nothing end random_storage_of_results = rand(3, 100) neightbors_data = rand(3, 100) -f(a, b, c) = callback(a, b, c, random_storage_of_results, neightbors_data) +f(a, b) = callback(a, b, random_storage_of_results, neightbors_data) ``` """ function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F) where {V, T <: AbstractVector, F} diff --git a/src/tree_ops.jl b/src/tree_ops.jl index ecde7e7..5901b66 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -124,7 +124,7 @@ end if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) - @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx, point) + @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx) end end return count diff --git a/test/test_inrange.jl b/test/test_inrange.jl index 9694e4b..1f1053f 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -95,7 +95,7 @@ end end @testset "inrange_runtime function" begin - function runtime_test(point_index, neighbor_index, point, sum_of_random_data, neightbor_points) + function runtime_test(point_index, neighbor_index, sum_of_random_data, neightbor_points) sum_of_random_data[1] += sum(neightbor_points[4:6,neighbor_index]) return nothing end @@ -103,7 +103,7 @@ end for T in (KDTree, BallTree, BruteTree) sum_runtime = fill(0.0, 1) data = rand(6, 100) # first 3 rows are "locations", last 3 rows are random data - f(a, b, c) = runtime_test(a, b, c, sum_runtime, data) + f(a, b) = runtime_test(a, b, sum_runtime, data) tree = KDTree(data[1:3, :]) inrange_callback!(tree, [0.5, 0.5, 0.5], 1.0, f) From adfb2312abc6b7138eb24b900feefd626d5bcfc0 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:51:11 -0600 Subject: [PATCH 14/18] inrange now calls inrange_callback --- src/ball_tree.jl | 2 +- src/brute_tree.jl | 2 +- src/inrange.jl | 102 ++++++++++++++++++++++++++++++---------------- src/tree_ops.jl | 9 ++-- 4 files changed, 74 insertions(+), 41 deletions(-) diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 96f68aa..8c79963 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -216,7 +216,7 @@ function inrange_kernel!(tree::BallTree, # The query ball encloses the sub tree bounding sphere. Add all points in the # sub tree without checking the distance function. if encloses_fast(dist, tree.metric, sphere, query_ball) - count += addall(tree, index, idx_in_ball) + count += addall(tree, index, idx_in_ball, callback, point_index) else # Recursively call the left and right sub tree. count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, callback, point_index) diff --git a/src/brute_tree.jl b/src/brute_tree.jl index a113e32..b9a3a0c 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -80,7 +80,7 @@ function inrange_kernel!(tree::BruteTree, if d <= r count += 1 idx_in_ball !== nothing && push!(idx_in_ball, i) - !isnothing(callback) && callback(point_index, i, point) + !isnothing(callback) && callback(point_index, i) end end return count diff --git a/src/inrange.jl b/src/inrange.jl index 7050f81..e362413 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -8,20 +8,20 @@ Find all the points in the tree which is closer than `radius` to `points`. If See also: `inrange!`, `inrangecount`. """ -function inrange(tree::NNTree, - points::AbstractVector{T}, - radius::Number, - sortres=false) where {T <: AbstractVector} - check_input(tree, points) - check_radius(radius) +# function inrange(tree::NNTree, +# points::AbstractVector{T}, +# radius::Number, +# sortres=false) where {T <: AbstractVector} +# check_input(tree, points) +# check_radius(radius) - idxs = [Vector{Int}() for _ in 1:length(points)] +# idxs = [Vector{Int}() for _ in 1:length(points)] - for i in 1:length(points) - inrange_point!(tree, points[i], radius, sortres, idxs[i]) - end - return idxs -end +# for i in 1:length(points) +# inrange_point!(tree, points[i], radius, sortres, idxs[i]) +# end +# return idxs +# end function inrange_point!(tree, point, radius, sortres, idx) count = _inrange(tree, point, radius, idx) @@ -52,28 +52,28 @@ function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T return idxs end -function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} - return inrange!(Int[], tree, point, radius, sortres) -end - -function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sortres=false) where {V, T <: Number} - dim = size(points, 1) - inrange_matrix(tree, points, radius, Val(dim), sortres) -end - -function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres) where {V, T <: Number, dim} - # TODO: DRY with inrange for AbstractVector - check_input(tree, points) - check_radius(radius) - n_points = size(points, 2) - idxs = [Vector{Int}() for _ in 1:n_points] - - for i in 1:n_points - point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - inrange_point!(tree, point, radius, sortres, idxs[i]) - end - return idxs -end +# function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} +# return inrange!(Int[], tree, point, radius, sortres) +# end + +# function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sortres=false) where {V, T <: Number} +# dim = size(points, 1) +# inrange_matrix(tree, points, radius, Val(dim), sortres) +# end + +# function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres) where {V, T <: Number, dim} +# # TODO: DRY with inrange for AbstractVector +# check_input(tree, points) +# check_radius(radius) +# n_points = size(points, 2) +# idxs = [Vector{Int}() for _ in 1:n_points] + +# for i in 1:n_points +# point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) +# inrange_point!(tree, point, radius, sortres, idxs[i]) +# end +# return idxs +# end """ inrangecount(tree::NNTree, points, radius) -> count @@ -156,4 +156,36 @@ function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::N _inrange(tree, point, radius, nothing, i, callback) end return nothing -end \ No newline at end of file +end + +function index_returning_runtime_function(point_index::Int, neighbor_index::Int, idxs) + push!(idxs[point_index], neighbor_index) + return nothing +end + +function inrange_callback_default(tree::NNTree{V}, points, radius::Number, sortres=false) where {V} + if points isa AbstractVector{<:AbstractVector} + n_points = length(points) + elseif points isa AbstractMatrix{<:Number} + n_points = size(points, 2) + elseif points isa AbstractVector{<:Number} + n_points = 1 + end + + idxs = [Int[] for _ in 1:n_points] + f(a, b) = index_returning_runtime_function(a, b, idxs) + inrange_callback!(tree, points, radius, f) + + if sortres + for i in eachindex(idxs) + sort!(idxs[i]) + end + end + + if length(idxs) == 1 + idxs = idxs[1] # If only one point, return a single vector instead of a vector of vectors + end + return idxs +end + +inrange(tree::NNTree{V}, points, radius::Number, sortres=false) where {V} = inrange_callback_default(tree, points, radius, sortres) \ No newline at end of file diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 5901b66..f3c0e11 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -119,7 +119,7 @@ end callback::Union{Nothing, Function}, point_index::Int) count = 0 - @inbounds for z in get_leaf_range(tree.tree_data, index) + for z in get_leaf_range(tree.tree_data, index) idx = tree.reordered ? z : tree.indices[z] if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 @@ -141,7 +141,7 @@ end # Add all points in this subtree since we have determined # they are all within the desired range -function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}) +function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}, callback::Union{Nothing, Function} = nothing, point_index::Int = 1) tree_data = tree.tree_data count = 0 if isleaf(tree_data.n_internal_nodes, index) @@ -149,10 +149,11 @@ function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:I idx = tree.reordered ? z : tree.indices[z] count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) + @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx) end else - count += addall(tree, getleft(index), idx_in_ball) - count += addall(tree, getright(index), idx_in_ball) + count += addall(tree, getleft(index), idx_in_ball, callback, point_index) + count += addall(tree, getright(index), idx_in_ball, callback, point_index) end return count end From ec1159c3c52f31a818b8be3866ccbdcd6ef29160 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 11 Jul 2025 13:35:45 -0600 Subject: [PATCH 15/18] replace inrange_point! --- src/inrange.jl | 68 ++++++++++---------------------------------------- 1 file changed, 13 insertions(+), 55 deletions(-) diff --git a/src/inrange.jl b/src/inrange.jl index e362413..a54f233 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -8,33 +8,7 @@ Find all the points in the tree which is closer than `radius` to `points`. If See also: `inrange!`, `inrangecount`. """ -# function inrange(tree::NNTree, -# points::AbstractVector{T}, -# radius::Number, -# sortres=false) where {T <: AbstractVector} -# check_input(tree, points) -# check_radius(radius) - -# idxs = [Vector{Int}() for _ in 1:length(points)] - -# for i in 1:length(points) -# inrange_point!(tree, points[i], radius, sortres, idxs[i]) -# end -# return idxs -# end - -function inrange_point!(tree, point, radius, sortres, idx) - count = _inrange(tree, point, radius, idx) - if idx !== nothing - if tree.reordered - @inbounds for j in 1:length(idx) - idx[j] = tree.indices[idx[j]] - end - end - sortres && sort!(idx) - end - return count -end +inrange(tree::NNTree{V}, points, radius::Number, sortres=false) where {V} = inrange_callback_default(tree, points, radius, sortres) """ inrange!(idxs, tree, point, radius) @@ -48,33 +22,14 @@ function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T check_input(tree, point) check_radius(radius) length(idxs) == 0 || throw(ArgumentError("idxs must be empty")) - inrange_point!(tree, point, radius, sortres, idxs) + + f(a, b) = index_returning_runtime_function(a, b, idxs) + inrange_callback!(tree, point, radius, f) + + sortres && sort!(idxs) return idxs end -# function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} -# return inrange!(Int[], tree, point, radius, sortres) -# end - -# function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sortres=false) where {V, T <: Number} -# dim = size(points, 1) -# inrange_matrix(tree, points, radius, Val(dim), sortres) -# end - -# function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres) where {V, T <: Number, dim} -# # TODO: DRY with inrange for AbstractVector -# check_input(tree, points) -# check_radius(radius) -# n_points = size(points, 2) -# idxs = [Vector{Int}() for _ in 1:n_points] - -# for i in 1:n_points -# point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) -# inrange_point!(tree, point, radius, sortres, idxs[i]) -# end -# return idxs -# end - """ inrangecount(tree::NNTree, points, radius) -> count @@ -83,7 +38,7 @@ Count all the points in the tree which are closer than `radius` to `points`. function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number) where {V, T <: Number} check_input(tree, point) check_radius(radius) - return inrange_point!(tree, point, radius, false, nothing) + return _inrange(tree, point, radius, nothing) end function inrangecount(tree::NNTree, @@ -91,7 +46,7 @@ function inrangecount(tree::NNTree, radius::Number) where {T <: AbstractVector} check_input(tree, points) check_radius(radius) - return inrange_point!.(Ref(tree), points, radius, false, nothing) + return _inrange.(Ref(tree), points, radius, nothing) end function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number} @@ -159,7 +114,11 @@ function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::N end function index_returning_runtime_function(point_index::Int, neighbor_index::Int, idxs) - push!(idxs[point_index], neighbor_index) + if eltype(idxs) <: Integer + push!(idxs, eltype(idxs)(neighbor_index)) + else + push!(idxs[point_index], neighbor_index) + end return nothing end @@ -188,4 +147,3 @@ function inrange_callback_default(tree::NNTree{V}, points, radius::Number, sortr return idxs end -inrange(tree::NNTree{V}, points, radius::Number, sortres=false) where {V} = inrange_callback_default(tree, points, radius, sortres) \ No newline at end of file From 85aa68d6b9ada736c87c725ec1b9cdbb644d32b5 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 11 Jul 2025 13:41:53 -0600 Subject: [PATCH 16/18] Removed unused inputs from inrange functions --- src/ball_tree.jl | 12 +++++------- src/brute_tree.jl | 5 +---- src/inrange.jl | 8 ++++---- src/kd_tree.jl | 10 ++++------ src/tree_ops.jl | 10 ++++------ 5 files changed, 18 insertions(+), 27 deletions(-) diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 8c79963..ee426ac 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -177,18 +177,16 @@ end function _inrange(tree::BallTree{V}, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}, point_index::Int = 1, callback::Union{Nothing, Function} = nothing) where {V} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - return inrange_kernel!(tree, 1, point, ball, idx_in_ball, callback, point_index) # Call the recursive range finder + return inrange_kernel!(tree, 1, point, ball, callback, point_index) # Call the recursive range finder end function inrange_kernel!(tree::BallTree, index::Int, point::AbstractVector, query_ball::HyperSphere, - idx_in_ball::Union{Nothing, Vector{<:Integer}}, callback::Union{Nothing, Function}, point_index::Int) @@ -208,7 +206,7 @@ function inrange_kernel!(tree::BallTree, # At a leaf node, check all points in the leaf node if isleaf(tree.tree_data.n_internal_nodes, index) r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r - return add_points_inrange!(idx_in_ball, tree, index, point, r, callback, point_index) + return add_points_inrange!(tree, index, point, r, callback, point_index) end count = 0 @@ -216,11 +214,11 @@ function inrange_kernel!(tree::BallTree, # The query ball encloses the sub tree bounding sphere. Add all points in the # sub tree without checking the distance function. if encloses_fast(dist, tree.metric, sphere, query_ball) - count += addall(tree, index, idx_in_ball, callback, point_index) + count += addall(tree, index, callback, point_index) else # Recursively call the left and right sub tree. - count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, callback, point_index) - count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, callback, point_index) + count += inrange_kernel!(tree, getleft(index), point, query_ball, callback, point_index) + count += inrange_kernel!(tree, getright(index), point, query_ball, callback, point_index) end return count end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index b9a3a0c..a2fffaf 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -61,17 +61,15 @@ end function _inrange(tree::BruteTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}, point_index::Int = 1, callback::Union{Nothing, Function} = nothing) - return inrange_kernel!(tree, point, radius, idx_in_ball, callback, point_index) + return inrange_kernel!(tree, point, radius, callback, point_index) end function inrange_kernel!(tree::BruteTree, point::AbstractVector, r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}, callback::Union{Nothing, Function}, point_index::Int) count = 0 @@ -79,7 +77,6 @@ function inrange_kernel!(tree::BruteTree, d = evaluate(tree.metric, tree.data[i], point) if d <= r count += 1 - idx_in_ball !== nothing && push!(idx_in_ball, i) !isnothing(callback) && callback(point_index, i) end end diff --git a/src/inrange.jl b/src/inrange.jl index a54f233..cf2e921 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -38,7 +38,7 @@ Count all the points in the tree which are closer than `radius` to `points`. function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number) where {V, T <: Number} check_input(tree, point) check_radius(radius) - return _inrange(tree, point, radius, nothing) + return _inrange(tree, point, radius) end function inrangecount(tree::NNTree, @@ -46,7 +46,7 @@ function inrangecount(tree::NNTree, radius::Number) where {T <: AbstractVector} check_input(tree, points) check_radius(radius) - return _inrange.(Ref(tree), points, radius, nothing) + return _inrange.(Ref(tree), points, radius) end function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number} @@ -88,7 +88,7 @@ function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::N check_radius(radius) for i in eachindex(points) - _inrange(tree, points[i], radius, nothing, i, callback) + _inrange(tree, points[i], radius, i, callback) end return nothing end @@ -108,7 +108,7 @@ function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::N n_points = size(points, 2) for i in 1:n_points point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - _inrange(tree, point, radius, nothing, i, callback) + _inrange(tree, point, radius, i, callback) end return nothing end diff --git a/src/kd_tree.jl b/src/kd_tree.jl index c9a5dcf..f36c6a6 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -207,11 +207,10 @@ end function _inrange(tree::KDTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[], point_index::Int = 1, callback::Union{Nothing, Function} = nothing) init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) - return inrange_kernel!(tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, + return inrange_kernel!(tree, 1, point, eval_pow(tree.metric, radius), tree.hyper_rec, init_min, callback, point_index) end @@ -220,7 +219,6 @@ function inrange_kernel!(tree::KDTree, index::Int, point::AbstractVector, r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}, hyper_rec::HyperRectangle, min_dist, callback::Union{Nothing, Function}, @@ -232,7 +230,7 @@ function inrange_kernel!(tree::KDTree, # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - return add_points_inrange!(idx_in_ball, tree, index, point, r, callback, point_index) + return add_points_inrange!(tree, index, point, r, callback, point_index) end split_val = tree.split_vals[index] @@ -259,7 +257,7 @@ function inrange_kernel!(tree::KDTree, ddiff = max(zero(lo - p_dim), lo - p_dim) end # Call closer sub tree - count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, callback, point_index) + count += inrange_kernel!(tree, close, point, r, hyper_rec_close, min_dist, callback, point_index) # TODO: We could potentially also keep track of the max distance # between the point and the hyper rectangle and add the whole sub tree @@ -271,6 +269,6 @@ function inrange_kernel!(tree::KDTree, ddiff_pow = eval_pow(M, ddiff) diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) new_min = eval_reduce(M, min_dist, diff_tot) - count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, callback, point_index) + count += inrange_kernel!(tree, far, point, r, hyper_rec_far, new_min, callback, point_index) return count end diff --git a/src/tree_ops.jl b/src/tree_ops.jl index f3c0e11..4f46c30 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -114,7 +114,7 @@ end # stop computing the distance function as soon as we reach the desired radius. # This will probably prevent SIMD and other optimizations so some care is needed # to evaluate if it is worth it. -@inline function add_points_inrange!(idx_in_ball::Union{Nothing, AbstractVector{<:Integer}}, tree::NNTree, +@inline function add_points_inrange!(tree::NNTree, index::Int, point::AbstractVector, r::Number, callback::Union{Nothing, Function}, point_index::Int) @@ -123,7 +123,6 @@ end idx = tree.reordered ? z : tree.indices[z] if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 - idx_in_ball !== nothing && push!(idx_in_ball, idx) @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx) end end @@ -141,19 +140,18 @@ end # Add all points in this subtree since we have determined # they are all within the desired range -function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}, callback::Union{Nothing, Function} = nothing, point_index::Int = 1) +function addall(tree::NNTree, index::Int, callback::Union{Nothing, Function} = nothing, point_index::Int = 1) tree_data = tree.tree_data count = 0 if isleaf(tree_data.n_internal_nodes, index) for z in get_leaf_range(tree_data, index) idx = tree.reordered ? z : tree.indices[z] count += 1 - idx_in_ball !== nothing && push!(idx_in_ball, idx) @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx) end else - count += addall(tree, getleft(index), idx_in_ball, callback, point_index) - count += addall(tree, getright(index), idx_in_ball, callback, point_index) + count += addall(tree, getleft(index), callback, point_index) + count += addall(tree, getright(index), callback, point_index) end return count end From b207f2a53f687147ebb317fc784cbf8e62d84f9a Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Fri, 11 Jul 2025 13:43:42 -0600 Subject: [PATCH 17/18] add test from master --- test/test_inrange.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_inrange.jl b/test/test_inrange.jl index 1f1053f..456c05f 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -115,4 +115,13 @@ end @test sum_idxs == sum_runtime[1] end +end + +@testset "inferrability matrix" begin + function foo(data, point) + b = KDTree(data) + return inrange(b, point, 0.1) + end + + @inferred foo([1.0 3.4; 4.5 3.4], [4.5; 3.4]) end \ No newline at end of file From 605be74e2434edf9abda394f96cf2513b51812c7 Mon Sep 17 00:00:00 2001 From: BTV25 <70768698+BTV25@users.noreply.github.com> Date: Sat, 12 Jul 2025 14:49:55 -0600 Subject: [PATCH 18/18] added function to bypass internal Val --- src/inrange.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/inrange.jl b/src/inrange.jl index 4eb4ad3..86b04b5 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -98,8 +98,12 @@ function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::N end function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F) where {V, T <: Number, F} + return inrange_callback!(tree, points, radius, callback, Val(length(points))) +end + +function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F, ::Val{dim}) where {V, T <: Number, F, dim} points = reshape(points, size(points, 1), 1) - return inrange_callback!(tree, points, radius, callback, Val(size(points, 1))) + return inrange_callback!(tree, points, radius, callback, Val(dim)) end function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, callback::F, ::Val{dim}) where {V, T <: Number, F, dim}