diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 30b11f0..2e8c853 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,7 +29,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v4 + - uses: actions/cache@v1 env: cache-name: cache-artifacts with: diff --git a/Project.toml b/Project.toml index 8c98be9..3b2705a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NearestNeighbors" uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.22" +version = "0.4.21" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" diff --git a/README.md b/README.md index 54081da..7efd7ac 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for - A matrix of size `nd × np` where `nd` is the dimensionality and `np` is the number of points, or - A vector of vectors with fixed dimensionality `nd`, i.e., `data` should be a `Vector{V}` where `V` is a subtype of `AbstractVector` with defined `length(V)`. For example a `Vector{V}` where `V = SVector{3, Float64}` is ok because `length(V) = 3` is defined. * `metric`: The `Metric` (from `Distances.jl`) to use, defaults to `Euclidean`. `KDTree` works with axis-aligned metrics: `Euclidean`, `Chebyshev`, `Minkowski`, and `Cityblock` while for `BallTree` and `BruteTree` other pre-defined `Metric`s can be used as well as custom metrics (that are subtypes of `Metric`). -* `leafsize`: Determines the number of points (default 25) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points. +* `leafsize`: Determines the number of points (default 10) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points. * `reorder`: If `true` (default), during tree construction this rearranges points to improve cache locality during querying. This will create a copy of the original data. All trees in `NearestNeighbors.jl` are static, meaning points cannot be added or removed after creation. @@ -49,20 +49,19 @@ brutetree = BruteTree(data) A kNN search finds the `k` nearest neighbors to a given point or points. This is done with the methods: ```julia -knn(tree, point[s], k [, skip=always_false]) -> idxs, dists -knn!(idxs, dists, tree, point, k [, skip=always_false]) +knn(tree, point[s], k, skip = always_false) -> idxs, dists +knn!(idxs, dists, tree, point, k, skip = always_false) ``` * `tree`: The tree instance. -* `point[s]`: A vector or matrix of points to find the `k` nearest neighbors for. A vector of numbers represents a single point; a matrix means the `k` nearest neighbors for each point (column) will be computed. `points` can also be a vector of vectors. -* `k`: Number of nearest neighbors to find. -* `skip` (optional): A predicate function to skip certain points, e.g., points already visited. +* `points`: A vector or matrix of points to find the `k` nearest neighbors for. A vector of numbers represents a single point; a matrix means the `k` nearest neighbors for each point (column) will be computed. `points` can also be a vector of vectors. +* `skip` (optional): A predicate to skip certain points, e.g., points already visited. For the single closest neighbor, you can use `nn`: ```julia -nn(tree, point[s] [, skip=always_false]) -> idx, dist +nn(tree, points, skip = always_false) -> idxs, dists ``` Examples: @@ -74,7 +73,7 @@ k = 3 point = rand(3) kdtree = KDTree(data) -idxs, dists = knn(kdtree, point, k) +idxs, dists = knn(kdtree, point, k, true) idxs # 3-element Array{Int64,1}: @@ -90,7 +89,7 @@ dists # Multiple points points = rand(3, 4) -idxs, dists = knn(kdtree, points, k) +idxs, dists = knn(kdtree, points, k, true) idxs # 4-element Array{Array{Int64,1},1}: @@ -110,7 +109,7 @@ idxs using StaticArrays v = @SVector[0.5, 0.3, 0.2]; -idxs, dists = knn(kdtree, v, k) +idxs, dists = knn(kdtree, v, k, true) idxs # 3-element Array{Int64,1}: @@ -134,15 +133,11 @@ knn!(idxs, dists, kdtree, v, k) A range search finds all neighbors within the range `r` of given point(s). This is done with the methods: ```julia -inrange(tree, point[s], radius) -> idxs -inrange!(idxs, tree, point, radius) +inrange(tree, points, r) -> idxs +inrange!(idxs, tree, point, r) ``` -* `tree`: The tree instance. -* `point[s]`: A vector or matrix of points to find neighbors for. -* `radius`: Search radius. - -Note: Distances are not returned, only indices. +Distances are not returned. Example: @@ -169,6 +164,40 @@ inrange!(idxs, balltree, point, r) neighborscount = inrangecount(balltree, point, r) ``` +### Passing a runtime function into the range search +```julia +inrange_callback!(tree, points, radius, callback) +``` + +Example: +```julia +using NearestNeighbors +data = rand(3,10^4) +data_values = rand(10^4) +r = 0.05 +points = rand(3,10) +results = zeros(10) + +# this function will sum the `data_values` corresponding to the `data` that is in range of `points` +# `p_idx` is the index of `points` i.e. 1-10 +# `data_idx` is is the index of the data in the tree that is in range +# `values` is data needed for the operation +# `results` is a storage space for the results +function sum_values!(p_idx, data_idx, values, results) + results[p_idx] += values[data_idx] +end + +# `callback` must be of the form f(p_idx, data_idx) +callback(p_idx, data_idx, p) = sum_values!(p_idx, data_idx, values, results) + +kdtree = KDTree(data) + +# runs the callback with all tree data points in range of points. In this case sums the `data_values` corresponding to the `data` that is in range of `points` +inrange_callback!(tree, points, radius, callback) +``` + + + ## Using On-Disk Data Sets By default, trees store a copy of the `data` provided during construction. For data sets larger than available memory, `DataFreeTree` can be used to strip a tree of its data field and re-link it later. diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 6edbc1c..f5b07bb 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -7,7 +7,7 @@ using StaticArrays import Base.show export NNTree, BruteTree, KDTree, BallTree, DataFreeTree -export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs +export knn, knn!, nn, inrange, inrange!,inrangecount, inrange_callback! # TODOs? , allpairs, distmat, npairs export injectdata export Euclidean, diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 7127952..ee426ac 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -14,20 +14,9 @@ end """ - BallTree(data [, metric = Euclidean(); leafsize = 25, reorder = true])::BallTree + BallTree(data [, metric = Euclidean(); leafsize = 25, reorder = true]) -> balltree Creates a `BallTree` from the data using the given `metric` and `leafsize`. - -# Arguments -- `data`: Point data as a matrix of size `nd × np` or vector of vectors -- `metric`: Distance metric to use (can be any `Metric` from Distances.jl). Default: `Euclidean()` -- `leafsize`: Number of points at which to stop splitting the tree. Default: `25` -- `reorder`: If `true`, reorder data to improve cache locality. Default: `true` - -# Returns -- `balltree`: A `BallTree` instance - -BallTree works with any metric and is often better for high-dimensional data. """ function BallTree(data::AbstractVector{V}, metric::Metric = Euclidean(); @@ -188,16 +177,18 @@ end function _inrange(tree::BallTree{V}, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V} + point_index::Int = 1, + callback::Union{Nothing, Function} = nothing) where {V} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder + return inrange_kernel!(tree, 1, point, ball, callback, point_index) # Call the recursive range finder end function inrange_kernel!(tree::BallTree, index::Int, point::AbstractVector, query_ball::HyperSphere, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + callback::Union{Nothing, Function}, + point_index::Int) if index > length(tree.hyper_spheres) return 0 @@ -215,7 +206,7 @@ function inrange_kernel!(tree::BallTree, # At a leaf node, check all points in the leaf node if isleaf(tree.tree_data.n_internal_nodes, index) r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(tree, index, point, r, callback, point_index) end count = 0 @@ -223,11 +214,11 @@ function inrange_kernel!(tree::BallTree, # The query ball encloses the sub tree bounding sphere. Add all points in the # sub tree without checking the distance function. if encloses_fast(dist, tree.metric, sphere, query_ball) - count += addall(tree, index, idx_in_ball) + count += addall(tree, index, callback, point_index) else # Recursively call the left and right sub tree. - count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball) - count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball) + count += inrange_kernel!(tree, getleft(index), point, query_ball, callback, point_index) + count += inrange_kernel!(tree, getright(index), point, query_ball, callback, point_index) end return count end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index bc882c5..a2fffaf 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -5,19 +5,9 @@ struct BruteTree{V <: AbstractVector,M <: PreMetric} <: NNTree{V,M} end """ - BruteTree(data [, metric = Euclidean()])::Brutetree + BruteTree(data [, metric = Euclidean()) -> brutetree Creates a `BruteTree` from the data using the given `metric`. - -# Arguments -- `data`: Point data as a matrix of size `nd × np` or vector of vectors -- `metric`: Distance metric to use (can be any `PreMetric` from Distances.jl). Default: `Euclidean()` - -# Returns -- `brutetree`: A `BruteTree` instance - -BruteTree performs exhaustive linear search and is useful as a baseline or for small datasets. -Note: `leafsize` and `reorder` parameters are ignored for BruteTree. """ function BruteTree(data::AbstractVector{V}, metric::PreMetric = Euclidean(); reorder::Bool=false, leafsize::Int=0, storedata::Bool=true) where {V <: AbstractVector} @@ -71,21 +61,23 @@ end function _inrange(tree::BruteTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) - return inrange_kernel!(tree, point, radius, idx_in_ball) + point_index::Int = 1, + callback::Union{Nothing, Function} = nothing) + return inrange_kernel!(tree, point, radius, callback, point_index) end function inrange_kernel!(tree::BruteTree, point::AbstractVector, r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + callback::Union{Nothing, Function}, + point_index::Int) count = 0 for i in 1:length(tree.data) d = evaluate(tree.metric, tree.data[i], point) if d <= r count += 1 - idx_in_ball !== nothing && push!(idx_in_ball, i) + !isnothing(callback) && callback(point_index, i) end end return count diff --git a/src/evaluation.jl b/src/evaluation.jl index 3b6f13f..358a397 100644 --- a/src/evaluation.jl +++ b/src/evaluation.jl @@ -4,6 +4,7 @@ @inline eval_pow(d::Minkowski, s) = abs(s)^d.p @inline eval_diff(::NonweightedMinkowskiMetric, a, b, dim) = a - b +@inline eval_diff(::Chebyshev, ::Any, b, dim) = b @inline eval_diff(m::WeightedMinkowskiMetric, a, b, dim) = m.weights[dim] * (a-b) function evaluate_maybe_end(d::Distances.UnionMetrics, a::AbstractVector, diff --git a/src/hyperrectangles.jl b/src/hyperrectangles.jl index 32f801b..1254ff4 100644 --- a/src/hyperrectangles.jl +++ b/src/hyperrectangles.jl @@ -40,36 +40,3 @@ get_max_distance_no_end(m, rec, point) = get_min_distance_no_end(m, rec, point) = get_min_max_distance_no_end(distance_function_min, m, rec, point) - -@inline function update_new_min(M::Metric, old_min, hyper_rec, p_dim, split_dim, split_val) - @inbounds begin - lo = hyper_rec.mins[split_dim] - hi = hyper_rec.maxes[split_dim] - end - ddiff = distance_function_min(p_dim, hi, lo) - split_diff = abs(p_dim - split_val) - split_diff_pow = eval_pow(M, split_diff) - ddiff_pow = eval_pow(M, ddiff) - diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) - return old_min + diff_tot -end - -# Compute per-dimension contributions for max distance -function get_max_distance_contributions(m::Metric, rec::HyperRectangle{V}, point::AbstractVector{T}) where {V,T} - p = Distances.parameters(m) - return V( - @inbounds begin - v = distance_function_max(point[dim], rec.maxes[dim], rec.mins[dim]) - p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim]) - end for dim in eachindex(point) - ) -end - -# Compute single dimension contribution for max distance -function get_max_distance_contribution_single(m::Metric, point_dim, min_bound::T, max_bound::T, dim::Integer) where {T} - v = distance_function_max(point_dim, max_bound, min_bound) - p = Distances.parameters(m) - return p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim]) -end - - diff --git a/src/inrange.jl b/src/inrange.jl index f10675e..86b04b5 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -1,59 +1,20 @@ check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0")) """ - inrange(tree::NNTree, points, radius) -> indices + inrange(tree::NNTree, points, radius [, sortres=false]) -> indices -Find all the points in the tree which are closer than `radius` to `points`. - -# Arguments -- `tree`: The tree instance -- `points`: Query point(s) - can be a vector (single point), matrix (multiple points), or vector of vectors -- `radius`: Search radius - -# Returns -- `indices`: Vector of indices of points within the radius +Find all the points in the tree which is closer than `radius` to `points`. If +`sortres = true` the resulting indices are sorted. See also: `inrange!`, `inrangecount`. """ -function inrange(tree::NNTree, - points::AbstractVector{T}, - radius::Number, - sortres=false) where {T <: AbstractVector} - check_input(tree, points) - check_radius(radius) - - idxs = [Vector{Int}() for _ in 1:length(points)] - - for i in 1:length(points) - inrange_point!(tree, points[i], radius, sortres, idxs[i]) - end - return idxs -end - -function inrange_point!(tree, point, radius, sortres, idx) - count = _inrange(tree, point, radius, idx) - if idx !== nothing - if tree.reordered - @inbounds for j in 1:length(idx) - idx[j] = tree.indices[idx[j]] - end - end - sortres && sort!(idx) - end - return count -end +inrange(tree::NNTree{V}, points, radius::Number, sortres=false) where {V} = inrange_callback_default(tree, points, radius, sortres) """ inrange!(idxs, tree, point, radius) Same functionality as `inrange` but stores the results in the input vector `idxs`. -Useful to avoid allocations or specify the element type of the output vector. - -# Arguments -- `idxs`: Pre-allocated vector to store indices (must be empty) -- `tree`: The tree instance -- `point`: Query point -- `radius`: Search radius +Useful if one want to avoid allocations or specify the element type of the output vector. See also: `inrange`, `inrangecount`. """ @@ -61,30 +22,11 @@ function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T check_input(tree, point) check_radius(radius) length(idxs) == 0 || throw(ArgumentError("idxs must be empty")) - inrange_point!(tree, point, radius, sortres, idxs) - return idxs -end -function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} - return inrange!(Int[], tree, point, radius, sortres) -end + f(a, b) = index_returning_runtime_function(a, b, idxs) + inrange_callback!(tree, point, radius, f) -function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sortres=false) where {V, T <: Number} - dim = size(points, 1) - inrange_matrix(tree, points, radius, Val(dim), sortres) -end - -function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres) where {V, T <: Number, dim} - # TODO: DRY with inrange for AbstractVector - check_input(tree, points) - check_radius(radius) - n_points = size(points, 2) - idxs = [Vector{Int}() for _ in 1:n_points] - - for i in 1:n_points - point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - inrange_point!(tree, point, radius, sortres, idxs[i]) - end + sortres && sort!(idxs) return idxs end @@ -92,19 +34,11 @@ end inrangecount(tree::NNTree, points, radius) -> count Count all the points in the tree which are closer than `radius` to `points`. - -# Arguments -- `tree`: The tree instance -- `points`: Query point(s) - can be a vector (single point), matrix (multiple points), or vector of vectors -- `radius`: Search radius - -# Returns -- `count`: Number of points within the radius (integer for single point, vector for multiple points) """ function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number) where {V, T <: Number} check_input(tree, point) check_radius(radius) - return inrange_point!(tree, point, radius, false, nothing) + return _inrange(tree, point, radius) end function inrangecount(tree::NNTree, @@ -112,7 +46,7 @@ function inrangecount(tree::NNTree, radius::Number) where {T <: AbstractVector} check_input(tree, points) check_radius(radius) - return inrange_point!.(Ref(tree), points, radius, false, nothing) + return _inrange.(Ref(tree), points, radius) end function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number} @@ -125,3 +59,104 @@ function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) end return inrangecount(tree, new_data, radius) end + +""" + inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::Function) + +Compute a runtime function for all in range queries. +Instead of returning the indicies, the `callback` is called for each point in points +and is given the points, the index of the point, and the index of the neighbor. +The `callback` should return nothing. +The `callback` should be of the form: +callback(point_index::Int, neighbor_index::Int) +where `point_index` is the index of the point in `points`, `neighbor_index` is the index of the neighbor in the tree. + +For example: +```julia +function callback(point_index, neighbor_index, random_storage_of_results, neightbors_data) + # do something with the points + return nothing +end + +random_storage_of_results = rand(3, 100) +neightbors_data = rand(3, 100) +f(a, b) = callback(a, b, random_storage_of_results, neightbors_data) +``` +""" +function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F) where {V, T <: AbstractVector, F} + check_input(tree, points) + check_radius(radius) + + for i in eachindex(points) + _inrange(tree, points[i], radius, i, callback) + end + return nothing +end + +function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, callback::F) where {V, T <: Number, F} + return inrange_callback!(tree, points, radius, callback, Val(size(points, 1))) +end + +function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F) where {V, T <: Number, F} + return inrange_callback!(tree, points, radius, callback, Val(length(points))) +end + +function inrange_callback!(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, callback::F, ::Val{dim}) where {V, T <: Number, F, dim} + points = reshape(points, size(points, 1), 1) + return inrange_callback!(tree, points, radius, callback, Val(dim)) +end + +function inrange_callback!(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, callback::F, ::Val{dim}) where {V, T <: Number, F, dim} + check_input(tree, points) + check_radius(radius) + n_points = size(points, 2) + for i in 1:n_points + point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) + _inrange(tree, point, radius, i, callback) + end + return nothing +end + +function index_returning_runtime_function(point_index::Int, neighbor_index::Int, idxs) + if eltype(idxs) <: Integer + push!(idxs, eltype(idxs)(neighbor_index)) + else + push!(idxs[point_index], neighbor_index) + end + return nothing +end + +function inrange_callback_default(tree::NNTree{V}, points, radius::Number, sortres=false) where {V} + if points isa AbstractVector{<:AbstractVector} + n_points = length(points) + elseif points isa AbstractMatrix{<:Number} + n_points = size(points, 2) + end + + idxs = [Int[] for _ in 1:n_points] + f(a, b) = index_returning_runtime_function(a, b, idxs) + inrange_callback!(tree, points, radius, f) + + if sortres + for i in eachindex(idxs) + sort!(idxs[i]) + end + end + + if length(idxs) == 1 + return idxs[1] + end + return idxs +end + +function inrange_callback_default(tree::NNTree{V}, points::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} + idxs = Int[] + f(a, b) = index_returning_runtime_function(a, b, idxs) + inrange_callback!(tree, points, radius, f) + + if sortres + sort!(idxs) + end + return idxs +end + diff --git a/src/kd_tree.jl b/src/kd_tree.jl index b84716f..f36c6a6 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -14,17 +14,7 @@ end KDTree(data [, metric = Euclidean(); leafsize = 25, reorder = true]) -> kdtree Creates a `KDTree` from the data using the given `metric` and `leafsize`. - -# Arguments -- `data`: Point data as a matrix of size `nd × np` or vector of vectors -- `metric`: Distance metric to use (must be a `MinkowskiMetric` like `Euclidean`, `Chebyshev`, `Minkowski`, or `Cityblock`). Default: `Euclidean()` -- `leafsize`: Number of points at which to stop splitting the tree. Default: `25` -- `reorder`: If `true`, reorder data to improve cache locality. Default: `true` - -# Returns -- `kdtree`: A `KDTree` instance - -KDTree works best for low-dimensional data with axis-aligned metrics. +The `metric` must be a `MinkowskiMetric`. """ function KDTree(data::AbstractVector{V}, metric::M = Euclidean(); @@ -73,10 +63,18 @@ function KDTree(data::AbstractVector{V}, indices = indices_reordered end + if metric isa Distances.UnionMetrics + p = parameters(metric) + if p !== nothing && length(p) != length(V) + throw(ArgumentError( + "dimension of input points:$(length(V)) and metric parameter:$(length(p)) must agree")) + end + end + KDTree(storedata ? data : similar(data, 0), hyper_rec, indices, metric, split_vals, split_dims, tree_data, reorder) end -function KDTree(data::AbstractVecOrMat{T}, + function KDTree(data::AbstractVecOrMat{T}, metric::M = Euclidean(); leafsize::Int = 25, storedata::Bool = true, @@ -114,7 +112,16 @@ function build_KDTree(index::Int, mid_idx = find_split(first(range), tree_data.leafsize, n_p) - split_dim = argmax(d -> hyper_rec.maxes[d] - hyper_rec.mins[d], 1:length(V)) + split_dim = 1 + max_spread = zero(T) + # Find dimension and spread where the spread is maximal + for d in 1:length(V) + spread = hyper_rec.maxes[d] - hyper_rec.mins[d] + if spread > max_spread + max_spread = spread + split_dim = d + end + end select_spec!(indices, mid_idx, first(range), last(range), data, split_dim) @@ -166,6 +173,8 @@ function knn_kernel!(tree::KDTree{V}, split_dim = tree.split_dims[index] p_dim = point[split_dim] split_val = tree.split_vals[index] + lo = hyper_rec.mins[split_dim] + hi = hyper_rec.maxes[split_dim] split_diff = p_dim - split_val M = tree.metric # Point is to the right of the split value @@ -174,72 +183,63 @@ function knn_kernel!(tree::KDTree{V}, far = getleft(index) hyper_rec_far = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) hyper_rec_close = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) + ddiff = max(zero(eltype(V)), p_dim - hi) else close = getleft(index) far = getright(index) hyper_rec_far = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) hyper_rec_close = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) + ddiff = max(zero(eltype(V)), lo - p_dim) end # Always call closer sub tree knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip) - if M isa Chebyshev - new_min = get_min_distance_no_end(M, hyper_rec_far, point) - else - new_min = update_new_min(M, min_dist, hyper_rec, p_dim, split_dim, split_val) - end - + split_diff_pow = eval_pow(M, split_diff) + ddiff_pow = eval_pow(M, ddiff) + diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) + new_min = eval_reduce(M, min_dist, diff_tot) if new_min < best_dists[1] knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip) end return end -function _inrange( - tree::KDTree, - point::AbstractVector, - radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[] - ) +function _inrange(tree::KDTree, + point::AbstractVector, + radius::Number, + point_index::Int = 1, + callback::Union{Nothing, Function} = nothing) init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) - init_max_contribs = get_max_distance_contributions(tree.metric, tree.hyper_rec, point) - init_max = tree.metric isa Chebyshev ? maximum(init_max_contribs) : sum(init_max_contribs) - return inrange_kernel!( - tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, - tree.hyper_rec, init_min, init_max_contribs, init_max - ) + return inrange_kernel!(tree, 1, point, eval_pow(tree.metric, radius), + tree.hyper_rec, init_min, callback, point_index) end # Explicitly check the distance between leaf node and point while traversing -function inrange_kernel!( - tree::KDTree, - index::Int, - point::AbstractVector, - r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}, - hyper_rec::HyperRectangle, - min_dist, - max_dist_contribs::SVector, - max_dist - ) +function inrange_kernel!(tree::KDTree, + index::Int, + point::AbstractVector, + r::Number, + hyper_rec::HyperRectangle, + min_dist, + callback::Union{Nothing, Function}, + point_index::Int) # Point is outside hyper rectangle, skip the whole sub tree if min_dist > r return 0 end - if max_dist < r - return addall(tree, index, idx_in_ball) - end - # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(tree, index, point, r, callback, point_index) end split_val = tree.split_vals[index] split_dim = tree.split_dims[index] + lo = hyper_rec.mins[split_dim] + hi = hyper_rec.maxes[split_dim] p_dim = point[split_dim] split_diff = p_dim - split_val + M = tree.metric count = 0 @@ -248,42 +248,27 @@ function inrange_kernel!( far = getleft(index) hyper_rec_far = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) hyper_rec_close = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) + ddiff = max(zero(p_dim - hi), p_dim - hi) else # Point is to the left of the split value close = getleft(index) far = getright(index) hyper_rec_far = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) hyper_rec_close = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) + ddiff = max(zero(lo - p_dim), lo - p_dim) end - # Compute contributions for both close and far subtrees - M = tree.metric - old_contrib = max_dist_contribs[split_dim] - if split_diff > 0 - # Point is to the right - # Close subtree: split_val as new min, far subtree: split_val as new max - new_contrib_close = get_max_distance_contribution_single(M, point[split_dim], split_val, hyper_rec.maxes[split_dim], split_dim) - new_contrib_far = get_max_distance_contribution_single(M, point[split_dim], hyper_rec.mins[split_dim], split_val, split_dim) - else - # Point is to the left - # Close subtree: split_val as new max, far subtree: split_val as new min - new_contrib_close = get_max_distance_contribution_single(M, point[split_dim], hyper_rec.mins[split_dim], split_val, split_dim) - new_contrib_far = get_max_distance_contribution_single(M, point[split_dim], split_val, hyper_rec.maxes[split_dim], split_dim) - end - - # Update contributions and distances for close subtree - new_max_contribs_close = setindex(max_dist_contribs, new_contrib_close, split_dim) - new_max_dist_close = M isa Chebyshev ? maximum(new_max_contribs_close) : max_dist - old_contrib + new_contrib_close - # Call closer sub tree - count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, new_max_contribs_close, new_max_dist_close) - - # Compute new min distance for far subtree - new_min = M isa Chebyshev ? get_min_distance_no_end(M, hyper_rec_far, point) : update_new_min(M, min_dist, hyper_rec, p_dim, split_dim, split_val) - - # Update contributions and distances for far subtree - new_max_contribs_far = setindex(max_dist_contribs, new_contrib_far, split_dim) - new_max_dist_far = M isa Chebyshev ? maximum(new_max_contribs_far) : max_dist - old_contrib + new_contrib_far - - # Call further sub tree - count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, new_max_contribs_far, new_max_dist_far) + count += inrange_kernel!(tree, close, point, r, hyper_rec_close, min_dist, callback, point_index) + + # TODO: We could potentially also keep track of the max distance + # between the point and the hyper rectangle and add the whole sub tree + # in case of the max distance being <= r similarly to the BallTree inrange method. + # It would be interesting to benchmark this on some different data sets. + + # Call further sub tree with the new min distance + split_diff_pow = eval_pow(M, split_diff) + ddiff_pow = eval_pow(M, ddiff) + diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) + new_min = eval_reduce(M, min_dist, diff_tot) + count += inrange_kernel!(tree, far, point, r, hyper_rec_far, new_min, callback, point_index) return count end diff --git a/src/knn.jl b/src/knn.jl index 7e13659..775f7eb 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -5,20 +5,12 @@ function check_k(tree, k) end """ - knn(tree::NNTree, points, k [, skip=always_false]) -> indices, distances + knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances -Performs a lookup of the `k` nearest neighbors to the `points` from the data -in the `tree`. - -# Arguments -- `tree`: The tree instance -- `points`: Query point(s) - can be a vector (single point), matrix (multiple points), or vector of vectors -- `k`: Number of nearest neighbors to find -- `skip`: Optional predicate function to skip points based on their index (default: `always_false`) - -# Returns -- `indices`: Indices of the k nearest neighbors -- `distances`: Distances to the k nearest neighbors +Performs a lookup of the `k` nearest neigbours to the `points` from the data +in the `tree`. `skip` is an optional predicate +to determine if a point that would be returned should be skipped based on its +index. See also: `knn!`, `nn`. """ @@ -53,18 +45,10 @@ function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, id end """ - knn!(idxs, dists, tree, point, k [, skip=always_false]) + knn!(idxs, dists, tree, point, k) Same functionality as `knn` but stores the results in the input vectors `idxs` and `dists`. -Useful to avoid allocations or specify the element type of the output vectors. - -# Arguments -- `idxs`: Pre-allocated vector to store indices (must be of length `k`) -- `dists`: Pre-allocated vector to store distances (must be of length `k`) -- `tree`: The tree instance -- `point`: Query point -- `k`: Number of nearest neighbors to find -- `skip`: Optional predicate function to skip points based on their index (default: `always_false`) +Useful if one want to avoid allocations or specify the element type of the output vectors. See also: `knn`, `nn`. """ @@ -107,16 +91,7 @@ end nn(tree::NNTree, point [, skip]) -> index, distance nn(tree::NNTree, points [, skip]) -> indices, distances -Performs a lookup of the single nearest neighbor to the `point(s)` from the data. - -# Arguments -- `tree`: The tree instance -- `point(s)`: Query point(s) - can be a vector (single point), matrix (multiple points), or vector of vectors -- `skip`: Optional predicate function to skip points based on their index (default: `always_false`) - -# Returns -- For single point: `index` and `distance` of the nearest neighbor -- For multiple points: vectors of `indices` and `distances` of the nearest neighbors +Performs a lookup of the single nearest neigbours to the `points` from the data. See also: `knn`. """ diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 39338cf..4f46c30 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -114,14 +114,16 @@ end # stop computing the distance function as soon as we reach the desired radius. # This will probably prevent SIMD and other optimizations so some care is needed # to evaluate if it is worth it. -@inline function add_points_inrange!(idx_in_ball::Union{Nothing, AbstractVector{<:Integer}}, tree::NNTree, - index::Int, point::AbstractVector, r::Number) +@inline function add_points_inrange!(tree::NNTree, + index::Int, point::AbstractVector, r::Number, + callback::Union{Nothing, Function}, + point_index::Int) count = 0 for z in get_leaf_range(tree.tree_data, index) idx = tree.reordered ? z : tree.indices[z] if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 - idx_in_ball !== nothing && push!(idx_in_ball, idx) + @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx) end end return count @@ -138,18 +140,18 @@ end # Add all points in this subtree since we have determined # they are all within the desired range -function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}) +function addall(tree::NNTree, index::Int, callback::Union{Nothing, Function} = nothing, point_index::Int = 1) tree_data = tree.tree_data count = 0 if isleaf(tree_data.n_internal_nodes, index) for z in get_leaf_range(tree_data, index) idx = tree.reordered ? z : tree.indices[z] count += 1 - idx_in_ball !== nothing && push!(idx_in_ball, idx) + @inbounds !isnothing(callback) && callback(point_index, tree.reordered ? tree.indices[idx] : idx) end else - count += addall(tree, getleft(index), idx_in_ball) - count += addall(tree, getright(index), idx_in_ball) + count += addall(tree, getleft(index), callback, point_index) + count += addall(tree, getright(index), callback, point_index) end return count end diff --git a/src/utilities.jl b/src/utilities.jl index 7ad3700..7a7f30a 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -94,4 +94,4 @@ end # Instead of ReinterpretArray wrapper, copy an array, interpreting it as a vector of SVectors copy_svec(::Type{T}, data, ::Val{dim}) where {T, dim} = - [SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)]::Vector{SVector{dim,T}} + [SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)] diff --git a/test/runtests.jl b/test/runtests.jl index d0ee227..7f0a9a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,28 +80,4 @@ using NearestNeighbors: HyperRectangle, get_min_distance_no_end, get_max_distanc @test get_min_distance_no_end(m, hr, point) ≈ NearestNeighbors.eval_pow(m, m(closest_point, point)) @test get_max_distance_no_end(m, hr, point) ≈ NearestNeighbors.eval_pow(m, m(furthest_point, point)) end - - for m in ms - hyper_rec = NearestNeighbors.HyperRectangle{SVector{1, Float32}}(Float32[0.5553872], Float32[0.6169486]) - point = [0.5] - min_dist = NearestNeighbors.get_min_distance_no_end(m, hyper_rec, point) - split_dim = 1 - split_val = 0.5844354f0 - hyper_rec_far = NearestNeighbors.HyperRectangle{SVector{1, Float32}}(Float32[0.5844354], Float32[0.6169486]) - new_min = NearestNeighbors.update_new_min(m, min_dist, hyper_rec, point[split_dim], split_dim, split_val) - new_min_true = NearestNeighbors.get_min_distance_no_end(m, hyper_rec_far, point) - @test new_min ≈ new_min_true - end - - for m in ms - hyper_rec = NearestNeighbors.HyperRectangle{SVector{2, Float64}}([0.07935189250034036, 0.682552911042077], [0.1619776648454222, 0.8046815005307764]) - point = [0.06630748183735935, 0.7541470744398973] - min_dist = NearestNeighbors.get_min_distance_no_end(m, hyper_rec, point) - split_dim = 2 - split_val = 0.7388396209627084 - hyper_rec_far = NearestNeighbors.HyperRectangle{SVector{2, Float64}}([0.07935189250034036, 0.682552911042077], [0.1619776648454222, 0.7388396209627084]) - new_min = NearestNeighbors.update_new_min(m, min_dist, hyper_rec, point[split_dim], split_dim, split_val) - new_min_true = NearestNeighbors.get_min_distance_no_end(m, hyper_rec_far, point) - @test new_min ≈ new_min_true broken = m isa Chebyshev - end end diff --git a/test/test_inrange.jl b/test/test_inrange.jl index 02b0914..456c05f 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -94,6 +94,29 @@ end end end +@testset "inrange_runtime function" begin + function runtime_test(point_index, neighbor_index, sum_of_random_data, neightbor_points) + sum_of_random_data[1] += sum(neightbor_points[4:6,neighbor_index]) + return nothing + end + + for T in (KDTree, BallTree, BruteTree) + sum_runtime = fill(0.0, 1) + data = rand(6, 100) # first 3 rows are "locations", last 3 rows are random data + f(a, b) = runtime_test(a, b, sum_runtime, data) + + tree = KDTree(data[1:3, :]) + inrange_callback!(tree, [0.5, 0.5, 0.5], 1.0, f) + idxs = inrange(tree, [0.5, 0.5, 0.5], 1.0) + sum_idxs = 0.0 + for i in eachindex(idxs) + sum_idxs += sum(data[4:6, idxs[i]]) + end + + @test sum_idxs == sum_runtime[1] + end +end + @testset "inferrability matrix" begin function foo(data, point) b = KDTree(data) @@ -101,4 +124,4 @@ end end @inferred foo([1.0 3.4; 4.5 3.4], [4.5; 3.4]) -end +end \ No newline at end of file diff --git a/test/test_knn.jl b/test/test_knn.jl index 3667772..4298409 100644 --- a/test/test_knn.jl +++ b/test/test_knn.jl @@ -148,12 +148,3 @@ end @test dists ≈ Float32.(dists2) end end - -@testset "inferrability matrix" begin - function foo(data, point) - b = KDTree(data) - return knn(b, point, 1) - end - - @inferred foo([1.0 3.4; 4.5 3.4], [4.5; 3.4]) -end