From 826e08793262e9e37364407f8d2405b40225a44d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 4 Dec 2025 15:33:03 -0700 Subject: [PATCH 1/3] add new framework --- .../src/execution/expressions/arithmetic.rs | 232 ++++++++++++++++++ native/core/src/execution/expressions/mod.rs | 1 + native/core/src/execution/planner.rs | 124 +++------- .../execution/planner/expression_registry.rs | 183 ++++++++++++++ native/core/src/execution/planner/traits.rs | 143 +++++++++++ 5 files changed, 595 insertions(+), 88 deletions(-) create mode 100644 native/core/src/execution/expressions/arithmetic.rs create mode 100644 native/core/src/execution/planner/expression_registry.rs create mode 100644 native/core/src/execution/planner/traits.rs diff --git a/native/core/src/execution/expressions/arithmetic.rs b/native/core/src/execution/expressions/arithmetic.rs new file mode 100644 index 0000000000..53a42bbe8c --- /dev/null +++ b/native/core/src/execution/expressions/arithmetic.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Arithmetic expression builders + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::logical_expr::Operator as DataFusionOperator; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::{expr::ExprStruct, Expr}; +use datafusion_comet_spark_expr::{create_modulo_expr, create_negate_expr, EvalMode}; + +use crate::execution::operators::ExecutionError; +use crate::execution::planner::traits::ExpressionBuilder; +use crate::execution::planner::{from_protobuf_eval_mode, BinaryExprOptions}; + +/// Builder for Add expressions +pub struct AddBuilder; + +impl ExpressionBuilder for AddBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &crate::execution::planner::PhysicalPlanner, + ) -> Result, ExecutionError> { + if let Some(ExprStruct::Add(expr)) = &spark_expr.expr_struct { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + planner.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Plus, + input_schema, + eval_mode, + ) + } else { + Err(ExecutionError::GeneralError( + "Expected Add expression".to_string(), + )) + } + } +} + +/// Builder for Subtract expressions +pub struct SubtractBuilder; + +impl ExpressionBuilder for SubtractBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &crate::execution::planner::PhysicalPlanner, + ) -> Result, ExecutionError> { + if let Some(ExprStruct::Subtract(expr)) = &spark_expr.expr_struct { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + planner.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Minus, + input_schema, + eval_mode, + ) + } else { + Err(ExecutionError::GeneralError( + "Expected Subtract expression".to_string(), + )) + } + } +} + +/// Builder for Multiply expressions +pub struct MultiplyBuilder; + +impl ExpressionBuilder for MultiplyBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &crate::execution::planner::PhysicalPlanner, + ) -> Result, ExecutionError> { + if let Some(ExprStruct::Multiply(expr)) = &spark_expr.expr_struct { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + planner.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Multiply, + input_schema, + eval_mode, + ) + } else { + Err(ExecutionError::GeneralError( + "Expected Multiply expression".to_string(), + )) + } + } +} + +/// Builder for Divide expressions +pub struct DivideBuilder; + +impl ExpressionBuilder for DivideBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &crate::execution::planner::PhysicalPlanner, + ) -> Result, ExecutionError> { + if let Some(ExprStruct::Divide(expr)) = &spark_expr.expr_struct { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + planner.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Divide, + input_schema, + eval_mode, + ) + } else { + Err(ExecutionError::GeneralError( + "Expected Divide expression".to_string(), + )) + } + } +} + +/// Builder for IntegralDivide expressions +pub struct IntegralDivideBuilder; + +impl ExpressionBuilder for IntegralDivideBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &crate::execution::planner::PhysicalPlanner, + ) -> Result, ExecutionError> { + if let Some(ExprStruct::IntegralDivide(expr)) = &spark_expr.expr_struct { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + planner.create_binary_expr_with_options( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Divide, + input_schema, + BinaryExprOptions { + is_integral_div: true, + }, + eval_mode, + ) + } else { + Err(ExecutionError::GeneralError( + "Expected IntegralDivide expression".to_string(), + )) + } + } +} + +/// Builder for Remainder expressions +pub struct RemainderBuilder; + +impl ExpressionBuilder for RemainderBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &crate::execution::planner::PhysicalPlanner, + ) -> Result, ExecutionError> { + if let Some(ExprStruct::Remainder(expr)) = &spark_expr.expr_struct { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let left = + planner.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let right = + planner.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + + let result = create_modulo_expr( + left, + right, + expr.return_type + .as_ref() + .map(crate::execution::serde::to_arrow_datatype) + .unwrap(), + input_schema, + eval_mode == EvalMode::Ansi, + &planner.session_ctx().state(), + ); + result.map_err(|e| ExecutionError::GeneralError(e.to_string())) + } else { + Err(ExecutionError::GeneralError( + "Expected Remainder expression".to_string(), + )) + } + } +} + +/// Builder for UnaryMinus expressions +pub struct UnaryMinusBuilder; + +impl ExpressionBuilder for UnaryMinusBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &crate::execution::planner::PhysicalPlanner, + ) -> Result, ExecutionError> { + if let Some(ExprStruct::UnaryMinus(expr)) = &spark_expr.expr_struct { + let child = planner.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + let result = create_negate_expr(child, expr.fail_on_error); + result.map_err(|e| ExecutionError::GeneralError(e.to_string())) + } else { + Err(ExecutionError::GeneralError( + "Expected UnaryMinus expression".to_string(), + )) + } + } +} diff --git a/native/core/src/execution/expressions/mod.rs b/native/core/src/execution/expressions/mod.rs index 9bb8fad456..84b930d059 100644 --- a/native/core/src/execution/expressions/mod.rs +++ b/native/core/src/execution/expressions/mod.rs @@ -17,6 +17,7 @@ //! Native DataFusion expressions +pub mod arithmetic; pub mod subquery; pub use datafusion_comet_spark_expr::EvalMode; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ccbd0b2508..38f92b5e68 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -17,12 +17,16 @@ //! Converts Spark physical plan to DataFusion physical plan +pub mod expression_registry; +pub mod traits; + use crate::execution::operators::IcebergScanExec; use crate::{ errors::ExpressionError, execution::{ expressions::subquery::Subquery, operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + planner::expression_registry::ExpressionRegistry, serde::to_arrow_datatype, shuffle::ShuffleWriterExec, }, @@ -62,8 +66,8 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_comet_spark_expr::{ - create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, create_modulo_expr, - create_negate_expr, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, + create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, + BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond, }; use iceberg::expr::Bind; @@ -142,7 +146,7 @@ struct JoinParameters { } #[derive(Default)] -struct BinaryExprOptions { +pub struct BinaryExprOptions { pub is_integral_div: bool, } @@ -154,6 +158,7 @@ pub struct PhysicalPlanner { exec_context_id: i64, partition: i32, session_ctx: Arc, + expression_registry: ExpressionRegistry, } impl Default for PhysicalPlanner { @@ -168,6 +173,7 @@ impl PhysicalPlanner { exec_context_id: TEST_EXEC_CONTEXT_ID, session_ctx, partition, + expression_registry: ExpressionRegistry::new(), } } @@ -176,6 +182,7 @@ impl PhysicalPlanner { exec_context_id, partition: self.partition, session_ctx: Arc::clone(&self.session_ctx), + expression_registry: self.expression_registry, } } @@ -184,6 +191,20 @@ impl PhysicalPlanner { &self.session_ctx } + /// Check if an expression is an arithmetic expression that should be handled by the registry + fn is_arithmetic_expression(expr_struct: &ExprStruct) -> bool { + matches!( + expr_struct, + ExprStruct::Add(_) + | ExprStruct::Subtract(_) + | ExprStruct::Multiply(_) + | ExprStruct::Divide(_) + | ExprStruct::IntegralDivide(_) + | ExprStruct::Remainder(_) + | ExprStruct::UnaryMinus(_) + ) + } + /// get DataFusion PartitionedFiles from a Spark FilePartition fn get_partitioned_files( &self, @@ -242,84 +263,17 @@ impl PhysicalPlanner { spark_expr: &Expr, input_schema: SchemaRef, ) -> Result, ExecutionError> { - match spark_expr.expr_struct.as_ref().unwrap() { - ExprStruct::Add(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Plus, - input_schema, - eval_mode, - ) + // Try to use the modular registry for arithmetic expressions first + if let Some(expr_struct) = spark_expr.expr_struct.as_ref() { + if Self::is_arithmetic_expression(expr_struct) { + return self + .expression_registry + .create_expr(spark_expr, input_schema, self); } - ExprStruct::Subtract(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Minus, - input_schema, - eval_mode, - ) - } - ExprStruct::Multiply(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Multiply, - input_schema, - eval_mode, - ) - } - ExprStruct::Divide(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Divide, - input_schema, - eval_mode, - ) - } - ExprStruct::IntegralDivide(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr_with_options( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Divide, - input_schema, - BinaryExprOptions { - is_integral_div: true, - }, - eval_mode, - ) - } - ExprStruct::Remainder(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - // TODO add support for EvalMode::TRY - // https://github.com/apache/datafusion-comet/issues/2021 - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = - self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + } - let result = create_modulo_expr( - left, - right, - expr.return_type.as_ref().map(to_arrow_datatype).unwrap(), - input_schema, - eval_mode == EvalMode::Ansi, - &self.session_ctx.state(), - ); - result.map_err(|e| GeneralError(e.to_string())) - } + // Fall back to the original monolithic match for other expressions + match spark_expr.expr_struct.as_ref().unwrap() { ExprStruct::Eq(expr) => { let left = self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; @@ -728,12 +682,6 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(NotExpr::new(child))) } - ExprStruct::UnaryMinus(expr) => { - let child: Arc = - self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; - let result = create_negate_expr(child, expr.fail_on_error); - result.map_err(|e| GeneralError(e.to_string())) - } ExprStruct::NormalizeNanAndZero(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); @@ -892,7 +840,7 @@ impl PhysicalPlanner { } } - fn create_binary_expr( + pub fn create_binary_expr( &self, left: &Expr, right: &Expr, @@ -913,7 +861,7 @@ impl PhysicalPlanner { } #[allow(clippy::too_many_arguments)] - fn create_binary_expr_with_options( + pub fn create_binary_expr_with_options( &self, left: &Expr, right: &Expr, @@ -2710,7 +2658,7 @@ fn rewrite_physical_expr( Ok(expr.rewrite(&mut rewriter).data()?) } -fn from_protobuf_eval_mode(value: i32) -> Result { +pub fn from_protobuf_eval_mode(value: i32) -> Result { match spark_expression::EvalMode::try_from(value)? { spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy), spark_expression::EvalMode::Try => Ok(EvalMode::Try), diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs new file mode 100644 index 0000000000..0c66cd6a88 --- /dev/null +++ b/native/core/src/execution/planner/expression_registry.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression registry for dispatching expression creation + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::{expr::ExprStruct, Expr}; + +use crate::execution::operators::ExecutionError; +use crate::execution::planner::traits::{ExpressionBuilder, ExpressionType}; + +/// Registry for expression builders +pub struct ExpressionRegistry { + builders: HashMap>, +} + +impl ExpressionRegistry { + /// Create a new expression registry with all builders registered + pub fn new() -> Self { + let mut registry = Self { + builders: HashMap::new(), + }; + + registry.register_all_expressions(); + registry + } + + /// Create a physical expression from a Spark protobuf expression + pub fn create_expr( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &super::PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr_type = Self::get_expression_type(spark_expr)?; + + if let Some(builder) = self.builders.get(&expr_type) { + builder.build(spark_expr, input_schema, planner) + } else { + Err(ExecutionError::GeneralError(format!( + "No builder registered for expression type: {:?}", + expr_type + ))) + } + } + + /// Register all expression builders + fn register_all_expressions(&mut self) { + // Register arithmetic expressions + self.register_arithmetic_expressions(); + + // TODO: Register other expression categories in future phases + // self.register_comparison_expressions(); + // self.register_string_expressions(); + // self.register_temporal_expressions(); + // etc. + } + + /// Register arithmetic expression builders + fn register_arithmetic_expressions(&mut self) { + use crate::execution::expressions::arithmetic::*; + + self.builders + .insert(ExpressionType::Add, Box::new(AddBuilder)); + self.builders + .insert(ExpressionType::Subtract, Box::new(SubtractBuilder)); + self.builders + .insert(ExpressionType::Multiply, Box::new(MultiplyBuilder)); + self.builders + .insert(ExpressionType::Divide, Box::new(DivideBuilder)); + self.builders.insert( + ExpressionType::IntegralDivide, + Box::new(IntegralDivideBuilder), + ); + self.builders + .insert(ExpressionType::Remainder, Box::new(RemainderBuilder)); + self.builders + .insert(ExpressionType::UnaryMinus, Box::new(UnaryMinusBuilder)); + } + + /// Extract expression type from Spark protobuf expression + fn get_expression_type(spark_expr: &Expr) -> Result { + match spark_expr.expr_struct.as_ref() { + Some(ExprStruct::Add(_)) => Ok(ExpressionType::Add), + Some(ExprStruct::Subtract(_)) => Ok(ExpressionType::Subtract), + Some(ExprStruct::Multiply(_)) => Ok(ExpressionType::Multiply), + Some(ExprStruct::Divide(_)) => Ok(ExpressionType::Divide), + Some(ExprStruct::IntegralDivide(_)) => Ok(ExpressionType::IntegralDivide), + Some(ExprStruct::Remainder(_)) => Ok(ExpressionType::Remainder), + Some(ExprStruct::UnaryMinus(_)) => Ok(ExpressionType::UnaryMinus), + + Some(ExprStruct::Eq(_)) => Ok(ExpressionType::Eq), + Some(ExprStruct::Neq(_)) => Ok(ExpressionType::Neq), + Some(ExprStruct::Lt(_)) => Ok(ExpressionType::Lt), + Some(ExprStruct::LtEq(_)) => Ok(ExpressionType::LtEq), + Some(ExprStruct::Gt(_)) => Ok(ExpressionType::Gt), + Some(ExprStruct::GtEq(_)) => Ok(ExpressionType::GtEq), + Some(ExprStruct::EqNullSafe(_)) => Ok(ExpressionType::EqNullSafe), + Some(ExprStruct::NeqNullSafe(_)) => Ok(ExpressionType::NeqNullSafe), + + Some(ExprStruct::And(_)) => Ok(ExpressionType::And), + Some(ExprStruct::Or(_)) => Ok(ExpressionType::Or), + Some(ExprStruct::Not(_)) => Ok(ExpressionType::Not), + + Some(ExprStruct::IsNull(_)) => Ok(ExpressionType::IsNull), + Some(ExprStruct::IsNotNull(_)) => Ok(ExpressionType::IsNotNull), + + Some(ExprStruct::BitwiseAnd(_)) => Ok(ExpressionType::BitwiseAnd), + Some(ExprStruct::BitwiseOr(_)) => Ok(ExpressionType::BitwiseOr), + Some(ExprStruct::BitwiseXor(_)) => Ok(ExpressionType::BitwiseXor), + Some(ExprStruct::BitwiseShiftLeft(_)) => Ok(ExpressionType::BitwiseShiftLeft), + Some(ExprStruct::BitwiseShiftRight(_)) => Ok(ExpressionType::BitwiseShiftRight), + + Some(ExprStruct::Bound(_)) => Ok(ExpressionType::Bound), + Some(ExprStruct::Unbound(_)) => Ok(ExpressionType::Unbound), + Some(ExprStruct::Literal(_)) => Ok(ExpressionType::Literal), + Some(ExprStruct::Cast(_)) => Ok(ExpressionType::Cast), + Some(ExprStruct::CaseWhen(_)) => Ok(ExpressionType::CaseWhen), + Some(ExprStruct::In(_)) => Ok(ExpressionType::In), + Some(ExprStruct::If(_)) => Ok(ExpressionType::If), + Some(ExprStruct::Substring(_)) => Ok(ExpressionType::Substring), + Some(ExprStruct::Like(_)) => Ok(ExpressionType::Like), + Some(ExprStruct::Rlike(_)) => Ok(ExpressionType::Rlike), + Some(ExprStruct::CheckOverflow(_)) => Ok(ExpressionType::CheckOverflow), + Some(ExprStruct::ScalarFunc(_)) => Ok(ExpressionType::ScalarFunc), + Some(ExprStruct::NormalizeNanAndZero(_)) => Ok(ExpressionType::NormalizeNanAndZero), + Some(ExprStruct::Subquery(_)) => Ok(ExpressionType::Subquery), + Some(ExprStruct::BloomFilterMightContain(_)) => { + Ok(ExpressionType::BloomFilterMightContain) + } + Some(ExprStruct::CreateNamedStruct(_)) => Ok(ExpressionType::CreateNamedStruct), + Some(ExprStruct::GetStructField(_)) => Ok(ExpressionType::GetStructField), + Some(ExprStruct::ToJson(_)) => Ok(ExpressionType::ToJson), + Some(ExprStruct::ToPrettyString(_)) => Ok(ExpressionType::ToPrettyString), + Some(ExprStruct::ListExtract(_)) => Ok(ExpressionType::ListExtract), + Some(ExprStruct::GetArrayStructFields(_)) => Ok(ExpressionType::GetArrayStructFields), + Some(ExprStruct::ArrayInsert(_)) => Ok(ExpressionType::ArrayInsert), + Some(ExprStruct::Rand(_)) => Ok(ExpressionType::Rand), + Some(ExprStruct::Randn(_)) => Ok(ExpressionType::Randn), + Some(ExprStruct::SparkPartitionId(_)) => Ok(ExpressionType::SparkPartitionId), + Some(ExprStruct::MonotonicallyIncreasingId(_)) => { + Ok(ExpressionType::MonotonicallyIncreasingId) + } + + Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour), + Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute), + Some(ExprStruct::Second(_)) => Ok(ExpressionType::Second), + Some(ExprStruct::TruncTimestamp(_)) => Ok(ExpressionType::TruncTimestamp), + + Some(other) => Err(ExecutionError::GeneralError(format!( + "Unsupported expression type: {:?}", + other + ))), + None => Err(ExecutionError::GeneralError( + "Expression struct is None".to_string(), + )), + } + } +} + +impl Default for ExpressionRegistry { + fn default() -> Self { + Self::new() + } +} diff --git a/native/core/src/execution/planner/traits.rs b/native/core/src/execution/planner/traits.rs new file mode 100644 index 0000000000..2fb1df4cf9 --- /dev/null +++ b/native/core/src/execution/planner/traits.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Core traits for the modular planner framework + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::Expr; +use jni::objects::GlobalRef; + +use crate::execution::operators::ScanExec; +use crate::execution::{operators::ExecutionError, spark_plan::SparkPlan}; + +/// Trait for building physical expressions from Spark protobuf expressions +pub trait ExpressionBuilder: Send + Sync { + /// Build a DataFusion physical expression from a Spark protobuf expression + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &super::PhysicalPlanner, + ) -> Result, ExecutionError>; +} + +/// Trait for building physical operators from Spark protobuf operators +pub trait OperatorBuilder: Send + Sync { + /// Build a Spark plan from a protobuf operator + fn build( + &self, + spark_plan: &datafusion_comet_proto::spark_operator::Operator, + inputs: &mut Vec>, + partition_count: usize, + planner: &super::PhysicalPlanner, + ) -> Result<(Vec, Arc), ExecutionError>; +} + +/// Enum to identify different expression types for registry dispatch +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpressionType { + // Arithmetic expressions + Add, + Subtract, + Multiply, + Divide, + IntegralDivide, + Remainder, + UnaryMinus, + + // Comparison expressions + Eq, + Neq, + Lt, + LtEq, + Gt, + GtEq, + EqNullSafe, + NeqNullSafe, + + // Logical expressions + And, + Or, + Not, + + // Null checks + IsNull, + IsNotNull, + + // Bitwise operations + BitwiseAnd, + BitwiseOr, + BitwiseXor, + BitwiseShiftLeft, + BitwiseShiftRight, + + // Other expressions + Bound, + Unbound, + Literal, + Cast, + CaseWhen, + In, + If, + Substring, + Like, + Rlike, + CheckOverflow, + ScalarFunc, + NormalizeNanAndZero, + Subquery, + BloomFilterMightContain, + CreateNamedStruct, + GetStructField, + ToJson, + ToPrettyString, + ListExtract, + GetArrayStructFields, + ArrayInsert, + Rand, + Randn, + SparkPartitionId, + MonotonicallyIncreasingId, + + // Time functions + Hour, + Minute, + Second, + TruncTimestamp, +} + +/// Enum to identify different operator types for registry dispatch +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperatorType { + Scan, + NativeScan, + IcebergScan, + Projection, + Filter, + HashAgg, + Limit, + Sort, + ShuffleWriter, + ParquetWriter, + Expand, + SortMergeJoin, + HashJoin, + Window, +} From 5dc8f645a53e5b5530f915bdb85b36eee5adc1f6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 4 Dec 2025 15:35:21 -0700 Subject: [PATCH 2/3] format and clippy --- .../src/execution/expressions/arithmetic.rs | 23 +++++++++++-------- native/core/src/execution/planner.rs | 5 ++-- native/core/src/execution/planner/traits.rs | 2 ++ 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/native/core/src/execution/expressions/arithmetic.rs b/native/core/src/execution/expressions/arithmetic.rs index 53a42bbe8c..ecd8c97acb 100644 --- a/native/core/src/execution/expressions/arithmetic.rs +++ b/native/core/src/execution/expressions/arithmetic.rs @@ -25,9 +25,12 @@ use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_proto::spark_expression::{expr::ExprStruct, Expr}; use datafusion_comet_spark_expr::{create_modulo_expr, create_negate_expr, EvalMode}; -use crate::execution::operators::ExecutionError; -use crate::execution::planner::traits::ExpressionBuilder; -use crate::execution::planner::{from_protobuf_eval_mode, BinaryExprOptions}; +use crate::execution::{ + operators::ExecutionError, + planner::{ + from_protobuf_eval_mode, traits::ExpressionBuilder, BinaryExprOptions, PhysicalPlanner, + }, +}; /// Builder for Add expressions pub struct AddBuilder; @@ -37,7 +40,7 @@ impl ExpressionBuilder for AddBuilder { &self, spark_expr: &Expr, input_schema: SchemaRef, - planner: &crate::execution::planner::PhysicalPlanner, + planner: &PhysicalPlanner, ) -> Result, ExecutionError> { if let Some(ExprStruct::Add(expr)) = &spark_expr.expr_struct { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; @@ -65,7 +68,7 @@ impl ExpressionBuilder for SubtractBuilder { &self, spark_expr: &Expr, input_schema: SchemaRef, - planner: &crate::execution::planner::PhysicalPlanner, + planner: &PhysicalPlanner, ) -> Result, ExecutionError> { if let Some(ExprStruct::Subtract(expr)) = &spark_expr.expr_struct { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; @@ -93,7 +96,7 @@ impl ExpressionBuilder for MultiplyBuilder { &self, spark_expr: &Expr, input_schema: SchemaRef, - planner: &crate::execution::planner::PhysicalPlanner, + planner: &PhysicalPlanner, ) -> Result, ExecutionError> { if let Some(ExprStruct::Multiply(expr)) = &spark_expr.expr_struct { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; @@ -121,7 +124,7 @@ impl ExpressionBuilder for DivideBuilder { &self, spark_expr: &Expr, input_schema: SchemaRef, - planner: &crate::execution::planner::PhysicalPlanner, + planner: &PhysicalPlanner, ) -> Result, ExecutionError> { if let Some(ExprStruct::Divide(expr)) = &spark_expr.expr_struct { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; @@ -149,7 +152,7 @@ impl ExpressionBuilder for IntegralDivideBuilder { &self, spark_expr: &Expr, input_schema: SchemaRef, - planner: &crate::execution::planner::PhysicalPlanner, + planner: &PhysicalPlanner, ) -> Result, ExecutionError> { if let Some(ExprStruct::IntegralDivide(expr)) = &spark_expr.expr_struct { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; @@ -180,7 +183,7 @@ impl ExpressionBuilder for RemainderBuilder { &self, spark_expr: &Expr, input_schema: SchemaRef, - planner: &crate::execution::planner::PhysicalPlanner, + planner: &PhysicalPlanner, ) -> Result, ExecutionError> { if let Some(ExprStruct::Remainder(expr)) = &spark_expr.expr_struct { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; @@ -217,7 +220,7 @@ impl ExpressionBuilder for UnaryMinusBuilder { &self, spark_expr: &Expr, input_schema: SchemaRef, - planner: &crate::execution::planner::PhysicalPlanner, + planner: &PhysicalPlanner, ) -> Result, ExecutionError> { if let Some(ExprStruct::UnaryMinus(expr)) = &spark_expr.expr_struct { let child = planner.create_expr(expr.child.as_ref().unwrap(), input_schema)?; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 38f92b5e68..61a5738eef 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -66,9 +66,8 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_comet_spark_expr::{ - create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, - BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, - SparkHour, SparkMinute, SparkSecond, + create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, + BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond, }; use iceberg::expr::Bind; diff --git a/native/core/src/execution/planner/traits.rs b/native/core/src/execution/planner/traits.rs index 2fb1df4cf9..a49ba1b95d 100644 --- a/native/core/src/execution/planner/traits.rs +++ b/native/core/src/execution/planner/traits.rs @@ -39,6 +39,7 @@ pub trait ExpressionBuilder: Send + Sync { } /// Trait for building physical operators from Spark protobuf operators +#[allow(dead_code)] pub trait OperatorBuilder: Send + Sync { /// Build a Spark plan from a protobuf operator fn build( @@ -125,6 +126,7 @@ pub enum ExpressionType { /// Enum to identify different operator types for registry dispatch #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[allow(dead_code)] pub enum OperatorType { Scan, NativeScan, From c0e4bf609214ec35ffb2d8df6a779dd2309d804d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 4 Dec 2025 15:39:31 -0700 Subject: [PATCH 3/3] use dynamic check for registered expression handlers --- native/core/src/execution/planner.rs | 26 ++++--------------- .../execution/planner/expression_registry.rs | 9 +++++++ 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 61a5738eef..b0a68c4219 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -190,20 +190,6 @@ impl PhysicalPlanner { &self.session_ctx } - /// Check if an expression is an arithmetic expression that should be handled by the registry - fn is_arithmetic_expression(expr_struct: &ExprStruct) -> bool { - matches!( - expr_struct, - ExprStruct::Add(_) - | ExprStruct::Subtract(_) - | ExprStruct::Multiply(_) - | ExprStruct::Divide(_) - | ExprStruct::IntegralDivide(_) - | ExprStruct::Remainder(_) - | ExprStruct::UnaryMinus(_) - ) - } - /// get DataFusion PartitionedFiles from a Spark FilePartition fn get_partitioned_files( &self, @@ -262,13 +248,11 @@ impl PhysicalPlanner { spark_expr: &Expr, input_schema: SchemaRef, ) -> Result, ExecutionError> { - // Try to use the modular registry for arithmetic expressions first - if let Some(expr_struct) = spark_expr.expr_struct.as_ref() { - if Self::is_arithmetic_expression(expr_struct) { - return self - .expression_registry - .create_expr(spark_expr, input_schema, self); - } + // Try to use the modular registry first - this automatically handles any registered expression types + if self.expression_registry.can_handle(spark_expr) { + return self + .expression_registry + .create_expr(spark_expr, input_schema, self); } // Fall back to the original monolithic match for other expressions diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index 0c66cd6a88..18d2040286 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -43,6 +43,15 @@ impl ExpressionRegistry { registry } + /// Check if the registry can handle a given expression type + pub fn can_handle(&self, spark_expr: &Expr) -> bool { + if let Ok(expr_type) = Self::get_expression_type(spark_expr) { + self.builders.contains_key(&expr_type) + } else { + false + } + } + /// Create a physical expression from a Spark protobuf expression pub fn create_expr( &self,