|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +use std::collections::HashMap; |
| 19 | +use std::sync::Arc; |
| 20 | + |
| 21 | +use arrow::datatypes::{DataType, Field, FieldRef}; |
| 22 | + |
| 23 | +use datafusion_common::tree_node::{Transformed, TreeNode}; |
| 24 | +use datafusion_common::{ |
| 25 | + Column, DFSchema, Result, TableReference, internal_err, plan_datafusion_err, plan_err, |
| 26 | +}; |
| 27 | +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; |
| 28 | +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; |
| 29 | +use datafusion_expr::{ |
| 30 | + ColumnarValue, CreateFunction, Expr, ExprSchemable, ReturnFieldArgs, |
| 31 | + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, |
| 32 | +}; |
| 33 | + |
| 34 | +use super::context::{FunctionFactory, RegisterFunction}; |
| 35 | +use crate::execution::session_state::SessionState; |
| 36 | + |
| 37 | +/// Built-in [`FunctionFactory`] that handles `CREATE MACRO` statements |
| 38 | +/// by creating scalar UDFs that expand macro bodies via [`ScalarUDFImpl::simplify`]. |
| 39 | +#[derive(Debug)] |
| 40 | +pub struct MacroFunctionFactory; |
| 41 | + |
| 42 | +#[async_trait::async_trait] |
| 43 | +impl FunctionFactory for MacroFunctionFactory { |
| 44 | + async fn create( |
| 45 | + &self, |
| 46 | + _state: &SessionState, |
| 47 | + statement: CreateFunction, |
| 48 | + ) -> Result<RegisterFunction> { |
| 49 | + let wrapper = ScalarMacroWrapper::try_from(statement)?; |
| 50 | + Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(wrapper)))) |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +#[derive(Debug, PartialEq, Eq, Hash)] |
| 55 | +struct ScalarMacroWrapper { |
| 56 | + name: String, |
| 57 | + body: Expr, |
| 58 | + arg_names: Vec<String>, |
| 59 | + arg_defaults: Vec<Option<Expr>>, |
| 60 | + signature: Signature, |
| 61 | +} |
| 62 | + |
| 63 | +impl TryFrom<CreateFunction> for ScalarMacroWrapper { |
| 64 | + type Error = datafusion_common::DataFusionError; |
| 65 | + |
| 66 | + fn try_from(def: CreateFunction) -> Result<Self> { |
| 67 | + let body = def |
| 68 | + .params |
| 69 | + .function_body |
| 70 | + .ok_or_else(|| plan_datafusion_err!("Macro body is required"))?; |
| 71 | + |
| 72 | + let (arg_names, arg_defaults) = match def.args { |
| 73 | + Some(args) => { |
| 74 | + let mut names = Vec::with_capacity(args.len()); |
| 75 | + let mut defaults = Vec::with_capacity(args.len()); |
| 76 | + for arg in args { |
| 77 | + let name = arg |
| 78 | + .name |
| 79 | + .ok_or_else(|| { |
| 80 | + plan_datafusion_err!("Macro arguments must be named") |
| 81 | + })? |
| 82 | + .value; |
| 83 | + names.push(name); |
| 84 | + defaults.push(arg.default_expr); |
| 85 | + } |
| 86 | + (names, defaults) |
| 87 | + } |
| 88 | + None => (vec![], vec![]), |
| 89 | + }; |
| 90 | + |
| 91 | + let signature = |
| 92 | + Signature::variadic_any(def.params.behavior.unwrap_or(Volatility::Volatile)); |
| 93 | + |
| 94 | + Ok(Self { |
| 95 | + name: def.name, |
| 96 | + body, |
| 97 | + arg_names, |
| 98 | + arg_defaults, |
| 99 | + signature, |
| 100 | + }) |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +impl ScalarUDFImpl for ScalarMacroWrapper { |
| 105 | + fn name(&self) -> &str { |
| 106 | + &self.name |
| 107 | + } |
| 108 | + |
| 109 | + fn signature(&self) -> &Signature { |
| 110 | + &self.signature |
| 111 | + } |
| 112 | + |
| 113 | + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
| 114 | + Ok(DataType::Null) |
| 115 | + } |
| 116 | + |
| 117 | + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> { |
| 118 | + let mut qualified_fields: Vec<(Option<TableReference>, Arc<Field>)> = |
| 119 | + Vec::with_capacity(self.arg_names.len()); |
| 120 | + for (i, name) in self.arg_names.iter().enumerate() { |
| 121 | + let data_type = if i < args.arg_fields.len() { |
| 122 | + args.arg_fields[i].data_type().clone() |
| 123 | + } else { |
| 124 | + DataType::Null |
| 125 | + }; |
| 126 | + let nullable = i >= args.arg_fields.len() || args.arg_fields[i].is_nullable(); |
| 127 | + qualified_fields.push(( |
| 128 | + None, |
| 129 | + Arc::new(Field::new(name.clone(), data_type, nullable)), |
| 130 | + )); |
| 131 | + } |
| 132 | + let schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; |
| 133 | + let data_type = self.body.get_type(&schema)?; |
| 134 | + let nullable = self.body.nullable(&schema)?; |
| 135 | + Ok(Arc::new(Field::new(self.name(), data_type, nullable))) |
| 136 | + } |
| 137 | + |
| 138 | + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| 139 | + internal_err!("Macros should be simplified before execution") |
| 140 | + } |
| 141 | + |
| 142 | + fn simplify( |
| 143 | + &self, |
| 144 | + args: Vec<Expr>, |
| 145 | + _info: &SimplifyContext, |
| 146 | + ) -> Result<ExprSimplifyResult> { |
| 147 | + let required_count = self.arg_defaults.iter().take_while(|d| d.is_none()).count(); |
| 148 | + let max_count = self.arg_names.len(); |
| 149 | + |
| 150 | + if args.len() < required_count || args.len() > max_count { |
| 151 | + return plan_err!( |
| 152 | + "Macro '{}' expects {} to {} arguments, got {}", |
| 153 | + self.name, |
| 154 | + required_count, |
| 155 | + max_count, |
| 156 | + args.len() |
| 157 | + ); |
| 158 | + } |
| 159 | + |
| 160 | + let replacement = self.body.clone().transform(|e| { |
| 161 | + if let Expr::Column(Column { |
| 162 | + relation: None, |
| 163 | + ref name, |
| 164 | + .. |
| 165 | + }) = e |
| 166 | + && let Some(pos) = self.arg_names.iter().position(|n| n == name) |
| 167 | + { |
| 168 | + let replacement_expr = if pos < args.len() { |
| 169 | + args[pos].clone() |
| 170 | + } else if let Some(default) = &self.arg_defaults[pos] { |
| 171 | + default.clone() |
| 172 | + } else { |
| 173 | + return plan_err!( |
| 174 | + "Missing argument '{}' for macro '{}'", |
| 175 | + name, |
| 176 | + self.name |
| 177 | + ); |
| 178 | + }; |
| 179 | + return Ok(Transformed::yes(replacement_expr)); |
| 180 | + } |
| 181 | + Ok(Transformed::no(e)) |
| 182 | + })?; |
| 183 | + |
| 184 | + Ok(ExprSimplifyResult::Simplified(replacement.data)) |
| 185 | + } |
| 186 | + |
| 187 | + fn output_ordering(&self, _input: &[ExprProperties]) -> Result<SortProperties> { |
| 188 | + Ok(SortProperties::Unordered) |
| 189 | + } |
| 190 | +} |
| 191 | + |
| 192 | +#[cfg(test)] |
| 193 | +mod tests { |
| 194 | + use super::*; |
| 195 | + use crate::prelude::SessionContext; |
| 196 | + |
| 197 | + #[tokio::test] |
| 198 | + async fn test_scalar_macro_basic() -> Result<()> { |
| 199 | + let ctx = SessionContext::new(); |
| 200 | + ctx.sql("CREATE MACRO add_macro(a, b) AS a + b") |
| 201 | + .await? |
| 202 | + .show() |
| 203 | + .await?; |
| 204 | + |
| 205 | + let result = ctx.sql("SELECT add_macro(1, 2)").await?.collect().await?; |
| 206 | + assert_eq!(result[0].num_rows(), 1); |
| 207 | + let val = result[0] |
| 208 | + .column(0) |
| 209 | + .as_any() |
| 210 | + .downcast_ref::<arrow::array::Int64Array>() |
| 211 | + .unwrap() |
| 212 | + .value(0); |
| 213 | + assert_eq!(val, 3); |
| 214 | + Ok(()) |
| 215 | + } |
| 216 | + |
| 217 | + #[tokio::test] |
| 218 | + async fn test_scalar_macro_with_defaults() -> Result<()> { |
| 219 | + let ctx = SessionContext::new(); |
| 220 | + ctx.sql("CREATE MACRO add_default(a, b := 5) AS a + b") |
| 221 | + .await? |
| 222 | + .show() |
| 223 | + .await?; |
| 224 | + |
| 225 | + let result = ctx.sql("SELECT add_default(10)").await?.collect().await?; |
| 226 | + let val = result[0] |
| 227 | + .column(0) |
| 228 | + .as_any() |
| 229 | + .downcast_ref::<arrow::array::Int64Array>() |
| 230 | + .unwrap() |
| 231 | + .value(0); |
| 232 | + assert_eq!(val, 15); |
| 233 | + Ok(()) |
| 234 | + } |
| 235 | + |
| 236 | + #[tokio::test] |
| 237 | + async fn test_scalar_macro_or_replace() -> Result<()> { |
| 238 | + let ctx = SessionContext::new(); |
| 239 | + ctx.sql("CREATE MACRO my_add(a, b) AS a + b") |
| 240 | + .await? |
| 241 | + .show() |
| 242 | + .await?; |
| 243 | + ctx.sql("CREATE OR REPLACE MACRO my_add(a, b) AS a + b + 1") |
| 244 | + .await? |
| 245 | + .show() |
| 246 | + .await?; |
| 247 | + |
| 248 | + let result = ctx.sql("SELECT my_add(1, 2)").await?.collect().await?; |
| 249 | + let val = result[0] |
| 250 | + .column(0) |
| 251 | + .as_any() |
| 252 | + .downcast_ref::<arrow::array::Int64Array>() |
| 253 | + .unwrap() |
| 254 | + .value(0); |
| 255 | + assert_eq!(val, 4); |
| 256 | + Ok(()) |
| 257 | + } |
| 258 | + |
| 259 | + #[tokio::test] |
| 260 | + async fn test_scalar_macro_drop() -> Result<()> { |
| 261 | + let ctx = SessionContext::new(); |
| 262 | + ctx.sql("CREATE MACRO drop_me(a) AS a + 1") |
| 263 | + .await? |
| 264 | + .show() |
| 265 | + .await?; |
| 266 | + ctx.sql("DROP FUNCTION drop_me").await?.show().await?; |
| 267 | + let result = ctx.sql("SELECT drop_me(1)").await; |
| 268 | + assert!(result.is_err()); |
| 269 | + Ok(()) |
| 270 | + } |
| 271 | +} |
0 commit comments