diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 14bf19f..177fcb8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,8 @@ jobs: steps: - uses: actions/checkout@v3 - uses: crate-ci/typos@v1.13.10 + with: + config: .typos.toml check: name: Check @@ -132,7 +134,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: EmbarkStudios/cargo-deny-action@v1 + - uses: EmbarkStudios/cargo-deny-action@v2 with: command: check license diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 0000000..0c8feb8 --- /dev/null +++ b/.typos.toml @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# typos configuration file +# Place this file in the project root (same level as Cargo.toml) +[files] +extend-exclude = ["Cargo.toml", "**/Cargo.toml"] diff --git a/Cargo.toml b/Cargo.toml index f2ee4fd..1f388ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,22 +25,23 @@ authors = ["Matthew Cramerus "] license = "Apache-2.0" description = "Materialized Views & Query Rewriting in DataFusion" keywords = ["arrow", "arrow-rs", "datafusion"] -rust-version = "1.80" +rust-version = "1.88.0" [dependencies] -arrow = "55" -arrow-schema = "55" -async-trait = "0.1" +aquamarine = "0.6.0" +arrow = "57.0.0" +arrow-schema = "57.0.0" +async-trait = "0.1.89" dashmap = "6" -datafusion = "47" -datafusion-common = "47" -datafusion-expr = "47" -datafusion-functions = "47" -datafusion-functions-aggregate = "47" -datafusion-optimizer = "47" -datafusion-physical-expr = "47" -datafusion-physical-plan = "47" -datafusion-sql = "47" +datafusion = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-common = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-expr = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-functions = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-functions-aggregate = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-optimizer = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-physical-expr = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-physical-plan = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } +datafusion-sql = { git = "https://github.com/massive-com/arrow-datafusion", rev = "da5d59f" } futures = "0.3" itertools = "0.14" log = "0.4" @@ -49,7 +50,13 @@ ordered-float = "5.0.0" [dev-dependencies] anyhow = "1.0.95" +criterion = "0.4" env_logger = "0.11.6" tempfile = "3.14.0" tokio = "1.42.0" url = "2.5.4" + +[[bench]] +name = "materialized_views_benchmark" +harness = false +path = "benches/materialized_views_benchmark.rs" diff --git a/benches/materialized_views_benchmark.rs b/benches/materialized_views_benchmark.rs new file mode 100644 index 0000000..7bc2f52 --- /dev/null +++ b/benches/materialized_views_benchmark.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use std::sync::Arc; +use std::time::Duration; + +use datafusion::datasource::provider_as_source; +use datafusion::datasource::TableProvider; +use datafusion::prelude::SessionContext; +use datafusion_common::Result as DfResult; +use datafusion_expr::LogicalPlan; +use datafusion_materialized_views::rewrite::normal_form::SpjNormalForm; +use datafusion_sql::TableReference; +use tokio::runtime::Builder; + +// Utility: generate CREATE TABLE SQL with n columns named c0..c{n-1} +fn make_create_table_sql(table_name: &str, ncols: usize) -> String { + let cols = (0..ncols) + .map(|i| format!("c{} INT", i)) + .collect::>() + .join(", "); + format!( + "CREATE TABLE {table} ({cols})", + table = table_name, + cols = cols + ) +} + +// Utility: generate a base SELECT that projects all columns and has a couple filters +fn make_base_sql(table_name: &str, ncols: usize) -> String { + let cols = (0..ncols) + .map(|i| format!("c{}", i)) + .collect::>() + .join(", "); + let mut where_clauses = vec![]; + if ncols > 0 { + where_clauses.push("c0 >= 0".to_string()); + } + if ncols > 1 { + where_clauses.push("c0 + c1 >= 0".to_string()); + } + let where_part = if where_clauses.is_empty() { + "".to_string() + } else { + format!(" WHERE {}", where_clauses.join(" AND ")) + }; + format!("SELECT {cols} FROM {table}{where}", cols = cols, table = table_name, where = where_part) +} + +// Utility: generate a query that is stricter and selects subset (so rewrite_from has a chance) +fn make_query_sql(table_name: &str, ncols: usize) -> String { + let take = std::cmp::max(1, ncols / 2); + let cols = (0..take) + .map(|i| format!("c{}", i)) + .collect::>() + .join(", "); + + let mut where_clauses = vec![]; + if ncols > 0 { + where_clauses.push("c0 >= 10".to_string()); + } + if ncols > 1 { + where_clauses.push("c0 * c1 > 100".to_string()); + } + if ncols > 10 { + where_clauses.push(format!("c{} >= 0", ncols - 1)); + } + + let where_part = if where_clauses.is_empty() { + "".to_string() + } else { + format!(" WHERE {}", where_clauses.join(" AND ")) + }; + + format!("SELECT {cols} FROM {table}{where}", cols = cols, table = table_name, where = where_part) +} + +// Build fixture: create SessionContext, the table, then return LogicalPlans for base & query and table provider +fn build_fixture_for_cols( + rt: &tokio::runtime::Runtime, + ncols: usize, +) -> DfResult<(LogicalPlan, LogicalPlan, Arc)> { + rt.block_on(async move { + let ctx = SessionContext::new(); + + // create table + let table_name = "t"; + let create_sql = make_create_table_sql(table_name, ncols); + ctx.sql(&create_sql).await?.collect().await?; // create table in catalog + + // base and query plans (optimize to normalize) + let base_sql = make_base_sql(table_name, ncols); + let query_sql = make_query_sql(table_name, ncols); + + let base_df = ctx.sql(&base_sql).await?; + let base_plan = base_df.into_optimized_plan()?; + + let query_df = ctx.sql(&query_sql).await?; + let query_plan = query_df.into_optimized_plan()?; + + // get table provider (Arc) + let table_ref = TableReference::bare(table_name); + let provider: Arc = ctx.table_provider(table_ref.clone()).await?; + + Ok((base_plan, query_plan, provider)) + }) +} + +// Criterion benchmark +fn criterion_benchmark(c: &mut Criterion) { + // columns to test + let col_cases = vec![1usize, 10, 20, 40, 80, 160, 320]; + + // build a tokio runtime that's broadly compatible + let rt = Builder::new_current_thread() + .enable_all() + .build() + .expect("tokio runtime"); + + let mut group = c.benchmark_group("view_matcher_spj"); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(5)); + group.sample_size(30); + + for &ncols in &col_cases { + // Build fixture + let (base_plan, query_plan, provider) = + build_fixture_for_cols(&rt, ncols).expect("fixture"); + + // Measure SpjNormalForm::new for base_plan and query_plan separately + let id_base = BenchmarkId::new("spj_normal_form_new", format!("cols={}", ncols)); + group.throughput(Throughput::Elements(1)); + group.bench_with_input(id_base, &base_plan, |b, plan| { + b.iter(|| { + let _nf = SpjNormalForm::new(plan).unwrap(); + }); + }); + + let id_query_nf = BenchmarkId::new("spj_normal_form_new_query", format!("cols={}", ncols)); + group.bench_with_input(id_query_nf, &query_plan, |b, plan| { + b.iter(|| { + let _nf = SpjNormalForm::new(plan).unwrap(); + }); + }); + + // Precompute normal forms once (to measure rewrite_from cost only) + let base_nf = SpjNormalForm::new(&base_plan).expect("base_nf"); + let query_nf = SpjNormalForm::new(&query_plan).expect("query_nf"); + + // qualifier for rewrite_from and a source created from the provider + let qualifier = TableReference::bare("mv"); + let source = provider_as_source(Arc::clone(&provider)); + + // Benchmark rewrite_from (this is the heavy check) + let id_rewrite = BenchmarkId::new("rewrite_from", format!("cols={}", ncols)); + group.bench_with_input(id_rewrite, &ncols, |b, &_n| { + b.iter(|| { + let _res = query_nf.rewrite_from(&base_nf, qualifier.clone(), Arc::clone(&source)); + }); + }); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/deny.toml b/deny.toml index a24420a..b516c5d 100644 --- a/deny.toml +++ b/deny.toml @@ -24,6 +24,8 @@ allow = [ "BSD-3-Clause", "CC0-1.0", "Unicode-3.0", - "Zlib" + "Zlib", + "ISC", + "bzip2-1.0.6" ] version = 2 diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..8cd75dd --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[toolchain] +channel = "1.91.0" +components = ["rust-analyzer", "rustfmt", "clippy"] \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 8b26d85..8e4ef48 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,16 +45,77 @@ pub mod materialized; pub mod rewrite; /// Configuration options for materialized view related features. +/// +/// # Materialized View Configuration +/// +/// Query rewriting uses two configuration options that work together: +/// +/// 1. **`use_in_query_rewrite` (on candidate MVs)**: Controls whether an MV is globally available +/// for query rewriting. MVs with `use_in_query_rewrite = false` are excluded from the +/// candidate pool entirely. +/// +/// 2. **`rewrite_targets` (on queried tables)**: When querying a table, this field filters which +/// MVs from the available pool should be considered as rewrite candidates for that specific table. +/// +/// The interaction works as follows: +/// - First, `use_in_query_rewrite` determines the global pool of available MVs +/// - Then, `rewrite_targets` on the queried table filters that pool for that specific query +/// - An MV must have `use_in_query_rewrite = true` **and** be in the `rewrite_targets` list +/// (or the list must be None) to be considered +/// +/// # Example +/// +/// ```ignore +/// // MV1: available for query rewriting +/// let mv1_config = MaterializedConfig { +/// use_in_query_rewrite: true, // MV1 is in the global pool +/// rewrite_targets: None, +/// }; +/// +/// // MV2: not available for query rewriting +/// let mv2_config = MaterializedConfig { +/// use_in_query_rewrite: false, // MV2 is excluded from the pool +/// rewrite_targets: None, +/// }; +/// +/// // Base table: only considers MV1 for rewrites +/// let base_config = MaterializedConfig { +/// use_in_query_rewrite: true, +/// rewrite_targets: Some(vec!["mv1".to_string()]), // Only MV1 is considered +/// }; +/// // When querying base_table: +/// // - MV1 will be considered (in pool + in targets list) +/// // - MV2 will NOT be considered (not in pool, even if added to targets list) +/// ``` #[derive(Debug, Clone)] pub struct MaterializedConfig { - /// Whether or not query rewriting should exploit this materialized view. + /// Whether or not this materialized view is available for query rewriting. + /// + /// If `false`, this MV will not be loaded into the query rewrite engine and cannot be used + /// as a rewrite candidate, regardless of any `rewrite_targets` settings on other tables. pub use_in_query_rewrite: bool, + + /// Optional candidate materialized views for query rewriting. + /// + /// When this table is queried, only the MVs listed here will be considered as rewrite candidates. + /// These should be full table names (e.g., `atlas.us_stocks_sip.trades_by_ticker`). + /// + /// - If `None` (default): all eligible MVs in the catalog (where `use_in_query_rewrite = true`) + /// are considered as rewrite candidates + /// - If `Some(vec![])`: no MVs are considered (effectively disables query rewriting for this table) + /// - If `Some(vec!["mv1", "mv2"])`: only mv1 and mv2 (if they have `use_in_query_rewrite = true`) + /// are considered as rewrite candidates + /// + /// Note: This field is typically set on the **queried table** (which may itself be an MV). + /// It acts as a whitelist that further filters the pool of available MVs for queries against this table. + pub rewrite_targets: Option>, } impl Default for MaterializedConfig { fn default() -> Self { Self { use_in_query_rewrite: true, + rewrite_targets: None, } } } diff --git a/src/materialized.rs b/src/materialized.rs index e089591..8e093c0 100644 --- a/src/materialized.rs +++ b/src/materialized.rs @@ -41,6 +41,7 @@ use datafusion::{ catalog::TableProvider, datasource::listing::{ListingTable, ListingTableUrl}, }; +use datafusion_common::DataFusionError; use datafusion_expr::LogicalPlan; use itertools::Itertools; @@ -110,6 +111,14 @@ pub trait Materialized: ListingTableLike { fn config(&self) -> MaterializedConfig { MaterializedConfig::default() } + + /// Which partition columns are 'static'. + /// Static partition columns are those that are used in incremental view maintenance. + /// These should be a prefix of the full set of partition columns returned by [`ListingTableLike::partition_columns`]. + /// The rest of the partition columns are 'dynamic' and their values will be generated at runtime during incremental refresh. + fn static_partition_columns(&self) -> Vec { + ::partition_columns(self) + } } /// Register a [`Materialized`] implementation in this registry. @@ -122,13 +131,38 @@ pub fn register_materialized() { } /// Attempt to cast the given TableProvider into a [`Materialized`]. -/// If the table's type has not been registered using [`register_materialized`], will return `None`. -pub fn cast_to_materialized(table: &dyn TableProvider) -> Option<&dyn Materialized> { - TABLE_TYPE_REGISTRY.cast_to_materialized(table).or_else(|| { - TABLE_TYPE_REGISTRY - .cast_to_decorator(table) - .and_then(|decorator| cast_to_materialized(decorator.base())) - }) +/// If the table's type has not been registered using [`register_materialized`], will return `Ok(None)`. +/// +/// Does a runtime check on some invariants of `Materialized` and returns an error if they are violated. +/// In particular, checks that the static partition columns are a prefix of all partition columns. +pub fn cast_to_materialized( + table: &dyn TableProvider, +) -> Result, DataFusionError> { + let materialized = match TABLE_TYPE_REGISTRY + .cast_to_materialized(table) + .map(Ok) + .or_else(|| { + TABLE_TYPE_REGISTRY + .cast_to_decorator(table) + .and_then(|decorator| cast_to_materialized(decorator.base()).transpose()) + }) + .transpose()? + { + None => return Ok(None), + Some(m) => m, + }; + + let static_partition_cols = materialized.static_partition_columns(); + let all_partition_cols = materialized.partition_columns(); + + if materialized.partition_columns()[..static_partition_cols.len()] != static_partition_cols[..] + { + return Err(DataFusionError::Plan(format!( + "Materialized view's static partition columns ({static_partition_cols:?}) must be a prefix of all partition columns ({all_partition_cols:?})" + ))); + } + + Ok(Some(materialized)) } /// A `TableProvider` that decorates other `TableProvider`s. diff --git a/src/materialized/dependencies.rs b/src/materialized/dependencies.rs index 5b42692..9cf8b8a 100644 --- a/src/materialized/dependencies.rs +++ b/src/materialized/dependencies.rs @@ -106,7 +106,7 @@ impl TableFunctionImpl for FileDependenciesUdtf { let table = util::get_table(self.catalog_list.as_ref(), &table_ref) .map_err(|e| DataFusionError::Plan(e.to_string()))?; - let mv = cast_to_materialized(table.as_ref()).ok_or(DataFusionError::Plan(format!( + let mv = cast_to_materialized(table.as_ref())?.ok_or(DataFusionError::Plan(format!( "mv_dependencies: table '{table_name} is not a materialized view. (Materialized TableProviders must be registered using register_materialized"), ))?; @@ -166,6 +166,15 @@ impl TableFunctionImpl for StaleFilesUdtf { &self.mv_dependencies.config_options.catalog.default_schema, ); + let table = util::get_table(self.mv_dependencies.catalog_list.as_ref(), &table_ref) + .map_err(|e| DataFusionError::Plan(e.to_string()))?; + let mv = cast_to_materialized(table.as_ref())?.ok_or(DataFusionError::Plan(format!( + "mv_dependencies: table '{table_name} is not a materialized view. (Materialized TableProviders must be registered using register_materialized"), + ))?; + + let url = mv.table_paths()[0].to_string(); + let num_static_partition_cols = mv.static_partition_columns().len(); + let logical_plan = LogicalPlanBuilder::scan_with_filters("dependencies", dependencies, None, vec![])? .aggregate( @@ -187,16 +196,21 @@ impl TableFunctionImpl for StaleFilesUdtf { )? .aggregate( vec![ - // We want to omit the file name along with any "special" partitions - // from the path before comparing it to the target partition. Special - // partitions must be leaf most nodes and are designated by a leading - // underscore. These are useful for adding additional information to a - // filename without affecting partitioning or staleness checks. - regexp_replace( - col("file_path"), - lit(r"(/_[^/=]+=[^/]+)*/[^/]*$"), - lit("/"), - None, + // Omit the file name along with any "special" partitions. + // This can include dynamic partition columns as well as some internal + // metadata columns that are not part of the schema + // + // We implement this by only taking the first N columns, + // where N is the number of static partition columns. + array_element( + regexp_match( + col("file_path"), + lit(format!( + "{url}(?:[^/=]+=[^/]+/){{{num_static_partition_cols}}}" + )), + None, + ), + lit(1), ) .alias("existing_target"), ], @@ -231,7 +245,7 @@ impl TableFunctionImpl for StaleFilesUdtf { /// Extract table name from args passed to TableFunctionImpl::call() fn get_table_name(args: &[Expr]) -> Result<&String> { match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(table_name))) => Ok(table_name), + Expr::Literal(ScalarValue::Utf8(Some(table_name)), _) => Ok(table_name), _ => Err(DataFusionError::Plan( "expected a single string literal argument to mv_dependencies".to_string(), )), @@ -249,16 +263,16 @@ pub fn mv_dependencies_plan( let plan = materialized_view.query().clone(); - let partition_cols = materialized_view.partition_columns(); - let partition_col_indices = plan + let static_partition_cols = materialized_view.static_partition_columns(); + let static_partition_col_indices = plan .schema() .fields() .iter() .enumerate() - .filter_map(|(i, f)| partition_cols.contains(f.name()).then_some(i)) + .filter_map(|(i, f)| static_partition_cols.contains(f.name()).then_some(i)) .collect(); - let pruned_plan_with_source_files = if partition_cols.is_empty() { + let pruned_plan_with_source_files = if static_partition_cols.is_empty() { get_source_files_all_partitions( materialized_view, &config_options.catalog, @@ -266,14 +280,14 @@ pub fn mv_dependencies_plan( ) } else { // Prune non-partition columns from all table scans - let pruned_plan = pushdown_projection_inexact(plan, &partition_col_indices)?; + let pruned_plan = pushdown_projection_inexact(plan, &static_partition_col_indices)?; // Now bubble up file metadata to the top of the plan push_up_file_metadata(pruned_plan, &config_options.catalog, row_metadata_registry) }?; // We now have data in the following form: - // (partition_col0, partition_col1, ..., __meta) + // (static_partition_col0, static_partition_col1, ..., __meta) // The last column is a list of structs containing the row metadata // We need to unnest it @@ -289,7 +303,7 @@ pub fn mv_dependencies_plan( LogicalPlanBuilder::from(pruned_plan_with_source_files) .unnest_column(files)? .project(vec![ - construct_target_path_from_partition_columns(materialized_view).alias("target"), + construct_target_path_from_static_partition_columns(materialized_view).alias("target"), get_field(files_col.clone(), "table_catalog").alias("source_table_catalog"), get_field(files_col.clone(), "table_schema").alias("source_table_schema"), get_field(files_col.clone(), "table_name").alias("source_table_name"), @@ -300,14 +314,16 @@ pub fn mv_dependencies_plan( .build() } -fn construct_target_path_from_partition_columns(materialized_view: &dyn Materialized) -> Expr { +fn construct_target_path_from_static_partition_columns( + materialized_view: &dyn Materialized, +) -> Expr { let table_path = lit(materialized_view.table_paths()[0] .as_str() // Trim the / (we'll add it back later if we need it) .trim_end_matches("/")); // Construct the paths for the build targets let mut hive_column_path_elements = materialized_view - .partition_columns() + .static_partition_columns() .iter() .map(|column_name| concat([lit(column_name.as_str()), lit("="), col(column_name)].to_vec())) .collect::>(); @@ -603,6 +619,21 @@ fn pushdown_projection_inexact(plan: LogicalPlan, indices: &HashSet) -> R .map(Expr::Column) .collect_vec(); + // GUARD: if after pushdown the set of relevant unnest columns is empty, + // avoid constructing an Unnest node with zero exec columns (which will + // later error in Unnest::try_new). Instead, simply project the + // desired output columns from the child plan (after pushing down the child projection). + // Related PR: https://github.com/apache/datafusion/pull/16632, after that we must + // also check for empty exec columns here. + if columns_to_unnest.is_empty() { + return LogicalPlanBuilder::from(pushdown_projection_inexact( + Arc::unwrap_or_clone(unnest.input), + &child_indices, + )?) + .project(columns_to_project)? + .build(); + } + LogicalPlanBuilder::from(pushdown_projection_inexact( Arc::unwrap_or_clone(unnest.input), &child_indices, @@ -922,7 +953,7 @@ mod test { use std::{any::Any, collections::HashSet, sync::Arc}; use arrow::util::pretty::pretty_format_batches; - use arrow_schema::SchemaRef; + use arrow_schema::{DataType, Field, FieldRef, Fields, SchemaRef}; use datafusion::{ assert_batches_eq, assert_batches_sorted_eq, catalog::{Session, TableProvider}, @@ -930,8 +961,9 @@ mod test { execution::session_state::SessionStateBuilder, prelude::{DataFrame, SessionConfig, SessionContext}, }; - use datafusion_common::{Column, Result, ScalarValue}; - use datafusion_expr::{Expr, JoinType, LogicalPlan, TableType}; + use datafusion_common::{Column, DFSchema, Result, ScalarValue}; + use datafusion_expr::builder::unnest; + use datafusion_expr::{EmptyRelation, Expr, JoinType, LogicalPlan, TableType}; use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; @@ -949,6 +981,7 @@ mod test { struct MockMaterializedView { table_path: ListingTableUrl, partition_columns: Vec, + static_partition_columns: Option>, // default = all partition columns query: LogicalPlan, file_ext: &'static str, } @@ -996,6 +1029,12 @@ mod test { fn query(&self) -> LogicalPlan { self.query.clone() } + + fn static_partition_columns(&self) -> Vec { + self.static_partition_columns + .clone() + .unwrap_or_else(|| self.partition_columns.clone()) + } } #[derive(Debug)] @@ -1165,12 +1204,14 @@ mod test { #[tokio::test] async fn test_deps() { + #[derive(Debug, Default)] struct TestCase { name: &'static str, query_to_analyze: &'static str, table_name: &'static str, - table_path: ListingTableUrl, + table_path: &'static str, partition_cols: Vec<&'static str>, + static_partition_cols: Option>, file_extension: &'static str, expected_output: Vec<&'static str>, file_metadata: &'static str, @@ -1182,7 +1223,7 @@ mod test { query_to_analyze: "SELECT column1 AS partition_column, concat(column2, column3) AS some_value FROM t1", table_name: "m1", - table_path: ListingTableUrl::parse("s3://m1/").unwrap(), + table_path: "s3://m1/", partition_cols: vec!["partition_column"], file_extension: ".parquet", expected_output: vec![ @@ -1209,12 +1250,13 @@ mod test { "| s3://m1/partition_column=2023/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, - TestCase { name: "omit 'special' partition columns", + TestCase { name: "omit internal metadata partition columns", query_to_analyze: "SELECT column1 AS partition_column, concat(column2, column3) AS some_value FROM t1", table_name: "m1", - table_path: ListingTableUrl::parse("s3://m1/").unwrap(), + table_path: "s3://m1/", partition_cols: vec!["partition_column"], file_extension: ".parquet", expected_output: vec![ @@ -1241,6 +1283,7 @@ mod test { "| s3://m1/partition_column=2023/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "transform year/month/day partition into timestamp partition", @@ -1250,7 +1293,7 @@ mod test { feed FROM t2", table_name: "m2", - table_path: ListingTableUrl::parse("s3://m2/").unwrap(), + table_path: "s3://m2/", partition_cols: vec!["timestamp", "feed"], file_extension: ".parquet", expected_output: vec![ @@ -1285,12 +1328,63 @@ mod test { "| s3://m2/timestamp=2024-12-06T00:00:00/feed=Z/ | 2023-07-10T16:00:00 | 2023-07-11T16:45:44 | true |", "+-----------------------------------------------+----------------------+-----------------------+----------+", ], + ..Default::default() + }, + TestCase { + name: "omit dynamic partition columns", + query_to_analyze: " + SELECT + year, + month, + day, + column2, + COUNT(*) AS ct + FROM t2 + GROUP BY year, month, day, column2 + ", + table_name: "m_dynamic", + table_path: "s3://m_dynamic/", + partition_cols: vec!["year", "month", "day", "column2"], + static_partition_cols: Some(vec!["year", "month", "day"]), + file_extension: ".parquet", + expected_output: vec![ + "+-------------------------------------------+----------------------+---------------------+-------------------+----------------------------------------------------------+----------------------+", + "| target | source_table_catalog | source_table_schema | source_table_name | source_uri | source_last_modified |", + "+-------------------------------------------+----------------------+---------------------+-------------------+----------------------------------------------------------+----------------------+", + "| s3://m_dynamic/year=2023/month=01/day=01/ | datafusion | test | t2 | s3://t2/year=2023/month=01/day=01/feed=A/data.01.parquet | 2023-07-11T16:29:26 |", + "| s3://m_dynamic/year=2023/month=01/day=02/ | datafusion | test | t2 | s3://t2/year=2023/month=01/day=02/feed=B/data.01.parquet | 2023-07-11T16:45:22 |", + "| s3://m_dynamic/year=2023/month=01/day=03/ | datafusion | test | t2 | s3://t2/year=2023/month=01/day=03/feed=C/data.01.parquet | 2023-07-11T16:45:44 |", + "| s3://m_dynamic/year=2024/month=12/day=04/ | datafusion | test | t2 | s3://t2/year=2024/month=12/day=04/feed=X/data.01.parquet | 2023-07-11T16:29:26 |", + "| s3://m_dynamic/year=2024/month=12/day=05/ | datafusion | test | t2 | s3://t2/year=2024/month=12/day=05/feed=Y/data.01.parquet | 2023-07-11T16:45:22 |", + "| s3://m_dynamic/year=2024/month=12/day=06/ | datafusion | test | t2 | s3://t2/year=2024/month=12/day=06/feed=Z/data.01.parquet | 2023-07-11T16:45:44 |", + "+-------------------------------------------+----------------------+---------------------+-------------------+----------------------------------------------------------+----------------------+", + ], + file_metadata: " + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2023/month=01/day=01/column2=1/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2023/month=01/day=02/column2=2/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2023/month=01/day=03/column2=3/data.01.parquet', '2023-07-10T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2024/month=12/day=04/column2=4/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2024/month=12/day=05/column2=5/data.01.parquet', '2023-07-12T16:00:00Z', 0), + ('datafusion', 'test', 'm_dynamic', 's3://m_dynamic/year=2024/month=12/day=06/column2=6/data.01.parquet', '2023-07-10T16:00:00Z', 0) + ", + expected_stale_files_output: vec![ + "+-------------------------------------------+----------------------+-----------------------+----------+", + "| target | target_last_modified | sources_last_modified | is_stale |", + "+-------------------------------------------+----------------------+-----------------------+----------+", + "| s3://m_dynamic/year=2023/month=01/day=01/ | 2023-07-12T16:00:00 | 2023-07-11T16:29:26 | false |", + "| s3://m_dynamic/year=2023/month=01/day=02/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:22 | false |", + "| s3://m_dynamic/year=2023/month=01/day=03/ | 2023-07-10T16:00:00 | 2023-07-11T16:45:44 | true |", + "| s3://m_dynamic/year=2024/month=12/day=04/ | 2023-07-12T16:00:00 | 2023-07-11T16:29:26 | false |", + "| s3://m_dynamic/year=2024/month=12/day=05/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:22 | false |", + "| s3://m_dynamic/year=2024/month=12/day=06/ | 2023-07-10T16:00:00 | 2023-07-11T16:45:44 | true |", + "+-------------------------------------------+----------------------+-----------------------+----------+", + ], }, TestCase { name: "materialized view has no partitions", query_to_analyze: "SELECT column1 AS output FROM t3", table_name: "m3", - table_path: ListingTableUrl::parse("s3://m3/").unwrap(), + table_path: "s3://m3/", partition_cols: vec![], file_extension: ".parquet", expected_output: vec![ @@ -1311,12 +1405,13 @@ mod test { "| s3://m3/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+----------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "simple equijoin on year", query_to_analyze: "SELECT * FROM t2 INNER JOIN t3 USING (year)", table_name: "m4", - table_path: ListingTableUrl::parse("s3://m4/").unwrap(), + table_path: "s3://m4/", partition_cols: vec!["year"], file_extension: ".parquet", expected_output: vec![ @@ -1345,6 +1440,7 @@ mod test { "| s3://m4/year=2024/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "triangular join on year", @@ -1357,7 +1453,7 @@ mod test { INNER JOIN t3 ON (t2.year <= t3.year)", table_name: "m4", - table_path: ListingTableUrl::parse("s3://m4/").unwrap(), + table_path: "s3://m4/", partition_cols: vec!["year"], file_extension: ".parquet", expected_output: vec![ @@ -1387,6 +1483,7 @@ mod test { "| s3://m4/year=2024/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, TestCase { name: "triangular left join, strict <", @@ -1399,7 +1496,7 @@ mod test { LEFT JOIN t3 ON (t2.year < t3.year)", table_name: "m4", - table_path: ListingTableUrl::parse("s3://m4/").unwrap(), + table_path: "s3://m4/", partition_cols: vec!["year"], file_extension: ".parquet", expected_output: vec![ @@ -1427,6 +1524,7 @@ mod test { "| s3://m4/year=2024/ | 2023-07-12T16:00:00 | 2023-07-11T16:45:44 | false |", "+--------------------+----------------------+-----------------------+----------+", ], + ..Default::default() }, ]; @@ -1447,7 +1545,7 @@ mod test { .enumerate() .filter_map(|(i, c)| case.partition_cols.contains(&c.name.as_str()).then_some(i)) .collect(); - println!("indices: {:?}", partition_col_indices); + println!("indices: {partition_col_indices:?}"); let analyzed = pushdown_projection_inexact(plan.clone(), &partition_col_indices)?; println!( "inexact projection pushdown:\n{}", @@ -1460,12 +1558,16 @@ mod test { // Register table with a decorator to exercise this functionality Arc::new(DecoratorTable { inner: Arc::new(MockMaterializedView { - table_path: case.table_path.clone(), + table_path: ListingTableUrl::parse(case.table_path).unwrap(), partition_columns: case .partition_cols .iter() .map(|s| s.to_string()) .collect(), + static_partition_columns: case + .static_partition_cols + .as_ref() + .map(|list| list.iter().map(|s| s.to_string()).collect()), query: plan, file_ext: case.file_extension, }), @@ -1482,6 +1584,15 @@ mod test { .collect() .await?; + context + .sql(&format!( + "SELECT * FROM file_metadata WHERE table_name = '{}'", + case.table_name + )) + .await? + .show() + .await?; + let df = context .sql(&format!( "SELECT * FROM mv_dependencies('{}', 'v2')", @@ -1720,19 +1831,19 @@ mod test { ", projection: &["year"], expected_plan: vec![ - "+--------------+--------------------------------------------------+", - "| plan_type | plan |", - "+--------------+--------------------------------------------------+", - "| logical_plan | Union |", - "| | Projection: coalesce(t1.year, t2.year) AS year |", - "| | Full Join: Using t1.year = t2.year |", - "| | SubqueryAlias: t1 |", - "| | Projection: t1.column1 AS year |", - "| | TableScan: t1 projection=[column1] |", - "| | SubqueryAlias: t2 |", - "| | TableScan: t2 projection=[year] |", - "| | TableScan: t3 projection=[year] |", - "+--------------+--------------------------------------------------+", + "+--------------+--------------------------------------------------------------------+", + "| plan_type | plan |", + "+--------------+--------------------------------------------------------------------+", + "| logical_plan | Union |", + "| | Projection: coalesce(CAST(t1.year AS Utf8View), t2.year) AS year |", + "| | Full Join: Using CAST(t1.year AS Utf8View) = t2.year |", + "| | SubqueryAlias: t1 |", + "| | Projection: t1.column1 AS year |", + "| | TableScan: t1 projection=[column1] |", + "| | SubqueryAlias: t2 |", + "| | TableScan: t2 projection=[year] |", + "| | TableScan: t3 projection=[year] |", + "+--------------+--------------------------------------------------------------------+", ], expected_output: vec![ "+------+", @@ -1837,4 +1948,86 @@ mod test { Ok(()) } + + #[test] + fn test_pushdown_unnest_guard_partition_date_only() -> Result<()> { + // This test simulates a simplified MV scenario: + // + // WITH events_structs AS ( + // SELECT id, date, unnest(events) AS evs + // FROM base_table + // ), + // flattened_events AS ( + // SELECT id, date, evs.event_type, evs.event_time + // FROM events_structs + // ), + // SELECT id, date, MAX(...) ... + // GROUP BY id, date + // + // The partition column is "date". During dependency plan + // building we only request "date" from this subtree, + // so pushdown_projection_inexact receives indices for + // the `date` column only. The guard must kick in: + // unnest(events) becomes unused, and the plan should + // collapse to just projecting `date` from the child. + + // 1. Build schema for base table + let id = Field::new("id", DataType::Utf8, true); + let date = Field::new("date", DataType::Utf8, true); + + // events: list> + let event_type = Field::new("event_type", DataType::Utf8, true); + let event_time = Field::new("event_time", DataType::Utf8, true); + let events_struct = Field::new( + "item", + DataType::Struct(Fields::from(vec![event_type, event_time])), + true, + ); + let events = Field::new( + "events", + DataType::List(FieldRef::from(Box::new(events_struct))), + true, + ); + + // Build DFSchema: (id, date, events) + let qualified_fields = vec![ + (None, Arc::new(id.clone())), + (None, Arc::new(date.clone())), + (None, Arc::new(events.clone())), + ]; + let df_schema = + DFSchema::new_with_metadata(qualified_fields, std::collections::HashMap::new())?; + + // 2. Build a dummy child plan (EmptyRelation with the schema) + let empty = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(df_schema), + }); + + // 3. Wrap it with an Unnest node on the "events" column + let events_col = Column::from_name("events"); + let unnest_plan = unnest(empty.clone(), vec![events_col.clone()])?; + + // 4. Partition column is "date". Look up its actual index dynamically. + let date_idx = unnest_plan + .schema() + .index_of_column(&Column::from_name("date"))?; + let mut indices: HashSet = HashSet::new(); + indices.insert(date_idx); + + // 5. Call pushdown_projection_inexact with {date} + let res = pushdown_projection_inexact(unnest_plan, &indices)?; + + // 6. Assert the result schema only contains `date` + let cols: Vec = res + .schema() + .fields() + .iter() + .map(|f| f.name().to_string()) + .collect(); + + assert_eq!(cols, vec!["date"]); + + Ok(()) + } } diff --git a/src/materialized/file_metadata.rs b/src/materialized/file_metadata.rs index d5a5c8e..85e5838 100644 --- a/src/materialized/file_metadata.rs +++ b/src/materialized/file_metadata.rs @@ -17,8 +17,9 @@ use arrow::array::{StringBuilder, TimestampNanosecondBuilder, UInt64Builder}; use arrow::record_batch::RecordBatch; -use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow_schema::{DataType, Field, TimeUnit}; use async_trait::async_trait; +use datafusion::arrow::datatypes::{Schema, SchemaRef}; use datafusion::catalog::SchemaProvider; use datafusion::catalog::{CatalogProvider, Session}; use datafusion::datasource::listing::ListingTableUrl; @@ -35,7 +36,7 @@ use datafusion::physical_plan::{ use datafusion::{ catalog::CatalogProviderList, execution::TaskContext, physical_plan::SendableRecordBatchStream, }; -use datafusion_common::{DataFusionError, Result, ScalarValue, ToDFSchema}; +use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{Expr, Operator, TableProviderFilterPushDown, TableType}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use futures::stream::{self, BoxStream}; @@ -103,7 +104,7 @@ impl TableProvider for FileMetadata { filters: &[Expr], limit: Option, ) -> Result> { - let dfschema = self.table_schema.clone().to_dfschema()?; + let dfschema = DFSchema::try_from(self.table_schema.as_ref().clone())?; let filters = filters .iter() @@ -226,7 +227,7 @@ impl ExecutionPlan for FileMetadataExec { .map(|record_batch| { record_batch .project(&projection) - .map_err(|e| DataFusionError::ArrowError(e, None)) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) }) .collect::>(); } @@ -858,7 +859,7 @@ mod test { .await?; ctx.sql( - "INSERT INTO t1 VALUES + "INSERT INTO t1 VALUES (1, '2021'), (2, '2022'), (3, '2023'), @@ -882,7 +883,7 @@ mod test { .await?; ctx.sql( - "INSERT INTO private.t1 VALUES + "INSERT INTO private.t1 VALUES (1, '2021', '01'), (2, '2022', '02'), (3, '2023', '03'), @@ -906,7 +907,7 @@ mod test { .await?; ctx.sql( - "INSERT INTO datafusion_mv.public.t3 VALUES + "INSERT INTO datafusion_mv.public.t3 VALUES (1, '2021-01-01'), (2, '2022-02-02'), (3, '2023-03-03'), @@ -929,8 +930,8 @@ mod test { ctx.sql( // Remove timestamps and trim (randomly generated) file names since they're not stable in tests "CREATE VIEW file_metadata_test_view AS SELECT - * EXCLUDE(file_path, last_modified), - regexp_replace(file_path, '/[^/]*$', '/') AS file_path + * EXCLUDE(file_path, last_modified), + regexp_replace(file_path, '/[^/]*$', '/') AS file_path FROM file_metadata", ) .await diff --git a/src/materialized/hive_partition.rs b/src/materialized/hive_partition.rs index 43ebfde..ad381cb 100644 --- a/src/materialized/hive_partition.rs +++ b/src/materialized/hive_partition.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use arrow::array::{Array, StringArray, StringBuilder}; -use arrow_schema::DataType; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ @@ -79,7 +79,7 @@ pub fn hive_partition_udf() -> ScalarUDF { ScalarUDF::new_from_impl(udf_impl) } -#[derive(Debug)] +#[derive(Debug, Hash, PartialEq, Eq)] struct HivePartitionUdf { pub signature: Signature, } diff --git a/src/materialized/row_metadata.rs b/src/materialized/row_metadata.rs index 6476f39..fa12cdf 100644 --- a/src/materialized/row_metadata.rs +++ b/src/materialized/row_metadata.rs @@ -98,7 +98,7 @@ impl RowMetadataRegistry { .get(&table.to_string()) .map(|o| Arc::clone(o.value())) .or_else(|| self.default_source.clone()) - .ok_or_else(|| DataFusionError::Internal(format!("No metadata source for {}", table))) + .ok_or_else(|| DataFusionError::Internal(format!("No metadata source for {table}"))) } } diff --git a/src/rewrite/exploitation.rs b/src/rewrite/exploitation.rs index b6fb761..a528220 100644 --- a/src/rewrite/exploitation.rs +++ b/src/rewrite/exploitation.rs @@ -23,7 +23,7 @@ use datafusion::catalog::TableProvider; use datafusion::datasource::provider_as_source; use datafusion::execution::context::SessionState; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; -use datafusion::physical_expr::{LexRequirement, PhysicalSortExpr, PhysicalSortRequirement}; +use datafusion::physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion::physical_expr_common::sort_expr::format_physical_sort_requirement_list; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; @@ -32,6 +32,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, Tre use datafusion_common::{DataFusionError, Result, TableReference}; use datafusion_expr::{Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore}; use datafusion_optimizer::OptimizerRule; +use datafusion_physical_expr::OrderingRequirements; use itertools::Itertools; use ordered_float::OrderedFloat; @@ -41,7 +42,9 @@ use super::normal_form::SpjNormalForm; use super::QueryRewriteOptions; /// A cost function. Used to evaluate the best physical plan among multiple equivalent choices. -pub type CostFn = Arc f64 + Send + Sync>; +pub type CostFn = Arc< + dyn for<'a> Fn(Box + 'a>) -> Vec + Send + Sync, +>; /// A logical optimizer that generates candidate logical plans in the form of [`OneOf`] nodes. #[derive(Debug)] @@ -56,10 +59,14 @@ impl ViewMatcher { for (resolved_table_ref, table) in super::util::list_tables(session_state.catalog_list().as_ref()).await? { - let Some(mv) = cast_to_materialized(table.as_ref()) else { + let Some(mv) = cast_to_materialized(table.as_ref())? else { continue; }; + // Filter the global pool of candidate MVs by their use_in_query_rewrite flag. + // MVs with use_in_query_rewrite = false are excluded from the pool entirely. + // Additional per-query filtering happens in ViewMatchingRewriter::f_down based on + // the queried table's rewrite_targets configuration. if !mv.config().use_in_query_rewrite { continue; } @@ -83,6 +90,59 @@ impl ViewMatcher { Ok(ViewMatcher { mv_plans }) } + + /// Returns the materialized views and their corresponding normal forms. + pub fn mv_plans(&self) -> &HashMap, SpjNormalForm)> { + &self.mv_plans + } + + /// Returns the list of MV candidates that would be considered for rewriting queries against the given table. + /// + /// This applies both stages of filtering: + /// 1. The global pool (already filtered by `use_in_query_rewrite` in `try_new_from_state`) + /// 2. The table's `rewrite_targets` configuration (if specified) + /// + /// # Arguments + /// * `table_reference` - The table being queried + /// + /// # Returns + /// A vector of table references for MVs that would be considered as rewrite candidates. + /// Returns all MVs from the global pool if the table doesn't specify `rewrite_targets`. + pub fn get_rewrite_candidates_for_table( + &self, + table_reference: &TableReference, + ) -> Result> { + // Check if the queried table has rewrite_targets specified + let rewrite_targets_filter: Option> = self + .mv_plans + .get(table_reference) + .and_then(|(table, _)| cast_to_materialized(table.as_ref()).ok().flatten()) + .and_then(|mv| { + mv.config().rewrite_targets.as_ref().map(|targets| { + targets + .iter() + .map(|s| TableReference::parse_str(s)) + .collect::>() + }) + }); + + // Filter MVs based on rewrite_targets + let candidates: Vec = self + .mv_plans + .keys() + .filter(|table_ref| { + // If rewrite_targets is specified, only include MVs in the list. + // If None, include all MVs from the global pool. + rewrite_targets_filter + .as_ref() + .map(|filter| filter.contains(table_ref)) + .unwrap_or(true) + }) + .cloned() + .collect(); + + Ok(candidates) + } } impl OptimizerRule for ViewMatcher { @@ -144,12 +204,22 @@ impl TreeNodeRewriter for ViewMatchingRewriter<'_> { Ok(form) => form, }; - // Generate candidate substitutions - let candidates = self + // Apply the second stage of MV filtering: get MVs that should be considered for this table. + // This uses get_rewrite_candidates_for_table which filters the global pool + // (already filtered by use_in_query_rewrite in try_new_from_state) based on + // the queried table's rewrite_targets config. + let candidate_mvs = self .parent - .mv_plans + .get_rewrite_candidates_for_table(&table_reference)?; + + // Generate candidate substitutions + let candidates = candidate_mvs .iter() - .filter_map(|(table_ref, (table, plan))| { + .filter_map(|table_ref| { + let (table, plan) = self.parent.mv_plans.get(table_ref)?; + Some((table_ref, table, plan)) + }) + .filter_map(|(table_ref, table, plan)| { // Only attempt rewrite if the view references our table in the first place plan.referenced_tables() .contains(&table_reference) @@ -272,6 +342,13 @@ pub struct OneOf { branches: Vec, } +impl OneOf { + /// Create a new OneOf node with the given branches. + pub fn new(branches: Vec) -> Self { + Self { branches } + } +} + impl UserDefinedLogicalNodeCore for OneOf { fn name(&self) -> &str { "OneOf" @@ -316,7 +393,7 @@ pub struct OneOfExec { // Optionally declare a required input ordering // This will inform DataFusion to add sorts to children, // which will improve cost estimation of candidates - required_input_ordering: Option, + required_input_ordering: Option, // Index of the candidate with the best cost best: usize, // Cost function to use in optimization @@ -337,7 +414,7 @@ impl OneOfExec { /// Create a new `OneOfExec` pub fn try_new( candidates: Vec>, - required_input_ordering: Option, + required_input_ordering: Option, cost: CostFn, ) -> Result { if candidates.is_empty() { @@ -345,9 +422,10 @@ impl OneOfExec { "can't create OneOfExec with empty children".to_string(), )); } - let best = candidates + + let best = cost(Box::new(candidates.iter().map(|c| c.as_ref()))) .iter() - .position_min_by_key(|candidate| OrderedFloat(cost(candidate.as_ref()))) + .position_min_by_key(|&cost| OrderedFloat(*cost)) .unwrap(); Ok(Self { @@ -366,7 +444,7 @@ impl OneOfExec { /// Modify this plan's required input ordering. /// Used for sort pushdown - pub fn with_required_input_ordering(self, requirement: Option) -> Self { + pub fn with_required_input_ordering(self, requirement: Option) -> Self { Self { required_input_ordering: requirement, ..self @@ -387,7 +465,7 @@ impl ExecutionPlan for OneOfExec { self.candidates[self.best].properties() } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.required_input_ordering.clone(); self.children().len()] } @@ -427,17 +505,24 @@ impl ExecutionPlan for OneOfExec { } fn statistics(&self) -> Result { - self.candidates[self.best].statistics() + self.candidates[self.best].partition_statistics(None) + } + + fn partition_statistics( + &self, + partition: Option, + ) -> Result { + self.candidates[self.best].partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + true } } impl DisplayAs for OneOfExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let costs = self - .children() - .iter() - .map(|c| (self.cost)(c.as_ref())) - .collect_vec(); + let costs = (self.cost)(Box::new(self.children().iter().map(|arc| arc.as_ref()))); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( @@ -448,12 +533,16 @@ impl DisplayAs for OneOfExec { format_physical_sort_requirement_list( &self .required_input_ordering - .clone() - .unwrap_or_default() - .into_iter() - .map(PhysicalSortExpr::from) - .map(PhysicalSortRequirement::from) - .collect_vec() + .as_ref() + .map(|req| { + req.clone() + .into_single() + .into_iter() + .map(PhysicalSortExpr::from) + .map(PhysicalSortRequirement::from) + .collect_vec() + }) + .unwrap_or_default(), ) ) } diff --git a/src/rewrite/normal_form.rs b/src/rewrite/normal_form.rs index 0d6f7d0..9cce9d8 100644 --- a/src/rewrite/normal_form.rs +++ b/src/rewrite/normal_form.rs @@ -233,22 +233,11 @@ impl SpjNormalForm { .map(|expr| predicate.normalize_expr(expr)) .collect(); - let mut referenced_tables = vec![]; - original_plan - .apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - referenced_tables.push(scan.table_name.clone()); - } - - Ok(TreeNodeRecursion::Continue) - }) - // No chance of error since we never return Err -- this unwrap is safe - .unwrap(); - Ok(Self { output_schema: Arc::clone(original_plan.schema()), output_exprs, - referenced_tables, + // Reuse referenced_tables collected during Predicate::new to avoid extra traversal + referenced_tables: predicate.referenced_tables.clone(), predicate, }) } @@ -258,32 +247,50 @@ impl SpjNormalForm { /// This is useful for rewriting queries to use materialized views. pub fn rewrite_from( &self, - mut other: &Self, + other: &Self, qualifier: TableReference, source: Arc, ) -> Result> { log::trace!("rewriting from {qualifier}"); + + // Cache columns() result to avoid repeated Vec allocation in the loop. + // DFSchema::columns() creates a new Vec on each call. + let output_columns = self.output_schema.columns(); + let mut new_output_exprs = Vec::with_capacity(self.output_exprs.len()); // check that our output exprs are sub-expressions of the other one's output exprs for (i, output_expr) in self.output_exprs.iter().enumerate() { - let new_output_expr = other - .predicate - .normalize_expr(output_expr.clone()) - .rewrite(&mut other)? - .data; - - // Check that all references to the original tables have been replaced. - // All remaining column expressions should be unqualified, which indicates - // that they refer to the output of the sub-plan (in this case the view) - if new_output_expr - .column_refs() - .iter() - .any(|c| c.relation.is_some()) - { - return Ok(None); - } + // Fast path for simple Column expressions (most common case). + // This avoids the expensive normalize_expr transform for columns. + let new_output_expr = if let Expr::Column(col) = output_expr { + let normalized_col = other.predicate.normalize_column(col); + match other.find_output_column(&normalized_col) { + Some(rewritten) => rewritten, + None => return Ok(None), // Column not found, can't rewrite + } + } else { + // Slow path: complex expressions need full transform + let new_output_expr = other + .predicate + .normalize_expr(output_expr.clone()) + .rewrite(&mut &*other)? + .data; + + // Check that all references to the original tables have been replaced. + // All remaining column expressions should be unqualified, which indicates + // that they refer to the output of the sub-plan (in this case the view) + if new_output_expr + .column_refs() + .iter() + .any(|c| c.relation.is_some()) + { + return Ok(None); + } + new_output_expr + }; - let column = &self.output_schema.columns()[i]; + // Use cached columns instead of calling .columns() on each iteration + let column = &output_columns[i]; new_output_exprs.push( new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()), ); @@ -310,7 +317,7 @@ impl SpjNormalForm { .into_iter() .chain(range_filters) .chain(residual_filters) - .map(|expr| expr.rewrite(&mut other).unwrap().data) + .map(|expr| expr.rewrite(&mut &*other).unwrap().data) .reduce(|a, b| a.and(b)); if all_filters @@ -329,6 +336,20 @@ impl SpjNormalForm { builder.project(new_output_exprs)?.build().map(Some) } + + /// Fast path: find a column in output_exprs and return rewritten expression. + /// This avoids full tree traversal for simple column lookups. + #[inline] + fn find_output_column(&self, col: &Column) -> Option { + self.output_exprs + .iter() + .position(|e| matches!(e, Expr::Column(c) if c == col)) + .map(|idx| { + Expr::Column(Column::new_unqualified( + self.output_schema.field(idx).name().clone(), + )) + }) + } } /// Stores information on filters from a Select-Project-Join plan. @@ -344,84 +365,95 @@ struct Predicate { ranges_by_equivalence_class: Vec>, /// Filter expressions that aren't column equality predicates or range filters. residuals: HashSet, + /// Tables referenced in this plan (collected during single-pass traversal) + referenced_tables: Vec, } impl Predicate { + /// Create a new Predicate by analyzing the given logical plan. + /// Uses single-pass traversal to collect schema, columns, filters, and referenced tables. fn new(plan: &LogicalPlan) -> Result { let mut schema = DFSchema::empty(); - plan.apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - let new_schema = DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - scan.source.schema().as_ref(), - )?; - schema = if schema.fields().is_empty() { - new_schema - } else { - schema.join(&new_schema)? - } - } + let mut columns_info: Vec<(Column, arrow::datatypes::DataType)> = Vec::new(); + let mut filters: Vec = Vec::new(); + let mut referenced_tables: Vec = Vec::new(); + + // Single traversal to collect everything + plan.apply(|node| { + match node { + LogicalPlan::TableScan(scan) => { + // Collect referenced table + referenced_tables.push(scan.table_name.clone()); - Ok(TreeNodeRecursion::Continue) - })?; + // Build schema + let new_schema = DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + scan.source.schema().as_ref(), + )?; + + // Collect columns with their data types + for (table_ref, field) in new_schema.iter() { + columns_info.push(( + Column::new(table_ref.cloned(), field.name()), + field.data_type().clone(), + )); + } - let mut new = Self { - schema, - eq_classes: vec![], - eq_class_idx_by_column: HashMap::default(), - ranges_by_equivalence_class: vec![], - residuals: HashSet::new(), - }; + // Merge schema + schema = if schema.fields().is_empty() { + new_schema + } else { + schema.join(&new_schema)? + }; - // Collect all referenced columns - plan.apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - for (i, (table_ref, field)) in DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - scan.source.schema().as_ref(), - )? - .iter() - .enumerate() - { - let column = Column::new(table_ref.cloned(), field.name()); - let data_type = field.data_type(); - new.eq_classes - .push(ColumnEquivalenceClass::new_singleton(column.clone())); - new.eq_class_idx_by_column.insert(column, i); - new.ranges_by_equivalence_class - .push(Some(Interval::make_unbounded(data_type)?)); + // Collect filters from TableScan + filters.extend(scan.filters.iter().cloned()); + } + LogicalPlan::Filter(filter) => { + filters.push(filter.predicate.clone()); } - } - - Ok(TreeNodeRecursion::Continue) - })?; - - // Collect any filters - plan.apply(|plan| { - let filters = match plan { - LogicalPlan::TableScan(scan) => scan.filters.as_slice(), - LogicalPlan::Filter(filter) => core::slice::from_ref(&filter.predicate), LogicalPlan::Join(_join) => { return Err(DataFusionError::Internal( "joins are not supported yet".to_string(), - )) + )); } - LogicalPlan::Projection(_) => &[], + LogicalPlan::Projection(_) => {} _ => { return Err(DataFusionError::Plan(format!( "unsupported logical plan: {}", - plan.display() - ))) + node.display() + ))); } - }; - - for expr in filters.iter().flat_map(split_conjunction) { - new.insert_conjuct(expr)?; } - Ok(TreeNodeRecursion::Continue) })?; + // Initialize data structures with known capacity + let num_columns = columns_info.len(); + let mut eq_classes = Vec::with_capacity(num_columns); + let mut eq_class_idx_by_column = HashMap::with_capacity(num_columns); + let mut ranges_by_equivalence_class = Vec::with_capacity(num_columns); + + for (i, (column, data_type)) in columns_info.into_iter().enumerate() { + eq_classes.push(ColumnEquivalenceClass::new_singleton(column.clone())); + eq_class_idx_by_column.insert(column, i); + ranges_by_equivalence_class.push(Some(Interval::make_unbounded(&data_type)?)); + } + + let mut new = Self { + schema, + eq_classes, + eq_class_idx_by_column, + ranges_by_equivalence_class, + residuals: HashSet::new(), + referenced_tables, + }; + + // Process all collected filters + for expr in filters.iter().flat_map(split_conjunction) { + new.insert_conjuct(expr)?; + } + Ok(new) } @@ -431,6 +463,17 @@ impl Predicate { .and_then(|&idx| self.eq_classes.get(idx)) } + /// Fast path: normalize a single Column without full tree traversal. + /// This is O(1) lookup instead of O(n) transform. + #[inline] + fn normalize_column(&self, col: &Column) -> Column { + if let Some(eq_class) = self.class_for_column(col) { + eq_class.columns.first().unwrap().clone() + } else { + col.clone() + } + } + /// Add a new column equivalence fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> { match ( @@ -455,6 +498,10 @@ impl Predicate { self.eq_classes[idx].columns.insert(c2.clone()); } (Some(&i), Some(&j)) => { + if i == j { + // The two columns are already in the same equivalence class. + return Ok(()); + } // We need to merge two existing column eq classes. // Delete the eq class with a larger index, @@ -523,8 +570,8 @@ impl Predicate { // so handling of open intervals is done by adding/subtracting the smallest increment. // However, there is not really a public API to do this, // other than the satisfy_greater method. - Operator::Lt => Ok( - match satisfy_greater( + Operator::Lt => { + let range_val = match satisfy_greater( &Interval::try_new(value.clone(), value.clone())?, &Interval::make_unbounded(&value.data_type())?, true, @@ -534,11 +581,19 @@ impl Predicate { *range = None; return Ok(()); } - }, - ), - // Same thing as above. - Operator::Gt => Ok( - match satisfy_greater( + }; + // If the type is not discrete (e.g. Utf8), satisfy_greater may return an unchanged value. + // This means the interval could not be tightened and it is unsafe to produce a closed interval + if range_val.upper() == &value { + Err(DataFusionError::Plan( + "cannot represent strict inequality as closed interval for non-discrete types".to_string(), + )) + } else { + Ok(range_val) + } + } + Operator::Gt => { + let range_val = match satisfy_greater( &Interval::make_unbounded(&value.data_type())?, &Interval::try_new(value.clone(), value.clone())?, true, @@ -548,8 +603,15 @@ impl Predicate { *range = None; return Ok(()); } - }, - ), + }; + if range_val.lower() == &value { + Err(DataFusionError::Plan( + "cannot represent strict inequality as closed interval for non-discrete types".to_string(), + )) + } else { + Ok(range_val) + } + } _ => Err(DataFusionError::Plan( "unsupported binary expression".to_string(), )), @@ -593,7 +655,7 @@ impl Predicate { /// Add a binary expression to our collection of filters. fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> { match (left, op, right) { - (Expr::Column(c), op, Expr::Literal(v)) => { + (Expr::Column(c), op, Expr::Literal(v, _)) => { if let Err(e) = self.add_range(c, &op, v) { // Add a range can fail in some cases, so just fallthrough log::debug!("failed to add range filter: {e}"); @@ -601,7 +663,7 @@ impl Predicate { return Ok(()); } } - (Expr::Literal(_), op, Expr::Column(_)) => { + (Expr::Literal(_, _), op, Expr::Column(_)) => { if let Some(swapped) = op.swap() { return self.insert_binary_expr(right, swapped, left); } @@ -714,14 +776,14 @@ impl Predicate { extra_range_filters.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(Expr::Column(other_column.clone())), op: Operator::Eq, - right: Box::new(Expr::Literal(range.lower().clone())), + right: Box::new(Expr::Literal(range.lower().clone(), None)), })) } else { if !range.lower().is_null() { extra_range_filters.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(Expr::Column(other_column.clone())), op: Operator::GtEq, - right: Box::new(Expr::Literal(range.lower().clone())), + right: Box::new(Expr::Literal(range.lower().clone(), None)), })) } @@ -729,7 +791,7 @@ impl Predicate { extra_range_filters.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(Expr::Column(other_column.clone())), op: Operator::LtEq, - right: Box::new(Expr::Literal(range.upper().clone())), + right: Box::new(Expr::Literal(range.upper().clone(), None)), })) } } @@ -773,6 +835,11 @@ impl Predicate { /// Rewrite all expressions in terms of their normal representatives /// with respect to this predicate's equivalence classes. fn normalize_expr(&self, e: Expr) -> Expr { + // Fast path: if it's a simple Column, avoid full transform traversal + if let Expr::Column(ref c) = e { + return Expr::Column(self.normalize_column(c)); + } + e.transform(&|e| { let c = match e { Expr::Column(c) => c, @@ -996,11 +1063,11 @@ mod test { ctx.sql(&format!( " CREATE EXTERNAL TABLE t1 ( - column1 VARCHAR, - column2 BIGINT, + column1 VARCHAR, + column2 BIGINT, column3 CHAR ) - STORED AS PARQUET + STORED AS PARQUET LOCATION '{}'", t1_path.path().to_string_lossy() )) @@ -1095,7 +1162,7 @@ mod test { assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref()); let expected = concat_batches( - &query_plan.schema().as_ref().clone().into(), + &query_plan.schema().inner().clone(), &context .execute_logical_plan(query_plan) .await? @@ -1104,7 +1171,7 @@ mod test { )?; let result = concat_batches( - &rewritten.schema().as_ref().clone().into(), + &rewritten.schema().inner().clone(), &context .execute_logical_plan(rewritten) .await? @@ -1144,11 +1211,16 @@ mod test { TestCase { name: "range filter + equality predicate", base: - "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'", + "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'", query: // Since column1 = column3 in the original view, // we are allowed to substitute column1 for column3 and vice versa. - "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'", + "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'", + }, + TestCase { + name: "range filter with inequality on non-discrete type", + base: "SELECT * FROM t1", + query: "SELECT column1 FROM t1 WHERE column1 < '2022'", }, TestCase { name: "duplicate expressions (X-209)", @@ -1205,4 +1277,95 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_predicate_new_collects_expected_data() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a table with known schema + ctx.sql( + "CREATE TABLE test_table ( + col1 INT, + col2 VARCHAR, + col3 DOUBLE + )", + ) + .await? + .collect() + .await?; + + // Create a plan with filters + let plan = ctx + .sql("SELECT col1, col2 FROM test_table WHERE col1 >= 10 AND col2 = col3") + .await? + .into_optimized_plan()?; + + let normal_form = SpjNormalForm::new(&plan)?; + + // Verify referenced_tables is collected + assert_eq!(normal_form.referenced_tables().len(), 1); + assert_eq!(normal_form.referenced_tables()[0].to_string(), "test_table"); + + // Verify output_exprs matches the projection (2 columns) + assert_eq!(normal_form.output_exprs().len(), 2); + + // Verify schema is preserved + assert_eq!(normal_form.output_schema().fields().len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_predicate_new_with_join_returns_error() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql("CREATE TABLE t1 (a INT, b INT)") + .await? + .collect() + .await?; + ctx.sql("CREATE TABLE t2 (c INT, d INT)") + .await? + .collect() + .await?; + + // Test that join returns an error as it's not supported yet + let plan = ctx + .sql("SELECT t1.a, t2.d FROM t1 JOIN t2 ON t1.b = t2.c WHERE t1.a >= 0 AND t2.d <= 100") + .await? + .into_optimized_plan()?; + + let result = SpjNormalForm::new(&plan); + + // Verify that join returns an error + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("joins are not supported yet")); + + Ok(()) + } + + #[tokio::test] + async fn test_predicate_new_with_range_filters() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql("CREATE TABLE range_test (x INT, y INT, z VARCHAR)") + .await? + .collect() + .await?; + + let plan = ctx + .sql("SELECT * FROM range_test WHERE x >= 10 AND x <= 100 AND y = 50") + .await? + .into_optimized_plan()?; + + let normal_form = SpjNormalForm::new(&plan)?; + + // Verify all columns are in output + assert_eq!(normal_form.output_exprs().len(), 3); + assert_eq!(normal_form.referenced_tables().len(), 1); + + Ok(()) + } } diff --git a/tests/materialized_listing_table.rs b/tests/materialized_listing_table.rs index 5ad9d25..7bba358 100644 --- a/tests/materialized_listing_table.rs +++ b/tests/materialized_listing_table.rs @@ -32,7 +32,9 @@ use datafusion::{ }, prelude::{SessionConfig, SessionContext}, }; -use datafusion_common::{Constraints, DataFusionError, ParamValues, ScalarValue, Statistics}; +use datafusion_common::{ + metadata::ScalarAndMetadata, Constraints, DataFusionError, ParamValues, ScalarValue, Statistics, +}; use datafusion_expr::{ col, dml::InsertOp, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, SortExpr, TableProviderFilterPushDown, TableType, @@ -185,7 +187,7 @@ async fn setup() -> Result { .await?; ctx.sql( - "INSERT INTO t1 VALUES + "INSERT INTO t1 VALUES (1, '2023-01-01', 'A'), (2, '2023-01-02', 'B'), (3, '2023-01-03', 'C'), @@ -251,7 +253,7 @@ async fn test_materialized_listing_table_incremental_maintenance() -> Result<()> // Insert another row into the source table ctx.sql( - "INSERT INTO t1 VALUES + "INSERT INTO t1 VALUES (7, '2024-12-07', 'W')", ) .await? @@ -352,12 +354,13 @@ impl MaterializedListingTable { file_sort_order: opts.file_sort_order, }); + let mut listing_table_config = ListingTableConfig::new(config.table_path); + if let Some(options) = options { + listing_table_config = listing_table_config.with_listing_options(options); + } + listing_table_config = listing_table_config.with_schema(Arc::new(file_schema)); Ok(MaterializedListingTable { - inner: ListingTable::try_new(ListingTableConfig { - table_paths: vec![config.table_path], - file_schema: Some(Arc::new(file_schema)), - options, - })?, + inner: ListingTable::try_new(listing_table_config)?, query: normalized_query, schema: normalized_schema, }) @@ -503,7 +506,7 @@ impl TableProvider for MaterializedListingTable { self.inner.get_table_definition() } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&self) -> Option> { // We _could_ return the LogicalPlan here, // but it will cause this table to be treated like a regular view // and the materialized results will not be used. @@ -549,7 +552,7 @@ impl TableProvider for MaterializedListingTable { fn parse_partition_values( path: &ObjectPath, partition_columns: &[(String, DataType)], -) -> Result, DataFusionError> { +) -> Result, DataFusionError> { let parts = path.parts().map(|part| part.to_owned()).collect::>(); let pairs = parts @@ -561,7 +564,7 @@ fn parse_partition_values( .iter() .map(|(column, datatype)| { let value = pairs.get(column.as_str()).copied().map(String::from); - ScalarValue::Utf8(value).cast_to(datatype) + ScalarAndMetadata::from(ScalarValue::Utf8(value)).cast_storage_to(datatype) }) .collect::, _>>()?; diff --git a/tests/query_rewrite_targets.rs b/tests/query_rewrite_targets.rs new file mode 100644 index 0000000..b8842e7 --- /dev/null +++ b/tests/query_rewrite_targets.rs @@ -0,0 +1,659 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the query rewrite targets feature +//! +//! This tests the two-stage filtering for query rewriting: +//! 1. First stage: `use_in_query_rewrite` filters the global pool of candidate MVs +//! 2. Second stage: `rewrite_targets` filters candidates for specific queries + +use std::any::Any; +use std::sync::Arc; + +use arrow_schema::SchemaRef; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::datasource::listing::ListingTableUrl; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::Result; +use datafusion_expr::{Expr, LogicalPlan, TableType}; +use datafusion_materialized_views::materialized::{ + cast_to_materialized, register_materialized, ListingTableLike, Materialized, +}; +use datafusion_materialized_views::rewrite::exploitation::ViewMatcher; +use datafusion_materialized_views::MaterializedConfig; +use datafusion_sql::TableReference; + +/// A mock materialized view for testing rewrite targets +#[derive(Debug)] +struct MockMaterializedView { + table_path: ListingTableUrl, + partition_columns: Vec, + query: LogicalPlan, + file_ext: &'static str, + config: MaterializedConfig, +} + +#[async_trait::async_trait] +impl TableProvider for MockMaterializedView { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::new(self.query.schema().as_arrow().clone()) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unimplemented!() + } +} + +impl ListingTableLike for MockMaterializedView { + fn table_paths(&self) -> Vec { + vec![self.table_path.clone()] + } + + fn partition_columns(&self) -> Vec { + self.partition_columns.clone() + } + + fn file_ext(&self) -> String { + self.file_ext.to_string() + } +} + +impl Materialized for MockMaterializedView { + fn query(&self) -> LogicalPlan { + self.query.clone() + } + + fn config(&self) -> MaterializedConfig { + self.config.clone() + } +} + +async fn setup() -> Result { + let _ = env_logger::builder().is_test(true).try_init(); + + register_materialized::(); + + let ctx = SessionContext::new(); + + // Create a base table + ctx.sql("CREATE TABLE base_table (id INT, value INT)") + .await? + .collect() + .await?; + + Ok(ctx) +} + +/// Helper to create a fully qualified table reference +/// DataFusion creates tables in the default datafusion.public schema +fn table_ref(name: &str) -> TableReference { + TableReference::full("datafusion", "public", name) +} + +#[tokio::test] +async fn test_use_in_query_rewrite_filters_global_pool() -> Result<()> { + // Test first stage filtering: use_in_query_rewrite filters the global pool + let ctx = setup().await?; + + let mv1_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 0") + .await? + .into_optimized_plan()?; + + let mv2_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 10") + .await? + .into_optimized_plan()?; + + // MV1: available for rewriting + let mv1 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv1/")?, + partition_columns: vec![], + query: mv1_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + // MV2: NOT available for rewriting (filtered out in first stage) + let mv2 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv2/")?, + partition_columns: vec![], + query: mv2_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: false, // Excluded from global pool + rewrite_targets: None, + }, + }); + + ctx.register_table("mv1", mv1 as Arc)?; + ctx.register_table("mv2", mv2 as Arc)?; + + // Create ViewMatcher - this applies first stage filtering + let view_matcher = ViewMatcher::try_new_from_state(&ctx.state()).await?; + let mv_plans = view_matcher.mv_plans(); + + // Only MV1 should be in the global pool + assert_eq!( + mv_plans.len(), + 1, + "Expected only 1 MV (mv1) in the pool; mv2 should be excluded by use_in_query_rewrite = false" + ); + + assert!( + mv_plans.contains_key(&table_ref("mv1")), + "mv1 should be in the global pool" + ); + assert!( + !mv_plans.contains_key(&table_ref("mv2")), + "mv2 should NOT be in the pool (use_in_query_rewrite = false)" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_rewrite_targets_none_uses_all_available_mvs() -> Result<()> { + // When rewrite_targets = None, all MVs in the pool are considered + let ctx = setup().await?; + + let base_mv_query = ctx + .sql("SELECT id, value FROM base_table") + .await? + .into_optimized_plan()?; + + let mv1_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 0") + .await? + .into_optimized_plan()?; + + let mv2_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 10") + .await? + .into_optimized_plan()?; + + // Base MV with rewrite_targets = None + let base_mv = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///base_mv/")?, + partition_columns: vec![], + query: base_mv_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, // All MVs from pool are considered + }, + }); + + let mv1 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv1/")?, + partition_columns: vec![], + query: mv1_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + let mv2 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv2/")?, + partition_columns: vec![], + query: mv2_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + ctx.register_table("base_mv", base_mv as Arc)?; + ctx.register_table("mv1", mv1 as Arc)?; + ctx.register_table("mv2", mv2 as Arc)?; + + let view_matcher = ViewMatcher::try_new_from_state(&ctx.state()).await?; + + // Get rewrite candidates for base_mv + let candidates = view_matcher.get_rewrite_candidates_for_table(&table_ref("base_mv"))?; + + // Should include all 3 MVs (base_mv itself, mv1, and mv2) + assert_eq!( + candidates.len(), + 3, + "Expected all 3 MVs to be candidates when rewrite_targets = None" + ); + assert!(candidates.contains(&table_ref("base_mv"))); + assert!(candidates.contains(&table_ref("mv1"))); + assert!(candidates.contains(&table_ref("mv2"))); + + Ok(()) +} + +#[tokio::test] +async fn test_rewrite_targets_empty_list() -> Result<()> { + // When rewrite_targets = Some(vec![]), no MVs are considered + let ctx = setup().await?; + + let base_mv_query = ctx + .sql("SELECT id, value FROM base_table") + .await? + .into_optimized_plan()?; + + let mv1_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 0") + .await? + .into_optimized_plan()?; + + // Base MV with empty rewrite_targets + let base_mv = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///base_mv/")?, + partition_columns: vec![], + query: base_mv_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: Some(vec![]), // No MVs considered + }, + }); + + let mv1 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv1/")?, + partition_columns: vec![], + query: mv1_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + ctx.register_table("base_mv", base_mv as Arc)?; + ctx.register_table("mv1", mv1 as Arc)?; + + let view_matcher = ViewMatcher::try_new_from_state(&ctx.state()).await?; + + // Get rewrite candidates for base_mv + let candidates = view_matcher.get_rewrite_candidates_for_table(&table_ref("base_mv"))?; + + // Should be empty because rewrite_targets = Some(vec![]) + assert_eq!( + candidates.len(), + 0, + "Expected no candidates when rewrite_targets = Some(vec![])" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_rewrite_targets_filters_to_specific_mvs() -> Result<()> { + // When rewrite_targets specifies MVs, only those are considered + let ctx = setup().await?; + + let base_mv_query = ctx + .sql("SELECT id, value FROM base_table") + .await? + .into_optimized_plan()?; + + let mv1_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 0") + .await? + .into_optimized_plan()?; + + let mv2_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 10") + .await? + .into_optimized_plan()?; + + let mv3_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 20") + .await? + .into_optimized_plan()?; + + // Base MV that only wants mv1 and mv3 + // Use fully qualified names in rewrite_targets + let base_mv = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///base_mv/")?, + partition_columns: vec![], + query: base_mv_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: Some(vec![ + "datafusion.public.mv1".to_string(), + "datafusion.public.mv3".to_string(), + ]), + }, + }); + + let mv1 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv1/")?, + partition_columns: vec![], + query: mv1_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + let mv2 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv2/")?, + partition_columns: vec![], + query: mv2_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + let mv3 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv3/")?, + partition_columns: vec![], + query: mv3_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + ctx.register_table("base_mv", base_mv as Arc)?; + ctx.register_table("mv1", mv1 as Arc)?; + ctx.register_table("mv2", mv2 as Arc)?; + ctx.register_table("mv3", mv3 as Arc)?; + + let view_matcher = ViewMatcher::try_new_from_state(&ctx.state()).await?; + + // Get rewrite candidates for base_mv + let candidates = view_matcher.get_rewrite_candidates_for_table(&table_ref("base_mv"))?; + + // Should only include mv1 and mv3 (not base_mv itself, not mv2) + assert_eq!( + candidates.len(), + 2, + "Expected only mv1 and mv3 as candidates" + ); + assert!(candidates.contains(&table_ref("mv1"))); + assert!(candidates.contains(&table_ref("mv3"))); + assert!( + !candidates.contains(&table_ref("mv2")), + "mv2 should not be included" + ); + assert!( + !candidates.contains(&table_ref("base_mv")), + "base_mv should not include itself" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_excluded_mv_not_candidate_even_if_in_targets() -> Result<()> { + // Test that an MV with use_in_query_rewrite = false is not a candidate, + // even if listed in rewrite_targets + let ctx = setup().await?; + + let base_mv_query = ctx + .sql("SELECT id, value FROM base_table") + .await? + .into_optimized_plan()?; + + let mv1_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 0") + .await? + .into_optimized_plan()?; + + let mv2_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 10") + .await? + .into_optimized_plan()?; + + // Base MV that lists both mv1 and mv2 in rewrite_targets + let base_mv = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///base_mv/")?, + partition_columns: vec![], + query: base_mv_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: Some(vec![ + "datafusion.public.mv1".to_string(), + "datafusion.public.mv2".to_string(), + ]), + }, + }); + + // MV1: in global pool + let mv1 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv1/")?, + partition_columns: vec![], + query: mv1_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + // MV2: NOT in global pool (excluded by use_in_query_rewrite = false) + let mv2 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv2/")?, + partition_columns: vec![], + query: mv2_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: false, // Excluded from global pool + rewrite_targets: None, + }, + }); + + ctx.register_table("base_mv", base_mv as Arc)?; + ctx.register_table("mv1", mv1 as Arc)?; + ctx.register_table("mv2", mv2 as Arc)?; + + let view_matcher = ViewMatcher::try_new_from_state(&ctx.state()).await?; + + // Verify mv2 is not in the global pool + let mv_plans = view_matcher.mv_plans(); + assert!(!mv_plans.contains_key(&table_ref("mv2"))); + + // Get rewrite candidates for base_mv + let candidates = view_matcher.get_rewrite_candidates_for_table(&table_ref("base_mv"))?; + + // Should only include mv1 (mv2 is excluded from pool) + assert_eq!( + candidates.len(), + 1, + "Expected only mv1 as candidate; mv2 should be excluded even though it's in rewrite_targets" + ); + assert!(candidates.contains(&table_ref("mv1"))); + assert!( + !candidates.contains(&table_ref("mv2")), + "mv2 should not be a candidate (use_in_query_rewrite = false)" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_config_default_values() -> Result<()> { + // Test default configuration values + let default_config = MaterializedConfig::default(); + + assert!( + default_config.use_in_query_rewrite, + "Default use_in_query_rewrite should be true" + ); + assert_eq!( + default_config.rewrite_targets, None, + "Default rewrite_targets should be None (consider all available MVs)" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_different_tables_different_rewrite_targets() -> Result<()> { + // Test that different tables can have different rewrite_targets + let ctx = setup().await?; + + let table1_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 0") + .await? + .into_optimized_plan()?; + + let table2_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 10") + .await? + .into_optimized_plan()?; + + let mv1_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 5") + .await? + .into_optimized_plan()?; + + let mv2_query = ctx + .sql("SELECT id, value FROM base_table WHERE value > 15") + .await? + .into_optimized_plan()?; + + // Table1: only uses mv1 + let table1 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///table1/")?, + partition_columns: vec![], + query: table1_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: Some(vec!["datafusion.public.mv1".to_string()]), + }, + }); + + // Table2: only uses mv2 + let table2 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///table2/")?, + partition_columns: vec![], + query: table2_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: Some(vec!["datafusion.public.mv2".to_string()]), + }, + }); + + let mv1 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv1/")?, + partition_columns: vec![], + query: mv1_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + let mv2 = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///mv2/")?, + partition_columns: vec![], + query: mv2_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: None, + }, + }); + + ctx.register_table("table1", table1 as Arc)?; + ctx.register_table("table2", table2 as Arc)?; + ctx.register_table("mv1", mv1 as Arc)?; + ctx.register_table("mv2", mv2 as Arc)?; + + let view_matcher = ViewMatcher::try_new_from_state(&ctx.state()).await?; + + // Check candidates for table1 + let table1_candidates = view_matcher.get_rewrite_candidates_for_table(&table_ref("table1"))?; + assert_eq!(table1_candidates.len(), 1); + assert!(table1_candidates.contains(&table_ref("mv1"))); + + // Check candidates for table2 + let table2_candidates = view_matcher.get_rewrite_candidates_for_table(&table_ref("table2"))?; + assert_eq!(table2_candidates.len(), 1); + assert!(table2_candidates.contains(&table_ref("mv2"))); + + Ok(()) +} + +#[tokio::test] +async fn test_rewrite_targets_config_storage() -> Result<()> { + // Test that rewrite_targets config is correctly stored and retrieved + let ctx = setup().await?; + + let targets = vec![ + "datafusion.public.mv1".to_string(), + "datafusion.public.mv2".to_string(), + ]; + + let mv_query = ctx + .sql("SELECT id, value FROM base_table") + .await? + .into_optimized_plan()?; + + let mv = Arc::new(MockMaterializedView { + table_path: ListingTableUrl::parse("file:///test_mv/")?, + partition_columns: vec![], + query: mv_query, + file_ext: ".parquet", + config: MaterializedConfig { + use_in_query_rewrite: true, + rewrite_targets: Some(targets.clone()), + }, + }); + + ctx.register_table("test_mv", mv as Arc)?; + + // Verify the config is properly stored and retrievable + let mv_provider = ctx.table_provider(table_ref("test_mv")).await?; + let mv_materialized = cast_to_materialized(mv_provider.as_ref())?.unwrap(); + + assert!(mv_materialized.config().use_in_query_rewrite); + assert_eq!( + mv_materialized.config().rewrite_targets, + Some(targets), + "rewrite_targets should be stored and retrievable" + ); + + Ok(()) +}