diff --git a/python/sedonadb/python/sedonadb/dbapi.py b/python/sedonadb/python/sedonadb/dbapi.py index 968b1b004..22596f3cd 100644 --- a/python/sedonadb/python/sedonadb/dbapi.py +++ b/python/sedonadb/python/sedonadb/dbapi.py @@ -38,7 +38,7 @@ def connect(**kwargs: Mapping[str, Any]) -> "Connection": >>> con = sedona.dbapi.connect() >>> with con.cursor() as cur: - ... cur.execute("SELECT 1 as one") + ... _ = cur.execute("SELECT 1 as one") ... cur.fetchall() [(1,)] """ diff --git a/rust/sedona-spatial-join/src/build_index.rs b/rust/sedona-spatial-join/src/build_index.rs new file mode 100644 index 000000000..16e45d526 --- /dev/null +++ b/rust/sedona-spatial-join/src/build_index.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_schema::SchemaRef; +use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::{memory_pool::MemoryConsumer, SendableRecordBatchStream, TaskContext}; +use datafusion_expr::JoinType; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use sedona_common::SedonaOptions; + +use crate::{ + index::{ + BuildSideBatchesCollector, CollectBuildSideMetrics, SpatialIndex, SpatialIndexBuilder, + SpatialJoinBuildMetrics, + }, + operand_evaluator::create_operand_evaluator, + spatial_predicate::SpatialPredicate, +}; + +pub(crate) async fn build_index( + context: Arc, + build_schema: SchemaRef, + build_streams: Vec, + spatial_predicate: SpatialPredicate, + join_type: JoinType, + probe_threads_count: usize, + metrics: ExecutionPlanMetricsSet, +) -> Result { + let session_config = context.session_config(); + let sedona_options = session_config + .options() + .extensions + .get::() + .cloned() + .unwrap_or_default(); + let memory_pool = context.memory_pool(); + let evaluator = + create_operand_evaluator(&spatial_predicate, sedona_options.spatial_join.clone()); + let collector = BuildSideBatchesCollector::new(evaluator); + let num_partitions = build_streams.len(); + let mut collect_metrics_vec = Vec::with_capacity(num_partitions); + let mut reservations = Vec::with_capacity(num_partitions); + for k in 0..num_partitions { + let consumer = + MemoryConsumer::new(format!("SpatialJoinCollectBuildSide[{}]", k)).with_can_spill(true); + let reservation = consumer.register(memory_pool); + reservations.push(reservation); + collect_metrics_vec.push(CollectBuildSideMetrics::new(k, &metrics)); + } + + let build_partitions = collector + .collect_all(build_streams, reservations, collect_metrics_vec) + .await?; + + let contains_external_stream = build_partitions + .iter() + .any(|partition| partition.build_side_batch_stream.is_external()); + if !contains_external_stream { + let mut index_builder = SpatialIndexBuilder::new( + build_schema, + spatial_predicate, + sedona_options.spatial_join, + join_type, + probe_threads_count, + Arc::clone(memory_pool), + SpatialJoinBuildMetrics::new(0, &metrics), + )?; + index_builder.add_partitions(build_partitions).await?; + index_builder.finish() + } else { + Err(DataFusionError::ResourcesExhausted("Memory limit exceeded while collecting indexed data. External spatial index builder is not yet implemented.".to_string())) + } +} diff --git a/rust/sedona-spatial-join/src/evaluated_batch.rs b/rust/sedona-spatial-join/src/evaluated_batch.rs new file mode 100644 index 000000000..ff67f1719 --- /dev/null +++ b/rust/sedona-spatial-join/src/evaluated_batch.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::RecordBatch; +use datafusion_expr::ColumnarValue; +use geo::Rect; +use wkb::reader::Wkb; + +use crate::operand_evaluator::EvaluatedGeometryArray; + +/// EvaluatedBatch contains the original record batch from the input stream and the evaluated +/// geometry array. +pub(crate) struct EvaluatedBatch { + /// Original record batch polled from the stream + pub batch: RecordBatch, + /// Evaluated geometry array, containing the geometry array containing geometries to be joined, + /// rects of joined geometries, evaluated distance columnar values if we are running a distance + /// join, etc. + pub geom_array: EvaluatedGeometryArray, +} + +impl EvaluatedBatch { + pub fn in_mem_size(&self) -> usize { + // NOTE: sometimes `geom_array` will reuse the memory of `batch`, especially when + // the expression for evaluating the geometry is a simple column reference. In this case, + // the in_mem_size will be overestimated. It is a conservative estimation so there's no risk + // of running out of memory because of underestimation. + self.batch.get_array_memory_size() + self.geom_array.in_mem_size() + } + + pub fn num_rows(&self) -> usize { + self.batch.num_rows() + } + + pub fn wkb(&self, idx: usize) -> Option<&Wkb<'_>> { + let wkbs = self.geom_array.wkbs(); + wkbs[idx].as_ref() + } + + pub fn rects(&self) -> &Vec>> { + &self.geom_array.rects + } + + pub fn distance(&self) -> &Option { + &self.geom_array.distance + } +} + +pub(crate) mod evaluated_batch_stream; diff --git a/rust/sedona-spatial-join/src/evaluated_batch/evaluated_batch_stream.rs b/rust/sedona-spatial-join/src/evaluated_batch/evaluated_batch_stream.rs new file mode 100644 index 000000000..958087f7b --- /dev/null +++ b/rust/sedona-spatial-join/src/evaluated_batch/evaluated_batch_stream.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::pin::Pin; + +use futures::Stream; + +use crate::evaluated_batch::EvaluatedBatch; +use datafusion_common::Result; + +/// A stream that produces [`EvaluatedBatch`] items. This stream may have purely in-memory or +/// out-of-core implementations. The type of the stream could be queried calling `is_external()`. +pub(crate) trait EvaluatedBatchStream: Stream> { + /// Returns true if this stream is an external stream, where batch data were spilled to disk. + fn is_external(&self) -> bool; +} + +pub(crate) type SendableEvaluatedBatchStream = Pin>; + +pub(crate) mod in_mem; diff --git a/rust/sedona-spatial-join/src/evaluated_batch/evaluated_batch_stream/in_mem.rs b/rust/sedona-spatial-join/src/evaluated_batch/evaluated_batch_stream/in_mem.rs new file mode 100644 index 000000000..57671547b --- /dev/null +++ b/rust/sedona-spatial-join/src/evaluated_batch/evaluated_batch_stream/in_mem.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + pin::Pin, + task::{Context, Poll}, + vec::IntoIter, +}; + +use datafusion_common::Result; + +use crate::evaluated_batch::{evaluated_batch_stream::EvaluatedBatchStream, EvaluatedBatch}; + +pub(crate) struct InMemoryEvaluatedBatchStream { + iter: IntoIter, +} + +impl InMemoryEvaluatedBatchStream { + pub fn new(batches: Vec) -> Self { + InMemoryEvaluatedBatchStream { + iter: batches.into_iter(), + } + } +} + +impl EvaluatedBatchStream for InMemoryEvaluatedBatchStream { + fn is_external(&self) -> bool { + false + } +} + +impl futures::Stream for InMemoryEvaluatedBatchStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.get_mut() + .iter + .next() + .map(|batch| Poll::Ready(Some(Ok(batch)))) + .unwrap_or(Poll::Ready(None)) + } +} diff --git a/rust/sedona-spatial-join/src/exec.rs b/rust/sedona-spatial-join/src/exec.rs index 6dc935b75..633fb8d3d 100644 --- a/rust/sedona-spatial-join/src/exec.rs +++ b/rust/sedona-spatial-join/src/exec.rs @@ -35,12 +35,12 @@ use datafusion_physical_plan::{ use parking_lot::Mutex; use crate::{ - index::{build_index, SpatialIndex, SpatialJoinBuildMetrics}, - once_fut::OnceAsync, + build_index::build_index, + index::SpatialIndex, spatial_predicate::{KNNPredicate, SpatialPredicate}, stream::{SpatialJoinProbeMetrics, SpatialJoinStream}, - utils::{asymmetric_join_output_partitioning, boundedness_from_children}, - // Re-export from sedona-common + utils::join_utils::{asymmetric_join_output_partitioning, boundedness_from_children}, + utils::once_fut::OnceAsync, SedonaOptions, }; @@ -141,7 +141,7 @@ pub struct SpatialJoinExec { } impl SpatialJoinExec { - // Try to create a new [`SpatialJoinExec`] + /// Try to create a new [`SpatialJoinExec`] pub fn try_new( left: Arc, right: Arc, @@ -449,25 +449,22 @@ impl ExecutionPlan for SpatialJoinExec { let num_partitions = build_side.output_partitioning().partition_count(); let mut build_streams = Vec::with_capacity(num_partitions); - let mut build_metrics = Vec::with_capacity(num_partitions); for k in 0..num_partitions { let stream = build_side.execute(k, Arc::clone(&context))?; build_streams.push(stream); - build_metrics.push(SpatialJoinBuildMetrics::new(k, &self.metrics)); } let probe_thread_count = self.right.output_partitioning().partition_count(); Ok(build_index( + Arc::clone(&context), build_side.schema(), build_streams, self.on.clone(), - sedona_options.spatial_join.clone(), - build_metrics, - Arc::clone(context.memory_pool()), self.join_type, probe_thread_count, + self.metrics.clone(), )) })? }; @@ -546,24 +543,21 @@ impl SpatialJoinExec { let num_partitions = build_side.output_partitioning().partition_count(); let mut build_streams = Vec::with_capacity(num_partitions); - let mut build_metrics = Vec::with_capacity(num_partitions); for k in 0..num_partitions { let stream = build_side.execute(k, Arc::clone(&context))?; build_streams.push(stream); - build_metrics.push(SpatialJoinBuildMetrics::new(k, &self.metrics)); } let probe_thread_count = self.right.output_partitioning().partition_count(); Ok(build_index( + Arc::clone(&context), build_side.schema(), build_streams, self.on.clone(), - sedona_options.spatial_join.clone(), - build_metrics, - Arc::clone(context.memory_pool()), self.join_type, probe_thread_count, + self.metrics.clone(), )) })? }; diff --git a/rust/sedona-spatial-join/src/index.rs b/rust/sedona-spatial-join/src/index.rs index 3b0a8b324..c6fa542c8 100644 --- a/rust/sedona-spatial-join/src/index.rs +++ b/rust/sedona-spatial-join/src/index.rs @@ -14,1946 +14,30 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use once_cell::sync::OnceCell; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; -use arrow_array::RecordBatch; -use arrow_schema::SchemaRef; -use datafusion_common::{utils::proxy::VecAllocExt, DataFusionError, Result}; -use datafusion_common_runtime::JoinSet; -use datafusion_execution::{ - memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}, - SendableRecordBatchStream, -}; -use datafusion_expr::{ColumnarValue, JoinType}; -use datafusion_physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; -use futures::StreamExt; -use geo_index::rtree::distance::{ - DistanceMetric, EuclideanDistance, GeometryAccessor, HaversineDistance, -}; -use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex}; -use geo_index::IndexableNum; -use geo_types::{Geometry, Point, Rect}; -use parking_lot::Mutex; -use sedona_expr::statistics::GeoStatistics; -use sedona_functions::st_analyze_aggr::AnalyzeAccumulator; -use sedona_geo::to_geo::item_to_geometry; -use sedona_geo_generic_alg::algorithm::Centroid; -use sedona_schema::datatypes::WKB_GEOMETRY; -use wkb::reader::Wkb; +pub(crate) mod build_side_collector; +mod knn_adapter; +pub(crate) mod spatial_index; +pub(crate) mod spatial_index_builder; -use crate::{ - concurrent_reservation::ConcurrentReservation, - operand_evaluator::{create_operand_evaluator, EvaluatedGeometryArray, OperandEvaluator}, - refine::{create_refiner, IndexQueryResultRefiner}, - spatial_predicate::SpatialPredicate, - utils::need_produce_result_in_final, +pub(crate) use build_side_collector::{ + BuildPartition, BuildSideBatchesCollector, CollectBuildSideMetrics, }; -use arrow::array::BooleanBufferBuilder; -use sedona_common::{option::SpatialJoinOptions, ExecutionMode}; - -// Type aliases for better readability -type SpatialRTree = RTree; -type DataIdToBatchPos = Vec<(i32, i32)>; -type RTreeBuildResult = (SpatialRTree, DataIdToBatchPos); - -/// The prealloc size for the refiner reservation. This is used to reduce the frequency of growing -/// the reservation when updating the refiner memory reservation. -const REFINER_RESERVATION_PREALLOC_SIZE: usize = 10 * 1024 * 1024; // 10MB - -/// Metrics for the build phase of the spatial join. -#[derive(Clone, Debug, Default)] -pub(crate) struct SpatialJoinBuildMetrics { - /// Total time for collecting build-side of join - pub(crate) build_time: metrics::Time, - /// Number of batches consumed by build-side - pub(crate) build_input_batches: metrics::Count, - /// Number of rows consumed by build-side - pub(crate) build_input_rows: metrics::Count, - /// Memory used by build-side in bytes - pub(crate) build_mem_used: metrics::Gauge, -} - -impl SpatialJoinBuildMetrics { - pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - Self { - build_time: MetricBuilder::new(metrics).subset_time("build_time", partition), - build_input_batches: MetricBuilder::new(metrics) - .counter("build_input_batches", partition), - build_input_rows: MetricBuilder::new(metrics).counter("build_input_rows", partition), - build_mem_used: MetricBuilder::new(metrics).gauge("build_mem_used", partition), - } - } -} - -/// Builder for constructing a SpatialIndex from geometry batches. -/// -/// This builder handles: -/// 1. Accumulating geometry batches to be indexed -/// 2. Building the spatial R-tree index -/// 3. Setting up memory tracking and visited bitmaps -/// 4. Configuring prepared geometries based on execution mode -pub(crate) struct SpatialIndexBuilder { - spatial_predicate: SpatialPredicate, - options: SpatialJoinOptions, - join_type: JoinType, - probe_threads_count: usize, - metrics: SpatialJoinBuildMetrics, - - /// Batches to be indexed - indexed_batches: Vec, - /// Memory reservation for tracking the memory usage of the spatial index - reservation: MemoryReservation, - - /// Statistics for indexed geometries - stats: GeoStatistics, - - /// Memory pool for managing the memory usage of the spatial index - memory_pool: Arc, -} - -impl SpatialIndexBuilder { - /// Create a new builder with the given configuration. - pub fn new( - spatial_predicate: SpatialPredicate, - options: SpatialJoinOptions, - join_type: JoinType, - probe_threads_count: usize, - memory_pool: Arc, - metrics: SpatialJoinBuildMetrics, - ) -> Result { - let consumer = MemoryConsumer::new("SpatialJoinIndex"); - let reservation = consumer.register(&memory_pool); - - Ok(Self { - spatial_predicate, - options, - join_type, - probe_threads_count, - metrics, - indexed_batches: Vec::new(), - reservation, - stats: GeoStatistics::empty(), - memory_pool, - }) - } - - /// Add a geometry batch to be indexed. - /// - /// This method accumulates geometry batches that will be used to build the spatial index. - /// Each batch contains processed geometry data along with memory usage information. - pub fn add_batch(&mut self, indexed_batch: IndexedBatch) { - let in_mem_size = indexed_batch.in_mem_size(); - self.indexed_batches.push(indexed_batch); - self.reservation.grow(in_mem_size); - self.metrics.build_mem_used.add(in_mem_size); - } - - pub fn with_stats(&mut self, stats: GeoStatistics) -> &mut Self { - self.stats.merge(&stats); - self - } - - /// Build the spatial R-tree index from collected geometry batches. - fn build_rtree(&mut self) -> Result { - let build_timer = self.metrics.build_time.timer(); - - let num_rects = self - .indexed_batches - .iter() - .map(|batch| batch.rects().len()) - .sum::(); - - let mut rtree_builder = RTreeBuilder::::new(num_rects as u32); - let mut batch_pos_vec = vec![(0, 0); num_rects]; - let rtree_mem_estimate = num_rects * RTREE_MEMORY_ESTIMATE_PER_RECT; - - self.reservation - .grow(batch_pos_vec.allocated_size() + rtree_mem_estimate); - - for (batch_idx, batch) in self.indexed_batches.iter().enumerate() { - let rects = batch.rects(); - for (idx, rect) in rects { - let min = rect.min(); - let max = rect.max(); - let data_idx = rtree_builder.add(min.x, min.y, max.x, max.y); - batch_pos_vec[data_idx as usize] = (batch_idx as i32, *idx as i32); - } - } - - let rtree = rtree_builder.finish::(); - build_timer.done(); - - self.metrics.build_mem_used.add(self.reservation.size()); - - Ok((rtree, batch_pos_vec)) - } - - /// Build visited bitmaps for tracking left-side indices in outer joins. - fn build_visited_bitmaps(&mut self) -> Result>>> { - if !need_produce_result_in_final(self.join_type) { - return Ok(None); - } - - let mut bitmaps = Vec::with_capacity(self.indexed_batches.len()); - let mut total_buffer_size = 0; - - for batch in &self.indexed_batches { - let batch_rows = batch.batch.num_rows(); - let buffer_size = batch_rows.div_ceil(8); - total_buffer_size += buffer_size; - - let mut bitmap = BooleanBufferBuilder::new(batch_rows); - bitmap.append_n(batch_rows, false); - bitmaps.push(bitmap); - } - - self.reservation.try_grow(total_buffer_size)?; - self.metrics.build_mem_used.add(total_buffer_size); - - Ok(Some(Mutex::new(bitmaps))) - } - - /// Create an rtree data index to consecutive index mapping. - fn build_geom_idx_vec(&mut self, batch_pos_vec: &Vec<(i32, i32)>) -> Vec { - let mut num_geometries = 0; - let mut batch_idx_offset = Vec::with_capacity(self.indexed_batches.len() + 1); - batch_idx_offset.push(0); - for batch in &self.indexed_batches { - num_geometries += batch.batch.num_rows(); - batch_idx_offset.push(num_geometries); - } - - let mut geom_idx_vec = Vec::with_capacity(batch_pos_vec.len()); - self.reservation.grow(geom_idx_vec.allocated_size()); - for (batch_idx, row_idx) in batch_pos_vec { - // Convert (batch_idx, row_idx) to a linear, sequential index - let batch_offset = batch_idx_offset[*batch_idx as usize]; - let prepared_idx = batch_offset + *row_idx as usize; - geom_idx_vec.push(prepared_idx); - } - - geom_idx_vec - } - - /// Finish building and return the completed SpatialIndex. - pub fn finish(mut self, schema: SchemaRef) -> Result { - if self.indexed_batches.is_empty() { - return Ok(SpatialIndex::empty( - self.spatial_predicate, - schema, - self.options, - AtomicUsize::new(self.probe_threads_count), - self.reservation, - self.memory_pool.clone(), - )); - } - - let evaluator = create_operand_evaluator(&self.spatial_predicate, self.options.clone()); - let num_geoms = self - .indexed_batches - .iter() - .map(|batch| batch.batch.num_rows()) - .sum::(); - - let (rtree, batch_pos_vec) = self.build_rtree()?; - let geom_idx_vec = self.build_geom_idx_vec(&batch_pos_vec); - let visited_left_side = self.build_visited_bitmaps()?; - - let refiner = create_refiner( - self.options.spatial_library, - &self.spatial_predicate, - self.options.clone(), - num_geoms, - self.stats, - ); - let consumer = MemoryConsumer::new("SpatialJoinRefiner"); - let refiner_reservation = consumer.register(&self.memory_pool); - let refiner_reservation = - ConcurrentReservation::try_new(REFINER_RESERVATION_PREALLOC_SIZE, refiner_reservation) - .unwrap(); - - let cache_size = batch_pos_vec.len(); - let knn_components = - KnnComponents::new(cache_size, &self.indexed_batches, self.memory_pool.clone())?; - - Ok(SpatialIndex { - schema, - evaluator, - refiner, - refiner_reservation, - rtree, - data_id_to_batch_pos: batch_pos_vec, - indexed_batches: self.indexed_batches, - geom_idx_vec, - visited_left_side, - probe_threads_counter: AtomicUsize::new(self.probe_threads_count), - knn_components, - reservation: self.reservation, - }) - } -} - -pub(crate) struct SpatialIndex { - schema: SchemaRef, - - /// The spatial predicate evaluator for the spatial predicate. - evaluator: Arc, - - /// The refiner for refining the index query results. - refiner: Arc, - - /// Memory reservation for tracking the memory usage of the refiner - refiner_reservation: ConcurrentReservation, - - /// R-tree index for the geometry batches. It takes MBRs as query windows and returns - /// data indexes. These data indexes should be translated using `data_id_to_batch_pos` to get - /// the original geometry batch index and row index, or translated using `prepared_geom_idx_vec` - /// to get the prepared geometries array index. - rtree: RTree, - - /// Indexed batches containing evaluated geometry arrays. It contains the original record - /// batches and geometry arrays obtained by evaluating the geometry expression on the build side. - indexed_batches: Vec, - /// An array for translating rtree data index to geometry batch index and row index - data_id_to_batch_pos: Vec<(i32, i32)>, - - /// An array for translating rtree data index to consecutive index. Each geometry may be indexed by - /// multiple boxes, so there could be multiple data indexes for the same geometry. A mapping for - /// squashing the index makes it easier for persisting per-geometry auxiliary data for evaluating - /// the spatial predicate. This is extensively used by the spatial predicate evaluators for storing - /// prepared geometries. - geom_idx_vec: Vec, - - /// Shared bitmap builders for visited left indices, one per batch - visited_left_side: Option>>, - - /// Counter of running probe-threads, potentially able to update `bitmap`. - /// Each time a probe thread finished probing the index, it will decrement the counter. - /// The last finished probe thread will produce the extra output batches for unmatched - /// build side when running left-outer joins. See also [`report_probe_completed`]. - probe_threads_counter: AtomicUsize, - - /// Shared KNN components (distance metrics and geometry cache) for efficient KNN queries - knn_components: KnnComponents, - - /// Memory reservation for tracking the memory usage of the spatial index - /// Cleared on `SpatialIndex` drop - #[expect(dead_code)] - reservation: MemoryReservation, -} - -/// Indexed batch containing the original record batch and the evaluated geometry array. -pub(crate) struct IndexedBatch { - batch: RecordBatch, - geom_array: EvaluatedGeometryArray, -} - -impl IndexedBatch { - pub fn in_mem_size(&self) -> usize { - // NOTE: sometimes `geom_array` will reuse the memory of `batch`, especially when - // the expression for evaluating the geometry is a simple column reference. In this case, - // the in_mem_size will be overestimated. - self.batch.get_array_memory_size() + self.geom_array.in_mem_size() - } - - pub fn wkb(&self, idx: usize) -> Option<&Wkb<'_>> { - let wkbs = self.geom_array.wkbs(); - wkbs[idx].as_ref() - } - - pub fn rects(&self) -> &Vec<(usize, Rect)> { - &self.geom_array.rects - } - - pub fn distance(&self) -> &Option { - &self.geom_array.distance - } -} - -#[derive(Debug)] -pub struct JoinResultMetrics { - pub count: usize, - pub candidate_count: usize, -} - -impl SpatialIndex { - fn empty( - spatial_predicate: SpatialPredicate, - schema: SchemaRef, - options: SpatialJoinOptions, - probe_threads_counter: AtomicUsize, - mut reservation: MemoryReservation, - memory_pool: Arc, - ) -> Self { - let evaluator = create_operand_evaluator(&spatial_predicate, options.clone()); - let refiner = create_refiner( - options.spatial_library, - &spatial_predicate, - options.clone(), - 0, - GeoStatistics::empty(), - ); - let refiner_reservation = reservation.split(0); - let refiner_reservation = ConcurrentReservation::try_new(0, refiner_reservation).unwrap(); - let rtree = RTreeBuilder::::new(0).finish::(); - Self { - schema, - evaluator, - refiner, - refiner_reservation, - rtree, - data_id_to_batch_pos: Vec::new(), - indexed_batches: Vec::new(), - geom_idx_vec: Vec::new(), - visited_left_side: None, - probe_threads_counter, - knn_components: KnnComponents::new(0, &[], memory_pool.clone()).unwrap(), // Empty index has no cache - reservation, - } - } - - pub(crate) fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - /// Create a KNN geometry accessor for accessing geometries with caching - fn create_knn_accessor(&self) -> SedonaKnnAdapter<'_> { - SedonaKnnAdapter::new( - &self.indexed_batches, - &self.data_id_to_batch_pos, - &self.knn_components, - ) - } - - /// Get the batch at the given index. - pub(crate) fn get_indexed_batch(&self, batch_idx: usize) -> &RecordBatch { - &self.indexed_batches[batch_idx].batch - } - - /// Query the spatial index with a probe geometry to find matching build-side geometries. - /// - /// This method implements a two-phase spatial join query: - /// 1. **Filter phase**: Uses the R-tree index with the probe geometry's bounding rectangle - /// to quickly identify candidate geometries that might satisfy the spatial predicate - /// 2. **Refinement phase**: Evaluates the exact spatial predicate on candidates to determine - /// actual matches - /// - /// # Arguments - /// * `probe_wkb` - The probe geometry in WKB format - /// * `probe_rect` - The minimum bounding rectangle of the probe geometry - /// * `distance` - Optional distance parameter for distance-based spatial predicates - /// * `build_batch_positions` - Output vector that will be populated with (batch_idx, row_idx) - /// pairs for each matching build-side geometry - /// - /// # Returns - /// * `JoinResultMetrics` containing the number of actual matches (`count`) and the number - /// of candidates from the filter phase (`candidate_count`) - pub(crate) fn query( - &self, - probe_wkb: &Wkb, - probe_rect: &Rect, - distance: &Option, - build_batch_positions: &mut Vec<(i32, i32)>, - ) -> Result { - let min = probe_rect.min(); - let max = probe_rect.max(); - let mut candidates = self.rtree.search(min.x, min.y, max.x, max.y); - if candidates.is_empty() { - return Ok(JoinResultMetrics { - count: 0, - candidate_count: 0, - }); - } - - // Sort and dedup candidates to avoid duplicate results when we index one geometry - // using several boxes. - candidates.sort_unstable(); - candidates.dedup(); - - // Refine the candidates retrieved from the r-tree index by evaluating the actual spatial predicate - self.refine(probe_wkb, &candidates, distance, build_batch_positions) - } - - /// Query the spatial index for k nearest neighbors of a given geometry. - /// - /// This method finds the k nearest neighbors to the probe geometry using: - /// 1. R-tree's built-in neighbors() method for efficient KNN search - /// 2. Distance refinement using actual geometry calculations - /// 3. Tie-breaker handling when enabled - /// - /// # Arguments - /// - /// * `probe_wkb` - WKB representation of the probe geometry - /// * `k` - Number of nearest neighbors to find - /// * `use_spheroid` - Whether to use spheroid distance calculation - /// * `include_tie_breakers` - Whether to include additional results with same distance as kth neighbor - /// * `build_batch_positions` - Output vector for matched positions - /// - /// # Returns - /// - /// * `JoinResultMetrics` containing the number of actual matches and candidates processed - pub(crate) fn query_knn( - &self, - probe_wkb: &Wkb, - k: u32, - use_spheroid: bool, - include_tie_breakers: bool, - build_batch_positions: &mut Vec<(i32, i32)>, - ) -> Result { - if k == 0 { - return Ok(JoinResultMetrics { - count: 0, - candidate_count: 0, - }); - } - - // Check if index is empty - if self.indexed_batches.is_empty() || self.data_id_to_batch_pos.is_empty() { - return Ok(JoinResultMetrics { - count: 0, - candidate_count: 0, - }); - } - - // Convert probe WKB to geo::Geometry - let probe_geom = match item_to_geometry(probe_wkb) { - Ok(geom) => geom, - Err(_) => { - // Empty or unsupported geometries (e.g., POINT EMPTY) return empty results - return Ok(JoinResultMetrics { - count: 0, - candidate_count: 0, - }); - } - }; - - // Select the appropriate distance metric - let distance_metric: &dyn DistanceMetric = if use_spheroid { - &self.knn_components.haversine_metric - } else { - &self.knn_components.euclidean_metric - }; - - // Create geometry accessor for on-demand WKB decoding and caching - let geometry_accessor = self.create_knn_accessor(); - - // Use neighbors_geometry to find k nearest neighbors - let initial_results = self.rtree.neighbors_geometry( - &probe_geom, - Some(k as usize), - None, // no max_distance filter - distance_metric, - &geometry_accessor, - ); - - if initial_results.is_empty() { - return Ok(JoinResultMetrics { - count: 0, - candidate_count: 0, - }); - } - - let mut final_results = initial_results; - let mut candidate_count = final_results.len(); - - // Handle tie-breakers if enabled - if include_tie_breakers && !final_results.is_empty() && k > 0 { - // Calculate distances for the initial k results to find the k-th distance - let mut distances_with_indices: Vec<(f64, u32)> = Vec::new(); - - for &result_idx in &final_results { - if (result_idx as usize) < self.data_id_to_batch_pos.len() { - if let Some(item_geom) = geometry_accessor.get_geometry(result_idx as usize) { - let distance = distance_metric.distance_to_geometry(&probe_geom, item_geom); - if let Some(distance_f64) = distance.to_f64() { - distances_with_indices.push((distance_f64, result_idx)); - } - } - } - } - - // Sort by distance - distances_with_indices - .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); - - // Find the k-th distance (if we have at least k results) - if distances_with_indices.len() >= k as usize { - let k_idx = (k as usize) - .min(distances_with_indices.len()) - .saturating_sub(1); - let max_distance = distances_with_indices[k_idx].0; - - // For tie-breakers, create spatial envelope around probe centroid and use rtree.search() - - let probe_centroid = probe_geom.centroid().unwrap_or(Point::new(0.0, 0.0)); - let probe_x = probe_centroid.x() as f32; - let probe_y = probe_centroid.y() as f32; - let max_distance_f32 = match f32::from_f64(max_distance) { - Some(val) => val, - None => { - // If conversion fails, return empty results for this probe - return Ok(JoinResultMetrics { - count: 0, - candidate_count: 0, - }); - } - }; - - // Create envelope bounds around probe centroid - let min_x = probe_x - max_distance_f32; - let min_y = probe_y - max_distance_f32; - let max_x = probe_x + max_distance_f32; - let max_y = probe_y + max_distance_f32; - - // Use rtree.search() with envelope bounds (like the old code) - let expanded_results = self.rtree.search(min_x, min_y, max_x, max_y); - - candidate_count = expanded_results.len(); - - // Calculate distances for all results and find ties - let mut all_distances_with_indices: Vec<(f64, u32)> = Vec::new(); - - for &result_idx in &expanded_results { - if (result_idx as usize) < self.data_id_to_batch_pos.len() { - if let Some(item_geom) = geometry_accessor.get_geometry(result_idx as usize) - { - let distance = - distance_metric.distance_to_geometry(&probe_geom, item_geom); - if let Some(distance_f64) = distance.to_f64() { - all_distances_with_indices.push((distance_f64, result_idx)); - } - } - } - } - - // Sort by distance - all_distances_with_indices - .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); - - // Include all results up to and including those with the same distance as the k-th result - const DISTANCE_TOLERANCE: f64 = 1e-9; - let mut tie_breaker_results: Vec = Vec::new(); - - for (i, &(distance, result_idx)) in all_distances_with_indices.iter().enumerate() { - if i < k as usize { - // Include the first k results - tie_breaker_results.push(result_idx); - } else if (distance - max_distance).abs() <= DISTANCE_TOLERANCE { - // Include tie-breakers (same distance as k-th result) - tie_breaker_results.push(result_idx); - } else { - // No more ties, stop - break; - } - } - - final_results = tie_breaker_results; - } - } else { - // When tie-breakers are disabled, limit results to exactly k - if final_results.len() > k as usize { - final_results.truncate(k as usize); - } - } - - // Convert results to build_batch_positions using existing data_id_to_batch_pos mapping - for &result_idx in &final_results { - if (result_idx as usize) < self.data_id_to_batch_pos.len() { - build_batch_positions.push(self.data_id_to_batch_pos[result_idx as usize]); - } - } - - Ok(JoinResultMetrics { - count: final_results.len(), - candidate_count, - }) - } - - fn refine( - &self, - probe_wkb: &Wkb, - candidates: &[u32], - distance: &Option, - build_batch_positions: &mut Vec<(i32, i32)>, - ) -> Result { - let candidate_count = candidates.len(); - - let mut index_query_results = Vec::with_capacity(candidate_count); - for data_idx in candidates { - let pos = self.data_id_to_batch_pos[*data_idx as usize]; - let (batch_idx, row_idx) = pos; - let indexed_batch = &self.indexed_batches[batch_idx as usize]; - let build_wkb = indexed_batch.wkb(row_idx as usize); - let Some(build_wkb) = build_wkb else { - continue; - }; - let distance = self.evaluator.resolve_distance( - indexed_batch.distance(), - row_idx as usize, - distance, - )?; - let geom_idx = self.geom_idx_vec[*data_idx as usize]; - index_query_results.push(IndexQueryResult { - wkb: build_wkb, - distance, - geom_idx, - position: pos, - }); - } - - if index_query_results.is_empty() { - return Ok(JoinResultMetrics { - count: 0, - candidate_count, - }); - } - - let results = self.refiner.refine(probe_wkb, &index_query_results)?; - let num_results = results.len(); - build_batch_positions.extend(results); - - // Update refiner memory reservation - self.refiner_reservation.resize(self.refiner.mem_usage())?; - - Ok(JoinResultMetrics { - count: num_results, - candidate_count, - }) - } - - /// Check if the index needs more probe statistics to determine the optimal execution mode. - /// - /// # Returns - /// * `bool` - `true` if the index needs more probe statistics, `false` otherwise. - pub(crate) fn need_more_probe_stats(&self) -> bool { - self.refiner.need_more_probe_stats() - } - - /// Merge the probe statistics into the index. - /// - /// # Arguments - /// * `stats` - The probe statistics to merge. - pub(crate) fn merge_probe_stats(&self, stats: GeoStatistics) { - self.refiner.merge_probe_stats(stats); - } - - /// Get the bitmaps for tracking visited left-side indices. The bitmaps will be updated - /// by the spatial join stream when producing output batches during index probing phase. - pub(crate) fn visited_left_side(&self) -> Option<&Mutex>> { - self.visited_left_side.as_ref() - } - - /// Decrements counter of running threads, and returns `true` - /// if caller is the last running thread - pub(crate) fn report_probe_completed(&self) -> bool { - self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 - } - - /// Get the memory usage of the refiner in bytes. - pub(crate) fn get_refiner_mem_usage(&self) -> usize { - self.refiner.mem_usage() - } - - /// Get the actual execution mode used by the refiner - pub(crate) fn get_actual_execution_mode(&self) -> ExecutionMode { - self.refiner.actual_execution_mode() - } -} +pub(crate) use spatial_index::SpatialIndex; +pub(crate) use spatial_index_builder::{SpatialIndexBuilder, SpatialJoinBuildMetrics}; +use wkb::reader::Wkb; -pub struct IndexQueryResult<'a, 'b> { +/// The result of a spatial index query +pub(crate) struct IndexQueryResult<'a, 'b> { pub wkb: &'b Wkb<'a>, pub distance: Option, pub geom_idx: usize, pub position: (i32, i32), } -#[allow(clippy::too_many_arguments)] -pub(crate) async fn build_index( - mut build_schema: SchemaRef, - build_streams: Vec, - spatial_predicate: SpatialPredicate, - options: SpatialJoinOptions, - metrics_vec: Vec, - memory_pool: Arc, - join_type: JoinType, - probe_threads_count: usize, -) -> Result { - // Handle empty streams case - if build_streams.is_empty() { - let consumer = MemoryConsumer::new("SpatialJoinIndex"); - let reservation = consumer.register(&memory_pool); - return Ok(SpatialIndex::empty( - spatial_predicate, - build_schema, - options, - AtomicUsize::new(probe_threads_count), - reservation, - memory_pool, - )); - } - - // Update schema from the first stream - build_schema = build_streams.first().unwrap().schema(); - let metrics = metrics_vec.first().unwrap().clone(); - let evaluator = create_operand_evaluator(&spatial_predicate, options.clone()); - - // Spawn all tasks to scan all build streams concurrently - let mut join_set = JoinSet::new(); - let collect_statistics = matches!(options.execution_mode, ExecutionMode::Speculative(_)); - for (partition, (stream, metrics)) in build_streams.into_iter().zip(metrics_vec).enumerate() { - let per_task_evaluator = Arc::clone(&evaluator); - let consumer = MemoryConsumer::new(format!("SpatialJoinFetchBuild[{partition}]")); - let reservation = consumer.register(&memory_pool); - join_set.spawn(async move { - collect_build_partition( - stream, - per_task_evaluator.as_ref(), - &metrics, - reservation, - collect_statistics, - ) - .await - }); - } - - // Process each task as it completes and add batches to builder - let results = join_set.join_all().await; - - // Create the builder to build the index - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - join_type, - probe_threads_count, - memory_pool.clone(), - metrics, - )?; - for result in results { - let build_partition = - result.map_err(|e| DataFusionError::Execution(format!("Task join error: {e}")))?; - - // Add each geometry batch to the builder - for indexed_batch in build_partition.batches { - builder.add_batch(indexed_batch); - } - builder.with_stats(build_partition.stats); - // build_partition.reservation will be dropped here. - } - - // Finish building the index - builder.finish(build_schema) -} - -struct BuildPartition { - batches: Vec, - stats: GeoStatistics, - - /// Memory reservation for tracking the memory usage of the build partition - /// Cleared on `BuildPartition` drop - #[allow(dead_code)] - reservation: MemoryReservation, -} - -async fn collect_build_partition( - mut stream: SendableRecordBatchStream, - evaluator: &dyn OperandEvaluator, - metrics: &SpatialJoinBuildMetrics, - mut reservation: MemoryReservation, - collect_statistics: bool, -) -> Result { - let mut batches = Vec::new(); - let mut analyzer = AnalyzeAccumulator::new(WKB_GEOMETRY, WKB_GEOMETRY); - - while let Some(batch) = stream.next().await { - let build_timer = metrics.build_time.timer(); - let batch = batch?; - - metrics.build_input_rows.add(batch.num_rows()); - metrics.build_input_batches.add(1); - - let geom_array = evaluator.evaluate_build(&batch)?; - let indexed_batch = IndexedBatch { batch, geom_array }; - - // Update statistics for each geometry in the batch - if collect_statistics { - for wkb in indexed_batch.geom_array.wkbs().iter().flatten() { - analyzer.update_statistics(wkb, wkb.buf().len())?; - } - } - - let in_mem_size = indexed_batch.in_mem_size(); - batches.push(indexed_batch); - - reservation.grow(in_mem_size); - metrics.build_mem_used.add(in_mem_size); - build_timer.done(); - } - - Ok(BuildPartition { - batches, - stats: analyzer.finish(), - reservation, - }) -} - -/// Rough estimate for in-memory size of the rtree per rect in bytes -const RTREE_MEMORY_ESTIMATE_PER_RECT: usize = 60; - -/// Shared KNN components that can be reused across queries -struct KnnComponents { - euclidean_metric: EuclideanDistance, - haversine_metric: HaversineDistance, - /// Pre-allocated vector for geometry cache - lock-free access - /// Indexed by rtree data index for O(1) access - geometry_cache: Vec>>, - /// Memory reservation to track geometry cache memory usage - _reservation: MemoryReservation, -} - -impl KnnComponents { - fn new( - cache_size: usize, - indexed_batches: &[IndexedBatch], - memory_pool: Arc, - ) -> datafusion_common::Result { - // Create memory consumer and reservation for geometry cache - let consumer = MemoryConsumer::new("SpatialJoinKnnGeometryCache"); - let mut reservation = consumer.register(&memory_pool); - - // Estimate maximum possible memory usage based on WKB sizes - let estimated_memory = Self::estimate_max_memory_usage(indexed_batches); - reservation.try_grow(estimated_memory)?; - - // Pre-allocate OnceCell vector - let geometry_cache = (0..cache_size).map(|_| OnceCell::new()).collect(); - - Ok(Self { - euclidean_metric: EuclideanDistance, - haversine_metric: HaversineDistance::default(), - geometry_cache, - _reservation: reservation, - }) - } - - /// Estimate the maximum memory usage for decoded geometries based on WKB sizes - fn estimate_max_memory_usage(indexed_batches: &[IndexedBatch]) -> usize { - let mut total_wkb_size = 0; - - for batch in indexed_batches { - for wkb in batch.geom_array.wkbs().iter().flatten() { - total_wkb_size += wkb.buf().len(); - } - } - total_wkb_size - } -} - -/// Geometry accessor for SedonaDB KNN queries. -/// This accessor provides on-demand WKB decoding and geometry caching for efficient -/// KNN queries with support for both Euclidean and Haversine distance metrics. -struct SedonaKnnAdapter<'a> { - indexed_batches: &'a [IndexedBatch], - data_id_to_batch_pos: &'a [(i32, i32)], - // Reference to KNN components for cache and memory tracking - knn_components: &'a KnnComponents, -} - -impl<'a> SedonaKnnAdapter<'a> { - /// Create a new adapter - fn new( - indexed_batches: &'a [IndexedBatch], - data_id_to_batch_pos: &'a [(i32, i32)], - knn_components: &'a KnnComponents, - ) -> Self { - Self { - indexed_batches, - data_id_to_batch_pos, - knn_components, - } - } -} - -impl<'a> GeometryAccessor for SedonaKnnAdapter<'a> { - /// Get geometry for the given item index with lock-free caching - fn get_geometry(&self, item_index: usize) -> Option<&Geometry> { - let geometry_cache = &self.knn_components.geometry_cache; - - // Bounds check - if item_index >= geometry_cache.len() || item_index >= self.data_id_to_batch_pos.len() { - return None; - } - - // Try to get from cache first - if let Some(geom) = geometry_cache[item_index].get() { - return Some(geom); - } - - // Cache miss - decode from WKB - let (batch_idx, row_idx) = self.data_id_to_batch_pos[item_index]; - let indexed_batch = &self.indexed_batches[batch_idx as usize]; - - if let Some(wkb) = indexed_batch.wkb(row_idx as usize) { - if let Ok(geom) = item_to_geometry(wkb) { - // Try to store in cache - if another thread got there first, we just use theirs - let _ = geometry_cache[item_index].set(geom); - // Return reference to the cached geometry - return geometry_cache[item_index].get(); - } - } - - // Failed to decode - don't cache invalid results - None - } -} - -#[cfg(test)] -mod tests { - use crate::spatial_predicate::{RelationPredicate, SpatialRelationType}; - - use super::*; - use arrow_array::RecordBatch; - use arrow_schema::{DataType, Field}; - use datafusion_execution::memory_pool::GreedyMemoryPool; - use datafusion_physical_expr::expressions::Column; - use geo_traits::Dimensions; - use sedona_common::option::{ExecutionMode, SpatialJoinOptions}; - use sedona_geometry::wkb_factory::write_wkb_empty_point; - use sedona_schema::datatypes::WKB_GEOMETRY; - use sedona_testing::create::create_array; - - #[test] - fn test_spatial_index_builder_empty() { - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - let schema = Arc::new(arrow_schema::Schema::empty()); - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - // Test finishing with empty data - let index = builder.finish(schema.clone()).unwrap(); - assert_eq!(index.schema(), schema); - assert_eq!(index.indexed_batches.len(), 0); - } - - #[test] - fn test_spatial_index_builder_add_batch() { - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - // Create a simple test geometry batch - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - let geom_batch = create_array( - &[ - Some("POINT (0.25 0.25)"), - Some("POINT (10 10)"), - None, - Some("POINT (0.25 0.25)"), - ], - &WKB_GEOMETRY, - ); - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - assert_eq!(builder.indexed_batches.len(), 1); - - let index = builder.finish(schema.clone()).unwrap(); - assert_eq!(index.schema(), schema); - assert_eq!(index.indexed_batches.len(), 1); - } - - #[test] - fn test_knn_query_execution_with_sample_data() { - // Create a spatial index with sample geometry data - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - // Create sample geometry data - points at known locations - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Create geometries at different distances from the query point (0, 0) - let geom_batch = create_array( - &[ - Some("POINT (1 0)"), // Distance: 1.0 - Some("POINT (0 2)"), // Distance: 2.0 - Some("POINT (3 0)"), // Distance: 3.0 - Some("POINT (0 4)"), // Distance: 4.0 - Some("POINT (5 0)"), // Distance: 5.0 - Some("POINT (2 2)"), // Distance: ~2.83 - Some("POINT (1 1)"), // Distance: ~1.41 - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Create a query geometry at origin (0, 0) - let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test KNN query with k=3 - let mut build_positions = Vec::new(); - let result = index - .query_knn( - query_wkb, - 3, // k=3 - false, // use_spheroid=false - false, // include_tie_breakers=false - &mut build_positions, - ) - .unwrap(); - - // Verify we got 3 results - assert_eq!(build_positions.len(), 3); - assert_eq!(result.count, 3); - assert!(result.candidate_count >= 3); - - // Create a mapping of positions to verify correct ordering - // We expect the 3 closest points: (1,0), (1,1), (0,2) - let expected_closest_indices = vec![0, 6, 1]; // Based on our sample data ordering - let mut found_indices = Vec::new(); - - for (_batch_idx, row_idx) in &build_positions { - found_indices.push(*row_idx as usize); - } - - // Sort to compare sets (order might vary due to implementation) - found_indices.sort(); - let mut expected_sorted = expected_closest_indices; - expected_sorted.sort(); - - assert_eq!(found_indices, expected_sorted); - } - - #[test] - fn test_knn_query_execution_with_different_k_values() { - // Create spatial index with more data points - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Create 10 points at regular intervals - let geom_batch = create_array( - &[ - Some("POINT (1 0)"), // 0: Distance 1 - Some("POINT (2 0)"), // 1: Distance 2 - Some("POINT (3 0)"), // 2: Distance 3 - Some("POINT (4 0)"), // 3: Distance 4 - Some("POINT (5 0)"), // 4: Distance 5 - Some("POINT (6 0)"), // 5: Distance 6 - Some("POINT (7 0)"), // 6: Distance 7 - Some("POINT (8 0)"), // 7: Distance 8 - Some("POINT (9 0)"), // 8: Distance 9 - Some("POINT (10 0)"), // 9: Distance 10 - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Query point at origin - let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test different k values - for k in [1, 3, 5, 7, 10] { - let mut build_positions = Vec::new(); - let result = index - .query_knn(query_wkb, k, false, false, &mut build_positions) - .unwrap(); - - // Verify we got exactly k results (or all available if k > total) - let expected_results = std::cmp::min(k as usize, 10); - assert_eq!(build_positions.len(), expected_results); - assert_eq!(result.count, expected_results); - - // Verify the results are the k closest points - let mut row_indices: Vec = build_positions - .iter() - .map(|(_, row_idx)| *row_idx as usize) - .collect(); - row_indices.sort(); - - let expected_indices: Vec = (0..expected_results).collect(); - assert_eq!(row_indices, expected_indices); - } - } - - #[test] - fn test_knn_query_execution_with_spheroid_distance() { - // Create spatial index - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Create points with geographic coordinates (longitude, latitude) - let geom_batch = create_array( - &[ - Some("POINT (-74.0 40.7)"), // NYC area - Some("POINT (-73.9 40.7)"), // Slightly east - Some("POINT (-74.1 40.7)"), // Slightly west - Some("POINT (-74.0 40.8)"), // Slightly north - Some("POINT (-74.0 40.6)"), // Slightly south - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Query point at NYC - let query_geom = create_array(&[Some("POINT (-74.0 40.7)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test with planar distance (spheroid distance is not supported) - let mut build_positions = Vec::new(); - let result = index - .query_knn( - query_wkb, - 3, // k=3 - false, // use_spheroid=false (only supported option) - false, - &mut build_positions, - ) - .unwrap(); - - // Should find results with planar distance calculation - assert!(!build_positions.is_empty()); // At least the exact match - assert!(result.count >= 1); - assert!(result.candidate_count >= 1); - - // Test that spheroid distance now works with Haversine metric - let mut build_positions_spheroid = Vec::new(); - let result_spheroid = index.query_knn( - query_wkb, - 3, // k=3 - true, // use_spheroid=true (now supported with Haversine) - false, - &mut build_positions_spheroid, - ); - - // Should succeed and return results - assert!(result_spheroid.is_ok()); - let result_spheroid = result_spheroid.unwrap(); - assert!(!build_positions_spheroid.is_empty()); - assert!(result_spheroid.count >= 1); - assert!(result_spheroid.candidate_count >= 1); - } - - #[test] - fn test_knn_query_execution_edge_cases() { - // Create spatial index - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Create sample data with some edge cases - let geom_batch = create_array( - &[ - Some("POINT (1 1)"), - Some("POINT (2 2)"), - None, // NULL geometry - Some("POINT (3 3)"), - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test k=0 (should return no results) - let mut build_positions = Vec::new(); - let result = index - .query_knn( - query_wkb, - 0, // k=0 - false, - false, - &mut build_positions, - ) - .unwrap(); - - assert_eq!(build_positions.len(), 0); - assert_eq!(result.count, 0); - assert_eq!(result.candidate_count, 0); - - // Test k > available geometries - let mut build_positions = Vec::new(); - let result = index - .query_knn( - query_wkb, - 10, // k=10, but only 3 valid geometries available - false, - false, - &mut build_positions, - ) - .unwrap(); - - // Should return all available valid geometries (excluding NULL) - assert_eq!(build_positions.len(), 3); - assert_eq!(result.count, 3); - } - - #[test] - fn test_knn_query_execution_empty_index() { - // Create empty spatial index - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - let schema = Arc::new(arrow_schema::Schema::empty()); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - let index = builder.finish(schema).unwrap(); - - // Try to query empty index - let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - let mut build_positions = Vec::new(); - let result = index - .query_knn(query_wkb, 5, false, false, &mut build_positions) - .unwrap(); - - // Should return no results for empty index - assert_eq!(build_positions.len(), 0); - assert_eq!(result.count, 0); - assert_eq!(result.candidate_count, 0); - } - - #[test] - fn test_knn_query_execution_with_tie_breakers() { - // Create a spatial index with sample geometry data - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 1, // probe_threads_count - memory_pool.clone(), - metrics, - ) - .unwrap(); - - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Create points where we have more ties at the k-th distance - // Query point is at (0.0, 0.0) - // We'll create a scenario with k=2 where there are 3 points at the same distance - // This ensures the tie-breaker logic has work to do - let geom_batch = create_array( - &[ - Some("POINT (1.0 0.0)"), // Squared distance 1.0 - Some("POINT (0.0 1.0)"), // Squared distance 1.0 (tie!) - Some("POINT (-1.0 0.0)"), // Squared distance 1.0 (tie!) - Some("POINT (0.0 -1.0)"), // Squared distance 1.0 (tie!) - Some("POINT (2.0 0.0)"), // Squared distance 4.0 - Some("POINT (0.0 2.0)"), // Squared distance 4.0 - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Query point at the origin (0.0, 0.0) - let query_geom = create_array(&[Some("POINT (0.0 0.0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test without tie-breakers: should return exactly k=2 results - let mut build_positions = Vec::new(); - let result = index - .query_knn( - query_wkb, - 2, // k=2 - false, // use_spheroid - false, // include_tie_breakers - &mut build_positions, - ) - .unwrap(); - - // Should return exactly 2 results (the closest point + 1 of the tied points) - assert_eq!(result.count, 2); - assert_eq!(build_positions.len(), 2); - - // Test with tie-breakers: should return k=2 plus all ties - let mut build_positions_with_ties = Vec::new(); - let result_with_ties = index - .query_knn( - query_wkb, - 2, // k=2 - false, // use_spheroid - true, // include_tie_breakers - &mut build_positions_with_ties, - ) - .unwrap(); - - // Should return more than 2 results because of ties - // We have 4 points at squared distance 1.0 (all tied for closest) - // With k=2 and tie-breakers: - // - Initial neighbors query returns 2 of the 4 tied points - // - Tie-breaker logic should find the other 2 tied points - // - Total should be 4 results (all points at distance 1.0) - - // With 4 points all at the same distance and k=2: - // - Without tie-breakers: should return exactly 2 - // - With tie-breakers: should return all 4 tied points - assert_eq!( - result.count, 2, - "Without tie-breakers should return exactly k=2" - ); - assert_eq!( - result_with_ties.count, 4, - "With tie-breakers should return all 4 tied points" - ); - assert_eq!(build_positions_with_ties.len(), 4); - } - - #[test] - fn test_query_knn_with_geometry_distance() { - // Create a spatial index with sample geometry data - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - // Create sample geometry data - points at known locations - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Create geometries at different distances from the query point (0, 0) - let geom_batch = create_array( - &[ - Some("POINT (1 0)"), // Distance: 1.0 - Some("POINT (0 2)"), // Distance: 2.0 - Some("POINT (3 0)"), // Distance: 3.0 - Some("POINT (0 4)"), // Distance: 4.0 - Some("POINT (5 0)"), // Distance: 5.0 - Some("POINT (2 2)"), // Distance: ~2.83 - Some("POINT (1 1)"), // Distance: ~1.41 - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Create a query geometry at origin (0, 0) - let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test the geometry-based query_knn method with k=3 - let mut build_positions = Vec::new(); - let result = index - .query_knn( - query_wkb, - 3, // k=3 - false, // use_spheroid=false - false, // include_tie_breakers=false - &mut build_positions, - ) - .unwrap(); - - // Verify we got results (should be 3 or less) - assert!(!build_positions.is_empty()); - assert!(build_positions.len() <= 3); - assert!(result.count > 0); - assert!(result.count <= 3); - - println!("KNN Geometry test - found {} results", result.count); - println!("Result positions: {build_positions:?}"); - } - - #[test] - fn test_query_knn_with_mixed_geometries() { - // Create a spatial index with complex geometries where geometry-based - // distance should differ from centroid-based distance - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - // Create different geometry types - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Mix of points and linestrings - let geom_batch = create_array( - &[ - Some("POINT (1 1)"), // Simple point - Some("LINESTRING (2 0, 2 4)"), // Vertical line - closest point should be (2, 1) - Some("LINESTRING (10 10, 10 20)"), // Far away line - Some("POINT (5 5)"), // Far point - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Query point close to the linestring - let query_geom = create_array(&[Some("POINT (2.1 1.0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test the geometry-based KNN method with mixed geometry types - let mut build_positions = Vec::new(); - - let result = index - .query_knn( - query_wkb, - 2, // k=2 - false, // use_spheroid=false - false, // include_tie_breakers=false - &mut build_positions, - ) - .unwrap(); - - // Should return results - assert!(!build_positions.is_empty()); - - println!("KNN with mixed geometries: {build_positions:?}"); - - // Should work with mixed geometry types - assert!(result.count > 0); - } - - #[test] - fn test_query_knn_with_tie_breakers_geometry_distance() { - // Create a spatial index with geometries that have identical distances for tie-breaker testing - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 4, - memory_pool, - metrics, - ) - .unwrap(); - - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - // Create points where we have multiple points at the same distance from the query point - // Query point will be at (0, 0), and we'll have 4 points all at distance sqrt(2) ≈ 1.414 - let geom_batch = create_array( - &[ - Some("POINT (1.0 1.0)"), // Distance: sqrt(2) - Some("POINT (1.0 -1.0)"), // Distance: sqrt(2) - tied with above - Some("POINT (-1.0 1.0)"), // Distance: sqrt(2) - tied with above - Some("POINT (-1.0 -1.0)"), // Distance: sqrt(2) - tied with above - Some("POINT (2.0 0.0)"), // Distance: 2.0 - farther away - Some("POINT (0.0 2.0)"), // Distance: 2.0 - farther away - ], - &WKB_GEOMETRY, - ); - - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Query point at the origin (0.0, 0.0) - let query_geom = create_array(&[Some("POINT (0.0 0.0)")], &WKB_GEOMETRY); - let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); - let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); - - // Test without tie-breakers: should return exactly k=2 results - let mut build_positions = Vec::new(); - let result = index - .query_knn( - query_wkb, - 2, // k=2 - false, // use_spheroid - false, // include_tie_breakers=false - &mut build_positions, - ) - .unwrap(); - - // Should return exactly 2 results - assert_eq!(result.count, 2); - assert_eq!(build_positions.len(), 2); - - // Test with tie-breakers: should return all tied points - let mut build_positions_with_ties = Vec::new(); - let result_with_ties = index - .query_knn( - query_wkb, - 2, // k=2 - false, // use_spheroid - true, // include_tie_breakers=true - &mut build_positions_with_ties, - ) - .unwrap(); - - // Should return more than 2 results because of ties (all 4 points at distance sqrt(2)) - assert!(result_with_ties.count >= 2); - } - - #[test] - fn test_knn_query_with_empty_geometry() { - // Create a spatial index with sample geometry data like other tests - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); - let options = SpatialJoinOptions { - execution_mode: ExecutionMode::PrepareBuild, - ..Default::default() - }; - let metrics = SpatialJoinBuildMetrics::default(); - - let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( - Arc::new(Column::new("geom", 0)), - Arc::new(Column::new("geom", 1)), - SpatialRelationType::Intersects, - )); - - let mut builder = SpatialIndexBuilder::new( - spatial_predicate, - options, - JoinType::Inner, - 1, // probe_threads_count - memory_pool.clone(), - metrics, - ) - .unwrap(); - - // Create geometry batch using the same pattern as other tests - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "geom", - DataType::Binary, - true, - )])); - let batch = RecordBatch::new_empty(schema.clone()); - - let geom_batch = create_array( - &[ - Some("POINT (0 0)"), - Some("POINT (1 1)"), - Some("POINT (2 2)"), - ], - &WKB_GEOMETRY, - ); - let indexed_batch = IndexedBatch { - batch, - geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), - }; - builder.add_batch(indexed_batch); - - let index = builder.finish(schema).unwrap(); - - // Create an empty point WKB - let mut empty_point_wkb = Vec::new(); - write_wkb_empty_point(&mut empty_point_wkb, Dimensions::Xy).unwrap(); - - // Query with the empty point - let mut build_positions = Vec::new(); - let result = index - .query_knn( - &wkb::reader::read_wkb(&empty_point_wkb).unwrap(), - 2, // k=2 - false, // use_spheroid - false, // include_tie_breakers - &mut build_positions, - ) - .unwrap(); - - // Should return empty results for empty geometry - assert_eq!(result.count, 0); - assert_eq!(result.candidate_count, 0); - assert!(build_positions.is_empty()); - } +/// The metrics for a spatial index query +#[derive(Debug)] +pub(crate) struct QueryResultMetrics { + pub count: usize, + pub candidate_count: usize, } diff --git a/rust/sedona-spatial-join/src/index/build_side_collector.rs b/rust/sedona-spatial-join/src/index/build_side_collector.rs new file mode 100644 index 000000000..5efc6c350 --- /dev/null +++ b/rust/sedona-spatial-join/src/index/build_side_collector.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::Result; +use datafusion_common_runtime::JoinSet; +use datafusion_execution::{memory_pool::MemoryReservation, SendableRecordBatchStream}; +use datafusion_physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use futures::StreamExt; +use sedona_expr::statistics::GeoStatistics; +use sedona_functions::st_analyze_aggr::AnalyzeAccumulator; +use sedona_schema::datatypes::WKB_GEOMETRY; + +use crate::{ + evaluated_batch::{ + evaluated_batch_stream::{ + in_mem::InMemoryEvaluatedBatchStream, SendableEvaluatedBatchStream, + }, + EvaluatedBatch, + }, + operand_evaluator::OperandEvaluator, +}; + +pub(crate) struct BuildPartition { + pub build_side_batch_stream: SendableEvaluatedBatchStream, + pub geo_statistics: GeoStatistics, + + /// Memory reservation for tracking the memory usage of the build partition + /// Cleared on `BuildPartition` drop + pub reservation: MemoryReservation, +} + +/// A collector for evaluating the spatial expression on build side batches and collect +/// them as asynchronous streams with additional statistics. The asynchronous streams +/// could then be fed into the spatial index builder to build an in-memory or external +/// spatial index, depending on the statistics collected by the collector. +#[derive(Clone)] +pub(crate) struct BuildSideBatchesCollector { + evaluator: Arc, +} + +pub(crate) struct CollectBuildSideMetrics { + /// Number of batches collected + num_batches: metrics::Count, + /// Number of rows collected + num_rows: metrics::Count, + /// Total in-memory size of batches collected. If the batches were spilled, this size is the + /// in-memory size if we load all batches into memory. This does not represent the in-memory size + /// of the resulting BuildPartition. + total_size_bytes: metrics::Gauge, + /// Total time taken to collect and process the build side batches. This does not include the time awaiting + /// for batches from the input stream. + time_taken: metrics::Time, +} + +impl CollectBuildSideMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + Self { + num_batches: MetricBuilder::new(metrics).counter("build_input_batches", partition), + num_rows: MetricBuilder::new(metrics).counter("build_input_rows", partition), + total_size_bytes: MetricBuilder::new(metrics) + .gauge("build_input_total_size_bytes", partition), + time_taken: MetricBuilder::new(metrics) + .subset_time("build_input_collection_time", partition), + } + } +} + +impl BuildSideBatchesCollector { + pub fn new(evaluator: Arc) -> Self { + BuildSideBatchesCollector { evaluator } + } + + pub async fn collect( + &self, + mut stream: SendableRecordBatchStream, + mut reservation: MemoryReservation, + metrics: &CollectBuildSideMetrics, + ) -> Result { + let evaluator = self.evaluator.as_ref(); + let mut in_mem_batches: Vec = Vec::new(); + let mut analyzer = AnalyzeAccumulator::new(WKB_GEOMETRY, WKB_GEOMETRY); + + while let Some(record_batch) = stream.next().await { + let record_batch = record_batch?; + let _timer = metrics.time_taken.timer(); + + // Process the record batch and create a BuildSideBatch + let geom_array = evaluator.evaluate_build(&record_batch)?; + + for wkb in geom_array.wkbs().iter().flatten() { + analyzer.update_statistics(wkb, wkb.buf().len())?; + } + + let build_side_batch = EvaluatedBatch { + batch: record_batch, + geom_array, + }; + + let in_mem_size = build_side_batch.in_mem_size(); + metrics.num_batches.add(1); + metrics.num_rows.add(build_side_batch.num_rows()); + metrics.total_size_bytes.add(in_mem_size); + + reservation.try_grow(in_mem_size)?; + in_mem_batches.push(build_side_batch); + } + + Ok(BuildPartition { + build_side_batch_stream: Box::pin(InMemoryEvaluatedBatchStream::new(in_mem_batches)), + geo_statistics: analyzer.finish(), + reservation, + }) + } + + pub async fn collect_all( + &self, + streams: Vec, + reservations: Vec, + metrics_vec: Vec, + ) -> Result> { + if streams.is_empty() { + return Ok(vec![]); + } + + // Spawn all tasks to scan all build streams concurrently + let mut join_set = JoinSet::new(); + for (partition_id, ((stream, metrics), reservation)) in streams + .into_iter() + .zip(metrics_vec) + .zip(reservations) + .enumerate() + { + let collector = self.clone(); + join_set.spawn(async move { + let result = collector.collect(stream, reservation, &metrics).await; + (partition_id, result) + }); + } + + // Wait for all async tasks to finish. Results may be returned in arbitrary order, + // so we need to reorder them by partition_id later. + let results = join_set.join_all().await; + + // Reorder results according to partition ids + let mut partitions: Vec> = Vec::with_capacity(results.len()); + partitions.resize_with(results.len(), || None); + for result in results { + let (partition_id, partition_result) = result; + let partition = partition_result?; + partitions[partition_id] = Some(partition); + } + + Ok(partitions.into_iter().map(|v| v.unwrap()).collect()) + } +} diff --git a/rust/sedona-spatial-join/src/index/knn_adapter.rs b/rust/sedona-spatial-join/src/index/knn_adapter.rs new file mode 100644 index 000000000..519d7137c --- /dev/null +++ b/rust/sedona-spatial-join/src/index/knn_adapter.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use once_cell::sync::OnceCell; +use std::sync::Arc; + +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use geo_index::rtree::{distance::GeometryAccessor, EuclideanDistance, HaversineDistance}; +use geo_types::Geometry; +use sedona_geo::to_geo::item_to_geometry; + +use crate::evaluated_batch::EvaluatedBatch; + +/// Shared KNN components that can be reused across queries +pub(crate) struct KnnComponents { + pub euclidean_metric: EuclideanDistance, + pub haversine_metric: HaversineDistance, + /// Pre-allocated vector for geometry cache - lock-free access + /// Indexed by rtree data index for O(1) access + geometry_cache: Vec>>, + /// Memory reservation to track geometry cache memory usage + _reservation: MemoryReservation, +} + +impl KnnComponents { + pub fn new( + cache_size: usize, + indexed_batches: &[EvaluatedBatch], + memory_pool: Arc, + ) -> datafusion_common::Result { + // Create memory consumer and reservation for geometry cache + let consumer = MemoryConsumer::new("SpatialJoinKnnGeometryCache"); + let mut reservation = consumer.register(&memory_pool); + + // Estimate maximum possible memory usage based on WKB sizes + let estimated_memory = Self::estimate_max_memory_usage(indexed_batches); + reservation.try_grow(estimated_memory)?; + + // Pre-allocate OnceCell vector + let geometry_cache = (0..cache_size).map(|_| OnceCell::new()).collect(); + + Ok(Self { + euclidean_metric: EuclideanDistance, + haversine_metric: HaversineDistance::default(), + geometry_cache, + _reservation: reservation, + }) + } + + /// Estimate the maximum memory usage for decoded geometries based on WKB sizes + pub fn estimate_max_memory_usage(indexed_batches: &[EvaluatedBatch]) -> usize { + let mut total_wkb_size = 0; + + for batch in indexed_batches { + for wkb in batch.geom_array.wkbs().iter().flatten() { + total_wkb_size += wkb.buf().len(); + } + } + total_wkb_size + } +} + +/// Geometry accessor for SedonaDB KNN queries. +/// This accessor provides on-demand WKB decoding and geometry caching for efficient +/// KNN queries with support for both Euclidean and Haversine distance metrics. +pub(crate) struct SedonaKnnAdapter<'a> { + indexed_batches: &'a [EvaluatedBatch], + data_id_to_batch_pos: &'a [(i32, i32)], + // Reference to KNN components for cache and memory tracking + knn_components: &'a KnnComponents, +} + +impl<'a> SedonaKnnAdapter<'a> { + /// Create a new adapter + pub fn new( + indexed_batches: &'a [EvaluatedBatch], + data_id_to_batch_pos: &'a [(i32, i32)], + knn_components: &'a KnnComponents, + ) -> Self { + Self { + indexed_batches, + data_id_to_batch_pos, + knn_components, + } + } +} + +impl<'a> GeometryAccessor for SedonaKnnAdapter<'a> { + /// Get geometry for the given item index with lock-free caching + fn get_geometry(&self, item_index: usize) -> Option<&Geometry> { + let geometry_cache = &self.knn_components.geometry_cache; + + // Bounds check + if item_index >= geometry_cache.len() || item_index >= self.data_id_to_batch_pos.len() { + return None; + } + + // Try to get from cache first + if let Some(geom) = geometry_cache[item_index].get() { + return Some(geom); + } + + // Cache miss - decode from WKB + let (batch_idx, row_idx) = self.data_id_to_batch_pos[item_index]; + let indexed_batch = &self.indexed_batches[batch_idx as usize]; + + if let Some(wkb) = indexed_batch.wkb(row_idx as usize) { + if let Ok(geom) = item_to_geometry(wkb) { + // Try to store in cache - if another thread got there first, we just use theirs + let _ = geometry_cache[item_index].set(geom); + // Return reference to the cached geometry + return geometry_cache[item_index].get(); + } + } + + // Failed to decode - don't cache invalid results + None + } +} diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs b/rust/sedona-spatial-join/src/index/spatial_index.rs new file mode 100644 index 000000000..09158d49d --- /dev/null +++ b/rust/sedona-spatial-join/src/index/spatial_index.rs @@ -0,0 +1,1456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use arrow_array::RecordBatch; +use arrow_schema::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::memory_pool::{MemoryPool, MemoryReservation}; +use geo_index::rtree::distance::{DistanceMetric, GeometryAccessor}; +use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex}; +use geo_index::IndexableNum; +use geo_types::{Point, Rect}; +use parking_lot::Mutex; +use sedona_expr::statistics::GeoStatistics; +use sedona_geo::to_geo::item_to_geometry; +use sedona_geo_generic_alg::algorithm::Centroid; +use wkb::reader::Wkb; + +use crate::{ + evaluated_batch::EvaluatedBatch, + index::{ + knn_adapter::{KnnComponents, SedonaKnnAdapter}, + IndexQueryResult, QueryResultMetrics, + }, + operand_evaluator::{create_operand_evaluator, OperandEvaluator}, + refine::{create_refiner, IndexQueryResultRefiner}, + spatial_predicate::SpatialPredicate, + utils::concurrent_reservation::ConcurrentReservation, +}; +use arrow::array::BooleanBufferBuilder; +use sedona_common::{option::SpatialJoinOptions, ExecutionMode}; + +pub(crate) struct SpatialIndex { + pub(crate) schema: SchemaRef, + + /// The spatial predicate evaluator for the spatial predicate. + pub(crate) evaluator: Arc, + + /// The refiner for refining the index query results. + pub(crate) refiner: Arc, + + /// Memory reservation for tracking the memory usage of the refiner + pub(crate) refiner_reservation: ConcurrentReservation, + + /// R-tree index for the geometry batches. It takes MBRs as query windows and returns + /// data indexes. These data indexes should be translated using `data_id_to_batch_pos` to get + /// the original geometry batch index and row index, or translated using `prepared_geom_idx_vec` + /// to get the prepared geometries array index. + pub(crate) rtree: RTree, + + /// Indexed batches containing evaluated geometry arrays. It contains the original record + /// batches and geometry arrays obtained by evaluating the geometry expression on the build side. + pub(crate) indexed_batches: Vec, + /// An array for translating rtree data index to geometry batch index and row index + pub(crate) data_id_to_batch_pos: Vec<(i32, i32)>, + + /// An array for translating rtree data index to consecutive index. Each geometry may be indexed by + /// multiple boxes, so there could be multiple data indexes for the same geometry. A mapping for + /// squashing the index makes it easier for persisting per-geometry auxiliary data for evaluating + /// the spatial predicate. This is extensively used by the spatial predicate evaluators for storing + /// prepared geometries. + pub(crate) geom_idx_vec: Vec, + + /// Shared bitmap builders for visited left indices, one per batch + pub(crate) visited_left_side: Option>>, + + /// Counter of running probe-threads, potentially able to update `bitmap`. + /// Each time a probe thread finished probing the index, it will decrement the counter. + /// The last finished probe thread will produce the extra output batches for unmatched + /// build side when running left-outer joins. See also [`report_probe_completed`]. + pub(crate) probe_threads_counter: AtomicUsize, + + /// Shared KNN components (distance metrics and geometry cache) for efficient KNN queries + pub(crate) knn_components: KnnComponents, + + /// Memory reservation for tracking the memory usage of the spatial index + /// Cleared on `SpatialIndex` drop + #[expect(dead_code)] + pub(crate) reservation: MemoryReservation, +} + +impl SpatialIndex { + pub fn empty( + spatial_predicate: SpatialPredicate, + schema: SchemaRef, + options: SpatialJoinOptions, + probe_threads_counter: AtomicUsize, + mut reservation: MemoryReservation, + memory_pool: Arc, + ) -> Self { + let evaluator = create_operand_evaluator(&spatial_predicate, options.clone()); + let refiner = create_refiner( + options.spatial_library, + &spatial_predicate, + options.clone(), + 0, + GeoStatistics::empty(), + ); + let refiner_reservation = reservation.split(0); + let refiner_reservation = ConcurrentReservation::try_new(0, refiner_reservation).unwrap(); + let rtree = RTreeBuilder::::new(0).finish::(); + Self { + schema, + evaluator, + refiner, + refiner_reservation, + rtree, + data_id_to_batch_pos: Vec::new(), + indexed_batches: Vec::new(), + geom_idx_vec: Vec::new(), + visited_left_side: None, + probe_threads_counter, + knn_components: KnnComponents::new(0, &[], memory_pool.clone()).unwrap(), // Empty index has no cache + reservation, + } + } + + pub(crate) fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + /// Create a KNN geometry accessor for accessing geometries with caching + fn create_knn_accessor(&self) -> SedonaKnnAdapter<'_> { + SedonaKnnAdapter::new( + &self.indexed_batches, + &self.data_id_to_batch_pos, + &self.knn_components, + ) + } + + /// Get the batch at the given index. + pub(crate) fn get_indexed_batch(&self, batch_idx: usize) -> &RecordBatch { + &self.indexed_batches[batch_idx].batch + } + + /// Query the spatial index with a probe geometry to find matching build-side geometries. + /// + /// This method implements a two-phase spatial join query: + /// 1. **Filter phase**: Uses the R-tree index with the probe geometry's bounding rectangle + /// to quickly identify candidate geometries that might satisfy the spatial predicate + /// 2. **Refinement phase**: Evaluates the exact spatial predicate on candidates to determine + /// actual matches + /// + /// # Arguments + /// * `probe_wkb` - The probe geometry in WKB format + /// * `probe_rect` - The minimum bounding rectangle of the probe geometry + /// * `distance` - Optional distance parameter for distance-based spatial predicates + /// * `build_batch_positions` - Output vector that will be populated with (batch_idx, row_idx) + /// pairs for each matching build-side geometry + /// + /// # Returns + /// * `JoinResultMetrics` containing the number of actual matches (`count`) and the number + /// of candidates from the filter phase (`candidate_count`) + pub(crate) fn query( + &self, + probe_wkb: &Wkb, + probe_rect: &Rect, + distance: &Option, + build_batch_positions: &mut Vec<(i32, i32)>, + ) -> Result { + let min = probe_rect.min(); + let max = probe_rect.max(); + let mut candidates = self.rtree.search(min.x, min.y, max.x, max.y); + if candidates.is_empty() { + return Ok(QueryResultMetrics { + count: 0, + candidate_count: 0, + }); + } + + // Sort and dedup candidates to avoid duplicate results when we index one geometry + // using several boxes. + candidates.sort_unstable(); + candidates.dedup(); + + // Refine the candidates retrieved from the r-tree index by evaluating the actual spatial predicate + self.refine(probe_wkb, &candidates, distance, build_batch_positions) + } + + /// Query the spatial index for k nearest neighbors of a given geometry. + /// + /// This method finds the k nearest neighbors to the probe geometry using: + /// 1. R-tree's built-in neighbors() method for efficient KNN search + /// 2. Distance refinement using actual geometry calculations + /// 3. Tie-breaker handling when enabled + /// + /// # Arguments + /// + /// * `probe_wkb` - WKB representation of the probe geometry + /// * `k` - Number of nearest neighbors to find + /// * `use_spheroid` - Whether to use spheroid distance calculation + /// * `include_tie_breakers` - Whether to include additional results with same distance as kth neighbor + /// * `build_batch_positions` - Output vector for matched positions + /// + /// # Returns + /// + /// * `JoinResultMetrics` containing the number of actual matches and candidates processed + pub(crate) fn query_knn( + &self, + probe_wkb: &Wkb, + k: u32, + use_spheroid: bool, + include_tie_breakers: bool, + build_batch_positions: &mut Vec<(i32, i32)>, + ) -> Result { + if k == 0 { + return Ok(QueryResultMetrics { + count: 0, + candidate_count: 0, + }); + } + + // Check if index is empty + if self.indexed_batches.is_empty() || self.data_id_to_batch_pos.is_empty() { + return Ok(QueryResultMetrics { + count: 0, + candidate_count: 0, + }); + } + + // Convert probe WKB to geo::Geometry + let probe_geom = match item_to_geometry(probe_wkb) { + Ok(geom) => geom, + Err(_) => { + // Empty or unsupported geometries (e.g., POINT EMPTY) return empty results + return Ok(QueryResultMetrics { + count: 0, + candidate_count: 0, + }); + } + }; + + // Select the appropriate distance metric + let distance_metric: &dyn DistanceMetric = if use_spheroid { + &self.knn_components.haversine_metric + } else { + &self.knn_components.euclidean_metric + }; + + // Create geometry accessor for on-demand WKB decoding and caching + let geometry_accessor = self.create_knn_accessor(); + + // Use neighbors_geometry to find k nearest neighbors + let initial_results = self.rtree.neighbors_geometry( + &probe_geom, + Some(k as usize), + None, // no max_distance filter + distance_metric, + &geometry_accessor, + ); + + if initial_results.is_empty() { + return Ok(QueryResultMetrics { + count: 0, + candidate_count: 0, + }); + } + + let mut final_results = initial_results; + let mut candidate_count = final_results.len(); + + // Handle tie-breakers if enabled + if include_tie_breakers && !final_results.is_empty() && k > 0 { + // Calculate distances for the initial k results to find the k-th distance + let mut distances_with_indices: Vec<(f64, u32)> = Vec::new(); + + for &result_idx in &final_results { + if (result_idx as usize) < self.data_id_to_batch_pos.len() { + if let Some(item_geom) = geometry_accessor.get_geometry(result_idx as usize) { + let distance = distance_metric.distance_to_geometry(&probe_geom, item_geom); + if let Some(distance_f64) = distance.to_f64() { + distances_with_indices.push((distance_f64, result_idx)); + } + } + } + } + + // Sort by distance + distances_with_indices + .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + + // Find the k-th distance (if we have at least k results) + if distances_with_indices.len() >= k as usize { + let k_idx = (k as usize) + .min(distances_with_indices.len()) + .saturating_sub(1); + let max_distance = distances_with_indices[k_idx].0; + + // For tie-breakers, create spatial envelope around probe centroid and use rtree.search() + + let probe_centroid = probe_geom.centroid().unwrap_or(Point::new(0.0, 0.0)); + let probe_x = probe_centroid.x() as f32; + let probe_y = probe_centroid.y() as f32; + let max_distance_f32 = match f32::from_f64(max_distance) { + Some(val) => val, + None => { + // If conversion fails, return empty results for this probe + return Ok(QueryResultMetrics { + count: 0, + candidate_count: 0, + }); + } + }; + + // Create envelope bounds around probe centroid + let min_x = probe_x - max_distance_f32; + let min_y = probe_y - max_distance_f32; + let max_x = probe_x + max_distance_f32; + let max_y = probe_y + max_distance_f32; + + // Use rtree.search() with envelope bounds (like the old code) + let expanded_results = self.rtree.search(min_x, min_y, max_x, max_y); + + candidate_count = expanded_results.len(); + + // Calculate distances for all results and find ties + let mut all_distances_with_indices: Vec<(f64, u32)> = Vec::new(); + + for &result_idx in &expanded_results { + if (result_idx as usize) < self.data_id_to_batch_pos.len() { + if let Some(item_geom) = geometry_accessor.get_geometry(result_idx as usize) + { + let distance = + distance_metric.distance_to_geometry(&probe_geom, item_geom); + if let Some(distance_f64) = distance.to_f64() { + all_distances_with_indices.push((distance_f64, result_idx)); + } + } + } + } + + // Sort by distance + all_distances_with_indices + .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + + // Include all results up to and including those with the same distance as the k-th result + const DISTANCE_TOLERANCE: f64 = 1e-9; + let mut tie_breaker_results: Vec = Vec::new(); + + for (i, &(distance, result_idx)) in all_distances_with_indices.iter().enumerate() { + if i < k as usize { + // Include the first k results + tie_breaker_results.push(result_idx); + } else if (distance - max_distance).abs() <= DISTANCE_TOLERANCE { + // Include tie-breakers (same distance as k-th result) + tie_breaker_results.push(result_idx); + } else { + // No more ties, stop + break; + } + } + + final_results = tie_breaker_results; + } + } else { + // When tie-breakers are disabled, limit results to exactly k + if final_results.len() > k as usize { + final_results.truncate(k as usize); + } + } + + // Convert results to build_batch_positions using existing data_id_to_batch_pos mapping + for &result_idx in &final_results { + if (result_idx as usize) < self.data_id_to_batch_pos.len() { + build_batch_positions.push(self.data_id_to_batch_pos[result_idx as usize]); + } + } + + Ok(QueryResultMetrics { + count: final_results.len(), + candidate_count, + }) + } + + fn refine( + &self, + probe_wkb: &Wkb, + candidates: &[u32], + distance: &Option, + build_batch_positions: &mut Vec<(i32, i32)>, + ) -> Result { + let candidate_count = candidates.len(); + + let mut index_query_results = Vec::with_capacity(candidate_count); + for data_idx in candidates { + let pos = self.data_id_to_batch_pos[*data_idx as usize]; + let (batch_idx, row_idx) = pos; + let indexed_batch = &self.indexed_batches[batch_idx as usize]; + let build_wkb = indexed_batch.wkb(row_idx as usize); + let Some(build_wkb) = build_wkb else { + continue; + }; + let distance = self.evaluator.resolve_distance( + indexed_batch.distance(), + row_idx as usize, + distance, + )?; + let geom_idx = self.geom_idx_vec[*data_idx as usize]; + index_query_results.push(IndexQueryResult { + wkb: build_wkb, + distance, + geom_idx, + position: pos, + }); + } + + if index_query_results.is_empty() { + return Ok(QueryResultMetrics { + count: 0, + candidate_count, + }); + } + + let results = self.refiner.refine(probe_wkb, &index_query_results)?; + let num_results = results.len(); + build_batch_positions.extend(results); + + // Update refiner memory reservation + self.refiner_reservation.resize(self.refiner.mem_usage())?; + + Ok(QueryResultMetrics { + count: num_results, + candidate_count, + }) + } + + /// Check if the index needs more probe statistics to determine the optimal execution mode. + /// + /// # Returns + /// * `bool` - `true` if the index needs more probe statistics, `false` otherwise. + pub(crate) fn need_more_probe_stats(&self) -> bool { + self.refiner.need_more_probe_stats() + } + + /// Merge the probe statistics into the index. + /// + /// # Arguments + /// * `stats` - The probe statistics to merge. + pub(crate) fn merge_probe_stats(&self, stats: GeoStatistics) { + self.refiner.merge_probe_stats(stats); + } + + /// Get the bitmaps for tracking visited left-side indices. The bitmaps will be updated + /// by the spatial join stream when producing output batches during index probing phase. + pub(crate) fn visited_left_side(&self) -> Option<&Mutex>> { + self.visited_left_side.as_ref() + } + + /// Decrements counter of running threads, and returns `true` + /// if caller is the last running thread + pub(crate) fn report_probe_completed(&self) -> bool { + self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 + } + + /// Get the memory usage of the refiner in bytes. + pub(crate) fn get_refiner_mem_usage(&self) -> usize { + self.refiner.mem_usage() + } + + /// Get the actual execution mode used by the refiner + pub(crate) fn get_actual_execution_mode(&self) -> ExecutionMode { + self.refiner.actual_execution_mode() + } +} + +#[cfg(test)] +mod tests { + use crate::{ + index::{SpatialIndexBuilder, SpatialJoinBuildMetrics}, + operand_evaluator::EvaluatedGeometryArray, + spatial_predicate::{RelationPredicate, SpatialRelationType}, + }; + + use super::*; + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field}; + use datafusion_execution::memory_pool::GreedyMemoryPool; + use datafusion_expr::JoinType; + use datafusion_physical_expr::expressions::Column; + use geo_traits::Dimensions; + use sedona_common::option::{ExecutionMode, SpatialJoinOptions}; + use sedona_geometry::wkb_factory::write_wkb_empty_point; + use sedona_schema::datatypes::WKB_GEOMETRY; + use sedona_testing::create::create_array; + + #[test] + fn test_spatial_index_builder_empty() { + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + let schema = Arc::new(arrow_schema::Schema::empty()); + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + let builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + // Test finishing with empty data + let index = builder.finish().unwrap(); + assert_eq!(index.schema(), schema); + assert_eq!(index.indexed_batches.len(), 0); + } + + #[test] + fn test_spatial_index_builder_add_batch() { + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + // Create a simple test geometry batch + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + let geom_batch = create_array( + &[ + Some("POINT (0.25 0.25)"), + Some("POINT (10 10)"), + None, + Some("POINT (0.25 0.25)"), + ], + &WKB_GEOMETRY, + ); + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + assert_eq!(index.schema(), schema); + assert_eq!(index.indexed_batches.len(), 1); + } + + #[test] + fn test_knn_query_execution_with_sample_data() { + // Create a spatial index with sample geometry data + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + // Create sample geometry data - points at known locations + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Create geometries at different distances from the query point (0, 0) + let geom_batch = create_array( + &[ + Some("POINT (1 0)"), // Distance: 1.0 + Some("POINT (0 2)"), // Distance: 2.0 + Some("POINT (3 0)"), // Distance: 3.0 + Some("POINT (0 4)"), // Distance: 4.0 + Some("POINT (5 0)"), // Distance: 5.0 + Some("POINT (2 2)"), // Distance: ~2.83 + Some("POINT (1 1)"), // Distance: ~1.41 + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Create a query geometry at origin (0, 0) + let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test KNN query with k=3 + let mut build_positions = Vec::new(); + let result = index + .query_knn( + query_wkb, + 3, // k=3 + false, // use_spheroid=false + false, // include_tie_breakers=false + &mut build_positions, + ) + .unwrap(); + + // Verify we got 3 results + assert_eq!(build_positions.len(), 3); + assert_eq!(result.count, 3); + assert!(result.candidate_count >= 3); + + // Create a mapping of positions to verify correct ordering + // We expect the 3 closest points: (1,0), (1,1), (0,2) + let expected_closest_indices = vec![0, 6, 1]; // Based on our sample data ordering + let mut found_indices = Vec::new(); + + for (_batch_idx, row_idx) in &build_positions { + found_indices.push(*row_idx as usize); + } + + // Sort to compare sets (order might vary due to implementation) + found_indices.sort(); + let mut expected_sorted = expected_closest_indices; + expected_sorted.sort(); + + assert_eq!(found_indices, expected_sorted); + } + + #[test] + fn test_knn_query_execution_with_different_k_values() { + // Create spatial index with more data points + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Create 10 points at regular intervals + let geom_batch = create_array( + &[ + Some("POINT (1 0)"), // 0: Distance 1 + Some("POINT (2 0)"), // 1: Distance 2 + Some("POINT (3 0)"), // 2: Distance 3 + Some("POINT (4 0)"), // 3: Distance 4 + Some("POINT (5 0)"), // 4: Distance 5 + Some("POINT (6 0)"), // 5: Distance 6 + Some("POINT (7 0)"), // 6: Distance 7 + Some("POINT (8 0)"), // 7: Distance 8 + Some("POINT (9 0)"), // 8: Distance 9 + Some("POINT (10 0)"), // 9: Distance 10 + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Query point at origin + let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test different k values + for k in [1, 3, 5, 7, 10] { + let mut build_positions = Vec::new(); + let result = index + .query_knn(query_wkb, k, false, false, &mut build_positions) + .unwrap(); + + // Verify we got exactly k results (or all available if k > total) + let expected_results = std::cmp::min(k as usize, 10); + assert_eq!(build_positions.len(), expected_results); + assert_eq!(result.count, expected_results); + + // Verify the results are the k closest points + let mut row_indices: Vec = build_positions + .iter() + .map(|(_, row_idx)| *row_idx as usize) + .collect(); + row_indices.sort(); + + let expected_indices: Vec = (0..expected_results).collect(); + assert_eq!(row_indices, expected_indices); + } + } + + #[test] + fn test_knn_query_execution_with_spheroid_distance() { + // Create spatial index + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Create points with geographic coordinates (longitude, latitude) + let geom_batch = create_array( + &[ + Some("POINT (-74.0 40.7)"), // NYC area + Some("POINT (-73.9 40.7)"), // Slightly east + Some("POINT (-74.1 40.7)"), // Slightly west + Some("POINT (-74.0 40.8)"), // Slightly north + Some("POINT (-74.0 40.6)"), // Slightly south + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Query point at NYC + let query_geom = create_array(&[Some("POINT (-74.0 40.7)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test with planar distance (spheroid distance is not supported) + let mut build_positions = Vec::new(); + let result = index + .query_knn( + query_wkb, + 3, // k=3 + false, // use_spheroid=false (only supported option) + false, + &mut build_positions, + ) + .unwrap(); + + // Should find results with planar distance calculation + assert!(!build_positions.is_empty()); // At least the exact match + assert!(result.count >= 1); + assert!(result.candidate_count >= 1); + + // Test that spheroid distance now works with Haversine metric + let mut build_positions_spheroid = Vec::new(); + let result_spheroid = index.query_knn( + query_wkb, + 3, // k=3 + true, // use_spheroid=true (now supported with Haversine) + false, + &mut build_positions_spheroid, + ); + + // Should succeed and return results + assert!(result_spheroid.is_ok()); + let result_spheroid = result_spheroid.unwrap(); + assert!(!build_positions_spheroid.is_empty()); + assert!(result_spheroid.count >= 1); + assert!(result_spheroid.candidate_count >= 1); + } + + #[test] + fn test_knn_query_execution_edge_cases() { + // Create spatial index + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Create sample data with some edge cases + let geom_batch = create_array( + &[ + Some("POINT (1 1)"), + Some("POINT (2 2)"), + None, // NULL geometry + Some("POINT (3 3)"), + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test k=0 (should return no results) + let mut build_positions = Vec::new(); + let result = index + .query_knn( + query_wkb, + 0, // k=0 + false, + false, + &mut build_positions, + ) + .unwrap(); + + assert_eq!(build_positions.len(), 0); + assert_eq!(result.count, 0); + assert_eq!(result.candidate_count, 0); + + // Test k > available geometries + let mut build_positions = Vec::new(); + let result = index + .query_knn( + query_wkb, + 10, // k=10, but only 3 valid geometries available + false, + false, + &mut build_positions, + ) + .unwrap(); + + // Should return all available valid geometries (excluding NULL) + assert_eq!(build_positions.len(), 3); + assert_eq!(result.count, 3); + } + + #[test] + fn test_knn_query_execution_empty_index() { + // Create empty spatial index + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + let schema = Arc::new(arrow_schema::Schema::empty()); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + let builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let index = builder.finish().unwrap(); + + // Try to query empty index + let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + let mut build_positions = Vec::new(); + let result = index + .query_knn(query_wkb, 5, false, false, &mut build_positions) + .unwrap(); + + // Should return no results for empty index + assert_eq!(build_positions.len(), 0); + assert_eq!(result.count, 0); + assert_eq!(result.candidate_count, 0); + } + + #[test] + fn test_knn_query_execution_with_tie_breakers() { + // Create a spatial index with sample geometry data + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 1, // probe_threads_count + memory_pool.clone(), + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Create points where we have more ties at the k-th distance + // Query point is at (0.0, 0.0) + // We'll create a scenario with k=2 where there are 3 points at the same distance + // This ensures the tie-breaker logic has work to do + let geom_batch = create_array( + &[ + Some("POINT (1.0 0.0)"), // Squared distance 1.0 + Some("POINT (0.0 1.0)"), // Squared distance 1.0 (tie!) + Some("POINT (-1.0 0.0)"), // Squared distance 1.0 (tie!) + Some("POINT (0.0 -1.0)"), // Squared distance 1.0 (tie!) + Some("POINT (2.0 0.0)"), // Squared distance 4.0 + Some("POINT (0.0 2.0)"), // Squared distance 4.0 + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Query point at the origin (0.0, 0.0) + let query_geom = create_array(&[Some("POINT (0.0 0.0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test without tie-breakers: should return exactly k=2 results + let mut build_positions = Vec::new(); + let result = index + .query_knn( + query_wkb, + 2, // k=2 + false, // use_spheroid + false, // include_tie_breakers + &mut build_positions, + ) + .unwrap(); + + // Should return exactly 2 results (the closest point + 1 of the tied points) + assert_eq!(result.count, 2); + assert_eq!(build_positions.len(), 2); + + // Test with tie-breakers: should return k=2 plus all ties + let mut build_positions_with_ties = Vec::new(); + let result_with_ties = index + .query_knn( + query_wkb, + 2, // k=2 + false, // use_spheroid + true, // include_tie_breakers + &mut build_positions_with_ties, + ) + .unwrap(); + + // Should return more than 2 results because of ties + // We have 4 points at squared distance 1.0 (all tied for closest) + // With k=2 and tie-breakers: + // - Initial neighbors query returns 2 of the 4 tied points + // - Tie-breaker logic should find the other 2 tied points + // - Total should be 4 results (all points at distance 1.0) + + // With 4 points all at the same distance and k=2: + // - Without tie-breakers: should return exactly 2 + // - With tie-breakers: should return all 4 tied points + assert_eq!( + result.count, 2, + "Without tie-breakers should return exactly k=2" + ); + assert_eq!( + result_with_ties.count, 4, + "With tie-breakers should return all 4 tied points" + ); + assert_eq!(build_positions_with_ties.len(), 4); + } + + #[test] + fn test_query_knn_with_geometry_distance() { + // Create a spatial index with sample geometry data + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + // Create sample geometry data - points at known locations + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Create geometries at different distances from the query point (0, 0) + let geom_batch = create_array( + &[ + Some("POINT (1 0)"), // Distance: 1.0 + Some("POINT (0 2)"), // Distance: 2.0 + Some("POINT (3 0)"), // Distance: 3.0 + Some("POINT (0 4)"), // Distance: 4.0 + Some("POINT (5 0)"), // Distance: 5.0 + Some("POINT (2 2)"), // Distance: ~2.83 + Some("POINT (1 1)"), // Distance: ~1.41 + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Create a query geometry at origin (0, 0) + let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test the geometry-based query_knn method with k=3 + let mut build_positions = Vec::new(); + let result = index + .query_knn( + query_wkb, + 3, // k=3 + false, // use_spheroid=false + false, // include_tie_breakers=false + &mut build_positions, + ) + .unwrap(); + + // Verify we got results (should be 3 or less) + assert!(!build_positions.is_empty()); + assert!(build_positions.len() <= 3); + assert!(result.count > 0); + assert!(result.count <= 3); + + println!("KNN Geometry test - found {} results", result.count); + println!("Result positions: {build_positions:?}"); + } + + #[test] + fn test_query_knn_with_mixed_geometries() { + // Create a spatial index with complex geometries where geometry-based + // distance should differ from centroid-based distance + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + // Create different geometry types + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Mix of points and linestrings + let geom_batch = create_array( + &[ + Some("POINT (1 1)"), // Simple point + Some("LINESTRING (2 0, 2 4)"), // Vertical line - closest point should be (2, 1) + Some("LINESTRING (10 10, 10 20)"), // Far away line + Some("POINT (5 5)"), // Far point + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Query point close to the linestring + let query_geom = create_array(&[Some("POINT (2.1 1.0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test the geometry-based KNN method with mixed geometry types + let mut build_positions = Vec::new(); + + let result = index + .query_knn( + query_wkb, + 2, // k=2 + false, // use_spheroid=false + false, // include_tie_breakers=false + &mut build_positions, + ) + .unwrap(); + + // Should return results + assert!(!build_positions.is_empty()); + + println!("KNN with mixed geometries: {build_positions:?}"); + + // Should work with mixed geometry types + assert!(result.count > 0); + } + + #[test] + fn test_query_knn_with_tie_breakers_geometry_distance() { + // Create a spatial index with geometries that have identical distances for tie-breaker testing + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 4, + memory_pool, + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + // Create points where we have multiple points at the same distance from the query point + // Query point will be at (0, 0), and we'll have 4 points all at distance sqrt(2) ≈ 1.414 + let geom_batch = create_array( + &[ + Some("POINT (1.0 1.0)"), // Distance: sqrt(2) + Some("POINT (1.0 -1.0)"), // Distance: sqrt(2) - tied with above + Some("POINT (-1.0 1.0)"), // Distance: sqrt(2) - tied with above + Some("POINT (-1.0 -1.0)"), // Distance: sqrt(2) - tied with above + Some("POINT (2.0 0.0)"), // Distance: 2.0 - farther away + Some("POINT (0.0 2.0)"), // Distance: 2.0 - farther away + ], + &WKB_GEOMETRY, + ); + + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Query point at the origin (0.0, 0.0) + let query_geom = create_array(&[Some("POINT (0.0 0.0)")], &WKB_GEOMETRY); + let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap(); + let query_wkb = &query_array.wkbs()[0].as_ref().unwrap(); + + // Test without tie-breakers: should return exactly k=2 results + let mut build_positions = Vec::new(); + let result = index + .query_knn( + query_wkb, + 2, // k=2 + false, // use_spheroid + false, // include_tie_breakers=false + &mut build_positions, + ) + .unwrap(); + + // Should return exactly 2 results + assert_eq!(result.count, 2); + assert_eq!(build_positions.len(), 2); + + // Test with tie-breakers: should return all tied points + let mut build_positions_with_ties = Vec::new(); + let result_with_ties = index + .query_knn( + query_wkb, + 2, // k=2 + false, // use_spheroid + true, // include_tie_breakers=true + &mut build_positions_with_ties, + ) + .unwrap(); + + // Should return more than 2 results because of ties (all 4 points at distance sqrt(2)) + assert!(result_with_ties.count >= 2); + } + + #[test] + fn test_knn_query_with_empty_geometry() { + // Create a spatial index with sample geometry data like other tests + let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareBuild, + ..Default::default() + }; + let metrics = SpatialJoinBuildMetrics::default(); + + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 1)), + SpatialRelationType::Intersects, + )); + + // Create geometry batch using the same pattern as other tests + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema.clone(), + spatial_predicate, + options, + JoinType::Inner, + 1, // probe_threads_count + memory_pool.clone(), + metrics, + ) + .unwrap(); + + let batch = RecordBatch::new_empty(schema.clone()); + + let geom_batch = create_array( + &[ + Some("POINT (0 0)"), + Some("POINT (1 1)"), + Some("POINT (2 2)"), + ], + &WKB_GEOMETRY, + ); + let indexed_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_batch, &WKB_GEOMETRY).unwrap(), + }; + builder.add_batch(indexed_batch); + + let index = builder.finish().unwrap(); + + // Create an empty point WKB + let mut empty_point_wkb = Vec::new(); + write_wkb_empty_point(&mut empty_point_wkb, Dimensions::Xy).unwrap(); + + // Query with the empty point + let mut build_positions = Vec::new(); + let result = index + .query_knn( + &wkb::reader::read_wkb(&empty_point_wkb).unwrap(), + 2, // k=2 + false, // use_spheroid + false, // include_tie_breakers + &mut build_positions, + ) + .unwrap(); + + // Should return empty results for empty geometry + assert_eq!(result.count, 0); + assert_eq!(result.candidate_count, 0); + assert!(build_positions.is_empty()); + } +} diff --git a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs new file mode 100644 index 000000000..2b07a829c --- /dev/null +++ b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::BooleanBufferBuilder; +use arrow_schema::SchemaRef; +use datafusion_physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use sedona_common::SpatialJoinOptions; +use sedona_expr::statistics::GeoStatistics; + +use datafusion_common::{utils::proxy::VecAllocExt, Result}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use datafusion_expr::JoinType; +use futures::StreamExt; +use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder}; +use parking_lot::Mutex; +use std::sync::{atomic::AtomicUsize, Arc}; + +use crate::{ + evaluated_batch::EvaluatedBatch, + index::{knn_adapter::KnnComponents, spatial_index::SpatialIndex, BuildPartition}, + operand_evaluator::create_operand_evaluator, + refine::create_refiner, + spatial_predicate::SpatialPredicate, + utils::{ + concurrent_reservation::ConcurrentReservation, join_utils::need_produce_result_in_final, + }, +}; + +// Type aliases for better readability +type SpatialRTree = RTree; +type DataIdToBatchPos = Vec<(i32, i32)>; +type RTreeBuildResult = (SpatialRTree, DataIdToBatchPos); + +/// Rough estimate for in-memory size of the rtree per rect in bytes +const RTREE_MEMORY_ESTIMATE_PER_RECT: usize = 60; + +/// The prealloc size for the refiner reservation. This is used to reduce the frequency of growing +/// the reservation when updating the refiner memory reservation. +const REFINER_RESERVATION_PREALLOC_SIZE: usize = 10 * 1024 * 1024; // 10MB + +/// Builder for constructing a SpatialIndex from geometry batches. +/// +/// This builder handles: +/// 1. Accumulating geometry batches to be indexed +/// 2. Building the spatial R-tree index +/// 3. Setting up memory tracking and visited bitmaps +/// 4. Configuring prepared geometries based on execution mode +pub(crate) struct SpatialIndexBuilder { + schema: SchemaRef, + spatial_predicate: SpatialPredicate, + options: SpatialJoinOptions, + join_type: JoinType, + probe_threads_count: usize, + metrics: SpatialJoinBuildMetrics, + + /// Batches to be indexed + indexed_batches: Vec, + /// Memory reservation for tracking the memory usage of the spatial index + reservation: MemoryReservation, + + /// Statistics for indexed geometries + stats: GeoStatistics, + + /// Memory pool for managing the memory usage of the spatial index + memory_pool: Arc, +} + +/// Metrics for the build phase of the spatial join. +#[derive(Clone, Debug, Default)] +pub(crate) struct SpatialJoinBuildMetrics { + /// Total time for collecting build-side of join + pub(crate) build_time: metrics::Time, + /// Memory used by the spatial-index in bytes + pub(crate) build_mem_used: metrics::Gauge, +} + +impl SpatialJoinBuildMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + Self { + build_time: MetricBuilder::new(metrics).subset_time("build_time", partition), + build_mem_used: MetricBuilder::new(metrics).gauge("build_mem_used", partition), + } + } +} + +impl SpatialIndexBuilder { + /// Create a new builder with the given configuration. + pub fn new( + schema: SchemaRef, + spatial_predicate: SpatialPredicate, + options: SpatialJoinOptions, + join_type: JoinType, + probe_threads_count: usize, + memory_pool: Arc, + metrics: SpatialJoinBuildMetrics, + ) -> Result { + let consumer = MemoryConsumer::new("SpatialJoinIndex"); + let reservation = consumer.register(&memory_pool); + + Ok(Self { + schema, + spatial_predicate, + options, + join_type, + probe_threads_count, + metrics, + indexed_batches: Vec::new(), + reservation, + stats: GeoStatistics::empty(), + memory_pool, + }) + } + + /// Add a geometry batch to be indexed. + /// + /// This method accumulates geometry batches that will be used to build the spatial index. + /// Each batch contains processed geometry data along with memory usage information. + pub fn add_batch(&mut self, indexed_batch: EvaluatedBatch) { + let in_mem_size = indexed_batch.in_mem_size(); + self.indexed_batches.push(indexed_batch); + self.reservation.grow(in_mem_size); + self.metrics.build_mem_used.add(in_mem_size); + } + + pub fn merge_stats(&mut self, stats: GeoStatistics) -> &mut Self { + self.stats.merge(&stats); + self + } + + /// Build the spatial R-tree index from collected geometry batches. + fn build_rtree(&mut self) -> Result { + let build_timer = self.metrics.build_time.timer(); + + let num_rects = self + .indexed_batches + .iter() + .map(|batch| batch.rects().iter().flatten().count()) + .sum::(); + + let mut rtree_builder = RTreeBuilder::::new(num_rects as u32); + let mut batch_pos_vec = vec![(-1, -1); num_rects]; + let rtree_mem_estimate = num_rects * RTREE_MEMORY_ESTIMATE_PER_RECT; + + self.reservation + .grow(batch_pos_vec.allocated_size() + rtree_mem_estimate); + + for (batch_idx, batch) in self.indexed_batches.iter().enumerate() { + let rects = batch.rects(); + for (idx, rect_opt) in rects.iter().enumerate() { + let Some(rect) = rect_opt else { + continue; + }; + let min = rect.min(); + let max = rect.max(); + let data_idx = rtree_builder.add(min.x, min.y, max.x, max.y); + batch_pos_vec[data_idx as usize] = (batch_idx as i32, idx as i32); + } + } + + let rtree = rtree_builder.finish::(); + build_timer.done(); + + self.metrics.build_mem_used.add(self.reservation.size()); + + Ok((rtree, batch_pos_vec)) + } + + /// Build visited bitmaps for tracking left-side indices in outer joins. + fn build_visited_bitmaps(&mut self) -> Result>>> { + if !need_produce_result_in_final(self.join_type) { + return Ok(None); + } + + let mut bitmaps = Vec::with_capacity(self.indexed_batches.len()); + let mut total_buffer_size = 0; + + for batch in &self.indexed_batches { + let batch_rows = batch.batch.num_rows(); + let buffer_size = batch_rows.div_ceil(8); + total_buffer_size += buffer_size; + + let mut bitmap = BooleanBufferBuilder::new(batch_rows); + bitmap.append_n(batch_rows, false); + bitmaps.push(bitmap); + } + + self.reservation.try_grow(total_buffer_size)?; + self.metrics.build_mem_used.add(total_buffer_size); + + Ok(Some(Mutex::new(bitmaps))) + } + + /// Create an rtree data index to consecutive index mapping. + fn build_geom_idx_vec(&mut self, batch_pos_vec: &Vec<(i32, i32)>) -> Vec { + let mut num_geometries = 0; + let mut batch_idx_offset = Vec::with_capacity(self.indexed_batches.len() + 1); + batch_idx_offset.push(0); + for batch in &self.indexed_batches { + num_geometries += batch.batch.num_rows(); + batch_idx_offset.push(num_geometries); + } + + let mut geom_idx_vec = Vec::with_capacity(batch_pos_vec.len()); + self.reservation.grow(geom_idx_vec.allocated_size()); + for (batch_idx, row_idx) in batch_pos_vec { + // Convert (batch_idx, row_idx) to a linear, sequential index + let batch_offset = batch_idx_offset[*batch_idx as usize]; + let prepared_idx = batch_offset + *row_idx as usize; + geom_idx_vec.push(prepared_idx); + } + + geom_idx_vec + } + + /// Finish building and return the completed SpatialIndex. + pub fn finish(mut self) -> Result { + if self.indexed_batches.is_empty() { + return Ok(SpatialIndex::empty( + self.spatial_predicate, + self.schema, + self.options, + AtomicUsize::new(self.probe_threads_count), + self.reservation, + self.memory_pool.clone(), + )); + } + + let evaluator = create_operand_evaluator(&self.spatial_predicate, self.options.clone()); + let num_geoms = self + .indexed_batches + .iter() + .map(|batch| batch.batch.num_rows()) + .sum::(); + + let (rtree, batch_pos_vec) = self.build_rtree()?; + let geom_idx_vec = self.build_geom_idx_vec(&batch_pos_vec); + let visited_left_side = self.build_visited_bitmaps()?; + + let refiner = create_refiner( + self.options.spatial_library, + &self.spatial_predicate, + self.options.clone(), + num_geoms, + self.stats, + ); + let consumer = MemoryConsumer::new("SpatialJoinRefiner"); + let refiner_reservation = consumer.register(&self.memory_pool); + let refiner_reservation = + ConcurrentReservation::try_new(REFINER_RESERVATION_PREALLOC_SIZE, refiner_reservation) + .unwrap(); + + let cache_size = batch_pos_vec.len(); + let knn_components = + KnnComponents::new(cache_size, &self.indexed_batches, self.memory_pool.clone())?; + + Ok(SpatialIndex { + schema: self.schema, + evaluator, + refiner, + refiner_reservation, + rtree, + data_id_to_batch_pos: batch_pos_vec, + indexed_batches: self.indexed_batches, + geom_idx_vec, + visited_left_side, + probe_threads_counter: AtomicUsize::new(self.probe_threads_count), + knn_components, + reservation: self.reservation, + }) + } + + pub async fn add_partitions(&mut self, partitions: Vec) -> Result<()> { + for partition in partitions { + self.add_partition(partition).await?; + } + Ok(()) + } + + pub async fn add_partition(&mut self, mut partition: BuildPartition) -> Result<()> { + let mut stream = partition.build_side_batch_stream; + while let Some(batch) = stream.next().await { + let indexed_batch = batch?; + self.add_batch(indexed_batch); + } + self.merge_stats(partition.geo_statistics); + let mem_bytes = partition.reservation.free(); + self.reservation.try_grow(mem_bytes)?; + Ok(()) + } +} diff --git a/rust/sedona-spatial-join/src/lib.rs b/rust/sedona-spatial-join/src/lib.rs index eb43e5787..592ac2e73 100644 --- a/rust/sedona-spatial-join/src/lib.rs +++ b/rust/sedona-spatial-join/src/lib.rs @@ -14,16 +14,16 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -pub mod concurrent_reservation; + +mod build_index; +mod evaluated_batch; pub mod exec; -pub mod index; -pub mod init_once_array; -pub mod once_fut; +mod index; pub mod operand_evaluator; pub mod optimizer; pub mod refine; pub mod spatial_predicate; -pub mod stream; +mod stream; pub mod utils; pub use exec::SpatialJoinExec; diff --git a/rust/sedona-spatial-join/src/operand_evaluator.rs b/rust/sedona-spatial-join/src/operand_evaluator.rs index 114d4309a..b665ffea2 100644 --- a/rust/sedona-spatial-join/src/operand_evaluator.rs +++ b/rust/sedona-spatial-join/src/operand_evaluator.rs @@ -92,21 +92,14 @@ pub(crate) fn create_operand_evaluator( pub(crate) struct EvaluatedGeometryArray { /// The array of geometries produced by evaluating the geometry expression. pub geometry_array: ArrayRef, - /// The rects of the geometries in the geometry array. Each geometry could be covered by a collection - /// of multiple rects. The first element of the tuple is the index of the geometry in the geometry array. - /// This array is guaranteed to be sorted by the index of the geometry. - pub rects: Vec<(usize, Rect)>, + /// The rects of the geometries in the geometry array. The length of this array is equal to the number of geometries. + /// The rects will be None for empty or null geometries. + pub rects: Vec>>, /// The distance value produced by evaluating the distance expression. pub distance: Option, - /// The array of WKBs of the geometries unwrapped from the geometry array. It is a reference to - /// some of the columns of the `geometry_array`. We need to keep it here since the WKB values reference - /// buffers inside the geometry array, but we'll only allow accessing Wkb<'a> where 'a is the lifetime of - /// the GeometryBatchResult to make the interfaces safe. - #[allow(dead_code)] - wkb_array: ArrayRef, - /// WKBs of the geometries in `wkb_array`. The wkb values reference buffers inside the geometry array, + /// WKBs of the geometries in `geometry_array`. The wkb values reference buffers inside the geometry array, /// but we'll only allow accessing Wkb<'a> where 'a is the lifetime of the GeometryBatchResult to make - /// the interfaces safe. The buffers in `wkb_array` are allocated on the heap and won't be moved when + /// the interfaces safe. The buffers in `geometry_array` are allocated on the heap and won't be moved when /// the GeometryBatchResult is moved, so we don't need to worry about pinning. wkbs: Vec>>, } @@ -115,26 +108,31 @@ impl EvaluatedGeometryArray { pub fn try_new(geometry_array: ArrayRef, sedona_type: &SedonaType) -> Result { let num_rows = geometry_array.len(); let mut rect_vec = Vec::with_capacity(num_rows); - let wkb_array = geometry_array.clone(); let mut wkbs = Vec::with_capacity(num_rows); - let mut idx = 0; - wkb_array.iter_as_wkb(sedona_type, num_rows, |wkb_opt| { - if let Some(wkb) = &wkb_opt { + geometry_array.iter_as_wkb(sedona_type, num_rows, |wkb_opt| { + let rect_opt = if let Some(wkb) = &wkb_opt { if let Some(rect) = wkb.bounding_rect() { let min = rect.min(); let max = rect.max(); // f64_box_to_f32 will ensure the resulting `f32` box is no smaller than the `f64` box. let (min_x, min_y, max_x, max_y) = f64_box_to_f32(min.x, min.y, max.x, max.y); let rect = Rect::new(coord!(x: min_x, y: min_y), coord!(x: max_x, y: max_y)); - rect_vec.push((idx, rect)); + Some(rect) + } else { + None } - } + } else { + None + }; + rect_vec.push(rect_opt); wkbs.push(wkb_opt); - idx += 1; Ok(()) })?; - // Safety: The wkbs must reference buffers inside the `wkb_array`. + // Safety: The wkbs must reference buffers inside the `geometry_array`. Since the `geometry_array` and + // `wkbs` are both owned by the `EvaluatedGeometryArray`, so they have the same lifetime. We'll never + // have a situation where the `EvaluatedGeometryArray` is dropped while the `wkbs` are still in use + // (guaranteed by the scope of the `wkbs` field and lifetime signature of the `wkbs` method). let wkbs = wkbs .into_iter() .map(|wkb| wkb.map(|wkb| unsafe { transmute(wkb) })) @@ -143,7 +141,6 @@ impl EvaluatedGeometryArray { geometry_array, rects: rect_vec, distance: None, - wkb_array, wkbs, }) } @@ -167,8 +164,6 @@ impl EvaluatedGeometryArray { // should be small, so the inaccuracy does not matter too much. let wkb_vec_size = self.wkbs.allocated_size(); - // We do not take wkb_array into consideration, since it is a reference to some of the - // columns of the geometry_array. self.geometry_array.get_array_memory_size() + self.rects.allocated_size() + distance_in_mem_size @@ -245,7 +240,10 @@ impl DistanceOperandEvaluator { let distance_columnar_value = distance_columnar_value.cast_to(&DataType::Float64, None)?; match &distance_columnar_value { ColumnarValue::Scalar(ScalarValue::Float64(Some(distance))) => { - result.rects.iter_mut().for_each(|(_, rect)| { + result.rects.iter_mut().for_each(|rect_opt| { + let Some(rect) = rect_opt else { + return; + }; expand_rect_in_place(rect, *distance); }); } @@ -255,9 +253,12 @@ impl DistanceOperandEvaluator { } ColumnarValue::Array(array) => { if let Some(array) = array.as_any().downcast_ref::() { - for (geom_idx, rect) in result.rects.iter_mut() { - if !array.is_null(*geom_idx) { - let dist = array.value(*geom_idx); + for (geom_idx, rect_opt) in result.rects.iter_mut().enumerate() { + if !array.is_null(geom_idx) { + let dist = array.value(geom_idx); + let Some(rect) = rect_opt else { + continue; + }; expand_rect_in_place(rect, dist); } } diff --git a/rust/sedona-spatial-join/src/refine/mod.rs b/rust/sedona-spatial-join/src/refine.rs similarity index 100% rename from rust/sedona-spatial-join/src/refine/mod.rs rename to rust/sedona-spatial-join/src/refine.rs diff --git a/rust/sedona-spatial-join/src/refine/geos.rs b/rust/sedona-spatial-join/src/refine/geos.rs index 9b4bc2980..be6fbf904 100644 --- a/rust/sedona-spatial-join/src/refine/geos.rs +++ b/rust/sedona-spatial-join/src/refine/geos.rs @@ -29,12 +29,12 @@ use wkb::reader::Wkb; use crate::{ index::IndexQueryResult, - init_once_array::InitOnceArray, refine::{ exec_mode_selector::{get_or_update_execution_mode, ExecModeSelector, SelectOptimalMode}, IndexQueryResultRefiner, }, spatial_predicate::{RelationPredicate, SpatialPredicate, SpatialRelationType}, + utils::init_once_array::InitOnceArray, }; /// GEOS-specific optimal mode selector that chooses the best execution mode @@ -283,9 +283,7 @@ impl GeosRefiner { continue; }; if is_newly_created { - // TODO: This ia a rough estimate of the memory usage of the prepared geometry and - // may not be accurate. - let prep_geom_size = index_result.wkb.buf().len() * 4; + let prep_geom_size = estimate_prep_geom_in_mem_size(index_result.wkb); self.mem_usage.fetch_add(prep_geom_size, Ordering::Relaxed); } if self.evaluator.evaluate_prepare_build( @@ -321,6 +319,13 @@ impl GeosRefiner { } } +fn estimate_prep_geom_in_mem_size(wkb: &Wkb<'_>) -> usize { + // TODO: This is a rough estimate of the memory usage of the prepared geometry and + // may not be accurate. + // https://github.com/apache/sedona-db/issues/281 + wkb.buf().len() * 4 +} + impl IndexQueryResultRefiner for GeosRefiner { fn refine( &self, diff --git a/rust/sedona-spatial-join/src/refine/tg.rs b/rust/sedona-spatial-join/src/refine/tg.rs index dc42f8b0b..4b1213d6c 100644 --- a/rust/sedona-spatial-join/src/refine/tg.rs +++ b/rust/sedona-spatial-join/src/refine/tg.rs @@ -30,12 +30,12 @@ use wkb::reader::Wkb; use crate::{ index::IndexQueryResult, - init_once_array::InitOnceArray, refine::{ exec_mode_selector::{get_or_update_execution_mode, ExecModeSelector, SelectOptimalMode}, IndexQueryResultRefiner, }, spatial_predicate::{RelationPredicate, SpatialPredicate, SpatialRelationType}, + utils::init_once_array::InitOnceArray, }; /// TG-specific optimal mode selector that chooses the best execution mode diff --git a/rust/sedona-spatial-join/src/stream.rs b/rust/sedona-spatial-join/src/stream.rs index 4fc4a59d3..67b730648 100644 --- a/rust/sedona-spatial-join/src/stream.rs +++ b/rust/sedona-spatial-join/src/stream.rs @@ -33,16 +33,15 @@ use std::collections::HashMap; use std::ops::Range; use std::sync::Arc; +use crate::evaluated_batch::EvaluatedBatch; use crate::index::SpatialIndex; -use crate::once_fut::{OnceAsync, OnceFut}; -use crate::operand_evaluator::{ - create_operand_evaluator, distance_value_at, EvaluatedGeometryArray, OperandEvaluator, -}; +use crate::operand_evaluator::{create_operand_evaluator, distance_value_at, OperandEvaluator}; use crate::spatial_predicate::SpatialPredicate; -use crate::utils::{ +use crate::utils::join_utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, get_final_indices_from_bit_map, need_produce_result_in_final, }; +use crate::utils::once_fut::{OnceAsync, OnceFut}; use arrow::array::RecordBatch; use arrow::datatypes::{Schema, SchemaRef}; use sedona_common::option::SpatialJoinOptions; @@ -151,7 +150,7 @@ impl SpatialJoinProbeMetrics { .counter("probe_input_batches", partition), probe_input_rows: MetricBuilder::new(metrics).counter("probe_input_rows", partition), output_batches: MetricBuilder::new(metrics).counter("output_batches", partition), - output_rows: MetricBuilder::new(metrics).counter("output_rows", partition), + output_rows: MetricBuilder::new(metrics).output_rows(partition), join_result_candidates: MetricBuilder::new(metrics) .counter("join_result_candidates", partition), join_result_count: MetricBuilder::new(metrics).counter("join_result_count", partition), @@ -415,8 +414,10 @@ impl SpatialJoinStream { SpatialJoinBatchIterator::new(SpatialJoinBatchIteratorParams { spatial_index: spatial_index.clone(), - probe_batch: probe_batch.clone(), - geom_array, + probe_evaluated_batch: EvaluatedBatch { + batch: probe_batch, + geom_array, + }, join_metrics: self.join_metrics.clone(), max_batch_size: self.target_output_batch_size, probe_side_ordered: self.probe_side_ordered, @@ -456,14 +457,10 @@ struct PartialBuildBatch { pub(crate) struct SpatialJoinBatchIterator { /// The spatial index reference spatial_index: Arc, - /// The probe batch being processed - probe_batch: RecordBatch, - /// The geometry batch result from evaluating the probe batch - geom_array: EvaluatedGeometryArray, + /// The probe side batch being processed + probe_evaluated_batch: EvaluatedBatch, /// Current probe row index being processed current_probe_idx: usize, - /// Current rect index being processed - current_rect_idx: usize, /// Join metrics for tracking performance join_metrics: SpatialJoinProbeMetrics, /// Maximum batch size before yielding a result @@ -485,8 +482,7 @@ pub(crate) struct SpatialJoinBatchIterator { /// Parameters for creating a SpatialJoinBatchIterator pub(crate) struct SpatialJoinBatchIteratorParams { pub spatial_index: Arc, - pub probe_batch: RecordBatch, - pub geom_array: EvaluatedGeometryArray, + pub probe_evaluated_batch: EvaluatedBatch, pub join_metrics: SpatialJoinProbeMetrics, pub max_batch_size: usize, pub probe_side_ordered: bool, @@ -498,10 +494,8 @@ impl SpatialJoinBatchIterator { pub(crate) fn new(params: SpatialJoinBatchIteratorParams) -> Result { Ok(Self { spatial_index: params.spatial_index, - probe_batch: params.probe_batch, - geom_array: params.geom_array, + probe_evaluated_batch: params.probe_evaluated_batch, current_probe_idx: 0, - current_rect_idx: 0, join_metrics: params.join_metrics, max_batch_size: params.max_batch_size, probe_side_ordered: params.probe_side_ordered, @@ -524,9 +518,10 @@ impl SpatialJoinBatchIterator { // Process probe rows incrementally until we have enough results or finish let initial_size = self.build_batch_positions.len(); - let wkbs = self.geom_array.wkbs(); - let rects = &self.geom_array.rects; - let distance = &self.geom_array.distance; + let geom_array = &self.probe_evaluated_batch.geom_array; + let wkbs = geom_array.wkbs(); + let rects = &geom_array.rects; + let distance = &geom_array.distance; let num_rows = wkbs.len(); @@ -575,22 +570,11 @@ impl SpatialJoinBatchIterator { self.join_metrics .join_result_count .add(join_result_metrics.count); - - // Skip all rects for this probe index since KNN doesn't use them - while self.current_rect_idx < rects.len() - && rects[self.current_rect_idx].0 == self.current_probe_idx - { - self.current_rect_idx += 1; - } } _ => { // Regular spatial join: process all rects for this probe index - while self.current_rect_idx < rects.len() - && rects[self.current_rect_idx].0 == self.current_probe_idx - { - let rect = &rects[self.current_rect_idx].1; - self.current_rect_idx += 1; - + let rect_opt = &rects[self.current_probe_idx]; + if let Some(rect) = rect_opt { let join_result_metrics = self.spatial_index.query( wkb, rect, @@ -676,7 +660,7 @@ impl SpatialJoinBatchIterator { let (build_indices, probe_indices) = match filter { Some(filter) => apply_join_filter_to_indices( &partial_build_batch, - &self.probe_batch, + &self.probe_evaluated_batch.batch, build_indices, probe_indices, filter, @@ -709,7 +693,7 @@ impl SpatialJoinBatchIterator { let result_batch = build_batch_from_indices( schema, &partial_build_batch, - &self.probe_batch, + &self.probe_evaluated_batch.batch, &build_indices, &probe_indices, column_indices, @@ -837,7 +821,6 @@ impl std::fmt::Debug for SpatialJoinBatchIterator { f.debug_struct("SpatialJoinBatchIterator") .field("max_batch_size", &self.max_batch_size) .field("current_probe_idx", &self.current_probe_idx) - .field("current_rect_idx", &self.current_rect_idx) .field("is_complete", &self.is_complete) .field( "build_batch_positions_len", diff --git a/rust/sedona-spatial-join/src/utils.rs b/rust/sedona-spatial-join/src/utils.rs index 83ec18f49..dab1c24d4 100644 --- a/rust/sedona-spatial-join/src/utils.rs +++ b/rust/sedona-spatial-join/src/utils.rs @@ -15,473 +15,7 @@ // specific language governing permissions and limitations // under the License. -/// Most of the code in this module are copied from the `datafusion_physical_plan::joins::utils` module. -/// https://github.com/apache/datafusion/blob/48.0.0/datafusion/physical-plan/src/joins/utils.rs -use std::{ops::Range, sync::Arc}; - -use arrow::array::{ - downcast_array, new_null_array, Array, BooleanBufferBuilder, RecordBatch, RecordBatchOptions, - UInt32Builder, UInt64Builder, -}; -use arrow::compute; -use arrow::datatypes::{ArrowNativeType, Schema, UInt32Type, UInt64Type}; -use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, UInt32Array, UInt64Array}; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::{JoinSide, Result}; -use datafusion_expr::JoinType; -use datafusion_physical_expr::Partitioning; -use datafusion_physical_plan::execution_plan::Boundedness; -use datafusion_physical_plan::joins::utils::{ - adjust_right_output_partitioning, ColumnIndex, JoinFilter, -}; -use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; - -/// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and -/// use the bit map to generate the part of result of the join. -/// -/// For example of the `Left` join, in each iteration of right side, can get the matched result, but need -/// to maintain the matched indices bit map to get the unmatched row for the left side. -pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::Left - | JoinType::LeftAnti - | JoinType::LeftSemi - | JoinType::LeftMark - | JoinType::Full - ) -} - -/// In the end of join execution, need to use bit map of the matched -/// indices to generate the final left and right indices. -/// -/// For example: -/// -/// 1. left_bit_map: `[true, false, true, true, false]` -/// 2. join_type: `Left` -/// -/// The result is: `([1,4], [null, null])` -pub(crate) fn get_final_indices_from_bit_map( - left_bit_map: &BooleanBufferBuilder, - join_type: JoinType, -) -> (UInt64Array, UInt32Array) { - let left_size = left_bit_map.len(); - if join_type == JoinType::LeftMark { - let left_indices = (0..left_size as u64).collect::(); - let right_indices = (0..left_size) - .map(|idx| left_bit_map.get_bit(idx).then_some(0)) - .collect::(); - return (left_indices, right_indices); - } - let left_indices = if join_type == JoinType::LeftSemi { - (0..left_size) - .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) - .collect::() - } else { - // just for `Left`, `LeftAnti` and `Full` join - // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally - (0..left_size) - .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64)) - .collect::() - }; - // right_indices - // all the element in the right side is None - let mut builder = UInt32Builder::with_capacity(left_indices.len()); - builder.append_nulls(left_indices.len()); - let right_indices = builder.finish(); - (left_indices, right_indices) -} - -pub(crate) fn apply_join_filter_to_indices( - build_input_buffer: &RecordBatch, - probe_batch: &RecordBatch, - build_indices: UInt64Array, - probe_indices: UInt32Array, - filter: &JoinFilter, - build_side: JoinSide, -) -> Result<(UInt64Array, UInt32Array)> { - if build_indices.is_empty() && probe_indices.is_empty() { - return Ok((build_indices, probe_indices)); - }; - - let intermediate_batch = build_batch_from_indices( - filter.schema(), - build_input_buffer, - probe_batch, - &build_indices, - &probe_indices, - filter.column_indices(), - build_side, - )?; - let filter_result = filter - .expression() - .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows())?; - let mask = as_boolean_array(&filter_result)?; - - let left_filtered = compute::filter(&build_indices, mask)?; - let right_filtered = compute::filter(&probe_indices, mask)?; - Ok(( - downcast_array(left_filtered.as_ref()), - downcast_array(right_filtered.as_ref()), - )) -} - -/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. -/// The resulting batch has [Schema] `schema`. -pub(crate) fn build_batch_from_indices( - schema: &Schema, - build_input_buffer: &RecordBatch, - probe_batch: &RecordBatch, - build_indices: &UInt64Array, - probe_indices: &UInt32Array, - column_indices: &[ColumnIndex], - build_side: JoinSide, -) -> Result { - if schema.fields().is_empty() { - let options = RecordBatchOptions::new() - .with_match_field_names(true) - .with_row_count(Some(build_indices.len())); - - return Ok(RecordBatch::try_new_with_options( - Arc::new(schema.clone()), - vec![], - &options, - )?); - } - - // build the columns of the new [RecordBatch]: - // 1. pick whether the column is from the left or right - // 2. based on the pick, `take` items from the different RecordBatches - let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); - - for column_index in column_indices { - let array = if column_index.side == JoinSide::None { - // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false - Arc::new(compute::is_not_null(probe_indices)?) - } else if column_index.side == build_side { - let array = build_input_buffer.column(column_index.index); - if array.is_empty() || build_indices.null_count() == build_indices.len() { - // Outer join would generate a null index when finding no match at our side. - // Therefore, it's possible we are empty but need to populate an n-length null array, - // where n is the length of the index array. - assert_eq!(build_indices.null_count(), build_indices.len()); - new_null_array(array.data_type(), build_indices.len()) - } else { - compute::take(array.as_ref(), build_indices, None)? - } - } else { - let array = probe_batch.column(column_index.index); - if array.is_empty() || probe_indices.null_count() == probe_indices.len() { - assert_eq!(probe_indices.null_count(), probe_indices.len()); - new_null_array(array.data_type(), probe_indices.len()) - } else { - compute::take(array.as_ref(), probe_indices, None)? - } - }; - columns.push(array); - } - Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) -} - -/// The input is the matched indices for left and right and -/// adjust the indices according to the join type -pub(crate) fn adjust_indices_by_join_type( - left_indices: UInt64Array, - right_indices: UInt32Array, - adjust_range: Range, - join_type: JoinType, - preserve_order_for_right: bool, -) -> Result<(UInt64Array, UInt32Array)> { - match join_type { - JoinType::Inner => { - // matched - Ok((left_indices, right_indices)) - } - JoinType::Left => { - // matched - Ok((left_indices, right_indices)) - // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap - } - JoinType::Right => { - // combine the matched and unmatched right result together - append_right_indices( - left_indices, - right_indices, - adjust_range, - preserve_order_for_right, - ) - } - JoinType::Full => append_right_indices(left_indices, right_indices, adjust_range, false), - JoinType::RightSemi => { - // need to remove the duplicated record in the right side - let right_indices = get_semi_indices(adjust_range, &right_indices); - // the left_indices will not be used later for the `right semi` join - Ok((left_indices, right_indices)) - } - JoinType::RightAnti => { - // need to remove the duplicated record in the right side - // get the anti index for the right side - let right_indices = get_anti_indices(adjust_range, &right_indices); - // the left_indices will not be used later for the `right anti` join - Ok((left_indices, right_indices)) - } - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::RightMark => { - // matched or unmatched left row will be produced in the end of loop - // When visit the right batch, we can output the matched left row and don't need to wait the end of loop - Ok(( - UInt64Array::from_iter_values(vec![]), - UInt32Array::from_iter_values(vec![]), - )) - } - } -} - -/// Appends right indices to left indices based on the specified order mode. -/// -/// The function operates in two modes: -/// 1. If `preserve_order_for_right` is true, probe matched and unmatched indices -/// are inserted in order using the `append_probe_indices_in_order()` method. -/// 2. Otherwise, unmatched probe indices are simply appended after matched ones. -/// -/// # Parameters -/// - `left_indices`: UInt64Array of left indices. -/// - `right_indices`: UInt32Array of right indices. -/// - `adjust_range`: Range to adjust the right indices. -/// - `preserve_order_for_right`: Boolean flag to determine the mode of operation. -/// -/// # Returns -/// A tuple of updated `UInt64Array` and `UInt32Array`. -pub(crate) fn append_right_indices( - left_indices: UInt64Array, - right_indices: UInt32Array, - adjust_range: Range, - preserve_order_for_right: bool, -) -> Result<(UInt64Array, UInt32Array)> { - if preserve_order_for_right { - Ok(append_probe_indices_in_order( - left_indices, - right_indices, - adjust_range, - )) - } else { - let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); - - if right_unmatched_indices.is_empty() { - Ok((left_indices, right_indices)) - } else { - // `into_builder()` can fail here when there is nothing to be filtered and - // left_indices or right_indices has the same reference to the cached indices. - // In that case, we use a slower alternative. - - // the new left indices: left_indices + null array - let mut new_left_indices_builder = - left_indices.into_builder().unwrap_or_else(|left_indices| { - let mut builder = UInt64Builder::with_capacity( - left_indices.len() + right_unmatched_indices.len(), - ); - debug_assert_eq!( - left_indices.null_count(), - 0, - "expected left indices to have no nulls" - ); - builder.append_slice(left_indices.values()); - builder - }); - new_left_indices_builder.append_nulls(right_unmatched_indices.len()); - let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); - - // the new right indices: right_indices + right_unmatched_indices - let mut new_right_indices_builder = - right_indices - .into_builder() - .unwrap_or_else(|right_indices| { - let mut builder = UInt32Builder::with_capacity( - right_indices.len() + right_unmatched_indices.len(), - ); - debug_assert_eq!( - right_indices.null_count(), - 0, - "expected right indices to have no nulls" - ); - builder.append_slice(right_indices.values()); - builder - }); - debug_assert_eq!( - right_unmatched_indices.null_count(), - 0, - "expected right unmatched indices to have no nulls" - ); - new_right_indices_builder.append_slice(right_unmatched_indices.values()); - let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); - - Ok((new_left_indices, new_right_indices)) - } - } -} - -/// Returns `range` indices which are not present in `input_indices` -pub(crate) fn get_anti_indices( - range: Range, - input_indices: &PrimitiveArray, -) -> PrimitiveArray -where - NativeAdapter: From<::Native>, -{ - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - - let offset = range.start; - - // get the anti index - (range) - .filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))) - .collect() -} - -/// Returns intersection of `range` and `input_indices` omitting duplicates -pub(crate) fn get_semi_indices( - range: Range, - input_indices: &PrimitiveArray, -) -> PrimitiveArray -where - NativeAdapter: From<::Native>, -{ - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - - let offset = range.start; - - // get the semi index - (range) - .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))) - .collect() -} - -/// Appends probe indices in order by considering the given build indices. -/// -/// This function constructs new build and probe indices by iterating through -/// the provided indices, and appends any missing values between previous and -/// current probe index with a corresponding null build index. -/// -/// # Parameters -/// -/// - `build_indices`: `PrimitiveArray` of `UInt64Type` containing build indices. -/// - `probe_indices`: `PrimitiveArray` of `UInt32Type` containing probe indices. -/// - `range`: The range of indices to consider. -/// -/// # Returns -/// -/// A tuple of two arrays: -/// - A `PrimitiveArray` of `UInt64Type` with the newly constructed build indices. -/// - A `PrimitiveArray` of `UInt32Type` with the newly constructed probe indices. -fn append_probe_indices_in_order( - build_indices: PrimitiveArray, - probe_indices: PrimitiveArray, - range: Range, -) -> (PrimitiveArray, PrimitiveArray) { - // Builders for new indices: - let mut new_build_indices = UInt64Builder::new(); - let mut new_probe_indices = UInt32Builder::new(); - // Set previous index as the start index for the initial loop: - let mut prev_index = range.start as u32; - // Zip the two iterators. - debug_assert!(build_indices.len() == probe_indices.len()); - for (build_index, probe_index) in build_indices - .values() - .into_iter() - .zip(probe_indices.values().into_iter()) - { - // Append values between previous and current probe index with null build index: - for value in prev_index..*probe_index { - new_probe_indices.append_value(value); - new_build_indices.append_null(); - } - // Append current indices: - new_probe_indices.append_value(*probe_index); - new_build_indices.append_value(*build_index); - // Set current probe index as previous for the next iteration: - prev_index = probe_index + 1; - } - // Append remaining probe indices after the last valid probe index with null build index. - for value in prev_index..range.end as u32 { - new_probe_indices.append_value(value); - new_build_indices.append_null(); - } - // Build arrays and return: - (new_build_indices.finish(), new_probe_indices.finish()) -} - -pub(crate) fn asymmetric_join_output_partitioning( - left: &Arc, - right: &Arc, - join_type: &JoinType, -) -> Partitioning { - match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left.schema().fields().len(), - ) - .unwrap_or_else(|_| Partitioning::UnknownPartitioning(1)), - JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::Full - | JoinType::LeftMark - | JoinType::RightMark => { - Partitioning::UnknownPartitioning(right.output_partitioning().partition_count()) - } - } -} - -/// This function is copied from -/// [`datafusion_physical_plan::physical_plan::execution_plan::boundedness_from_children`]. -/// It is used to determine the boundedness of the join operator based on the boundedness of its children. -pub(crate) fn boundedness_from_children<'a>( - children: impl IntoIterator>, -) -> Boundedness { - let mut unbounded_with_finite_mem = false; - - for child in children { - match child.boundedness() { - Boundedness::Unbounded { - requires_infinite_memory: true, - } => { - return Boundedness::Unbounded { - requires_infinite_memory: true, - } - } - Boundedness::Unbounded { - requires_infinite_memory: false, - } => { - unbounded_with_finite_mem = true; - } - Boundedness::Bounded => {} - } - } - - if unbounded_with_finite_mem { - Boundedness::Unbounded { - requires_infinite_memory: false, - } - } else { - Boundedness::Bounded - } -} +pub(crate) mod concurrent_reservation; +pub(crate) mod init_once_array; +pub(crate) mod join_utils; +pub(crate) mod once_fut; diff --git a/rust/sedona-spatial-join/src/concurrent_reservation.rs b/rust/sedona-spatial-join/src/utils/concurrent_reservation.rs similarity index 100% rename from rust/sedona-spatial-join/src/concurrent_reservation.rs rename to rust/sedona-spatial-join/src/utils/concurrent_reservation.rs diff --git a/rust/sedona-spatial-join/src/init_once_array.rs b/rust/sedona-spatial-join/src/utils/init_once_array.rs similarity index 100% rename from rust/sedona-spatial-join/src/init_once_array.rs rename to rust/sedona-spatial-join/src/utils/init_once_array.rs diff --git a/rust/sedona-spatial-join/src/utils/join_utils.rs b/rust/sedona-spatial-join/src/utils/join_utils.rs new file mode 100644 index 000000000..83ec18f49 --- /dev/null +++ b/rust/sedona-spatial-join/src/utils/join_utils.rs @@ -0,0 +1,487 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Most of the code in this module are copied from the `datafusion_physical_plan::joins::utils` module. +/// https://github.com/apache/datafusion/blob/48.0.0/datafusion/physical-plan/src/joins/utils.rs +use std::{ops::Range, sync::Arc}; + +use arrow::array::{ + downcast_array, new_null_array, Array, BooleanBufferBuilder, RecordBatch, RecordBatchOptions, + UInt32Builder, UInt64Builder, +}; +use arrow::compute; +use arrow::datatypes::{ArrowNativeType, Schema, UInt32Type, UInt64Type}; +use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, UInt32Array, UInt64Array}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{JoinSide, Result}; +use datafusion_expr::JoinType; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_plan::execution_plan::Boundedness; +use datafusion_physical_plan::joins::utils::{ + adjust_right_output_partitioning, ColumnIndex, JoinFilter, +}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; + +/// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and +/// use the bit map to generate the part of result of the join. +/// +/// For example of the `Left` join, in each iteration of right side, can get the matched result, but need +/// to maintain the matched indices bit map to get the unmatched row for the left side. +pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Full + ) +} + +/// In the end of join execution, need to use bit map of the matched +/// indices to generate the final left and right indices. +/// +/// For example: +/// +/// 1. left_bit_map: `[true, false, true, true, false]` +/// 2. join_type: `Left` +/// +/// The result is: `([1,4], [null, null])` +pub(crate) fn get_final_indices_from_bit_map( + left_bit_map: &BooleanBufferBuilder, + join_type: JoinType, +) -> (UInt64Array, UInt32Array) { + let left_size = left_bit_map.len(); + if join_type == JoinType::LeftMark { + let left_indices = (0..left_size as u64).collect::(); + let right_indices = (0..left_size) + .map(|idx| left_bit_map.get_bit(idx).then_some(0)) + .collect::(); + return (left_indices, right_indices); + } + let left_indices = if join_type == JoinType::LeftSemi { + (0..left_size) + .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) + .collect::() + } else { + // just for `Left`, `LeftAnti` and `Full` join + // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally + (0..left_size) + .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64)) + .collect::() + }; + // right_indices + // all the element in the right side is None + let mut builder = UInt32Builder::with_capacity(left_indices.len()); + builder.append_nulls(left_indices.len()); + let right_indices = builder.finish(); + (left_indices, right_indices) +} + +pub(crate) fn apply_join_filter_to_indices( + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_indices: UInt64Array, + probe_indices: UInt32Array, + filter: &JoinFilter, + build_side: JoinSide, +) -> Result<(UInt64Array, UInt32Array)> { + if build_indices.is_empty() && probe_indices.is_empty() { + return Ok((build_indices, probe_indices)); + }; + + let intermediate_batch = build_batch_from_indices( + filter.schema(), + build_input_buffer, + probe_batch, + &build_indices, + &probe_indices, + filter.column_indices(), + build_side, + )?; + let filter_result = filter + .expression() + .evaluate(&intermediate_batch)? + .into_array(intermediate_batch.num_rows())?; + let mask = as_boolean_array(&filter_result)?; + + let left_filtered = compute::filter(&build_indices, mask)?; + let right_filtered = compute::filter(&probe_indices, mask)?; + Ok(( + downcast_array(left_filtered.as_ref()), + downcast_array(right_filtered.as_ref()), + )) +} + +/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. +/// The resulting batch has [Schema] `schema`. +pub(crate) fn build_batch_from_indices( + schema: &Schema, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_indices: &UInt64Array, + probe_indices: &UInt32Array, + column_indices: &[ColumnIndex], + build_side: JoinSide, +) -> Result { + if schema.fields().is_empty() { + let options = RecordBatchOptions::new() + .with_match_field_names(true) + .with_row_count(Some(build_indices.len())); + + return Ok(RecordBatch::try_new_with_options( + Arc::new(schema.clone()), + vec![], + &options, + )?); + } + + // build the columns of the new [RecordBatch]: + // 1. pick whether the column is from the left or right + // 2. based on the pick, `take` items from the different RecordBatches + let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); + + for column_index in column_indices { + let array = if column_index.side == JoinSide::None { + // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false + Arc::new(compute::is_not_null(probe_indices)?) + } else if column_index.side == build_side { + let array = build_input_buffer.column(column_index.index); + if array.is_empty() || build_indices.null_count() == build_indices.len() { + // Outer join would generate a null index when finding no match at our side. + // Therefore, it's possible we are empty but need to populate an n-length null array, + // where n is the length of the index array. + assert_eq!(build_indices.null_count(), build_indices.len()); + new_null_array(array.data_type(), build_indices.len()) + } else { + compute::take(array.as_ref(), build_indices, None)? + } + } else { + let array = probe_batch.column(column_index.index); + if array.is_empty() || probe_indices.null_count() == probe_indices.len() { + assert_eq!(probe_indices.null_count(), probe_indices.len()); + new_null_array(array.data_type(), probe_indices.len()) + } else { + compute::take(array.as_ref(), probe_indices, None)? + } + }; + columns.push(array); + } + Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) +} + +/// The input is the matched indices for left and right and +/// adjust the indices according to the join type +pub(crate) fn adjust_indices_by_join_type( + left_indices: UInt64Array, + right_indices: UInt32Array, + adjust_range: Range, + join_type: JoinType, + preserve_order_for_right: bool, +) -> Result<(UInt64Array, UInt32Array)> { + match join_type { + JoinType::Inner => { + // matched + Ok((left_indices, right_indices)) + } + JoinType::Left => { + // matched + Ok((left_indices, right_indices)) + // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap + } + JoinType::Right => { + // combine the matched and unmatched right result together + append_right_indices( + left_indices, + right_indices, + adjust_range, + preserve_order_for_right, + ) + } + JoinType::Full => append_right_indices(left_indices, right_indices, adjust_range, false), + JoinType::RightSemi => { + // need to remove the duplicated record in the right side + let right_indices = get_semi_indices(adjust_range, &right_indices); + // the left_indices will not be used later for the `right semi` join + Ok((left_indices, right_indices)) + } + JoinType::RightAnti => { + // need to remove the duplicated record in the right side + // get the anti index for the right side + let right_indices = get_anti_indices(adjust_range, &right_indices); + // the left_indices will not be used later for the `right anti` join + Ok((left_indices, right_indices)) + } + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::RightMark => { + // matched or unmatched left row will be produced in the end of loop + // When visit the right batch, we can output the matched left row and don't need to wait the end of loop + Ok(( + UInt64Array::from_iter_values(vec![]), + UInt32Array::from_iter_values(vec![]), + )) + } + } +} + +/// Appends right indices to left indices based on the specified order mode. +/// +/// The function operates in two modes: +/// 1. If `preserve_order_for_right` is true, probe matched and unmatched indices +/// are inserted in order using the `append_probe_indices_in_order()` method. +/// 2. Otherwise, unmatched probe indices are simply appended after matched ones. +/// +/// # Parameters +/// - `left_indices`: UInt64Array of left indices. +/// - `right_indices`: UInt32Array of right indices. +/// - `adjust_range`: Range to adjust the right indices. +/// - `preserve_order_for_right`: Boolean flag to determine the mode of operation. +/// +/// # Returns +/// A tuple of updated `UInt64Array` and `UInt32Array`. +pub(crate) fn append_right_indices( + left_indices: UInt64Array, + right_indices: UInt32Array, + adjust_range: Range, + preserve_order_for_right: bool, +) -> Result<(UInt64Array, UInt32Array)> { + if preserve_order_for_right { + Ok(append_probe_indices_in_order( + left_indices, + right_indices, + adjust_range, + )) + } else { + let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); + + if right_unmatched_indices.is_empty() { + Ok((left_indices, right_indices)) + } else { + // `into_builder()` can fail here when there is nothing to be filtered and + // left_indices or right_indices has the same reference to the cached indices. + // In that case, we use a slower alternative. + + // the new left indices: left_indices + null array + let mut new_left_indices_builder = + left_indices.into_builder().unwrap_or_else(|left_indices| { + let mut builder = UInt64Builder::with_capacity( + left_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + left_indices.null_count(), + 0, + "expected left indices to have no nulls" + ); + builder.append_slice(left_indices.values()); + builder + }); + new_left_indices_builder.append_nulls(right_unmatched_indices.len()); + let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); + + // the new right indices: right_indices + right_unmatched_indices + let mut new_right_indices_builder = + right_indices + .into_builder() + .unwrap_or_else(|right_indices| { + let mut builder = UInt32Builder::with_capacity( + right_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + right_indices.null_count(), + 0, + "expected right indices to have no nulls" + ); + builder.append_slice(right_indices.values()); + builder + }); + debug_assert_eq!( + right_unmatched_indices.null_count(), + 0, + "expected right unmatched indices to have no nulls" + ); + new_right_indices_builder.append_slice(right_unmatched_indices.values()); + let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); + + Ok((new_left_indices, new_right_indices)) + } + } +} + +/// Returns `range` indices which are not present in `input_indices` +pub(crate) fn get_anti_indices( + range: Range, + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v.as_usize()) + .filter(|v| range.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - range.start, true); + }); + + let offset = range.start; + + // get the anti index + (range) + .filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))) + .collect() +} + +/// Returns intersection of `range` and `input_indices` omitting duplicates +pub(crate) fn get_semi_indices( + range: Range, + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v.as_usize()) + .filter(|v| range.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - range.start, true); + }); + + let offset = range.start; + + // get the semi index + (range) + .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))) + .collect() +} + +/// Appends probe indices in order by considering the given build indices. +/// +/// This function constructs new build and probe indices by iterating through +/// the provided indices, and appends any missing values between previous and +/// current probe index with a corresponding null build index. +/// +/// # Parameters +/// +/// - `build_indices`: `PrimitiveArray` of `UInt64Type` containing build indices. +/// - `probe_indices`: `PrimitiveArray` of `UInt32Type` containing probe indices. +/// - `range`: The range of indices to consider. +/// +/// # Returns +/// +/// A tuple of two arrays: +/// - A `PrimitiveArray` of `UInt64Type` with the newly constructed build indices. +/// - A `PrimitiveArray` of `UInt32Type` with the newly constructed probe indices. +fn append_probe_indices_in_order( + build_indices: PrimitiveArray, + probe_indices: PrimitiveArray, + range: Range, +) -> (PrimitiveArray, PrimitiveArray) { + // Builders for new indices: + let mut new_build_indices = UInt64Builder::new(); + let mut new_probe_indices = UInt32Builder::new(); + // Set previous index as the start index for the initial loop: + let mut prev_index = range.start as u32; + // Zip the two iterators. + debug_assert!(build_indices.len() == probe_indices.len()); + for (build_index, probe_index) in build_indices + .values() + .into_iter() + .zip(probe_indices.values().into_iter()) + { + // Append values between previous and current probe index with null build index: + for value in prev_index..*probe_index { + new_probe_indices.append_value(value); + new_build_indices.append_null(); + } + // Append current indices: + new_probe_indices.append_value(*probe_index); + new_build_indices.append_value(*build_index); + // Set current probe index as previous for the next iteration: + prev_index = probe_index + 1; + } + // Append remaining probe indices after the last valid probe index with null build index. + for value in prev_index..range.end as u32 { + new_probe_indices.append_value(value); + new_build_indices.append_null(); + } + // Build arrays and return: + (new_build_indices.finish(), new_probe_indices.finish()) +} + +pub(crate) fn asymmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Partitioning { + match join_type { + JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( + right.output_partitioning(), + left.schema().fields().len(), + ) + .unwrap_or_else(|_| Partitioning::UnknownPartitioning(1)), + JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftMark + | JoinType::RightMark => { + Partitioning::UnknownPartitioning(right.output_partitioning().partition_count()) + } + } +} + +/// This function is copied from +/// [`datafusion_physical_plan::physical_plan::execution_plan::boundedness_from_children`]. +/// It is used to determine the boundedness of the join operator based on the boundedness of its children. +pub(crate) fn boundedness_from_children<'a>( + children: impl IntoIterator>, +) -> Boundedness { + let mut unbounded_with_finite_mem = false; + + for child in children { + match child.boundedness() { + Boundedness::Unbounded { + requires_infinite_memory: true, + } => { + return Boundedness::Unbounded { + requires_infinite_memory: true, + } + } + Boundedness::Unbounded { + requires_infinite_memory: false, + } => { + unbounded_with_finite_mem = true; + } + Boundedness::Bounded => {} + } + } + + if unbounded_with_finite_mem { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + } +} diff --git a/rust/sedona-spatial-join/src/once_fut.rs b/rust/sedona-spatial-join/src/utils/once_fut.rs similarity index 100% rename from rust/sedona-spatial-join/src/once_fut.rs rename to rust/sedona-spatial-join/src/utils/once_fut.rs