Skip to content

Commit cec14fc

Browse files
committed
initial
1 parent 2d5f016 commit cec14fc

5 files changed

Lines changed: 497 additions & 7 deletions

File tree

datafusion/core/src/execution/context/mod.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,10 +1434,21 @@ impl SessionContext {
14341434

14351435
match function_factory {
14361436
Some(f) => f.create(&state, stmt).await?,
1437-
_ => {
1438-
return Err(DataFusionError::Configuration(
1439-
"Function factory has not been configured".to_string(),
1440-
));
1437+
None => {
1438+
let is_macro = stmt
1439+
.params
1440+
.language
1441+
.as_ref()
1442+
.is_some_and(|l| l.value.eq_ignore_ascii_case("macro"));
1443+
if is_macro {
1444+
crate::execution::macro_factory::MacroFunctionFactory
1445+
.create(&state, stmt)
1446+
.await?
1447+
} else {
1448+
return Err(DataFusionError::Configuration(
1449+
"Function factory has not been configured".to_string(),
1450+
));
1451+
}
14411452
}
14421453
}
14431454
};
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::collections::HashMap;
19+
use std::sync::Arc;
20+
21+
use arrow::datatypes::{DataType, Field, FieldRef};
22+
23+
use datafusion_common::tree_node::{Transformed, TreeNode};
24+
use datafusion_common::{
25+
Column, DFSchema, Result, TableReference, internal_err, plan_datafusion_err, plan_err,
26+
};
27+
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
28+
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29+
use datafusion_expr::{
30+
ColumnarValue, CreateFunction, Expr, ExprSchemable, ReturnFieldArgs,
31+
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
32+
};
33+
34+
use super::context::{FunctionFactory, RegisterFunction};
35+
use crate::execution::session_state::SessionState;
36+
37+
/// Built-in [`FunctionFactory`] that handles `CREATE MACRO` statements
38+
/// by creating scalar UDFs that expand macro bodies via [`ScalarUDFImpl::simplify`].
39+
#[derive(Debug)]
40+
pub struct MacroFunctionFactory;
41+
42+
#[async_trait::async_trait]
43+
impl FunctionFactory for MacroFunctionFactory {
44+
async fn create(
45+
&self,
46+
_state: &SessionState,
47+
statement: CreateFunction,
48+
) -> Result<RegisterFunction> {
49+
let wrapper = ScalarMacroWrapper::try_from(statement)?;
50+
Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(wrapper))))
51+
}
52+
}
53+
54+
#[derive(Debug, PartialEq, Eq, Hash)]
55+
struct ScalarMacroWrapper {
56+
name: String,
57+
body: Expr,
58+
arg_names: Vec<String>,
59+
arg_defaults: Vec<Option<Expr>>,
60+
signature: Signature,
61+
}
62+
63+
impl TryFrom<CreateFunction> for ScalarMacroWrapper {
64+
type Error = datafusion_common::DataFusionError;
65+
66+
fn try_from(def: CreateFunction) -> Result<Self> {
67+
let body = def
68+
.params
69+
.function_body
70+
.ok_or_else(|| plan_datafusion_err!("Macro body is required"))?;
71+
72+
let (arg_names, arg_defaults) = match def.args {
73+
Some(args) => {
74+
let mut names = Vec::with_capacity(args.len());
75+
let mut defaults = Vec::with_capacity(args.len());
76+
for arg in args {
77+
let name = arg
78+
.name
79+
.ok_or_else(|| {
80+
plan_datafusion_err!("Macro arguments must be named")
81+
})?
82+
.value;
83+
names.push(name);
84+
defaults.push(arg.default_expr);
85+
}
86+
(names, defaults)
87+
}
88+
None => (vec![], vec![]),
89+
};
90+
91+
let signature =
92+
Signature::variadic_any(def.params.behavior.unwrap_or(Volatility::Volatile));
93+
94+
Ok(Self {
95+
name: def.name,
96+
body,
97+
arg_names,
98+
arg_defaults,
99+
signature,
100+
})
101+
}
102+
}
103+
104+
impl ScalarUDFImpl for ScalarMacroWrapper {
105+
fn name(&self) -> &str {
106+
&self.name
107+
}
108+
109+
fn signature(&self) -> &Signature {
110+
&self.signature
111+
}
112+
113+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
114+
Ok(DataType::Null)
115+
}
116+
117+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
118+
let mut qualified_fields: Vec<(Option<TableReference>, Arc<Field>)> =
119+
Vec::with_capacity(self.arg_names.len());
120+
for (i, name) in self.arg_names.iter().enumerate() {
121+
let data_type = if i < args.arg_fields.len() {
122+
args.arg_fields[i].data_type().clone()
123+
} else {
124+
DataType::Null
125+
};
126+
let nullable = i >= args.arg_fields.len() || args.arg_fields[i].is_nullable();
127+
qualified_fields.push((
128+
None,
129+
Arc::new(Field::new(name.clone(), data_type, nullable)),
130+
));
131+
}
132+
let schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?;
133+
let data_type = self.body.get_type(&schema)?;
134+
let nullable = self.body.nullable(&schema)?;
135+
Ok(Arc::new(Field::new(self.name(), data_type, nullable)))
136+
}
137+
138+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
139+
internal_err!("Macros should be simplified before execution")
140+
}
141+
142+
fn simplify(
143+
&self,
144+
args: Vec<Expr>,
145+
_info: &SimplifyContext,
146+
) -> Result<ExprSimplifyResult> {
147+
let required_count = self.arg_defaults.iter().take_while(|d| d.is_none()).count();
148+
let max_count = self.arg_names.len();
149+
150+
if args.len() < required_count || args.len() > max_count {
151+
return plan_err!(
152+
"Macro '{}' expects {} to {} arguments, got {}",
153+
self.name,
154+
required_count,
155+
max_count,
156+
args.len()
157+
);
158+
}
159+
160+
let replacement = self.body.clone().transform(|e| {
161+
if let Expr::Column(Column {
162+
relation: None,
163+
ref name,
164+
..
165+
}) = e
166+
&& let Some(pos) = self.arg_names.iter().position(|n| n == name)
167+
{
168+
let replacement_expr = if pos < args.len() {
169+
args[pos].clone()
170+
} else if let Some(default) = &self.arg_defaults[pos] {
171+
default.clone()
172+
} else {
173+
return plan_err!(
174+
"Missing argument '{}' for macro '{}'",
175+
name,
176+
self.name
177+
);
178+
};
179+
return Ok(Transformed::yes(replacement_expr));
180+
}
181+
Ok(Transformed::no(e))
182+
})?;
183+
184+
Ok(ExprSimplifyResult::Simplified(replacement.data))
185+
}
186+
187+
fn output_ordering(&self, _input: &[ExprProperties]) -> Result<SortProperties> {
188+
Ok(SortProperties::Unordered)
189+
}
190+
}
191+
192+
#[cfg(test)]
193+
mod tests {
194+
use super::*;
195+
use crate::prelude::SessionContext;
196+
197+
#[tokio::test]
198+
async fn test_scalar_macro_basic() -> Result<()> {
199+
let ctx = SessionContext::new();
200+
ctx.sql("CREATE MACRO add_macro(a, b) AS a + b")
201+
.await?
202+
.show()
203+
.await?;
204+
205+
let result = ctx.sql("SELECT add_macro(1, 2)").await?.collect().await?;
206+
assert_eq!(result[0].num_rows(), 1);
207+
let val = result[0]
208+
.column(0)
209+
.as_any()
210+
.downcast_ref::<arrow::array::Int64Array>()
211+
.unwrap()
212+
.value(0);
213+
assert_eq!(val, 3);
214+
Ok(())
215+
}
216+
217+
#[tokio::test]
218+
async fn test_scalar_macro_with_defaults() -> Result<()> {
219+
let ctx = SessionContext::new();
220+
ctx.sql("CREATE MACRO add_default(a, b := 5) AS a + b")
221+
.await?
222+
.show()
223+
.await?;
224+
225+
let result = ctx.sql("SELECT add_default(10)").await?.collect().await?;
226+
let val = result[0]
227+
.column(0)
228+
.as_any()
229+
.downcast_ref::<arrow::array::Int64Array>()
230+
.unwrap()
231+
.value(0);
232+
assert_eq!(val, 15);
233+
Ok(())
234+
}
235+
236+
#[tokio::test]
237+
async fn test_scalar_macro_or_replace() -> Result<()> {
238+
let ctx = SessionContext::new();
239+
ctx.sql("CREATE MACRO my_add(a, b) AS a + b")
240+
.await?
241+
.show()
242+
.await?;
243+
ctx.sql("CREATE OR REPLACE MACRO my_add(a, b) AS a + b + 1")
244+
.await?
245+
.show()
246+
.await?;
247+
248+
let result = ctx.sql("SELECT my_add(1, 2)").await?.collect().await?;
249+
let val = result[0]
250+
.column(0)
251+
.as_any()
252+
.downcast_ref::<arrow::array::Int64Array>()
253+
.unwrap()
254+
.value(0);
255+
assert_eq!(val, 4);
256+
Ok(())
257+
}
258+
259+
#[tokio::test]
260+
async fn test_scalar_macro_drop() -> Result<()> {
261+
let ctx = SessionContext::new();
262+
ctx.sql("CREATE MACRO drop_me(a) AS a + 1")
263+
.await?
264+
.show()
265+
.await?;
266+
ctx.sql("DROP FUNCTION drop_me").await?.show().await?;
267+
let result = ctx.sql("SELECT drop_me(1)").await;
268+
assert!(result.is_err());
269+
Ok(())
270+
}
271+
}

datafusion/core/src/execution/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
//! Shared state for query planning and execution.
1919
2020
pub mod context;
21+
/// Built-in support for DuckDB-compatible `CREATE MACRO` statements.
22+
pub mod macro_factory;
2123
pub mod session_state;
2224
pub use session_state::{SessionState, SessionStateBuilder};
2325

0 commit comments

Comments
 (0)