diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..390633e --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(cargo clippy:*)" + ] + } +} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 14bf19f..177fcb8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,8 @@ jobs: steps: - uses: actions/checkout@v3 - uses: crate-ci/typos@v1.13.10 + with: + config: .typos.toml check: name: Check @@ -132,7 +134,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: EmbarkStudios/cargo-deny-action@v1 + - uses: EmbarkStudios/cargo-deny-action@v2 with: command: check license diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 0000000..0c8feb8 --- /dev/null +++ b/.typos.toml @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# typos configuration file +# Place this file in the project root (same level as Cargo.toml) +[files] +extend-exclude = ["Cargo.toml", "**/Cargo.toml"] diff --git a/Cargo.toml b/Cargo.toml index f2ee4fd..7372894 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,22 +25,23 @@ authors = ["Matthew Cramerus "] license = "Apache-2.0" description = "Materialized Views & Query Rewriting in DataFusion" keywords = ["arrow", "arrow-rs", "datafusion"] -rust-version = "1.80" +rust-version = "1.88.0" [dependencies] -arrow = "55" -arrow-schema = "55" -async-trait = "0.1" +aquamarine = "0.6.0" +arrow = "57.0.0" +arrow-schema = "57.0.0" +async-trait = "0.1.89" dashmap = "6" -datafusion = "47" -datafusion-common = "47" -datafusion-expr = "47" -datafusion-functions = "47" -datafusion-functions-aggregate = "47" -datafusion-optimizer = "47" -datafusion-physical-expr = "47" -datafusion-physical-plan = "47" -datafusion-sql = "47" +datafusion = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-common = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-expr = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-functions = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-functions-aggregate = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-optimizer = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-physical-expr = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-physical-plan = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } +datafusion-sql = { git = "https://github.com/massive-com/arrow-datafusion", rev = "3f7b02d" } futures = "0.3" itertools = "0.14" log = "0.4" @@ -49,7 +50,13 @@ ordered-float = "5.0.0" [dev-dependencies] anyhow = "1.0.95" +criterion = "0.4" env_logger = "0.11.6" tempfile = "3.14.0" tokio = "1.42.0" url = "2.5.4" + +[[bench]] +name = "materialized_views_benchmark" +harness = false +path = "benches/materialized_views_benchmark.rs" diff --git a/benches/materialized_views_benchmark.rs b/benches/materialized_views_benchmark.rs new file mode 100644 index 0000000..7bc2f52 --- /dev/null +++ b/benches/materialized_views_benchmark.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use std::sync::Arc; +use std::time::Duration; + +use datafusion::datasource::provider_as_source; +use datafusion::datasource::TableProvider; +use datafusion::prelude::SessionContext; +use datafusion_common::Result as DfResult; +use datafusion_expr::LogicalPlan; +use datafusion_materialized_views::rewrite::normal_form::SpjNormalForm; +use datafusion_sql::TableReference; +use tokio::runtime::Builder; + +// Utility: generate CREATE TABLE SQL with n columns named c0..c{n-1} +fn make_create_table_sql(table_name: &str, ncols: usize) -> String { + let cols = (0..ncols) + .map(|i| format!("c{} INT", i)) + .collect::>() + .join(", "); + format!( + "CREATE TABLE {table} ({cols})", + table = table_name, + cols = cols + ) +} + +// Utility: generate a base SELECT that projects all columns and has a couple filters +fn make_base_sql(table_name: &str, ncols: usize) -> String { + let cols = (0..ncols) + .map(|i| format!("c{}", i)) + .collect::>() + .join(", "); + let mut where_clauses = vec![]; + if ncols > 0 { + where_clauses.push("c0 >= 0".to_string()); + } + if ncols > 1 { + where_clauses.push("c0 + c1 >= 0".to_string()); + } + let where_part = if where_clauses.is_empty() { + "".to_string() + } else { + format!(" WHERE {}", where_clauses.join(" AND ")) + }; + format!("SELECT {cols} FROM {table}{where}", cols = cols, table = table_name, where = where_part) +} + +// Utility: generate a query that is stricter and selects subset (so rewrite_from has a chance) +fn make_query_sql(table_name: &str, ncols: usize) -> String { + let take = std::cmp::max(1, ncols / 2); + let cols = (0..take) + .map(|i| format!("c{}", i)) + .collect::>() + .join(", "); + + let mut where_clauses = vec![]; + if ncols > 0 { + where_clauses.push("c0 >= 10".to_string()); + } + if ncols > 1 { + where_clauses.push("c0 * c1 > 100".to_string()); + } + if ncols > 10 { + where_clauses.push(format!("c{} >= 0", ncols - 1)); + } + + let where_part = if where_clauses.is_empty() { + "".to_string() + } else { + format!(" WHERE {}", where_clauses.join(" AND ")) + }; + + format!("SELECT {cols} FROM {table}{where}", cols = cols, table = table_name, where = where_part) +} + +// Build fixture: create SessionContext, the table, then return LogicalPlans for base & query and table provider +fn build_fixture_for_cols( + rt: &tokio::runtime::Runtime, + ncols: usize, +) -> DfResult<(LogicalPlan, LogicalPlan, Arc)> { + rt.block_on(async move { + let ctx = SessionContext::new(); + + // create table + let table_name = "t"; + let create_sql = make_create_table_sql(table_name, ncols); + ctx.sql(&create_sql).await?.collect().await?; // create table in catalog + + // base and query plans (optimize to normalize) + let base_sql = make_base_sql(table_name, ncols); + let query_sql = make_query_sql(table_name, ncols); + + let base_df = ctx.sql(&base_sql).await?; + let base_plan = base_df.into_optimized_plan()?; + + let query_df = ctx.sql(&query_sql).await?; + let query_plan = query_df.into_optimized_plan()?; + + // get table provider (Arc) + let table_ref = TableReference::bare(table_name); + let provider: Arc = ctx.table_provider(table_ref.clone()).await?; + + Ok((base_plan, query_plan, provider)) + }) +} + +// Criterion benchmark +fn criterion_benchmark(c: &mut Criterion) { + // columns to test + let col_cases = vec![1usize, 10, 20, 40, 80, 160, 320]; + + // build a tokio runtime that's broadly compatible + let rt = Builder::new_current_thread() + .enable_all() + .build() + .expect("tokio runtime"); + + let mut group = c.benchmark_group("view_matcher_spj"); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(5)); + group.sample_size(30); + + for &ncols in &col_cases { + // Build fixture + let (base_plan, query_plan, provider) = + build_fixture_for_cols(&rt, ncols).expect("fixture"); + + // Measure SpjNormalForm::new for base_plan and query_plan separately + let id_base = BenchmarkId::new("spj_normal_form_new", format!("cols={}", ncols)); + group.throughput(Throughput::Elements(1)); + group.bench_with_input(id_base, &base_plan, |b, plan| { + b.iter(|| { + let _nf = SpjNormalForm::new(plan).unwrap(); + }); + }); + + let id_query_nf = BenchmarkId::new("spj_normal_form_new_query", format!("cols={}", ncols)); + group.bench_with_input(id_query_nf, &query_plan, |b, plan| { + b.iter(|| { + let _nf = SpjNormalForm::new(plan).unwrap(); + }); + }); + + // Precompute normal forms once (to measure rewrite_from cost only) + let base_nf = SpjNormalForm::new(&base_plan).expect("base_nf"); + let query_nf = SpjNormalForm::new(&query_plan).expect("query_nf"); + + // qualifier for rewrite_from and a source created from the provider + let qualifier = TableReference::bare("mv"); + let source = provider_as_source(Arc::clone(&provider)); + + // Benchmark rewrite_from (this is the heavy check) + let id_rewrite = BenchmarkId::new("rewrite_from", format!("cols={}", ncols)); + group.bench_with_input(id_rewrite, &ncols, |b, &_n| { + b.iter(|| { + let _res = query_nf.rewrite_from(&base_nf, qualifier.clone(), Arc::clone(&source)); + }); + }); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/deny.toml b/deny.toml index a24420a..b516c5d 100644 --- a/deny.toml +++ b/deny.toml @@ -24,6 +24,8 @@ allow = [ "BSD-3-Clause", "CC0-1.0", "Unicode-3.0", - "Zlib" + "Zlib", + "ISC", + "bzip2-1.0.6" ] version = 2 diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..8cd75dd --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[toolchain] +channel = "1.91.0" +components = ["rust-analyzer", "rustfmt", "clippy"] \ No newline at end of file diff --git a/src/materialized.rs b/src/materialized.rs index e089591..8e093c0 100644 --- a/src/materialized.rs +++ b/src/materialized.rs @@ -41,6 +41,7 @@ use datafusion::{ catalog::TableProvider, datasource::listing::{ListingTable, ListingTableUrl}, }; +use datafusion_common::DataFusionError; use datafusion_expr::LogicalPlan; use itertools::Itertools; @@ -110,6 +111,14 @@ pub trait Materialized: ListingTableLike { fn config(&self) -> MaterializedConfig { MaterializedConfig::default() } + + /// Which partition columns are 'static'. + /// Static partition columns are those that are used in incremental view maintenance. + /// These should be a prefix of the full set of partition columns returned by [`ListingTableLike::partition_columns`]. + /// The rest of the partition columns are 'dynamic' and their values will be generated at runtime during incremental refresh. + fn static_partition_columns(&self) -> Vec { + ::partition_columns(self) + } } /// Register a [`Materialized`] implementation in this registry. @@ -122,13 +131,38 @@ pub fn register_materialized() { } /// Attempt to cast the given TableProvider into a [`Materialized`]. -/// If the table's type has not been registered using [`register_materialized`], will return `None`. -pub fn cast_to_materialized(table: &dyn TableProvider) -> Option<&dyn Materialized> { - TABLE_TYPE_REGISTRY.cast_to_materialized(table).or_else(|| { - TABLE_TYPE_REGISTRY - .cast_to_decorator(table) - .and_then(|decorator| cast_to_materialized(decorator.base())) - }) +/// If the table's type has not been registered using [`register_materialized`], will return `Ok(None)`. +/// +/// Does a runtime check on some invariants of `Materialized` and returns an error if they are violated. +/// In particular, checks that the static partition columns are a prefix of all partition columns. +pub fn cast_to_materialized( + table: &dyn TableProvider, +) -> Result, DataFusionError> { + let materialized = match TABLE_TYPE_REGISTRY + .cast_to_materialized(table) + .map(Ok) + .or_else(|| { + TABLE_TYPE_REGISTRY + .cast_to_decorator(table) + .and_then(|decorator| cast_to_materialized(decorator.base()).transpose()) + }) + .transpose()? + { + None => return Ok(None), + Some(m) => m, + }; + + let static_partition_cols = materialized.static_partition_columns(); + let all_partition_cols = materialized.partition_columns(); + + if materialized.partition_columns()[..static_partition_cols.len()] != static_partition_cols[..] + { + return Err(DataFusionError::Plan(format!( + "Materialized view's static partition columns ({static_partition_cols:?}) must be a prefix of all partition columns ({all_partition_cols:?})" + ))); + } + + Ok(Some(materialized)) } /// A `TableProvider` that decorates other `TableProvider`s. diff --git a/src/materialized/dependencies.rs b/src/materialized/dependencies.rs index 5b42692..c00609a 100644 --- a/src/materialized/dependencies.rs +++ b/src/materialized/dependencies.rs @@ -24,7 +24,7 @@ use datafusion::{ use datafusion_common::{ alias::AliasGenerator, internal_err, - tree_node::{Transformed, TreeNode}, + tree_node::{Transformed, TreeNode, TreeNodeRecursion}, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ @@ -39,6 +39,19 @@ use crate::materialized::META_COLUMN; use super::{cast_to_materialized, row_metadata::RowMetadataRegistry, util, Materialized}; +/// Options for dependency analysis. +#[derive(Debug, Clone, Default)] +pub struct DependencyOptions { + /// If true, require that all branches of a UNION in the MV query have matching + /// field names at each position. + /// + /// SQL UNION combines results positionally, so branches are permitted to have + /// different field names for the same column. Enabling this check ensures that + /// UNION branches are semantically consistent, which can catch unintentional + /// column misalignment when tracking partition-level dependencies. + pub strict_union_schema_names: bool, +} + /// A table function that shows build targets and dependencies for a materialized view: /// /// ```ignore @@ -65,11 +78,13 @@ pub fn mv_dependencies( catalog_list: Arc, row_metadata_registry: Arc, options: &ConfigOptions, + dependency_options: DependencyOptions, ) -> Arc { Arc::new(FileDependenciesUdtf::new( catalog_list, row_metadata_registry, options, + dependency_options, )) } @@ -78,6 +93,7 @@ struct FileDependenciesUdtf { catalog_list: Arc, row_metadata_registry: Arc, config_options: ConfigOptions, + dependency_options: DependencyOptions, } impl FileDependenciesUdtf { @@ -85,11 +101,13 @@ impl FileDependenciesUdtf { catalog_list: Arc, row_metadata_registry: Arc, config_options: &ConfigOptions, + dependency_options: DependencyOptions, ) -> Self { Self { catalog_list, config_options: config_options.clone(), row_metadata_registry, + dependency_options, } } } @@ -106,7 +124,7 @@ impl TableFunctionImpl for FileDependenciesUdtf { let table = util::get_table(self.catalog_list.as_ref(), &table_ref) .map_err(|e| DataFusionError::Plan(e.to_string()))?; - let mv = cast_to_materialized(table.as_ref()).ok_or(DataFusionError::Plan(format!( + let mv = cast_to_materialized(table.as_ref())?.ok_or(DataFusionError::Plan(format!( "mv_dependencies: table '{table_name} is not a materialized view. (Materialized TableProviders must be registered using register_materialized"), ))?; @@ -115,6 +133,7 @@ impl TableFunctionImpl for FileDependenciesUdtf { mv, self.row_metadata_registry.as_ref(), &self.config_options, + &self.dependency_options, )?, None, ))) @@ -135,13 +154,15 @@ pub fn stale_files( row_metadata_registry: Arc, file_metadata: Arc, config_options: &ConfigOptions, + dependency_options: DependencyOptions, ) -> Arc { Arc::new(StaleFilesUdtf { - mv_dependencies: FileDependenciesUdtf { + mv_dependencies: FileDependenciesUdtf::new( catalog_list, row_metadata_registry, - config_options: config_options.clone(), - }, + config_options, + dependency_options, + ), file_metadata, }) } @@ -166,6 +187,15 @@ impl TableFunctionImpl for StaleFilesUdtf { &self.mv_dependencies.config_options.catalog.default_schema, ); + let table = util::get_table(self.mv_dependencies.catalog_list.as_ref(), &table_ref) + .map_err(|e| DataFusionError::Plan(e.to_string()))?; + let mv = cast_to_materialized(table.as_ref())?.ok_or(DataFusionError::Plan(format!( + "mv_dependencies: table '{table_name} is not a materialized view. (Materialized TableProviders must be registered using register_materialized"), + ))?; + + let url = mv.table_paths()[0].to_string(); + let num_static_partition_cols = mv.static_partition_columns().len(); + let logical_plan = LogicalPlanBuilder::scan_with_filters("dependencies", dependencies, None, vec![])? .aggregate( @@ -187,16 +217,21 @@ impl TableFunctionImpl for StaleFilesUdtf { )? .aggregate( vec![ - // We want to omit the file name along with any "special" partitions - // from the path before comparing it to the target partition. Special - // partitions must be leaf most nodes and are designated by a leading - // underscore. These are useful for adding additional information to a - // filename without affecting partitioning or staleness checks. - regexp_replace( - col("file_path"), - lit(r"(/_[^/=]+=[^/]+)*/[^/]*$"), - lit("/"), - None, + // Omit the file name along with any "special" partitions. + // This can include dynamic partition columns as well as some internal + // metadata columns that are not part of the schema + // + // We implement this by only taking the first N columns, + // where N is the number of static partition columns. + array_element( + regexp_match( + col("file_path"), + lit(format!( + "{url}(?:[^/=]+=[^/]+/){{{num_static_partition_cols}}}" + )), + None, + ), + lit(1), ) .alias("existing_target"), ], @@ -231,34 +266,64 @@ impl TableFunctionImpl for StaleFilesUdtf { /// Extract table name from args passed to TableFunctionImpl::call() fn get_table_name(args: &[Expr]) -> Result<&String> { match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(table_name))) => Ok(table_name), + Expr::Literal(ScalarValue::Utf8(Some(table_name)), _) => Ok(table_name), _ => Err(DataFusionError::Plan( "expected a single string literal argument to mv_dependencies".to_string(), )), } } +/// Validates that all branches of every UNION node in the plan have the same +/// field names at each position as the union's output schema. +fn check_union_schema_names(plan: &LogicalPlan) -> Result<()> { + use datafusion_expr::logical_plan::Union; + plan.apply(|node| { + if let LogicalPlan::Union(Union { inputs, schema }) = node { + for (i, field) in schema.fields().iter().enumerate() { + for input in inputs.iter() { + let input_name = input.schema().field(i).name(); + if input_name != field.name() { + return Err(DataFusionError::Plan(format!( + "strict union schema check failed: field at position {i} \ + is '{}' in one branch but '{}' in another", + field.name(), + input_name, + ))); + } + } + } + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(()) +} + /// Returns a logical plan that, when executed, lists expected build targets /// for this materialized view, together with the dependencies for each target. pub fn mv_dependencies_plan( materialized_view: &dyn Materialized, row_metadata_registry: &RowMetadataRegistry, config_options: &ConfigOptions, + dependency_options: &DependencyOptions, ) -> Result { use datafusion_expr::logical_plan::*; + if dependency_options.strict_union_schema_names { + check_union_schema_names(&materialized_view.query())?; + } + let plan = materialized_view.query().clone(); - let partition_cols = materialized_view.partition_columns(); - let partition_col_indices = plan + let static_partition_cols = materialized_view.static_partition_columns(); + let static_partition_col_indices = plan .schema() .fields() .iter() .enumerate() - .filter_map(|(i, f)| partition_cols.contains(f.name()).then_some(i)) + .filter_map(|(i, f)| static_partition_cols.contains(f.name()).then_some(i)) .collect(); - let pruned_plan_with_source_files = if partition_cols.is_empty() { + let pruned_plan_with_source_files = if static_partition_cols.is_empty() { get_source_files_all_partitions( materialized_view, &config_options.catalog, @@ -266,14 +331,14 @@ pub fn mv_dependencies_plan( ) } else { // Prune non-partition columns from all table scans - let pruned_plan = pushdown_projection_inexact(plan, &partition_col_indices)?; + let pruned_plan = pushdown_projection_inexact(plan, &static_partition_col_indices)?; // Now bubble up file metadata to the top of the plan push_up_file_metadata(pruned_plan, &config_options.catalog, row_metadata_registry) }?; // We now have data in the following form: - // (partition_col0, partition_col1, ..., __meta) + // (static_partition_col0, static_partition_col1, ..., __meta) // The last column is a list of structs containing the row metadata // We need to unnest it @@ -289,7 +354,7 @@ pub fn mv_dependencies_plan( LogicalPlanBuilder::from(pruned_plan_with_source_files) .unnest_column(files)? .project(vec![ - construct_target_path_from_partition_columns(materialized_view).alias("target"), + construct_target_path_from_static_partition_columns(materialized_view).alias("target"), get_field(files_col.clone(), "table_catalog").alias("source_table_catalog"), get_field(files_col.clone(), "table_schema").alias("source_table_schema"), get_field(files_col.clone(), "table_name").alias("source_table_name"), @@ -300,14 +365,16 @@ pub fn mv_dependencies_plan( .build() } -fn construct_target_path_from_partition_columns(materialized_view: &dyn Materialized) -> Expr { +fn construct_target_path_from_static_partition_columns( + materialized_view: &dyn Materialized, +) -> Expr { let table_path = lit(materialized_view.table_paths()[0] .as_str() // Trim the / (we'll add it back later if we need it) .trim_end_matches("/")); // Construct the paths for the build targets let mut hive_column_path_elements = materialized_view - .partition_columns() + .static_partition_columns() .iter() .map(|column_name| concat([lit(column_name.as_str()), lit("="), col(column_name)].to_vec())) .collect::>(); @@ -603,6 +670,21 @@ fn pushdown_projection_inexact(plan: LogicalPlan, indices: &HashSet) -> R .map(Expr::Column) .collect_vec(); + // GUARD: if after pushdown the set of relevant unnest columns is empty, + // avoid constructing an Unnest node with zero exec columns (which will + // later error in Unnest::try_new). Instead, simply project the + // desired output columns from the child plan (after pushing down the child projection). + // Related PR: https://github.com/apache/datafusion/pull/16632, after that we must + // also check for empty exec columns here. + if columns_to_unnest.is_empty() { + return LogicalPlanBuilder::from(pushdown_projection_inexact( + Arc::unwrap_or_clone(unnest.input), + &child_indices, + )?) + .project(columns_to_project)? + .build(); + } + LogicalPlanBuilder::from(pushdown_projection_inexact( Arc::unwrap_or_clone(unnest.input), &child_indices, @@ -922,7 +1004,7 @@ mod test { use std::{any::Any, collections::HashSet, sync::Arc}; use arrow::util::pretty::pretty_format_batches; - use arrow_schema::SchemaRef; + use arrow_schema::{DataType, Field, FieldRef, Fields, SchemaRef}; use datafusion::{ assert_batches_eq, assert_batches_sorted_eq, catalog::{Session, TableProvider}, @@ -930,8 +1012,9 @@ mod test { execution::session_state::SessionStateBuilder, prelude::{DataFrame, SessionConfig, SessionContext}, }; - use datafusion_common::{Column, Result, ScalarValue}; - use datafusion_expr::{Expr, JoinType, LogicalPlan, TableType}; + use datafusion_common::{Column, DFSchema, Result, ScalarValue}; + use datafusion_expr::builder::unnest; + use datafusion_expr::{EmptyRelation, Expr, JoinType, LogicalPlan, TableType}; use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; @@ -942,13 +1025,14 @@ mod test { Decorator, ListingTableLike, Materialized, }; - use super::{mv_dependencies, stale_files}; + use super::{mv_dependencies, mv_dependencies_plan, stale_files, DependencyOptions}; /// A mock materialized view. #[derive(Debug)] struct MockMaterializedView { table_path: ListingTableUrl, partition_columns: Vec, + static_partition_columns: Option>, // default = all partition columns query: LogicalPlan, file_ext: &'static str, } @@ -996,6 +1080,12 @@ mod test { fn query(&self) -> LogicalPlan { self.query.clone() } + + fn static_partition_columns(&self) -> Vec { + self.static_partition_columns + .clone() + .unwrap_or_else(|| self.partition_columns.clone()) + } } #[derive(Debug)] @@ -1147,6 +1237,7 @@ mod test { Arc::clone(ctx.state().catalog_list()), row_metadata_registry.clone(), ctx.copied_config().options(), + DependencyOptions::default(), ), ); @@ -1157,6 +1248,7 @@ mod test { Arc::clone(&row_metadata_registry), metadata_table, ctx.copied_config().options(), + DependencyOptions::default(), ), ); @@ -1165,12 +1257,14 @@ mod test { #[tokio::test] async fn test_deps() { + #[derive(Debug, Default)] struct TestCase { name: &'static str, query_to_analyze: &'static str, table_name: &'static str, - table_path: ListingTableUrl, + table_path: &'static str, partition_cols: Vec<&'static str>, + static_partition_cols: Option>, file_extension: &'static str, expected_output: Vec<&'static str>, file_metadata: &'static str, @@ -1182,7 +1276,7 @@ mod test { query_to_analyze: "SELECT column1 AS partition_column, concat(column2, column3) AS some_value FROM t1", table_name: "m1", - table_path: ListingTableUrl::parse("s3://m1/").unwrap(), + table_path: "s3://m1/", partition_cols: vec!["partition_column"], file_extension: ".parquet", expected_output: vec![ @@ -1209,12 +1303,13 @@ mod test { "| s3://m1/partition_column=2023/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, - TestCase { name: "omit 'special' partition columns", + TestCase { name: "omit internal metadata partition columns", query_to_analyze: "SELECT column1 AS partition_column, concat(column2, column3) AS some_value FROM t1", table_name: "m1", - table_path: ListingTableUrl::parse("s3://m1/").unwrap(), + table_path: "s3://m1/", partition_cols: vec!["partition_column"], file_extension: ".parquet", expected_output: vec![ @@ -1241,6 +1336,7 @@ mod test { "| s3://m1/partition_column=2023/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "transform year/month/day partition into timestamp partition", @@ -1250,7 +1346,7 @@ mod test { feed FROM t2", table_name: "m2", - table_path: ListingTableUrl::parse("s3://m2/").unwrap(), + table_path: "s3://m2/", partition_cols: vec!["timestamp", "feed"], file_extension: ".parquet", expected_output: vec![ @@ -1285,12 +1381,63 @@ mod test { "| s3://m2/timestamp=2024-12-06T00:00:00/feed=Z/ | 2023-07-10T16:00:00 | 2023-07-11T16:45:44 | true |", "+-----------------------------------------------+----------------------+-----------------------+----------+", ], + ..Default::default() + }, + TestCase { + name: "omit dynamic partition columns", + query_to_analyze: " + SELECT + year, + month, + day, + column2, + COUNT(*) AS ct + FROM t2 + GROUP BY year, month, day, column2 + ", + table_name: "m_dynamic", + table_path: "s3://m_dynamic/", + partition_cols: vec!["year", "month", "day", "column2"], + static_partition_cols: Some(vec!["year", "month", "day"]), + file_extension: ".parquet", + expected_output: vec![ + "+-------------------------------------------+----------------------+---------------------+-------------------+----------------------------------------------------------+----------------------+", + "| target | source_table_catalog | source_table_schema | source_table_name | source_uri | source_last_modified |", + "+-------------------------------------------+----------------------+---------------------+-------------------+----------------------------------------------------------+----------------------+", + "| s3://m_dynamic/year=2023/month=01/day=01/ | datafusion | test | t2 | s3://t2/year=2023/month=01/day=01/feed=A/data.01.parquet | 2023-07-11T16:29:26 |", + "| s3://m_dynamic/year=2023/month=01/day=02/ | datafusion | test | t2 | s3://t2/year=2023/month=01/day=02/feed=B/data.01.parquet | 2023-07-11T16:45:22 |", + "| s3://m_dynamic/year=2023/month=01/day=03/ | datafusion | test | t2 | s3://t2/year=2023/month=01/day=03/feed=C/data.01.parquet | 2023-07-11T16:45:44 |", + "| s3://m_dynamic/year=2024/month=12/day=04/ | datafusion | test | t2 | s3://t2/year=2024/month=12/day=04/feed=X/data.01.parquet | 2023-07-11T16:29:26 |", + "| s3://m_dynamic/year=2024/month=12/day=05/ | datafusion | test | t2 | s3://t2/year=2024/month=12/day=05/feed=Y/data.01.parquet | 2023-07-11T16:45:22 |", + "| s3://m_dynamic/year=2024/month=12/day=06/ | datafusion | test | t2 | s3://t2/year=2024/month=12/day=06/feed=Z/data.01.parquet | 2023-07-11T16:45:44 |", + "+-------------------------------------------+----------------------+---------------------+-------------------+----------------------------------------------------------+----------------------+", + ], + file_metadata: " + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2023/month=01/day=01/column2=1/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2023/month=01/day=02/column2=2/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2023/month=01/day=03/column2=3/data.01.parquet', '2023-07-10T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2024/month=12/day=04/column2=4/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2024/month=12/day=05/column2=5/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2024/month=12/day=06/column2=6/data.01.parquet', '2023-07-10T16:00:00Z', 0) + ", + expected_stale_files_output: vec![ + "+-------------------------------------------+----------------------+-----------------------+----------+", + "| target | target_last_modified | sources_last_modified | is_stale |", + "+-------------------------------------------+----------------------+-----------------------+----------+", + "| s3://m_dynamic/year=2023/month=01/day=01/ | 2023-07-12T16:00:00 | 2023-07-11T16:29:26 | false |", + "| s3://m_dynamic/year=2023/month=01/day=02/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:22 | false |", + "| s3://m_dynamic/year=2023/month=01/day=03/ | 2023-07-10T16:00:00 | 2023-07-11T16:45:44 | true |", + "| s3://m_dynamic/year=2024/month=12/day=04/ | 2023-07-12T16:00:00 | 2023-07-11T16:29:26 | false |", + "| s3://m_dynamic/year=2024/month=12/day=05/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:22 | false |", + "| s3://m_dynamic/year=2024/month=12/day=06/ | 2023-07-10T16:00:00 | 2023-07-11T16:45:44 | true |", + "+-------------------------------------------+----------------------+-----------------------+----------+", + ], }, TestCase { name: "materialized view has no partitions", query_to_analyze: "SELECT column1 AS output FROM t3", table_name: "m3", - table_path: ListingTableUrl::parse("s3://m3/").unwrap(), + table_path: "s3://m3/", partition_cols: vec![], file_extension: ".parquet", expected_output: vec![ @@ -1311,12 +1458,13 @@ mod test { "| s3://m3/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+----------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "simple equijoin on year", query_to_analyze: "SELECT * FROM t2 INNER JOIN t3 USING (year)", table_name: "m4", - table_path: ListingTableUrl::parse("s3://m4/").unwrap(), + table_path: "s3://m4/", partition_cols: vec!["year"], file_extension: ".parquet", expected_output: vec![ @@ -1345,6 +1493,7 @@ mod test { "| s3://m4/year=2024/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "triangular join on year", @@ -1357,7 +1506,7 @@ mod test { INNER JOIN t3 ON (t2.year <= t3.year)", table_name: "m4", - table_path: ListingTableUrl::parse("s3://m4/").unwrap(), + table_path: "s3://m4/", partition_cols: vec!["year"], file_extension: ".parquet", expected_output: vec![ @@ -1387,6 +1536,7 @@ mod test { "| s3://m4/year=2024/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "triangular left join, strict <", @@ -1399,7 +1549,7 @@ mod test { LEFT JOIN t3 ON (t2.year < t3.year)", table_name: "m4", - table_path: ListingTableUrl::parse("s3://m4/").unwrap(), + table_path: "s3://m4/", partition_cols: vec!["year"], file_extension: ".parquet", expected_output: vec![ @@ -1427,6 +1577,7 @@ mod test { "| s3://m4/year=2024/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, ]; @@ -1447,7 +1598,7 @@ mod test { .enumerate() .filter_map(|(i, c)| case.partition_cols.contains(&c.name.as_str()).then_some(i)) .collect(); - println!("indices: {:?}", partition_col_indices); + println!("indices: {partition_col_indices:?}"); let analyzed = pushdown_projection_inexact(plan.clone(), &partition_col_indices)?; println!( "inexact projection pushdown:\n{}", @@ -1460,12 +1611,16 @@ mod test { // Register table with a decorator to exercise this functionality Arc::new(DecoratorTable { inner: Arc::new(MockMaterializedView { - table_path: case.table_path.clone(), + table_path: ListingTableUrl::parse(case.table_path).unwrap(), partition_columns: case .partition_cols .iter() .map(|s| s.to_string()) .collect(), + static_partition_columns: case + .static_partition_cols + .as_ref() + .map(|list| list.iter().map(|s| s.to_string()).collect()), query: plan, file_ext: case.file_extension, }), @@ -1482,6 +1637,15 @@ mod test { .collect() .await?; + context + .sql(&format!( + "SELECT * FROM file_metadata WHERE table_name = '{}'", + case.table_name + )) + .await? + .show() + .await?; + let df = context .sql(&format!( "SELECT * FROM mv_dependencies('{}', 'v2')", @@ -1720,19 +1884,19 @@ mod test { ", projection: &["year"], expected_plan: vec![ - "+--------------+--------------------------------------------------+", - "| plan_type | plan |", - "+--------------+--------------------------------------------------+", - "| logical_plan | Union |", - "| | Projection: coalesce(t1.year, t2.year) AS year |", - "| | Full Join: Using t1.year = t2.year |", - "| | SubqueryAlias: t1 |", - "| | Projection: t1.column1 AS year |", - "| | TableScan: t1 projection=[column1] |", - "| | SubqueryAlias: t2 |", - "| | TableScan: t2 projection=[year] |", - "| | TableScan: t3 projection=[year] |", - "+--------------+--------------------------------------------------+", + "+--------------+--------------------------------------------------------------------+", + "| plan_type | plan |", + "+--------------+--------------------------------------------------------------------+", + "| logical_plan | Union |", + "| | Projection: coalesce(CAST(t1.year AS Utf8View), t2.year) AS year |", + "| | Full Join: Using CAST(t1.year AS Utf8View) = t2.year |", + "| | SubqueryAlias: t1 |", + "| | Projection: t1.column1 AS year |", + "| | TableScan: t1 projection=[column1] |", + "| | SubqueryAlias: t2 |", + "| | TableScan: t2 projection=[year] |", + "| | TableScan: t3 projection=[year] |", + "+--------------+--------------------------------------------------------------------+", ], expected_output: vec![ "+------+", @@ -1837,4 +2001,234 @@ mod test { Ok(()) } + + /// Build a `LogicalPlan::Union` directly from two `EmptyRelation` inputs. + /// DataFusion's SQL planner normalizes branch column names with alias + /// projections, so constructing the node manually is the only way to + /// produce a Union whose inputs have genuinely different field names. + fn make_union(left_fields: Vec<&str>, right_fields: Vec<&str>) -> Result { + use arrow_schema::Schema; + use datafusion_expr::logical_plan::Union; + + let make_schema = |names: &[&str]| -> Result> { + let fields: Vec<_> = names + .iter() + .map(|&n| Field::new(n, arrow_schema::DataType::Utf8, true)) + .collect(); + DFSchema::try_from(Schema::new(fields)).map(Arc::new) + }; + + let left_schema = make_schema(&left_fields)?; + let right_schema = make_schema(&right_fields)?; + + let left = LogicalPlan::EmptyRelation(datafusion_expr::logical_plan::EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&left_schema), + }); + let right = LogicalPlan::EmptyRelation(datafusion_expr::logical_plan::EmptyRelation { + produce_one_row: false, + schema: right_schema, + }); + + Ok(LogicalPlan::Union(Union { + inputs: vec![Arc::new(left), Arc::new(right)], + schema: left_schema, // output schema = first branch (DataFusion convention) + })) + } + + #[test] + fn test_strict_union_schema_names_matching() -> Result<()> { + use super::check_union_schema_names; + + // Both branches have the same field name — check must pass. + let plan = make_union(vec!["year"], vec!["year"])?; + check_union_schema_names(&plan) + } + + #[test] + fn test_strict_union_schema_names_mismatched() -> Result<()> { + use super::check_union_schema_names; + + // Second branch has `column1` while the output schema says `year`. + let plan = make_union(vec!["year"], vec!["column1"])?; + let err = check_union_schema_names(&plan).unwrap_err(); + assert!( + err.to_string().contains("strict union schema check failed"), + "unexpected error message: {err}" + ); + // Error should name the offending fields. + assert!(err.to_string().contains("year"), "{err}"); + assert!(err.to_string().contains("column1"), "{err}"); + Ok(()) + } + + #[test] + fn test_strict_union_schema_names_multi_column_partial_mismatch() -> Result<()> { + use super::check_union_schema_names; + + // First column matches, second does not. + let plan = make_union(vec!["year", "month"], vec!["year", "day"])?; + let err = check_union_schema_names(&plan).unwrap_err(); + assert!( + err.to_string().contains("strict union schema check failed"), + "unexpected error message: {err}" + ); + assert!(err.to_string().contains("month"), "{err}"); + assert!(err.to_string().contains("day"), "{err}"); + Ok(()) + } + + #[test] + fn test_strict_union_schema_names_nested_inner_mismatch() -> Result<()> { + use super::check_union_schema_names; + use datafusion_expr::logical_plan::Union; + + // Outer UNION has matching names (both say `year`), + // but wraps an inner UNION whose second branch says `column1`. + // The strict check must walk the full plan tree and catch the inner mismatch. + let inner = make_union(vec!["year"], vec!["column1"])?; + let outer_right = make_union(vec!["year"], vec!["year"])?; + + let outer_schema = inner.schema().clone(); + let nested = LogicalPlan::Union(Union { + inputs: vec![Arc::new(inner), Arc::new(outer_right)], + schema: outer_schema, + }); + + let err = check_union_schema_names(&nested).unwrap_err(); + assert!( + err.to_string().contains("strict union schema check failed"), + "unexpected error message: {err}" + ); + Ok(()) + } + + #[tokio::test] + async fn test_strict_union_via_mv_dependencies_plan() -> Result<()> { + // Verify DependencyOptions.strict_union_schema_names is wired through + // mv_dependencies_plan: non-strict passes, strict fails. + let ctx = setup().await?; + + let mismatched_query = make_union(vec!["year"], vec!["column1"])?; + let mv = MockMaterializedView { + table_path: ListingTableUrl::parse("s3://mv/").unwrap(), + partition_columns: vec!["year".into()], + static_partition_columns: None, + query: mismatched_query, + file_ext: "parquet", + }; + let row_metadata_registry = { + let metadata_table = ctx.table_provider("file_metadata").await?; + Arc::new(RowMetadataRegistry::new_with_default_source(Arc::new( + ObjectStoreRowMetadataSource::with_file_metadata(metadata_table), + ))) + }; + + // Non-strict: must not error on schema name mismatch. + mv_dependencies_plan( + &mv, + &row_metadata_registry, + ctx.copied_config().options(), + &DependencyOptions::default(), + )?; + + // Strict: must reject the mismatched UNION. + let err = mv_dependencies_plan( + &mv, + &row_metadata_registry, + ctx.copied_config().options(), + &DependencyOptions { + strict_union_schema_names: true, + }, + ) + .unwrap_err(); + assert!( + err.to_string().contains("strict union schema check failed"), + "unexpected error message: {err}" + ); + + Ok(()) + } + + #[test] + fn test_pushdown_unnest_guard_partition_date_only() -> Result<()> { + // This test simulates a simplified MV scenario: + // + // WITH events_structs AS ( + // SELECT id, date, unnest(events) AS evs + // FROM base_table + // ), + // flattened_events AS ( + // SELECT id, date, evs.event_type, evs.event_time + // FROM events_structs + // ), + // SELECT id, date, MAX(...) ... + // GROUP BY id, date + // + // The partition column is "date". During dependency plan + // building we only request "date" from this subtree, + // so pushdown_projection_inexact receives indices for + // the `date` column only. The guard must kick in: + // unnest(events) becomes unused, and the plan should + // collapse to just projecting `date` from the child. + + // 1. Build schema for base table + let id = Field::new("id", DataType::Utf8, true); + let date = Field::new("date", DataType::Utf8, true); + + // events: list> + let event_type = Field::new("event_type", DataType::Utf8, true); + let event_time = Field::new("event_time", DataType::Utf8, true); + let events_struct = Field::new( + "item", + DataType::Struct(Fields::from(vec![event_type, event_time])), + true, + ); + let events = Field::new( + "events", + DataType::List(FieldRef::from(Box::new(events_struct))), + true, + ); + + // Build DFSchema: (id, date, events) + let qualified_fields = vec![ + (None, Arc::new(id.clone())), + (None, Arc::new(date.clone())), + (None, Arc::new(events.clone())), + ]; + let df_schema = + DFSchema::new_with_metadata(qualified_fields, std::collections::HashMap::new())?; + + // 2. Build a dummy child plan (EmptyRelation with the schema) + let empty = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(df_schema), + }); + + // 3. Wrap it with an Unnest node on the "events" column + let events_col = Column::from_name("events"); + let unnest_plan = unnest(empty.clone(), vec![events_col.clone()])?; + + // 4. Partition column is "date". Look up its actual index dynamically. + let date_idx = unnest_plan + .schema() + .index_of_column(&Column::from_name("date"))?; + let mut indices: HashSet = HashSet::new(); + indices.insert(date_idx); + + // 5. Call pushdown_projection_inexact with {date} + let res = pushdown_projection_inexact(unnest_plan, &indices)?; + + // 6. Assert the result schema only contains `date` + let cols: Vec = res + .schema() + .fields() + .iter() + .map(|f| f.name().to_string()) + .collect(); + + assert_eq!(cols, vec!["date"]); + + Ok(()) + } } diff --git a/src/materialized/file_metadata.rs b/src/materialized/file_metadata.rs index d5a5c8e..85e5838 100644 --- a/src/materialized/file_metadata.rs +++ b/src/materialized/file_metadata.rs @@ -17,8 +17,9 @@ use arrow::array::{StringBuilder, TimestampNanosecondBuilder, UInt64Builder}; use arrow::record_batch::RecordBatch; -use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow_schema::{DataType, Field, TimeUnit}; use async_trait::async_trait; +use datafusion::arrow::datatypes::{Schema, SchemaRef}; use datafusion::catalog::SchemaProvider; use datafusion::catalog::{CatalogProvider, Session}; use datafusion::datasource::listing::ListingTableUrl; @@ -35,7 +36,7 @@ use datafusion::physical_plan::{ use datafusion::{ catalog::CatalogProviderList, execution::TaskContext, physical_plan::SendableRecordBatchStream, }; -use datafusion_common::{DataFusionError, Result, ScalarValue, ToDFSchema}; +use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{Expr, Operator, TableProviderFilterPushDown, TableType}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use futures::stream::{self, BoxStream}; @@ -103,7 +104,7 @@ impl TableProvider for FileMetadata { filters: &[Expr], limit: Option, ) -> Result> { - let dfschema = self.table_schema.clone().to_dfschema()?; + let dfschema = DFSchema::try_from(self.table_schema.as_ref().clone())?; let filters = filters .iter() @@ -226,7 +227,7 @@ impl ExecutionPlan for FileMetadataExec { .map(|record_batch| { record_batch .project(&projection) - .map_err(|e| DataFusionError::ArrowError(e, None)) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) }) .collect::>(); } @@ -858,7 +859,7 @@ mod test { .await?; ctx.sql( - "INSERT INTO t1 VALUES + "INSERT INTO t1 VALUES (1, '2021'), (2, '2022'), (3, '2023'), @@ -882,7 +883,7 @@ mod test { .await?; ctx.sql( - "INSERT INTO private.t1 VALUES + "INSERT INTO private.t1 VALUES (1, '2021', '01'), (2, '2022', '02'), (3, '2023', '03'), @@ -906,7 +907,7 @@ mod test { .await?; ctx.sql( - "INSERT INTO datafusion_mv.public.t3 VALUES + "INSERT INTO datafusion_mv.public.t3 VALUES (1, '2021-01-01'), (2, '2022-02-02'), (3, '2023-03-03'), @@ -929,8 +930,8 @@ mod test { ctx.sql( // Remove timestamps and trim (randomly generated) file names since they're not stable in tests "CREATE VIEW file_metadata_test_view AS SELECT - * EXCLUDE(file_path, last_modified), - regexp_replace(file_path, '/[^/]*$', '/') AS file_path + * EXCLUDE(file_path, last_modified), + regexp_replace(file_path, '/[^/]*$', '/') AS file_path FROM file_metadata", ) .await diff --git a/src/materialized/hive_partition.rs b/src/materialized/hive_partition.rs index 43ebfde..ad381cb 100644 --- a/src/materialized/hive_partition.rs +++ b/src/materialized/hive_partition.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use arrow::array::{Array, StringArray, StringBuilder}; -use arrow_schema::DataType; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ @@ -79,7 +79,7 @@ pub fn hive_partition_udf() -> ScalarUDF { ScalarUDF::new_from_impl(udf_impl) } -#[derive(Debug)] +#[derive(Debug, Hash, PartialEq, Eq)] struct HivePartitionUdf { pub signature: Signature, } diff --git a/src/materialized/row_metadata.rs b/src/materialized/row_metadata.rs index 6476f39..fa12cdf 100644 --- a/src/materialized/row_metadata.rs +++ b/src/materialized/row_metadata.rs @@ -98,7 +98,7 @@ impl RowMetadataRegistry { .get(&table.to_string()) .map(|o| Arc::clone(o.value())) .or_else(|| self.default_source.clone()) - .ok_or_else(|| DataFusionError::Internal(format!("No metadata source for {}", table))) + .ok_or_else(|| DataFusionError::Internal(format!("No metadata source for {table}"))) } } diff --git a/src/rewrite/exploitation.rs b/src/rewrite/exploitation.rs index b6fb761..8bc2cd8 100644 --- a/src/rewrite/exploitation.rs +++ b/src/rewrite/exploitation.rs @@ -23,7 +23,7 @@ use datafusion::catalog::TableProvider; use datafusion::datasource::provider_as_source; use datafusion::execution::context::SessionState; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; -use datafusion::physical_expr::{LexRequirement, PhysicalSortExpr, PhysicalSortRequirement}; +use datafusion::physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion::physical_expr_common::sort_expr::format_physical_sort_requirement_list; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; @@ -32,6 +32,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, Tre use datafusion_common::{DataFusionError, Result, TableReference}; use datafusion_expr::{Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore}; use datafusion_optimizer::OptimizerRule; +use datafusion_physical_expr::OrderingRequirements; use itertools::Itertools; use ordered_float::OrderedFloat; @@ -41,7 +42,9 @@ use super::normal_form::SpjNormalForm; use super::QueryRewriteOptions; /// A cost function. Used to evaluate the best physical plan among multiple equivalent choices. -pub type CostFn = Arc f64 + Send + Sync>; +pub type CostFn = Arc< + dyn for<'a> Fn(Box + 'a>) -> Vec + Send + Sync, +>; /// A logical optimizer that generates candidate logical plans in the form of [`OneOf`] nodes. #[derive(Debug)] @@ -56,7 +59,7 @@ impl ViewMatcher { for (resolved_table_ref, table) in super::util::list_tables(session_state.catalog_list().as_ref()).await? { - let Some(mv) = cast_to_materialized(table.as_ref()) else { + let Some(mv) = cast_to_materialized(table.as_ref())? else { continue; }; @@ -83,6 +86,11 @@ impl ViewMatcher { Ok(ViewMatcher { mv_plans }) } + + /// Returns the materialized views and their corresponding normal forms. + pub fn mv_plans(&self) -> &HashMap, SpjNormalForm)> { + &self.mv_plans + } } impl OptimizerRule for ViewMatcher { @@ -272,6 +280,13 @@ pub struct OneOf { branches: Vec, } +impl OneOf { + /// Create a new OneOf node with the given branches. + pub fn new(branches: Vec) -> Self { + Self { branches } + } +} + impl UserDefinedLogicalNodeCore for OneOf { fn name(&self) -> &str { "OneOf" @@ -316,7 +331,7 @@ pub struct OneOfExec { // Optionally declare a required input ordering // This will inform DataFusion to add sorts to children, // which will improve cost estimation of candidates - required_input_ordering: Option, + required_input_ordering: Option, // Index of the candidate with the best cost best: usize, // Cost function to use in optimization @@ -337,7 +352,7 @@ impl OneOfExec { /// Create a new `OneOfExec` pub fn try_new( candidates: Vec>, - required_input_ordering: Option, + required_input_ordering: Option, cost: CostFn, ) -> Result { if candidates.is_empty() { @@ -345,9 +360,10 @@ impl OneOfExec { "can't create OneOfExec with empty children".to_string(), )); } - let best = candidates + + let best = cost(Box::new(candidates.iter().map(|c| c.as_ref()))) .iter() - .position_min_by_key(|candidate| OrderedFloat(cost(candidate.as_ref()))) + .position_min_by_key(|&cost| OrderedFloat(*cost)) .unwrap(); Ok(Self { @@ -366,7 +382,7 @@ impl OneOfExec { /// Modify this plan's required input ordering. /// Used for sort pushdown - pub fn with_required_input_ordering(self, requirement: Option) -> Self { + pub fn with_required_input_ordering(self, requirement: Option) -> Self { Self { required_input_ordering: requirement, ..self @@ -387,7 +403,7 @@ impl ExecutionPlan for OneOfExec { self.candidates[self.best].properties() } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.required_input_ordering.clone(); self.children().len()] } @@ -427,17 +443,24 @@ impl ExecutionPlan for OneOfExec { } fn statistics(&self) -> Result { - self.candidates[self.best].statistics() + self.candidates[self.best].partition_statistics(None) + } + + fn partition_statistics( + &self, + partition: Option, + ) -> Result { + self.candidates[self.best].partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + true } } impl DisplayAs for OneOfExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let costs = self - .children() - .iter() - .map(|c| (self.cost)(c.as_ref())) - .collect_vec(); + let costs = (self.cost)(Box::new(self.children().iter().map(|arc| arc.as_ref()))); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( @@ -448,12 +471,16 @@ impl DisplayAs for OneOfExec { format_physical_sort_requirement_list( &self .required_input_ordering - .clone() - .unwrap_or_default() - .into_iter() - .map(PhysicalSortExpr::from) - .map(PhysicalSortRequirement::from) - .collect_vec() + .as_ref() + .map(|req| { + req.clone() + .into_single() + .into_iter() + .map(PhysicalSortExpr::from) + .map(PhysicalSortRequirement::from) + .collect_vec() + }) + .unwrap_or_default(), ) ) } diff --git a/src/rewrite/normal_form.rs b/src/rewrite/normal_form.rs index 0d6f7d0..b5be586 100644 --- a/src/rewrite/normal_form.rs +++ b/src/rewrite/normal_form.rs @@ -233,22 +233,11 @@ impl SpjNormalForm { .map(|expr| predicate.normalize_expr(expr)) .collect(); - let mut referenced_tables = vec![]; - original_plan - .apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - referenced_tables.push(scan.table_name.clone()); - } - - Ok(TreeNodeRecursion::Continue) - }) - // No chance of error since we never return Err -- this unwrap is safe - .unwrap(); - Ok(Self { output_schema: Arc::clone(original_plan.schema()), output_exprs, - referenced_tables, + // Reuse referenced_tables collected during Predicate::new to avoid extra traversal + referenced_tables: predicate.referenced_tables.clone(), predicate, }) } @@ -258,32 +247,50 @@ impl SpjNormalForm { /// This is useful for rewriting queries to use materialized views. pub fn rewrite_from( &self, - mut other: &Self, + other: &Self, qualifier: TableReference, source: Arc, ) -> Result> { log::trace!("rewriting from {qualifier}"); + + // Cache columns() result to avoid repeated Vec allocation in the loop. + // DFSchema::columns() creates a new Vec on each call. + let output_columns = self.output_schema.columns(); + let mut new_output_exprs = Vec::with_capacity(self.output_exprs.len()); // check that our output exprs are sub-expressions of the other one's output exprs for (i, output_expr) in self.output_exprs.iter().enumerate() { - let new_output_expr = other - .predicate - .normalize_expr(output_expr.clone()) - .rewrite(&mut other)? - .data; - - // Check that all references to the original tables have been replaced. - // All remaining column expressions should be unqualified, which indicates - // that they refer to the output of the sub-plan (in this case the view) - if new_output_expr - .column_refs() - .iter() - .any(|c| c.relation.is_some()) - { - return Ok(None); - } + // Fast path for simple Column expressions (most common case). + // This avoids the expensive normalize_expr transform for columns. + let new_output_expr = if let Expr::Column(col) = output_expr { + let normalized_col = other.predicate.normalize_column(col); + match other.find_output_column(&normalized_col) { + Some(rewritten) => rewritten, + None => return Ok(None), // Column not found, can't rewrite + } + } else { + // Slow path: complex expressions need full transform + let new_output_expr = other + .predicate + .normalize_expr(output_expr.clone()) + .rewrite(&mut &*other)? + .data; + + // Check that all references to the original tables have been replaced. + // All remaining column expressions should be unqualified, which indicates + // that they refer to the output of the sub-plan (in this case the view) + if new_output_expr + .column_refs() + .iter() + .any(|c| c.relation.is_some()) + { + return Ok(None); + } + new_output_expr + }; - let column = &self.output_schema.columns()[i]; + // Use cached columns instead of calling .columns() on each iteration + let column = &output_columns[i]; new_output_exprs.push( new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()), ); @@ -310,7 +317,7 @@ impl SpjNormalForm { .into_iter() .chain(range_filters) .chain(residual_filters) - .map(|expr| expr.rewrite(&mut other).unwrap().data) + .map(|expr| expr.rewrite(&mut &*other).unwrap().data) .reduce(|a, b| a.and(b)); if all_filters @@ -329,6 +336,20 @@ impl SpjNormalForm { builder.project(new_output_exprs)?.build().map(Some) } + + /// Fast path: find a column in output_exprs and return rewritten expression. + /// This avoids full tree traversal for simple column lookups. + #[inline] + fn find_output_column(&self, col: &Column) -> Option { + self.output_exprs + .iter() + .position(|e| matches!(e, Expr::Column(c) if c == col)) + .map(|idx| { + Expr::Column(Column::new_unqualified( + self.output_schema.field(idx).name().clone(), + )) + }) + } } /// Stores information on filters from a Select-Project-Join plan. @@ -344,84 +365,95 @@ struct Predicate { ranges_by_equivalence_class: Vec>, /// Filter expressions that aren't column equality predicates or range filters. residuals: HashSet, + /// Tables referenced in this plan (collected during single-pass traversal) + referenced_tables: Vec, } impl Predicate { + /// Create a new Predicate by analyzing the given logical plan. + /// Uses single-pass traversal to collect schema, columns, filters, and referenced tables. fn new(plan: &LogicalPlan) -> Result { let mut schema = DFSchema::empty(); - plan.apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - let new_schema = DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - scan.source.schema().as_ref(), - )?; - schema = if schema.fields().is_empty() { - new_schema - } else { - schema.join(&new_schema)? - } - } + let mut columns_info: Vec<(Column, arrow::datatypes::DataType)> = Vec::new(); + let mut filters: Vec = Vec::new(); + let mut referenced_tables: Vec = Vec::new(); + + // Single traversal to collect everything + plan.apply(|node| { + match node { + LogicalPlan::TableScan(scan) => { + // Collect referenced table + referenced_tables.push(scan.table_name.clone()); - Ok(TreeNodeRecursion::Continue) - })?; + // Build schema + let new_schema = DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + scan.source.schema().as_ref(), + )?; + + // Collect columns with their data types + for (table_ref, field) in new_schema.iter() { + columns_info.push(( + Column::new(table_ref.cloned(), field.name()), + field.data_type().clone(), + )); + } - let mut new = Self { - schema, - eq_classes: vec![], - eq_class_idx_by_column: HashMap::default(), - ranges_by_equivalence_class: vec![], - residuals: HashSet::new(), - }; + // Merge schema + schema = if schema.fields().is_empty() { + new_schema + } else { + schema.join(&new_schema)? + }; - // Collect all referenced columns - plan.apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - for (i, (table_ref, field)) in DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - scan.source.schema().as_ref(), - )? - .iter() - .enumerate() - { - let column = Column::new(table_ref.cloned(), field.name()); - let data_type = field.data_type(); - new.eq_classes - .push(ColumnEquivalenceClass::new_singleton(column.clone())); - new.eq_class_idx_by_column.insert(column, i); - new.ranges_by_equivalence_class - .push(Some(Interval::make_unbounded(data_type)?)); + // Collect filters from TableScan + filters.extend(scan.filters.iter().cloned()); + } + LogicalPlan::Filter(filter) => { + filters.push(filter.predicate.clone()); } - } - - Ok(TreeNodeRecursion::Continue) - })?; - - // Collect any filters - plan.apply(|plan| { - let filters = match plan { - LogicalPlan::TableScan(scan) => scan.filters.as_slice(), - LogicalPlan::Filter(filter) => core::slice::from_ref(&filter.predicate), LogicalPlan::Join(_join) => { return Err(DataFusionError::Internal( "joins are not supported yet".to_string(), - )) + )); } - LogicalPlan::Projection(_) => &[], + LogicalPlan::Projection(_) => {} _ => { return Err(DataFusionError::Plan(format!( "unsupported logical plan: {}", - plan.display() - ))) + node.display() + ))); } - }; - - for expr in filters.iter().flat_map(split_conjunction) { - new.insert_conjuct(expr)?; } - Ok(TreeNodeRecursion::Continue) })?; + // Initialize data structures with known capacity + let num_columns = columns_info.len(); + let mut eq_classes = Vec::with_capacity(num_columns); + let mut eq_class_idx_by_column = HashMap::with_capacity(num_columns); + let mut ranges_by_equivalence_class = Vec::with_capacity(num_columns); + + for (i, (column, data_type)) in columns_info.into_iter().enumerate() { + eq_classes.push(ColumnEquivalenceClass::new_singleton(column.clone())); + eq_class_idx_by_column.insert(column, i); + ranges_by_equivalence_class.push(Some(Interval::make_unbounded(&data_type)?)); + } + + let mut new = Self { + schema, + eq_classes, + eq_class_idx_by_column, + ranges_by_equivalence_class, + residuals: HashSet::new(), + referenced_tables, + }; + + // Process all collected filters + for expr in filters.iter().flat_map(split_conjunction) { + new.insert_conjuct(expr)?; + } + Ok(new) } @@ -431,6 +463,17 @@ impl Predicate { .and_then(|&idx| self.eq_classes.get(idx)) } + /// Fast path: normalize a single Column without full tree traversal. + /// This is O(1) lookup instead of O(n) transform. + #[inline] + fn normalize_column(&self, col: &Column) -> Column { + if let Some(eq_class) = self.class_for_column(col) { + eq_class.columns.first().unwrap().clone() + } else { + col.clone() + } + } + /// Add a new column equivalence fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> { match ( @@ -455,6 +498,10 @@ impl Predicate { self.eq_classes[idx].columns.insert(c2.clone()); } (Some(&i), Some(&j)) => { + if i == j { + // The two columns are already in the same equivalence class. + return Ok(()); + } // We need to merge two existing column eq classes. // Delete the eq class with a larger index, @@ -523,8 +570,8 @@ impl Predicate { // so handling of open intervals is done by adding/subtracting the smallest increment. // However, there is not really a public API to do this, // other than the satisfy_greater method. - Operator::Lt => Ok( - match satisfy_greater( + Operator::Lt => { + let range_val = match satisfy_greater( &Interval::try_new(value.clone(), value.clone())?, &Interval::make_unbounded(&value.data_type())?, true, @@ -534,11 +581,19 @@ impl Predicate { *range = None; return Ok(()); } - }, - ), - // Same thing as above. - Operator::Gt => Ok( - match satisfy_greater( + }; + // If the type is not discrete (e.g. Utf8), satisfy_greater may return an unchanged value. + // This means the interval could not be tightened and it is unsafe to produce a closed interval + if range_val.upper() == &value { + Err(DataFusionError::Plan( + "cannot represent strict inequality as closed interval for non-discrete types".to_string(), + )) + } else { + Ok(range_val) + } + } + Operator::Gt => { + let range_val = match satisfy_greater( &Interval::make_unbounded(&value.data_type())?, &Interval::try_new(value.clone(), value.clone())?, true, @@ -548,8 +603,15 @@ impl Predicate { *range = None; return Ok(()); } - }, - ), + }; + if range_val.lower() == &value { + Err(DataFusionError::Plan( + "cannot represent strict inequality as closed interval for non-discrete types".to_string(), + )) + } else { + Ok(range_val) + } + } _ => Err(DataFusionError::Plan( "unsupported binary expression".to_string(), )), @@ -593,7 +655,28 @@ impl Predicate { /// Add a binary expression to our collection of filters. fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> { match (left, op, right) { - (Expr::Column(c), op, Expr::Literal(v)) => { + (Expr::Column(c), op, Expr::Literal(v, _)) => { + // Normalize boolean expressions to canonical form: + // col = false -> NOT col + // col != true -> NOT col + // col = true -> col + // col != false -> col + // This ensures semantic equivalence matching (e.g., "active = false" matches "NOT active") + if let ScalarValue::Boolean(Some(b)) = v { + match (op, b) { + (Operator::Eq, false) | (Operator::NotEq, true) => { + self.residuals + .insert(Expr::Not(Box::new(Expr::Column(c.clone())))); + return Ok(()); + } + (Operator::Eq, true) | (Operator::NotEq, false) => { + self.residuals.insert(Expr::Column(c.clone())); + return Ok(()); + } + _ => {} + } + } + if let Err(e) = self.add_range(c, &op, v) { // Add a range can fail in some cases, so just fallthrough log::debug!("failed to add range filter: {e}"); @@ -601,7 +684,7 @@ impl Predicate { return Ok(()); } } - (Expr::Literal(_), op, Expr::Column(_)) => { + (Expr::Literal(_, _), op, Expr::Column(_)) => { if let Some(swapped) = op.swap() { return self.insert_binary_expr(right, swapped, left); } @@ -714,14 +797,14 @@ impl Predicate { extra_range_filters.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(Expr::Column(other_column.clone())), op: Operator::Eq, - right: Box::new(Expr::Literal(range.lower().clone())), + right: Box::new(Expr::Literal(range.lower().clone(), None)), })) } else { if !range.lower().is_null() { extra_range_filters.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(Expr::Column(other_column.clone())), op: Operator::GtEq, - right: Box::new(Expr::Literal(range.lower().clone())), + right: Box::new(Expr::Literal(range.lower().clone(), None)), })) } @@ -729,7 +812,7 @@ impl Predicate { extra_range_filters.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(Expr::Column(other_column.clone())), op: Operator::LtEq, - right: Box::new(Expr::Literal(range.upper().clone())), + right: Box::new(Expr::Literal(range.upper().clone(), None)), })) } } @@ -773,6 +856,11 @@ impl Predicate { /// Rewrite all expressions in terms of their normal representatives /// with respect to this predicate's equivalence classes. fn normalize_expr(&self, e: Expr) -> Expr { + // Fast path: if it's a simple Column, avoid full transform traversal + if let Expr::Column(ref c) = e { + return Expr::Column(self.normalize_column(c)); + } + e.transform(&|e| { let c = match e { Expr::Column(c) => c, @@ -996,11 +1084,11 @@ mod test { ctx.sql(&format!( " CREATE EXTERNAL TABLE t1 ( - column1 VARCHAR, - column2 BIGINT, + column1 VARCHAR, + column2 BIGINT, column3 CHAR ) - STORED AS PARQUET + STORED AS PARQUET LOCATION '{}'", t1_path.path().to_string_lossy() )) @@ -1095,7 +1183,7 @@ mod test { assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref()); let expected = concat_batches( - &query_plan.schema().as_ref().clone().into(), + &query_plan.schema().inner().clone(), &context .execute_logical_plan(query_plan) .await? @@ -1104,7 +1192,7 @@ mod test { )?; let result = concat_batches( - &rewritten.schema().as_ref().clone().into(), + &rewritten.schema().inner().clone(), &context .execute_logical_plan(rewritten) .await? @@ -1144,11 +1232,16 @@ mod test { TestCase { name: "range filter + equality predicate", base: - "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'", + "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'", query: // Since column1 = column3 in the original view, // we are allowed to substitute column1 for column3 and vice versa. - "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'", + "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'", + }, + TestCase { + name: "range filter with inequality on non-discrete type", + base: "SELECT * FROM t1", + query: "SELECT column1 FROM t1 WHERE column1 < '2022'", }, TestCase { name: "duplicate expressions (X-209)", @@ -1205,4 +1298,282 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_predicate_new_collects_expected_data() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a table with known schema + ctx.sql( + "CREATE TABLE test_table ( + col1 INT, + col2 VARCHAR, + col3 DOUBLE + )", + ) + .await? + .collect() + .await?; + + // Create a plan with filters + let plan = ctx + .sql("SELECT col1, col2 FROM test_table WHERE col1 >= 10 AND col2 = col3") + .await? + .into_optimized_plan()?; + + let normal_form = SpjNormalForm::new(&plan)?; + + // Verify referenced_tables is collected + assert_eq!(normal_form.referenced_tables().len(), 1); + assert_eq!(normal_form.referenced_tables()[0].to_string(), "test_table"); + + // Verify output_exprs matches the projection (2 columns) + assert_eq!(normal_form.output_exprs().len(), 2); + + // Verify schema is preserved + assert_eq!(normal_form.output_schema().fields().len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_predicate_new_with_join_returns_error() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql("CREATE TABLE t1 (a INT, b INT)") + .await? + .collect() + .await?; + ctx.sql("CREATE TABLE t2 (c INT, d INT)") + .await? + .collect() + .await?; + + // Test that join returns an error as it's not supported yet + let plan = ctx + .sql("SELECT t1.a, t2.d FROM t1 JOIN t2 ON t1.b = t2.c WHERE t1.a >= 0 AND t2.d <= 100") + .await? + .into_optimized_plan()?; + + let result = SpjNormalForm::new(&plan); + + // Verify that join returns an error + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("joins are not supported yet")); + + Ok(()) + } + + #[tokio::test] + async fn test_predicate_new_with_range_filters() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql("CREATE TABLE range_test (x INT, y INT, z VARCHAR)") + .await? + .collect() + .await?; + + let plan = ctx + .sql("SELECT * FROM range_test WHERE x >= 10 AND x <= 100 AND y = 50") + .await? + .into_optimized_plan()?; + + let normal_form = SpjNormalForm::new(&plan)?; + + // Verify all columns are in output + assert_eq!(normal_form.output_exprs().len(), 3); + assert_eq!(normal_form.referenced_tables().len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn test_boolean_expression_normalization() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = SessionContext::new(); + + // Create table with boolean column + ctx.sql( + "CREATE TABLE bool_test ( + id INT, + active BOOLEAN, + name VARCHAR + )", + ) + .await? + .collect() + .await?; + + ctx.sql("INSERT INTO bool_test VALUES (1, true, 'a'), (2, false, 'b')") + .await? + .collect() + .await?; + + // MV: uses "active = false" + let mv_plan = ctx + .sql("SELECT * FROM bool_test WHERE active = false") + .await? + .into_optimized_plan()?; + let mv_normal_form = SpjNormalForm::new(&mv_plan)?; + + ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false") + .await? + .collect() + .await?; + + // Query: uses "NOT active" (semantically equivalent to "active = false") + let query_plan = ctx + .sql("SELECT id, name FROM bool_test WHERE NOT active") + .await? + .into_optimized_plan()?; + let query_normal_form = SpjNormalForm::new(&query_plan)?; + + let table_ref = TableReference::bare("mv"); + let rewritten = query_normal_form.rewrite_from( + &mv_normal_form, + table_ref.clone(), + provider_as_source(ctx.table_provider(table_ref).await?), + )?; + + assert!( + rewritten.is_some(), + "Expected MV with 'active = false' to match query with 'NOT active'" + ); + + // Also test the reverse: MV with "NOT active", query with "active = false" + let mv_plan2 = ctx + .sql("SELECT * FROM bool_test WHERE NOT active") + .await? + .into_optimized_plan()?; + let mv_normal_form2 = SpjNormalForm::new(&mv_plan2)?; + + ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test WHERE NOT active") + .await? + .collect() + .await?; + + let query_plan2 = ctx + .sql("SELECT id FROM bool_test WHERE active = false") + .await? + .into_optimized_plan()?; + let query_normal_form2 = SpjNormalForm::new(&query_plan2)?; + + let table_ref2 = TableReference::bare("mv2"); + let rewritten2 = query_normal_form2.rewrite_from( + &mv_normal_form2, + table_ref2.clone(), + provider_as_source(ctx.table_provider(table_ref2).await?), + )?; + + assert!( + rewritten2.is_some(), + "Expected MV with 'NOT active' to match query with 'active = false'" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_boolean_column_normalization() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = SessionContext::new(); + + ctx.sql( + "CREATE TABLE bool_test ( + id INT, + active BOOLEAN, + name VARCHAR + )", + ) + .await? + .collect() + .await?; + + // Test: MV with "active = false" should match query with "NOT active" + let mv_plan = ctx + .sql("SELECT * FROM bool_test WHERE active = false") + .await? + .into_optimized_plan()?; + let mv_normal_form = SpjNormalForm::new(&mv_plan)?; + + ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false") + .await? + .collect() + .await?; + + let query_plan = ctx + .sql("SELECT id, name FROM bool_test WHERE NOT active") + .await? + .into_optimized_plan()?; + let query_normal_form = SpjNormalForm::new(&query_plan)?; + + let table_ref = TableReference::bare("mv"); + let rewritten = query_normal_form.rewrite_from( + &mv_normal_form, + table_ref.clone(), + provider_as_source(ctx.table_provider(table_ref).await?), + )?; + + // Should successfully rewrite + assert!( + rewritten.is_some(), + "Expected MV with 'active = false' to match query with 'NOT active'" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_boolean_true_normalization() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = SessionContext::new(); + + ctx.sql( + "CREATE TABLE bool_test2 ( + id INT, + enabled BOOLEAN + )", + ) + .await? + .collect() + .await?; + + // Test: MV with "enabled = true" should match query with just "enabled" + let mv_plan = ctx + .sql("SELECT * FROM bool_test2 WHERE enabled = true") + .await? + .into_optimized_plan()?; + let mv_normal_form = SpjNormalForm::new(&mv_plan)?; + + ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test2 WHERE enabled = true") + .await? + .collect() + .await?; + + let query_plan = ctx + .sql("SELECT id FROM bool_test2 WHERE enabled") + .await? + .into_optimized_plan()?; + let query_normal_form = SpjNormalForm::new(&query_plan)?; + + let table_ref = TableReference::bare("mv2"); + let rewritten = query_normal_form.rewrite_from( + &mv_normal_form, + table_ref.clone(), + provider_as_source(ctx.table_provider(table_ref).await?), + )?; + + assert!( + rewritten.is_some(), + "Expected MV with 'enabled = true' to match query with 'enabled'" + ); + + Ok(()) + } } diff --git a/tests/materialized_listing_table.rs b/tests/materialized_listing_table.rs index 5ad9d25..00e4886 100644 --- a/tests/materialized_listing_table.rs +++ b/tests/materialized_listing_table.rs @@ -32,13 +32,15 @@ use datafusion::{ }, prelude::{SessionConfig, SessionContext}, }; -use datafusion_common::{Constraints, DataFusionError, ParamValues, ScalarValue, Statistics}; +use datafusion_common::{ + metadata::ScalarAndMetadata, Constraints, DataFusionError, ParamValues, ScalarValue, Statistics, +}; use datafusion_expr::{ col, dml::InsertOp, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, SortExpr, TableProviderFilterPushDown, TableType, }; use datafusion_materialized_views::materialized::{ - dependencies::{mv_dependencies, stale_files}, + dependencies::{mv_dependencies, stale_files, DependencyOptions}, file_metadata::{DefaultFileMetadataProvider, FileMetadata}, register_materialized, row_metadata::RowMetadataRegistry, @@ -158,6 +160,7 @@ async fn setup() -> Result { Arc::clone(ctx.state().catalog_list()), Arc::clone(&row_metadata_registry), ctx.state().config_options(), + DependencyOptions::default(), ), ); ctx.register_udtf( @@ -167,6 +170,7 @@ async fn setup() -> Result { Arc::clone(&row_metadata_registry), Arc::clone(&file_metadata) as Arc, ctx.state().config_options(), + DependencyOptions::default(), ), ); @@ -185,7 +189,7 @@ async fn setup() -> Result { .await?; ctx.sql( - "INSERT INTO t1 VALUES + "INSERT INTO t1 VALUES (1, '2023-01-01', 'A'), (2, '2023-01-02', 'B'), (3, '2023-01-03', 'C'), @@ -251,7 +255,7 @@ async fn test_materialized_listing_table_incremental_maintenance() -> Result<()> // Insert another row into the source table ctx.sql( - "INSERT INTO t1 VALUES + "INSERT INTO t1 VALUES (7, '2024-12-07', 'W')", ) .await? @@ -352,12 +356,13 @@ impl MaterializedListingTable { file_sort_order: opts.file_sort_order, }); + let mut listing_table_config = ListingTableConfig::new(config.table_path); + if let Some(options) = options { + listing_table_config = listing_table_config.with_listing_options(options); + } + listing_table_config = listing_table_config.with_schema(Arc::new(file_schema)); Ok(MaterializedListingTable { - inner: ListingTable::try_new(ListingTableConfig { - table_paths: vec![config.table_path], - file_schema: Some(Arc::new(file_schema)), - options, - })?, + inner: ListingTable::try_new(listing_table_config)?, query: normalized_query, schema: normalized_schema, }) @@ -503,7 +508,7 @@ impl TableProvider for MaterializedListingTable { self.inner.get_table_definition() } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&self) -> Option> { // We _could_ return the LogicalPlan here, // but it will cause this table to be treated like a regular view // and the materialized results will not be used. @@ -549,7 +554,7 @@ impl TableProvider for MaterializedListingTable { fn parse_partition_values( path: &ObjectPath, partition_columns: &[(String, DataType)], -) -> Result, DataFusionError> { +) -> Result, DataFusionError> { let parts = path.parts().map(|part| part.to_owned()).collect::>(); let pairs = parts @@ -561,7 +566,7 @@ fn parse_partition_values( .iter() .map(|(column, datatype)| { let value = pairs.get(column.as_str()).copied().map(String::from); - ScalarValue::Utf8(value).cast_to(datatype) + ScalarAndMetadata::from(ScalarValue::Utf8(value)).cast_storage_to(datatype) }) .collect::, _>>()?;