diff --git a/Cargo.lock b/Cargo.lock index 983a74ed3cf03..87c18826096c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2259,7 +2259,6 @@ dependencies = [ "regex", "sha2", "tokio", - "unicode-segmentation", "uuid", ] diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 6621706c35c81..35900e16c18ed 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ParquetOpener`] state machine for opening Parquet files +//! [`ParquetOpener`] and [`ParquetMorselizer`] state machines for opening Parquet files use crate::page_filter::PagePruningAccessPlanFilter; use crate::row_filter::build_projection_read_plan; @@ -26,11 +26,16 @@ use crate::{ }; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::DataType; +use datafusion_common::internal_err; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; +use datafusion_datasource::morsel::{ + Morsel, MorselPlan, MorselPlanner, Morselizer, PendingMorselPlanner, +}; use datafusion_physical_expr::projection::{ProjectionExprs, Projector}; use datafusion_physical_expr::utils::reassign_expr_columns; use datafusion_physical_expr_adapter::replace_columns_with_literals; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; +use std::fmt; use std::future::Future; use std::mem; use std::pin::Pin; @@ -77,12 +82,26 @@ use parquet::bloom_filter::Sbbf; use parquet::errors::ParquetError; use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader}; -/// Entry point for opening a Parquet file +/// Implements [`FileOpener`] for Parquet +#[derive(Clone)] +pub(super) struct ParquetOpener { + pub(super) morselizer: ParquetMorselizer, +} + +impl FileOpener for ParquetOpener { + fn open(&self, partitioned_file: PartitionedFile) -> Result { + let future = ParquetOpenFuture::new(&self.morselizer, partitioned_file)?; + Ok(Box::pin(future)) + } +} + +/// Stateless Parquet morselizer implementation. /// /// Reading a Parquet file is a multi-stage process, with multiple CPU-intensive /// steps interspersed with I/O steps. The code in this module implements the steps /// as an explicit state machine -- see [`ParquetOpenState`] for details. -pub(super) struct ParquetOpener { +#[derive(Clone)] +pub(super) struct ParquetMorselizer { /// Execution partition index pub(crate) partition_index: usize, /// Projection to apply on top of the table schema (i.e. can reference partition columns). @@ -137,6 +156,23 @@ pub(super) struct ParquetOpener { pub reverse_row_groups: bool, } +impl fmt::Debug for ParquetMorselizer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetMorselizer") + .field("partition_index", &self.partition_index) + .field("preserve_order", &self.preserve_order) + .field("enable_page_index", &self.enable_page_index) + .field("enable_bloom_filter", &self.enable_bloom_filter) + .finish() + } +} + +impl Morselizer for ParquetMorselizer { + fn plan_file(&self, file: PartitionedFile) -> Result> { + Ok(Box::new(ParquetMorselPlanner::try_new(self, file)?)) + } +} + /// States for [`ParquetOpenFuture`] /// /// These states correspond to the steps required to read and apply various @@ -216,6 +252,27 @@ enum ParquetOpenState { Done, } +impl fmt::Debug for ParquetOpenState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = match self { + ParquetOpenState::Start { .. } => "Start", + #[cfg(feature = "parquet_encryption")] + ParquetOpenState::LoadEncryption(_) => "LoadEncryption", + ParquetOpenState::PruneFile(_) => "PruneFile", + ParquetOpenState::LoadMetadata(_) => "LoadMetadata", + ParquetOpenState::PrepareFilters(_) => "PrepareFilters", + ParquetOpenState::LoadPageIndex(_) => "LoadPageIndex", + ParquetOpenState::PruneWithStatistics(_) => "PruneWithStatistics", + ParquetOpenState::LoadBloomFilters(_) => "LoadBloomFilters", + ParquetOpenState::PruneWithBloomFilters(_) => "PruneWithBloomFilters", + ParquetOpenState::BuildStream(_) => "BuildStream", + ParquetOpenState::Ready(_) => "Ready", + ParquetOpenState::Done => "Done", + }; + f.write_str(state) + } +} + struct PreparedParquetOpen { partition_index: usize, partitioned_file: PartitionedFile, @@ -290,37 +347,13 @@ struct BloomFiltersLoadedParquetOpen { row_group_bloom_filters: Vec, } -/// Implements state machine described in [`ParquetOpenState`] -struct ParquetOpenFuture { - state: ParquetOpenState, -} - -impl ParquetOpenFuture { - #[cfg(feature = "parquet_encryption")] - fn new(prepared: PreparedParquetOpen, encryption_context: EncryptionContext) -> Self { - Self { - state: ParquetOpenState::Start { - prepared: Box::new(prepared), - encryption_context: Arc::new(encryption_context), - }, - } - } - - #[cfg(not(feature = "parquet_encryption"))] - fn new(prepared: PreparedParquetOpen) -> Self { - Self { - state: ParquetOpenState::Start { - prepared: Box::new(prepared), - }, - } - } -} - impl ParquetOpenState { /// Applies one CPU-only state transition. /// /// `Load*` states do not transition here and are returned unchanged so the /// driver loop can poll their inner futures separately. + /// + /// Implements state machine described in [`ParquetOpenState`] fn transition(self) -> Result { match self { ParquetOpenState::Start { @@ -392,93 +425,208 @@ impl ParquetOpenState { } } +/// Adapter for a [`MorselPlanner`] to the [`FileOpener`] API +/// +/// Compatibility adapter that drives a morsel planner through the +/// [`FileOpener`] API. +struct ParquetOpenFuture { + planner: Option>, + pending_io: Option, + ready_morsels: VecDeque>, +} + +impl ParquetOpenFuture { + fn new( + morselizer: &ParquetMorselizer, + partitioned_file: PartitionedFile, + ) -> Result { + Ok(Self { + planner: Some(morselizer.plan_file(partitioned_file)?), + pending_io: None, + ready_morsels: VecDeque::new(), + }) + } +} + impl Future for ParquetOpenFuture { type Output = Result>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { - let state = mem::replace(&mut self.state, ParquetOpenState::Done); - let mut state = state.transition()?; - - match state { - #[cfg(feature = "parquet_encryption")] - ParquetOpenState::LoadEncryption(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => ParquetOpenState::PruneFile(result?), - Poll::Pending => { - self.state = ParquetOpenState::LoadEncryption(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::LoadMetadata(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => { - ParquetOpenState::PrepareFilters(Box::new(result?)) - } - Poll::Pending => { - self.state = ParquetOpenState::LoadMetadata(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::LoadPageIndex(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => { - ParquetOpenState::PruneWithStatistics(Box::new(result?)) - } - Poll::Pending => { - self.state = ParquetOpenState::LoadPageIndex(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::LoadBloomFilters(mut future) => { - state = match future.poll_unpin(cx) { - Poll::Ready(result) => { - ParquetOpenState::PruneWithBloomFilters(Box::new(result?)) - } - Poll::Pending => { - self.state = ParquetOpenState::LoadBloomFilters(future); - return Poll::Pending; - } - }; - } - ParquetOpenState::Ready(stream) => { - return Poll::Ready(Ok(stream)); - } - ParquetOpenState::Done => { - return Poll::Ready(Ok(futures::stream::empty().boxed())); + // If planner I/O completed, resume with the returned planner. + if let Some(io_future) = self.pending_io.as_mut() { + let maybe_planner = ready!(io_future.poll_unpin(cx)); + // Clear `pending_io` before handling the result so an error + // cannot leave both continuation paths populated. + self.pending_io = None; + if self.planner.is_some() { + return Poll::Ready(internal_err!( + "ParquetOpenFuture does not support concurrent planners" + )); } + self.planner = Some(maybe_planner?); + } + + // If a stream morsel is ready, return it. + if let Some(morsel) = self.ready_morsels.pop_front() { + return Poll::Ready(Ok(morsel.into_stream())); + } - // For all other states, loop again and try to transition - // immediately. All states are explicitly listed here to ensure any - // new states are handled correctly - ParquetOpenState::Start { .. } => {} - ParquetOpenState::PruneFile(_) => {} - ParquetOpenState::PrepareFilters(_) => {} - ParquetOpenState::PruneWithStatistics(_) => {} - ParquetOpenState::PruneWithBloomFilters(_) => {} - ParquetOpenState::BuildStream(_) => {} + // This shim must always own either a planner, a pending planner + // future, or a ready morsel. Reaching this branch means the + // continuation was lost. + let Some(planner) = self.planner.take() else { + return Poll::Ready(internal_err!( + "ParquetOpenFuture polled after completion" + )); }; - self.state = state; + // Planner completed without producing a stream morsel. + // (e.g. all row groups were pruned) + let Some(mut plan) = planner.plan()? else { + return Poll::Ready(Ok(futures::stream::empty().boxed())); + }; + + let mut child_planners = plan.take_ready_planners(); + if child_planners.len() > 1 { + return Poll::Ready(internal_err!( + "Parquet FileOpener adapter does not support child morsel planners" + )); + } + self.planner = child_planners.pop(); + + self.ready_morsels = plan.take_morsels().into(); + + if let Some(io_future) = plan.take_pending_planner() { + self.pending_io = Some(io_future); + } } } } -impl FileOpener for ParquetOpener { - fn open(&self, partitioned_file: PartitionedFile) -> Result { - let prepared = self.prepare_open_file(partitioned_file)?; +/// Implements the Morsel API +struct ParquetStreamMorsel { + stream: BoxStream<'static, Result>, +} + +impl ParquetStreamMorsel { + fn new(stream: BoxStream<'static, Result>) -> Self { + Self { stream } + } +} + +impl fmt::Debug for ParquetStreamMorsel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetStreamMorsel") + .finish_non_exhaustive() + } +} + +impl Morsel for ParquetStreamMorsel { + fn into_stream(self: Box) -> BoxStream<'static, Result> { + self.stream + } +} + +/// Per-file planner that owns the current [`ParquetOpenState`]. +struct ParquetMorselPlanner { + /// Ready to perform CPU-only planning work. + state: ParquetOpenState, +} + +impl fmt::Debug for ParquetMorselPlanner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ParquetMorselPlanner::Ready") + .field(&self.state) + .finish() + } +} + +impl ParquetMorselPlanner { + fn try_new(morselizer: &ParquetMorselizer, file: PartitionedFile) -> Result { + let prepared = morselizer.prepare_open_file(file)?; #[cfg(feature = "parquet_encryption")] - let future = ParquetOpenFuture::new(prepared, self.get_encryption_context()); + let state = ParquetOpenState::Start { + prepared: Box::new(prepared), + encryption_context: Arc::new(morselizer.get_encryption_context()), + }; #[cfg(not(feature = "parquet_encryption"))] - let future = ParquetOpenFuture::new(prepared); - Ok(Box::pin(future)) + let state = ParquetOpenState::Start { + prepared: Box::new(prepared), + }; + Ok(Self { state }) + } + + /// Schedule an I/O future that resolves to the next planner to run. + /// + /// This helper + /// + /// 1. drives one I/O phase to completion + /// 2. wraps the resulting state in a new [`ParquetMorselPlanner`] + /// 3. returns a [`MorselPlan`] containing the boxed future for the caller + /// to poll + /// + fn schedule_io(future: F) -> MorselPlan + where + F: Future> + Send + 'static, + { + let io_future = async move { + let next_state = future.await?; + Ok(Box::new(ParquetMorselPlanner { state: next_state }) as _) + }; + MorselPlan::new().with_pending_planner(io_future) + } +} + +impl MorselPlanner for ParquetMorselPlanner { + fn plan(self: Box) -> Result> { + if let ParquetOpenState::Done = self.state { + return Ok(None); + } + + let state = self.state.transition()?; + + match state { + #[cfg(feature = "parquet_encryption")] + ParquetOpenState::LoadEncryption(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneFile(future.await?)) + }))) + } + ParquetOpenState::LoadMetadata(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PrepareFilters(Box::new(future.await?))) + }))) + } + ParquetOpenState::LoadPageIndex(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneWithStatistics(Box::new( + future.await?, + ))) + }))) + } + ParquetOpenState::LoadBloomFilters(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneWithBloomFilters(Box::new( + future.await?, + ))) + }))) + } + ParquetOpenState::Ready(stream) => { + let morsels: Vec> = + vec![Box::new(ParquetStreamMorsel::new(stream))]; + Ok(Some(MorselPlan::new().with_morsels(morsels))) + } + ParquetOpenState::Done => Ok(None), + cpu_state => Ok(Some( + MorselPlan::new() + .with_planners(vec![Box::new(Self { state: cpu_state })]), + )), + } } } -impl ParquetOpener { +impl ParquetMorselizer { /// Perform the CPU-only setup for opening a parquet file. fn prepare_open_file( &self, @@ -1447,7 +1595,7 @@ impl EncryptionContext { } } -impl ParquetOpener { +impl ParquetMorselizer { #[cfg(feature = "parquet_encryption")] fn get_encryption_context(&self) -> EncryptionContext { EncryptionContext::new( @@ -1576,7 +1724,7 @@ fn should_enable_page_index( mod test { use std::sync::Arc; - use super::{ConstantColumns, constant_columns_from_stats}; + use super::{ConstantColumns, ParquetMorselizer, constant_columns_from_stats}; use crate::{DefaultParquetFileReaderFactory, RowGroupAccess, opener::ParquetOpener}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use bytes::{BufMut, BytesMut}; @@ -1731,11 +1879,12 @@ mod test { ProjectionExprs::from_indices(&all_indices, &file_schema) }; - ParquetOpener { + let morselizer = ParquetMorselizer { partition_index: self.partition_index, projection, batch_size: self.batch_size, limit: self.limit, + preserve_order: self.preserve_order, predicate: self.predicate, table_schema, metadata_size_hint: self.metadata_size_hint, @@ -1757,8 +1906,8 @@ mod test { encryption_factory: None, max_predicate_cache_size: self.max_predicate_cache_size, reverse_row_groups: self.reverse_row_groups, - preserve_order: self.preserve_order, - } + }; + ParquetOpener { morselizer } } } diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 3a64137a2a3f8..1e54e98dfd04b 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use crate::DefaultParquetFileReaderFactory; use crate::ParquetFileReaderFactory; -use crate::opener::ParquetOpener; use crate::opener::build_pruning_predicates; +use crate::opener::{ParquetMorselizer, ParquetOpener}; use crate::row_filter::can_expr_be_pushed_down_with_schemas; use datafusion_common::config::ConfigOptions; #[cfg(feature = "parquet_encryption")] @@ -543,32 +543,34 @@ impl FileSource for ParquetSource { .map(|time_unit| parse_coerce_int96_string(time_unit.as_str()).unwrap()); let opener = Arc::new(ParquetOpener { - partition_index: partition, - projection: self.projection.clone(), - batch_size: self - .batch_size - .expect("Batch size must set before creating ParquetOpener"), - limit: base_config.limit, - preserve_order: base_config.preserve_order, - predicate: self.predicate.clone(), - table_schema: self.table_schema.clone(), - metadata_size_hint: self.metadata_size_hint, - metrics: self.metrics().clone(), - parquet_file_reader_factory, - pushdown_filters: self.pushdown_filters(), - reorder_filters: self.reorder_filters(), - force_filter_selections: self.force_filter_selections(), - enable_page_index: self.enable_page_index(), - enable_bloom_filter: self.bloom_filter_on_read(), - enable_row_group_stats_pruning: self.table_parquet_options.global.pruning, - coerce_int96, - #[cfg(feature = "parquet_encryption")] - file_decryption_properties, - expr_adapter_factory, - #[cfg(feature = "parquet_encryption")] - encryption_factory: self.get_encryption_factory_with_config(), - max_predicate_cache_size: self.max_predicate_cache_size(), - reverse_row_groups: self.reverse_row_groups, + morselizer: ParquetMorselizer { + partition_index: partition, + projection: self.projection.clone(), + batch_size: self + .batch_size + .expect("Batch size must set before creating ParquetOpener"), + limit: base_config.limit, + preserve_order: base_config.preserve_order, + predicate: self.predicate.clone(), + table_schema: self.table_schema.clone(), + metadata_size_hint: self.metadata_size_hint, + metrics: self.metrics().clone(), + parquet_file_reader_factory, + pushdown_filters: self.pushdown_filters(), + reorder_filters: self.reorder_filters(), + force_filter_selections: self.force_filter_selections(), + enable_page_index: self.enable_page_index(), + enable_bloom_filter: self.bloom_filter_on_read(), + enable_row_group_stats_pruning: self.table_parquet_options.global.pruning, + coerce_int96, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties, + expr_adapter_factory, + #[cfg(feature = "parquet_encryption")] + encryption_factory: self.get_encryption_factory_with_config(), + max_predicate_cache_size: self.max_predicate_cache_size(), + reverse_row_groups: self.reverse_row_groups, + }, }); Ok(opener) } diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index bcc4627050d4a..a9600271c28ce 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -38,6 +38,7 @@ pub mod file_scan_config; pub mod file_sink_config; pub mod file_stream; pub mod memory; +pub mod morsel; pub mod projection; pub mod schema_adapter; pub mod sink; diff --git a/datafusion/datasource/src/morsel/mod.rs b/datafusion/datasource/src/morsel/mod.rs new file mode 100644 index 0000000000000..5f200d7794690 --- /dev/null +++ b/datafusion/datasource/src/morsel/mod.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Structures for Morsel Driven IO. +//! +//! NOTE: As of DataFusion 54.0.0, these are experimental APIs that may change +//! substantially. +//! +//! Morsel Driven IO is a technique for parallelizing the reading of large files +//! by dividing them into smaller "morsels" that are processed independently. +//! +//! It is inspired by the paper [Morsel-Driven Parallelism: A NUMA-Aware Query +//! Evaluation Framework for the Many-Core Age](https://db.in.tum.de/~leis/papers/morsels.pdf). + +use crate::PartitionedFile; +use arrow::array::RecordBatch; +use datafusion_common::Result; +use futures::FutureExt; +use futures::future::BoxFuture; +use futures::stream::BoxStream; +use std::fmt::Debug; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A Morsel of work ready to resolve to a stream of [`RecordBatch`]es. +/// +/// This represents a single morsel of work that is ready to be processed. It +/// has all data necessary (does not need any I/O) and is ready to be turned +/// into a stream of [`RecordBatch`]es for processing by the execution engine. +pub trait Morsel: Send + Debug { + /// Consume this morsel and produce a stream of [`RecordBatch`]es for processing. + /// + /// Note: This may do CPU work to decode already-loaded data, but should not + /// do any I/O work such as reading from the file. + fn into_stream(self: Box) -> BoxStream<'static, Result>; +} + +/// A Morselizer takes a single [`PartitionedFile`] and creates the initial planner +/// for that file. +/// +/// This is the entry point for morsel driven I/O. +pub trait Morselizer: Send + Sync + Debug { + /// Return the initial [`MorselPlanner`] for this file. + /// + /// Morselizing a file may involve CPU work, such as parsing parquet + /// metadata and evaluating pruning predicates. It should NOT do any I/O + /// work, such as reading from the file. Any needed I/O should be done using + /// [`MorselPlan::with_pending_planner`]. + fn plan_file(&self, file: PartitionedFile) -> Result>; +} + +/// A Morsel Planner is responsible for creating morsels for a given scan. +/// +/// The [`MorselPlanner`] is the unit of I/O. There is only ever a single I/O +/// outstanding for a specific planner. DataFusion may run +/// multiple planners in parallel, which corresponds to multiple parallel +/// I/O requests. +/// +/// It is not a Rust `Stream` so that it can explicitly separate CPU bound +/// work from I/O work. +/// +/// The design is similar to `ParquetPushDecoder`: when `plan` is called, it +/// should do CPU work to produce the next morsels or discover the next I/O +/// phase. +/// +/// Best practice is to spawn I/O in a Tokio task on a separate runtime to +/// ensure that CPU work doesn't block or slow down I/O work, but this is not +/// strictly required by the API. +pub trait MorselPlanner: Send + Debug { + /// Attempt to plan morsels. This may involve CPU work, such as parsing + /// parquet metadata and evaluating pruning predicates. + /// + /// It should NOT do any I/O work, such as reading from the file. If I/O is + /// required, the returned [`MorselPlan`] should contain a pending planner + /// future that the caller polls to drive the I/O work to completion. Once + /// that future resolves, it yields a planner ready for work. + /// + /// Note this function is **not async** to make it explicitly clear that if + /// I/O is required, it should be done in the returned `io_future`. + /// + /// Returns `None` if the planner has no more work to do. + /// + /// # Empty Morsel Plans + /// + /// It may return `None`, which means no batches will be read from the file + /// (e.g. due to late-pruning based on statistics). + /// + /// # Output Ordering + /// + /// See the comments on [`MorselPlan`] for the logical output order. + fn plan(self: Box) -> Result>; +} + +/// Return result of [`MorselPlanner::plan`]. +/// +/// # Logical Ordering +/// +/// For plans where the output order of rows is maintained, the output order of +/// a [`MorselPlanner`] is logically defined as follows: +/// 1. All morsels that are directly produced +/// 2. Recursively, all morsels produced by the returned `planners` +#[derive(Default)] +pub struct MorselPlan { + /// Morsels ready for CPU work + morsels: Vec>, + /// Planners that are ready for CPU work. + ready_planners: Vec>, + /// A future with planner I/O that resolves to a CPU ready planner. + /// + /// DataFusion will poll this future occasionally to drive the I/O work to + /// completion. Once it resolves, planning continues with the returned + /// planner. + pending_planner: Option, +} + +impl MorselPlan { + /// Create an empty morsel plan. + pub fn new() -> Self { + Self::default() + } + + /// Set the ready morsels. + pub fn with_morsels(mut self, morsels: Vec>) -> Self { + self.morsels = morsels; + self + } + + /// Set the ready child planners. + pub fn with_planners(mut self, planners: Vec>) -> Self { + self.ready_planners = planners; + self + } + + /// Set the pending planner for an I/O phase. + pub fn with_pending_planner(mut self, io_future: F) -> Self + where + F: Future>> + Send + 'static, + { + self.pending_planner = Some(PendingMorselPlanner::new(io_future)); + self + } + + /// Set the pending planner for an I/O phase. + pub fn set_pending_planner(&mut self, io_future: F) + where + F: Future>> + Send + 'static, + { + self.pending_planner = Some(PendingMorselPlanner::new(io_future)); + } + + /// Take the ready morsels. + pub fn take_morsels(&mut self) -> Vec> { + std::mem::take(&mut self.morsels) + } + + /// Take the ready child planners. + pub fn take_ready_planners(&mut self) -> Vec> { + std::mem::take(&mut self.ready_planners) + } + + /// Take the pending I/O future, if any. + pub fn take_pending_planner(&mut self) -> Option { + self.pending_planner.take() + } + + /// Returns `true` if this plan contains an I/O future. + pub fn has_io_future(&self) -> bool { + self.pending_planner.is_some() + } +} + +/// Wrapper for I/O that must complete before planning can continue. +pub struct PendingMorselPlanner { + future: BoxFuture<'static, Result>>, +} + +impl PendingMorselPlanner { + /// Create a new pending planner future. + /// + /// Example + /// ``` + /// # use datafusion_common::DataFusionError; + /// # use datafusion_datasource::morsel::{MorselPlanner, PendingMorselPlanner}; + /// let work = async move { + /// let planner: Box = { + /// // Do I/O work here, then return the next planner to run. + /// # unimplemented!(); + /// }; + /// Ok(planner) as Result<_, DataFusionError>; + /// }; + /// let pending_io = PendingMorselPlanner::new(work); + /// ``` + pub fn new(future: F) -> Self + where + F: Future>> + Send + 'static, + { + Self { + future: future.boxed(), + } + } + + /// Consume this wrapper and return the underlying future. + pub fn into_future(self) -> BoxFuture<'static, Result>> { + self.future + } +} + +/// Forwards polling to the underlying future. +impl Future for PendingMorselPlanner { + type Output = Result>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // forward request to inner + self.future.as_mut().poll(cx) + } +} diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 7503c337517ef..02b8e842280bf 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -59,7 +59,7 @@ regex_expressions = ["regex"] # enable string functions string_expressions = ["uuid"] # enable unicode functions -unicode_expressions = ["unicode-segmentation"] +unicode_expressions = [] [lib] name = "datafusion_functions" @@ -87,7 +87,6 @@ num-traits = { workspace = true } rand = { workspace = true } regex = { workspace = true, optional = true } sha2 = { workspace = true, optional = true } -unicode-segmentation = { version = "^1.13.2", optional = true } uuid = { workspace = true, features = ["v4"], optional = true } [dev-dependencies] diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index b05296721655e..0b67883c17c87 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -154,23 +154,20 @@ impl ScalarUDFImpl for ArrowCastFunc { fn simplify( &self, - mut args: Vec, + args: Vec, info: &SimplifyContext, ) -> Result { // convert this into a real cast - let target_type = data_type_from_args(self.name(), &args)?; - // remove second (type) argument - args.pop().unwrap(); - let arg = args.pop().unwrap(); - - let source_type = info.get_data_type(&arg)?; + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = data_type_from_type_arg(self.name(), &type_arg)?; + let source_type = info.get_data_type(&source_arg)?; let new_expr = if source_type == target_type { // the argument's data type is already the correct type - arg + source_arg } else { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), + expr: Box::new(source_arg), field: target_type.into_nullable_field_ref(), }) }; @@ -183,10 +180,8 @@ impl ScalarUDFImpl for ArrowCastFunc { } } -/// Returns the requested type from the arguments -pub(crate) fn data_type_from_args(name: &str, args: &[Expr]) -> Result { - let [_, type_arg] = take_function_args(name, args)?; - +/// Returns the requested type from the type argument +pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result { let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( "{name} requires its second argument to be a constant string, got {:?}", diff --git a/datafusion/functions/src/core/arrow_try_cast.rs b/datafusion/functions/src/core/arrow_try_cast.rs index 61a5291c05ed9..d27b29ba5736d 100644 --- a/datafusion/functions/src/core/arrow_try_cast.rs +++ b/datafusion/functions/src/core/arrow_try_cast.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use super::arrow_cast::data_type_from_args; +use super::arrow_cast::data_type_from_type_arg; /// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring. /// @@ -127,20 +127,18 @@ impl ScalarUDFImpl for ArrowTryCastFunc { fn simplify( &self, - mut args: Vec, + args: Vec, info: &SimplifyContext, ) -> Result { - let target_type = data_type_from_args(self.name(), &args)?; - // remove second (type) argument - args.pop().unwrap(); - let arg = args.pop().unwrap(); + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = data_type_from_type_arg(self.name(), &type_arg)?; - let source_type = info.get_data_type(&arg)?; + let source_type = info.get_data_type(&source_arg)?; let new_expr = if source_type == target_type { - arg + source_arg } else { Expr::TryCast(datafusion_expr::TryCast { - expr: Box::new(arg), + expr: Box::new(source_arg), field: target_type.into_nullable_field_ref(), }) }; diff --git a/datafusion/functions/src/core/cast_to_type.rs b/datafusion/functions/src/core/cast_to_type.rs new file mode 100644 index 0000000000000..abc7d440e04ba --- /dev/null +++ b/datafusion/functions/src/core/cast_to_type.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`CastToTypeFunc`]: Implementation of the `cast_to_type` function + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, internal_err, utils::take_function_args}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Casts the first argument to the data type of the second argument. +/// +/// Only the type of the second argument is used; its value is ignored. +/// This is useful in macros or generic SQL where you need to preserve +/// or match types dynamically. +/// +/// For example: +/// ```sql +/// select cast_to_type('42', NULL::INTEGER); +/// ``` +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored.", + syntax_example = "cast_to_type(expression, reference)", + sql_example = r#"```sql +> select cast_to_type('42', NULL::INTEGER) as a; ++----+ +| a | ++----+ +| 42 | ++----+ + +> select cast_to_type(1 + 2, NULL::DOUBLE) as b; ++-----+ +| b | ++-----+ +| 3.0 | ++-----+ +```"#, + argument( + name = "expression", + description = "The expression to cast. It can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CastToTypeFunc { + signature: Signature, +} + +impl Default for CastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl CastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CastToTypeFunc { + fn name(&self) -> &str { + "cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [source_field, reference_field] = + take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + // Nullability is inherited only from the first argument (the value + // being cast). The second argument is used solely for its type, so + // its own nullability is irrelevant. The one exception is when the + // target type is Null – that type is inherently nullable. + let nullable = source_field.is_nullable() || target_type == DataType::Null; + Ok(Field::new(self.name(), target_type, nullable).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("cast_to_type should have been simplified to cast") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = info.get_data_type(&type_arg)?; + let source_type = info.get_data_type(&source_arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + source_arg + } else { + let nullable = info.nullable(&source_arg)? || target_type == DataType::Null; + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(source_arg), + field: Field::new("", target_type, nullable).into(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index e8737612a1dcf..d3c48573667c9 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -24,6 +24,7 @@ pub mod arrow_cast; pub mod arrow_metadata; pub mod arrow_try_cast; pub mod arrowtypeof; +pub mod cast_to_type; pub mod coalesce; pub mod expr_ext; pub mod getfield; @@ -37,6 +38,7 @@ pub mod nvl2; pub mod overlay; pub mod planner; pub mod r#struct; +pub mod try_cast_to_type; pub mod union_extract; pub mod union_tag; pub mod version; @@ -44,6 +46,8 @@ pub mod version; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast); +make_udf_function!(cast_to_type::CastToTypeFunc, cast_to_type); +make_udf_function!(try_cast_to_type::TryCastToTypeFunc, try_cast_to_type); make_udf_function!(nullif::NullIfFunc, nullif); make_udf_function!(nvl::NVLFunc, nvl); make_udf_function!(nvl2::NVL2Func, nvl2); @@ -75,6 +79,14 @@ pub mod expr_fn { arrow_try_cast, "Casts a value to a specific Arrow data type, returning NULL if the cast fails", arg1 arg2 + ),( + cast_to_type, + "Casts the first argument to the data type of the second argument", + arg1 arg2 + ),( + try_cast_to_type, + "Casts the first argument to the data type of the second argument, returning NULL on failure", + arg1 arg2 ),( nvl, "Returns value2 if value1 is NULL; otherwise it returns value1", @@ -147,6 +159,8 @@ pub fn functions() -> Vec> { nullif(), arrow_cast(), arrow_try_cast(), + cast_to_type(), + try_cast_to_type(), arrow_metadata(), nvl(), nvl2(), diff --git a/datafusion/functions/src/core/try_cast_to_type.rs b/datafusion/functions/src/core/try_cast_to_type.rs new file mode 100644 index 0000000000000..4c5af4cc6d228 --- /dev/null +++ b/datafusion/functions/src/core/try_cast_to_type.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`TryCastToTypeFunc`]: Implementation of the `try_cast_to_type` function + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{ + Result, datatype::DataTypeExt, internal_err, utils::take_function_args, +}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Like [`cast_to_type`](super::cast_to_type::CastToTypeFunc) but returns NULL +/// on cast failure instead of erroring. +/// +/// This is implemented by simplifying `try_cast_to_type(expr, ref)` into +/// `Expr::TryCast` during optimization. +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored.", + syntax_example = "try_cast_to_type(expression, reference)", + sql_example = r#"```sql +> select try_cast_to_type('123', NULL::INTEGER) as a, + try_cast_to_type('not_a_number', NULL::INTEGER) as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +```"#, + argument( + name = "expression", + description = "The expression to cast. It can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryCastToTypeFunc { + signature: Signature, +} + +impl Default for TryCastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl TryCastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TryCastToTypeFunc { + fn name(&self) -> &str { + "try_cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // TryCast can always return NULL (on cast failure), so always nullable + let [_, reference_field] = take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + Ok(Field::new(self.name(), target_type, true).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("try_cast_to_type should have been simplified to try_cast") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = info.get_data_type(&type_arg)?; + let source_type = info.get_data_type(&source_arg)?; + let new_expr = if source_type == target_type { + source_arg + } else { + Expr::TryCast(datafusion_expr::TryCast { + expr: Box::new(source_arg), + field: target_type.into_nullable_field_ref(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/unicode/common.rs b/datafusion/functions/src/unicode/common.rs index 002776e6c6538..0158325e98a19 100644 --- a/datafusion/functions/src/unicode/common.rs +++ b/datafusion/functions/src/unicode/common.rs @@ -78,6 +78,39 @@ impl LeftRightSlicer for RightSlicer { } } +/// Returns the byte offset of the `n`th codepoint in `string`, +/// or `string.len()` if the string has fewer than `n` codepoints. +#[inline] +pub(crate) fn byte_offset_of_char(string: &str, n: usize) -> usize { + string + .char_indices() + .nth(n) + .map_or(string.len(), |(i, _)| i) +} + +/// If `string` has more than `n` codepoints, returns the byte offset of +/// the `n`-th codepoint boundary. Otherwise returns the total codepoint count. +#[inline] +pub(crate) fn char_count_or_boundary(string: &str, n: usize) -> StringCharLen { + let mut count = 0; + for (byte_idx, _) in string.char_indices() { + if count == n { + return StringCharLen::ByteOffset(byte_idx); + } + count += 1; + } + StringCharLen::CharCount(count) +} + +/// Result of [`char_count_or_boundary`]. +pub(crate) enum StringCharLen { + /// The string has more than `n` codepoints; contains the byte offset + /// at the `n`-th codepoint boundary. + ByteOffset(usize), + /// The string has `n` or fewer codepoints; contains the exact count. + CharCount(usize), +} + /// Calculate the byte length of the substring of `n` chars from string `string` #[inline] fn left_right_byte_length(string: &str, n: i64) -> usize { @@ -88,11 +121,9 @@ fn left_right_byte_length(string: &str, n: i64) -> usize { .map(|(index, _)| index) .unwrap_or(0), Ordering::Equal => 0, - Ordering::Greater => string - .char_indices() - .nth(n.unsigned_abs().min(usize::MAX as u64) as usize) - .map(|(index, _)| index) - .unwrap_or(string.len()), + Ordering::Greater => { + byte_offset_of_char(string, n.unsigned_abs().min(usize::MAX as u64) as usize) + } } } diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index d7487c385e84f..d27bc8633e730 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -24,7 +24,6 @@ use arrow::array::{ OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; -use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; @@ -178,7 +177,9 @@ impl ScalarUDFImpl for LPadFunc { } } -use super::common::{try_as_scalar_i64, try_as_scalar_str}; +use super::common::{ + StringCharLen, char_count_or_boundary, try_as_scalar_i64, try_as_scalar_str, +}; /// Optimized lpad for constant target_len and fill arguments. fn lpad_scalar_args<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( @@ -270,27 +271,22 @@ fn lpad_scalar_unicode<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( let data_capacity = string_array.len().saturating_mul(target_len * 4); let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); - let mut graphemes_buf = Vec::new(); for maybe_string in string_array.iter() { match maybe_string { - Some(string) => { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars.is_empty() { - builder.append_value(string); - } else { - let pad_chars = target_len - graphemes_buf.len(); - let pad_bytes = char_byte_offsets[pad_chars]; - builder.write_str(&padding_buf[..pad_bytes])?; + Some(string) => match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + if !fill_chars.is_empty() { + let pad_chars = target_len - char_count; + let pad_bytes = char_byte_offsets[pad_chars]; + builder.write_str(&padding_buf[..pad_bytes])?; + } builder.append_value(string); } - } + }, None => builder.append_null(), } } @@ -378,7 +374,6 @@ where { let array = if let Some(fill_array) = fill_array { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); let mut fill_chars_buf = Vec::new(); for ((string, target_len), fill) in string_array @@ -407,8 +402,7 @@ where } if string.is_ascii() && fill.is_ascii() { - // ASCII fast path: byte length == character length, - // so we skip expensive grapheme segmentation. + // ASCII fast path: byte length == character length. let str_len = string.len(); if target_len < str_len { builder.append_value(&string[..target_len]); @@ -428,26 +422,24 @@ where builder.append_value(string); } } else { - // Reuse buffers by clearing and refilling - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - fill_chars_buf.clear(); fill_chars_buf.extend(fill.chars()); - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars_buf.is_empty() { - builder.append_value(string); - } else { - for l in 0..target_len - graphemes_buf.len() { - let c = - *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap(); - builder.write_char(c)?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + if !fill_chars_buf.is_empty() { + for l in 0..target_len - char_count { + let c = *fill_chars_buf + .get(l % fill_chars_buf.len()) + .unwrap(); + builder.write_char(c)?; + } + } + builder.append_value(string); } - builder.append_value(string); } } } else { @@ -458,7 +450,6 @@ where builder.finish() } else { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); for (string, target_len) in string_array.iter().zip(length_array.iter()) { if let (Some(string), Some(target_len)) = (string, target_len) { @@ -491,19 +482,16 @@ where builder.append_value(string); } } else { - // Reuse buffer by clearing and refilling - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else { - for _ in 0..(target_len - graphemes_buf.len()) { - builder.write_str(" ")?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + for _ in 0..(target_len - char_count) { + builder.write_str(" ")?; + } + builder.append_value(string); } - builder.append_value(string); } } } else { diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 44ce4640422d6..b3e14f93526ab 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -24,7 +24,6 @@ use arrow::array::{ OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; -use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; @@ -178,7 +177,9 @@ impl ScalarUDFImpl for RPadFunc { } } -use super::common::{try_as_scalar_i64, try_as_scalar_str}; +use super::common::{ + StringCharLen, char_count_or_boundary, try_as_scalar_i64, try_as_scalar_str, +}; /// Optimized rpad for constant target_len and fill arguments. fn rpad_scalar_args<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( @@ -271,28 +272,23 @@ fn rpad_scalar_unicode<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( let data_capacity = string_array.len().saturating_mul(target_len * 4); let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); - let mut graphemes_buf = Vec::new(); for maybe_string in string_array.iter() { match maybe_string { - Some(string) => { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars.is_empty() { - builder.append_value(string); - } else { - let pad_chars = target_len - graphemes_buf.len(); - let pad_bytes = char_byte_offsets[pad_chars]; + Some(string) => match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { builder.write_str(string)?; - builder.write_str(&padding_buf[..pad_bytes])?; + if !fill_chars.is_empty() { + let pad_chars = target_len - char_count; + let pad_bytes = char_byte_offsets[pad_chars]; + builder.write_str(&padding_buf[..pad_bytes])?; + } builder.append_value(""); } - } + }, None => builder.append_null(), } } @@ -377,7 +373,6 @@ where { let array = if let Some(fill_array) = fill_array { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); let mut fill_chars_buf = Vec::new(); for ((string, target_len), fill) in string_array @@ -406,8 +401,7 @@ where } if string.is_ascii() && fill.is_ascii() { - // ASCII fast path: byte length == character length, - // so we skip expensive grapheme segmentation. + // ASCII fast path: byte length == character length. let str_len = string.len(); if target_len < str_len { builder.append_value(&string[..target_len]); @@ -428,26 +422,25 @@ where builder.append_value(""); } } else { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - fill_chars_buf.clear(); fill_chars_buf.extend(fill.chars()); - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else if fill_chars_buf.is_empty() { - builder.append_value(string); - } else { - builder.write_str(string)?; - for l in 0..target_len - graphemes_buf.len() { - let c = - *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap(); - builder.write_char(c)?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + builder.write_str(string)?; + if !fill_chars_buf.is_empty() { + for l in 0..target_len - char_count { + let c = *fill_chars_buf + .get(l % fill_chars_buf.len()) + .unwrap(); + builder.write_char(c)?; + } + } + builder.append_value(""); } - builder.append_value(""); } } } else { @@ -458,7 +451,6 @@ where builder.finish() } else { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut graphemes_buf = Vec::new(); for (string, target_len) in string_array.iter().zip(length_array.iter()) { if let (Some(string), Some(target_len)) = (string, target_len) { @@ -492,19 +484,17 @@ where builder.append_value(""); } } else { - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if target_len < graphemes_buf.len() { - let end: usize = - graphemes_buf[..target_len].iter().map(|g| g.len()).sum(); - builder.append_value(&string[..end]); - } else { - builder.write_str(string)?; - for _ in 0..(target_len - graphemes_buf.len()) { - builder.write_str(" ")?; + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + builder.write_str(string)?; + for _ in 0..(target_len - char_count) { + builder.write_str(" ")?; + } + builder.append_value(""); } - builder.append_value(""); } } } else { diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 5f95c095a644e..29dc660b86f62 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -21,7 +21,6 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use datafusion_common::HashMap; -use unicode_segmentation::UnicodeSegmentation; use crate::utils::make_scalar_function; use datafusion_common::{Result, exec_err}; @@ -97,11 +96,10 @@ impl ScalarUDFImpl for TranslateFunc { try_as_scalar_str(&args.args[1]), try_as_scalar_str(&args.args[2]), ) { - let to_graphemes: Vec<&str> = to_str.graphemes(true).collect(); + let to_chars: Vec = to_str.chars().collect(); - let mut from_map: HashMap<&str, usize> = HashMap::new(); - for (index, c) in from_str.graphemes(true).enumerate() { - // Ignore characters that already exist in from_map + let mut from_map: HashMap = HashMap::new(); + for (index, c) in from_str.chars().enumerate() { from_map.entry(c).or_insert(index); } @@ -117,7 +115,7 @@ impl ScalarUDFImpl for TranslateFunc { translate_with_map( arr, &from_map, - &to_graphemes, + &to_chars, ascii_table.as_ref(), builder, ) @@ -129,7 +127,7 @@ impl ScalarUDFImpl for TranslateFunc { translate_with_map( arr, &from_map, - &to_graphemes, + &to_chars, ascii_table.as_ref(), builder, ) @@ -141,7 +139,7 @@ impl ScalarUDFImpl for TranslateFunc { translate_with_map( arr, &from_map, - &to_graphemes, + &to_chars, ascii_table.as_ref(), builder, ) @@ -215,48 +213,27 @@ where let from_array_iter = ArrayIter::new(from_array); let to_array_iter = ArrayIter::new(to_array); - // Reusable buffers to avoid allocating for each row - let mut from_map: HashMap<&str, usize> = HashMap::new(); - let mut from_graphemes: Vec<&str> = Vec::new(); - let mut to_graphemes: Vec<&str> = Vec::new(); - let mut string_graphemes: Vec<&str> = Vec::new(); - let mut result_graphemes: Vec<&str> = Vec::new(); + let mut from_map: HashMap = HashMap::new(); + let mut to_chars: Vec = Vec::new(); + let mut result_buf = String::new(); for ((string, from), to) in string_array_iter.zip(from_array_iter).zip(to_array_iter) { match (string, from, to) { (Some(string), Some(from), Some(to)) => { - // Clear and reuse buffers from_map.clear(); - from_graphemes.clear(); - to_graphemes.clear(); - string_graphemes.clear(); - result_graphemes.clear(); - - // Build from_map using reusable buffer - from_graphemes.extend(from.graphemes(true)); - for (index, c) in from_graphemes.iter().enumerate() { - // Ignore characters that already exist in from_map - from_map.entry(*c).or_insert(index); - } + to_chars.clear(); + result_buf.clear(); - // Build to_graphemes - to_graphemes.extend(to.graphemes(true)); - - // Process string and build result - string_graphemes.extend(string.graphemes(true)); - for c in &string_graphemes { - match from_map.get(*c) { - Some(n) => { - if let Some(replacement) = to_graphemes.get(*n) { - result_graphemes.push(*replacement); - } - } - None => result_graphemes.push(*c), - } + for (index, c) in from.chars().enumerate() { + from_map.entry(c).or_insert(index); } - builder.append_value(&result_graphemes.concat()); + to_chars.extend(to.chars()); + + translate_char_by_char(string, &from_map, &to_chars, &mut result_buf); + + builder.append_value(&result_buf); } _ => builder.append_null(), } @@ -265,6 +242,27 @@ where Ok(builder.finish()) } +/// Translate `input` character-by-character using `from_map` and `to_chars`, +/// appending the result to `buf`. +#[inline] +fn translate_char_by_char( + input: &str, + from_map: &HashMap, + to_chars: &[char], + buf: &mut String, +) { + for c in input.chars() { + match from_map.get(&c) { + Some(n) => { + if let Some(&replacement) = to_chars.get(*n) { + buf.push(replacement); + } + } + None => buf.push(c), + } + } +} + /// Sentinel value in the ASCII translate table indicating the character should /// be deleted (the `from` character has no corresponding `to` character). Any /// value > 127 works since valid ASCII is 0–127. @@ -301,11 +299,11 @@ fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> { /// Optimized translate for constant `from` and `to` arguments: uses a pre-built /// translation map instead of rebuilding it for every row. When an ASCII byte /// lookup table is provided, ASCII input rows use the lookup table; non-ASCII -/// inputs fallback to using the map. +/// inputs fall back to the char-based map. fn translate_with_map<'a, V, O>( string_array: V, - from_map: &HashMap<&str, usize>, - to_graphemes: &[&str], + from_map: &HashMap, + to_chars: &[char], ascii_table: Option<&[u8; 128]>, mut builder: O, ) -> Result @@ -313,7 +311,7 @@ where V: ArrayAccessor, O: StringLikeArrayBuilder, { - let mut result_graphemes: Vec<&str> = Vec::new(); + let mut result_buf = String::new(); let mut ascii_buf: Vec = Vec::new(); for string in ArrayIter::new(string_array) { @@ -335,21 +333,9 @@ where std::str::from_utf8_unchecked(&ascii_buf) }); } else { - // Slow path: grapheme-based translation - result_graphemes.clear(); - - for c in s.graphemes(true) { - match from_map.get(c) { - Some(n) => { - if let Some(replacement) = to_graphemes.get(*n) { - result_graphemes.push(*replacement); - } - } - None => result_graphemes.push(c), - } - } - - builder.append_value(&result_graphemes.concat()); + result_buf.clear(); + translate_char_by_char(s, from_map, to_chars, &mut result_buf); + builder.append_value(&result_buf); } } None => builder.append_null(), @@ -445,7 +431,7 @@ mod tests { StringArray ); // Non-ASCII input with ASCII scalar from/to: exercises the - // grapheme fallback within translate_with_map. + // char-based fallback within translate_with_map. test_function!( TranslateFunc::new(), vec![ diff --git a/datafusion/sqllogictest/test_files/cast_to_type.slt b/datafusion/sqllogictest/test_files/cast_to_type.slt new file mode 100644 index 0000000000000..128846c0f5157 --- /dev/null +++ b/datafusion/sqllogictest/test_files/cast_to_type.slt @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +## Tests for cast_to_type function +####### + +# Basic string to integer cast +query I +SELECT cast_to_type('42', 1::INTEGER); +---- +42 + +# String to double cast +query R +SELECT cast_to_type('3.14', 1.0::DOUBLE); +---- +3.14 + +# Integer to string cast +query T +SELECT cast_to_type(42, 'a'::VARCHAR); +---- +42 + +# Integer to double cast +query R +SELECT cast_to_type(42, 0.0::DOUBLE); +---- +42 + +# Same-type is a no-op +query I +SELECT cast_to_type(42, 0::INTEGER); +---- +42 + +# Second argument is a typed NULL double +query R +SELECT cast_to_type('3.14', NULL::DOUBLE); +---- +3.14 + +# Second argument is a typed NULL integer +query I +SELECT cast_to_type(42, NULL::INTEGER); +---- +42 + +# Second argument is a typed NULL string +query T +SELECT cast_to_type('42', NULL::VARCHAR); +---- +42 + +# NULL first argument +query I +SELECT cast_to_type(NULL, 0::INTEGER); +---- +NULL + +# CASE expression as first argument +query I +SELECT cast_to_type(CASE WHEN true THEN '1' ELSE '2' END, NULL::INTEGER); +---- +1 + +# Arithmetic expression as first argument +query R +SELECT cast_to_type(1 + 2, NULL::DOUBLE); +---- +3 + +# Nested cast_to_type +query T +SELECT cast_to_type(cast_to_type('3.14', NULL::DOUBLE), NULL::VARCHAR); +---- +3.14 + +# Subquery as second argument +query I +SELECT cast_to_type('42', (SELECT NULL::INTEGER)); +---- +42 + +# Column reference as second argument +statement ok +CREATE TABLE t1 (int_col INTEGER, text_col VARCHAR, double_col DOUBLE); + +statement ok +INSERT INTO t1 VALUES (1, 'hello', 3.14), (2, 'world', 2.72); + +query I +SELECT cast_to_type('99', int_col) FROM t1 LIMIT 1; +---- +99 + +query T +SELECT cast_to_type(123, text_col) FROM t1 LIMIT 1; +---- +123 + +query R +SELECT cast_to_type('1.5', double_col) FROM t1 LIMIT 1; +---- +1.5 + +# Case statement as second argument +query I +SELECT cast_to_type('42', CASE WHEN random() < 2 THEN 1 ELSE 0 END); +---- +42 + +# Use with column values as first argument +query R +SELECT cast_to_type(int_col, 1.0::DOUBLE) FROM t1; +---- +1 +2 + +# Cast column to match another column's type +query T +SELECT cast_to_type(int_col, text_col) FROM t1; +---- +1 +2 + +# Boolean cast +query B +SELECT cast_to_type(1, NULL::BOOLEAN); +---- +true + +# String to date cast +query D +SELECT cast_to_type('2024-01-15', NULL::DATE); +---- +2024-01-15 + +# Error on invalid cast +statement error Cannot cast string 'not_a_number' to value of Int32 type +SELECT cast_to_type('not_a_number', NULL::INTEGER); + +# Error on invalid target type +statement error Unsupported SQL type INVALID +SELECT cast_to_type('42', NULL::INVALID); + +statement ok +DROP TABLE t1; + +####### +## Nullability tests for cast_to_type +####### + +statement ok +set datafusion.catalog.information_schema = true; + +# Non-nullable input -> non-nullable output +statement ok +CREATE VIEW v_cast_nonnull AS SELECT cast_to_type(42, NULL::INTEGER) as a; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_cast_nonnull'; +---- +a Int32 NO + +statement ok +DROP VIEW v_cast_nonnull; + +# Nullable input -> nullable output +statement ok +CREATE TABLE t_nullable (x INTEGER); + +statement ok +INSERT INTO t_nullable VALUES (1), (NULL); + +statement ok +CREATE VIEW v_cast_null AS SELECT cast_to_type(x, 1.0::DOUBLE) as a FROM t_nullable; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_cast_null'; +---- +a Float64 YES + +# If we cast to the null type itself the result is nullable even if the input is not +statement ok +CREATE VIEW v_cast_to_null AS SELECT cast_to_type(42, null) as a; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_cast_to_null'; +---- +a Null YES + +statement ok +DROP VIEW v_cast_null; + +statement ok +DROP TABLE t_nullable; + +####### +## Tests for try_cast_to_type function (fallible variant returning NULL) +####### + +# Basic string to integer cast +query I +SELECT try_cast_to_type('42', NULL::INTEGER); +---- +42 + +# Invalid cast returns NULL instead of error +query I +SELECT try_cast_to_type('not_a_number', NULL::INTEGER); +---- +NULL + +# String to double cast +query R +SELECT try_cast_to_type('3.14', NULL::DOUBLE); +---- +3.14 + +# Invalid double returns NULL +query R +SELECT try_cast_to_type('abc', NULL::DOUBLE); +---- +NULL + +# Integer to string cast (always succeeds) +query T +SELECT try_cast_to_type(42, NULL::VARCHAR); +---- +42 + +# Same-type is a no-op +query I +SELECT try_cast_to_type(42, 0::INTEGER); +---- +42 + +# NULL first argument +query I +SELECT try_cast_to_type(NULL, 0::INTEGER); +---- +NULL + +# CASE expression as first argument +query I +SELECT try_cast_to_type(CASE WHEN true THEN '1' ELSE '2' END, NULL::INTEGER); +---- +1 + +# Arithmetic expression as first argument +query R +SELECT try_cast_to_type(1 + 2, NULL::DOUBLE); +---- +3 + +# Nested: try_cast_to_type inside cast_to_type +query T +SELECT cast_to_type(try_cast_to_type('3.14', NULL::DOUBLE), NULL::VARCHAR); +---- +3.14 + +# Subquery as second argument +query I +SELECT try_cast_to_type('42', (SELECT NULL::INTEGER)); +---- +42 + +# Column reference as second argument +statement ok +CREATE TABLE t2 (int_col INTEGER, text_col VARCHAR); + +statement ok +INSERT INTO t2 VALUES (1, 'hello'), (2, 'world'); + +query I +SELECT try_cast_to_type('99', int_col) FROM t2 LIMIT 1; +---- +99 + +query I +SELECT try_cast_to_type(text_col, int_col) FROM t2; +---- +NULL +NULL + +# Cast column to match another column's type +query T +SELECT try_cast_to_type(int_col, text_col) FROM t2; +---- +1 +2 + +# Boolean cast +query B +SELECT try_cast_to_type(1, NULL::BOOLEAN); +---- +true + +# String to date - valid +query D +SELECT try_cast_to_type('2024-01-15', NULL::DATE); +---- +2024-01-15 + +# String to date - invalid returns NULL +query D +SELECT try_cast_to_type('not_a_date', NULL::DATE); +---- +NULL + +statement ok +DROP TABLE t2; + +####### +## Nullability tests for try_cast_to_type +####### + +# try_cast_to_type is always nullable (cast can fail) +statement ok +CREATE VIEW v_trycast AS SELECT try_cast_to_type(42, 1::INTEGER) as a; + +query TTT +SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_name = 'v_trycast'; +---- +a Int32 YES + +statement ok +DROP VIEW v_trycast; + +statement ok +set datafusion.catalog.information_schema = false; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt index 99e0b10430186..e27ff1e9c1a00 100644 --- a/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt +++ b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt @@ -128,43 +128,6 @@ from (values ('a'), ('b')) as tbl(col); NULL NULL NULL NULL NULL NULL -# Extract domain from URL using anchored pattern with trailing .* -# This tests that the full URL suffix is replaced, not just the matched prefix -query T -SELECT regexp_replace(url, '^https?://(?:www\.)?([^/]+)/.*$', '\1') FROM (VALUES - ('https://www.example.com/path/to/page?q=1'), - ('http://test.org/foo/bar'), - ('https://example.com/'), - ('not-a-url') -) AS t(url); ----- -example.com -test.org -example.com -not-a-url - -# More than one capture group should disable the short-regex fast path. -# This still uses replacement \1, but captures_len() will be > 2, so the -# implementation must fall back to the normal regexp_replace path. -query T -SELECT regexp_replace(url, '^https?://((www\.)?([^/]+))/.*$', '\1') FROM (VALUES - ('https://www.example.com/path/to/page?q=1'), - ('http://test.org/foo/bar'), - ('not-a-url') -) AS t(url); ----- -www.example.com -test.org -not-a-url - -# If the overall pattern matches but capture group 1 does not participate, -# regexp_replace(..., '\1') should substitute the empty string, not keep -# the original input. -query B -SELECT regexp_replace('bzzz', '^(a)?b.*$', '\1') = ''; ----- -true - # Stripping trailing .*$ must not change match semantics for inputs with # newlines when the original pattern does not use the 's' flag. query B @@ -183,3 +146,111 @@ SELECT regexp_replace( ) = concat('x', chr(10), 'rest'); ---- true + + +# Fixture for testing optimizations in regexp_replace +statement ok +CREATE TABLE regexp_replace_optimized_cases ( + value string, + regexp string, + replacement string, + expected string +); + +# Extract domain from URL using anchored pattern with trailing .* +# This tests that the full URL suffix is replaced, not just the matched prefix. +statement ok +INSERT INTO regexp_replace_optimized_cases VALUES + ('https://www.example.com/path/to/page?q=1', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'example.com'), + ('http://test.org/foo/bar', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'test.org'), + ('https://example.com/', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'example.com'), + ('not-a-url', '^https?://(?:www\.)?([^/]+)/.*$', '\1', 'not-a-url'); + +# More than one capture group should disable the short-regex fast path. +# This still uses replacement \1, but captures_len() will be > 2, so the +# implementation must fall back to the normal regexp_replace path. +statement ok +INSERT INTO regexp_replace_optimized_cases VALUES + ('https://www.example.com/path/to/page?q=1', '^https?://((www\.)?([^/]+))/.*$', '\1', 'www.example.com'), + ('http://test.org/foo/bar', '^https?://((www\.)?([^/]+))/.*$', '\1', 'test.org'), + ('not-a-url', '^https?://((www\.)?([^/]+))/.*$', '\1', 'not-a-url'); + +# If the overall pattern matches but capture group 1 does not participate, +# regexp_replace(..., '\1') should substitute the empty string, not keep +# the original input. +statement ok +INSERT INTO regexp_replace_optimized_cases VALUES + ('bzzz', '^(a)?b.*$', '\1', ''); + + +query TB +SELECT value, regexp_replace(value, regexp, replacement) = expected +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +query TB +SELECT value, regexp_replace( + arrow_cast(value, 'LargeUtf8'), + arrow_cast(regexp, 'LargeUtf8'), + arrow_cast(replacement, 'LargeUtf8') + ) = arrow_cast(expected, 'LargeUtf8') +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +query TB +SELECT value, regexp_replace( + arrow_cast(value, 'Utf8View'), + arrow_cast(regexp, 'Utf8View'), + arrow_cast(replacement, 'Utf8View') + ) = arrow_cast(expected, 'Utf8View') +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +query TB +SELECT value, regexp_replace( + arrow_cast(value, 'Dictionary(Int32, Utf8)'), + arrow_cast(regexp, 'Dictionary(Int32, Utf8)'), + arrow_cast(replacement, 'Dictionary(Int32, Utf8)') + ) = arrow_cast(expected, 'Dictionary(Int32, Utf8)') +FROM regexp_replace_optimized_cases +ORDER BY regexp, value, replacement, expected; +---- +bzzz true +http://test.org/foo/bar true +https://www.example.com/path/to/page?q=1 true +not-a-url true +http://test.org/foo/bar true +https://example.com/ true +https://www.example.com/path/to/page?q=1 true +not-a-url true + +# cleanup +statement ok +DROP TABLE regexp_replace_optimized_cases; diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index d4fe8ee178719..97f2a40c13fea 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -312,6 +312,35 @@ SELECT lpad(NULL, 5, 'xy') ---- NULL +# lpad counts Unicode codepoints, not grapheme clusters. +# chr(769) is U+0301 COMBINING ACUTE ACCENT — 'e' || chr(769) is 2 codepoints +# but renders as a single grapheme cluster. + +# Input with combining character: 'e' + combining accent + 'x' = 3 codepoints. +# Padding to 4 means 1 space prepended. +query BII +SELECT lpad('e' || chr(769) || 'x', 4) = ' ' || 'e' || chr(769) || 'x', + character_length('e' || chr(769) || 'x'), + character_length(lpad('e' || chr(769) || 'x', 4)) +---- +true 3 4 + +# Truncating input with combining character: 'e' + combining accent + 'x' + 'y' +# = 4 codepoints. Truncating to 3 keeps first 3 codepoints: 'e' + combining accent + 'x'. +query BI +SELECT lpad('e' || chr(769) || 'xy', 3) = 'e' || chr(769) || 'x', + character_length(lpad('e' || chr(769) || 'xy', 3)) +---- +true 3 + +# Fill string with combining character: fill is 'e' + combining accent = 2 codepoints. +# Padding 'x' (1 codepoint) to length 5 means 4 fill codepoints = 2 cycles of fill. +query BI +SELECT lpad('x', 5, 'e' || chr(769)) = 'e' || chr(769) || 'e' || chr(769) || 'x', + character_length(lpad('x', 5, 'e' || chr(769))) +---- +true 5 + query T SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ---- @@ -583,6 +612,34 @@ SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') ---- NULL +# rpad counts Unicode codepoints, not grapheme clusters. +# chr(769) is U+0301 COMBINING ACUTE ACCENT. + +# Input with combining character: 'e' + combining accent + 'x' = 3 codepoints. +# Padding to 4 means 1 space appended. +query BII +SELECT rpad('e' || chr(769) || 'x', 4) = 'e' || chr(769) || 'x' || ' ', + character_length('e' || chr(769) || 'x'), + character_length(rpad('e' || chr(769) || 'x', 4)) +---- +true 3 4 + +# Truncating input with combining character: 'e' + combining accent + 'x' + 'y' +# = 4 codepoints. Truncating to 3 keeps first 3 codepoints: 'e' + combining accent + 'x'. +query BI +SELECT rpad('e' || chr(769) || 'xy', 3) = 'e' || chr(769) || 'x', + character_length(rpad('e' || chr(769) || 'xy', 3)) +---- +true 3 + +# Fill string with combining character: fill is 'e' + combining accent = 2 codepoints. +# Padding 'x' (1 codepoint) to length 5 means 4 fill codepoints = 2 cycles of fill. +query BI +SELECT rpad('x', 5, 'e' || chr(769)) = 'x' || 'e' || chr(769) || 'e' || chr(769), + character_length(rpad('x', 5, 'e' || chr(769))) +---- +true 5 + query I SELECT char_length('') ---- @@ -1829,3 +1886,27 @@ query T SELECT arrow_typeof(translate(arrow_cast('12345', 'Utf8View'), '143', 'ax')) ---- Utf8View + +# translate operates on Unicode codepoints, not grapheme clusters. +# chr(769) is U+0301 COMBINING ACUTE ACCENT. + +# Replacing a combining accent (a single codepoint) with another character. +# 'e' || chr(769) is 2 codepoints; translating chr(769) → 'X' replaces just the accent. +query B +SELECT translate('e' || chr(769), chr(769), 'X') = 'eX' +---- +true + +# Replacing the base character but not the combining accent. +query B +SELECT translate('e' || chr(769) || 'y', 'e', 'a') = 'a' || chr(769) || 'y' +---- +true + +# Deleting a combining accent (from longer than to). +# 'e' || chr(769) || 'x' with chr(769) in `from` but no corresponding `to` entry → deleted. +query BI +SELECT translate('e' || chr(769) || 'x', chr(769), '') = 'ex', + character_length(translate('e' || chr(769) || 'x', chr(769), '')) +---- +true 2 diff --git a/docs/source/library-user-guide/upgrading/54.0.0.md b/docs/source/library-user-guide/upgrading/54.0.0.md index 4e6178345bcce..c5d03ebf8878c 100644 --- a/docs/source/library-user-guide/upgrading/54.0.0.md +++ b/docs/source/library-user-guide/upgrading/54.0.0.md @@ -355,4 +355,16 @@ which results in a few changes: `timestamp-*` is UTC timezone-aware, while `local-timestamp-*` is timezone-naive. +### `lpad`, `rpad`, and `translate` now operate on Unicode codepoints instead of grapheme clusters + +Previously, `lpad`, `rpad`, and `translate` used Unicode grapheme cluster +segmentation to measure and manipulate strings. They now use Unicode codepoints, +which is consistent with the SQL standard and most other SQL implementations. It +also matches the behavior of other string-related functions in DataFusion. + +The difference is only observable for strings containing combining characters +(e.g., U+0301 COMBINING ACUTE ACCENT) or other multi-codepoint grapheme +clusters (e.g., ZWJ emoji sequences). For ASCII and most common Unicode text, +behavior is unchanged. + [#17861]: https://github.com/apache/datafusion/pull/17861 diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c303b43fc8844..d1b80f1f90b8b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -5285,7 +5285,9 @@ union_tag(union_expression) - [arrow_metadata](#arrow_metadata) - [arrow_try_cast](#arrow_try_cast) - [arrow_typeof](#arrow_typeof) +- [cast_to_type](#cast_to_type) - [get_field](#get_field) +- [try_cast_to_type](#try_cast_to_type) - [version](#version) ### `arrow_cast` @@ -5405,6 +5407,37 @@ arrow_typeof(expression) +---------------------------+------------------------+ ``` +### `cast_to_type` + +Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored. + +```sql +cast_to_type(expression, reference) +``` + +#### Arguments + +- **expression**: The expression to cast. It can be a constant, column, or function, and any combination of operators. +- **reference**: Reference expression whose data type determines the target cast type. The value is ignored. + +#### Example + +```sql +> select cast_to_type('42', NULL::INTEGER) as a; ++----+ +| a | ++----+ +| 42 | ++----+ + +> select cast_to_type(1 + 2, NULL::DOUBLE) as b; ++-----+ +| b | ++-----+ +| 3.0 | ++-----+ +``` + ### `get_field` Returns a field within a map or a struct with the given key. @@ -5457,6 +5490,32 @@ get_field(expression, field_name[, field_name2, ...]) +--------+ ``` +### `try_cast_to_type` + +Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored. + +```sql +try_cast_to_type(expression, reference) +``` + +#### Arguments + +- **expression**: The expression to cast. It can be a constant, column, or function, and any combination of operators. +- **reference**: Reference expression whose data type determines the target cast type. The value is ignored. + +#### Example + +```sql +> select try_cast_to_type('123', NULL::INTEGER) as a, + try_cast_to_type('not_a_number', NULL::INTEGER) as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +``` + ### `version` Returns the version of DataFusion.