Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 164 additions & 15 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion_expr::{
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr_common::signature::TypeSignature;
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
use regex::Regex;
Expand Down Expand Up @@ -945,6 +946,7 @@ struct ScalarFunctionWrapper {
expr: Expr,
signature: Signature,
return_type: DataType,
defaults: Vec<Option<Expr>>,
}

impl ScalarUDFImpl for ScalarFunctionWrapper {
Expand Down Expand Up @@ -973,15 +975,19 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let replacement = Self::replacement(&self.expr, &args)?;
let replacement = Self::replacement(&self.expr, &args, &self.defaults)?;

Ok(ExprSimplifyResult::Simplified(replacement))
}
}

impl ScalarFunctionWrapper {
// replaces placeholders with actual arguments
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
fn replacement(
expr: &Expr,
args: &[Expr],
defaults: &[Option<Expr>],
) -> Result<Expr> {
let result = expr.clone().transform(|e| {
let r = match e {
Expr::Placeholder(placeholder) => {
Expand All @@ -990,10 +996,13 @@ impl ScalarFunctionWrapper {
if placeholder_position < args.len() {
Transformed::yes(args[placeholder_position].clone())
} else {
exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?
match defaults[placeholder_position] {
Some(ref default) => Transformed::yes(default.clone()),
None => exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?,
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Index Bounds in Placeholder Defaults Safety Trap

Potential index out of bounds when accessing defaults[placeholder_position]. If a placeholder references a parameter position that exceeds the number of function parameters (e.g., $5 when only 2 parameters are defined), placeholder_position will be >= defaults.len(), causing a panic. The code should check placeholder_position < defaults.len() before accessing the array.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The Cursor AI reviewer is correct that a index out of bounds error will be raised when the the SQL DDL statement uses a positional placeholder, like $5, when there are less than 5 function parameters. The finding prevents a runtime panic and abort of the application.

Comment on lines 997 to +1005
Copy link
Copy Markdown

@coderabbitai coderabbitai bot Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard defaults lookup to avoid panic on invalid placeholders.

If the SQL body references $N where N exceeds the declared parameter count, placeholder_position indexes past defaults, causing an immediate panic instead of a user-facing error. Previously this returned a clean “argument missing” error; now it will crash. Please restore graceful handling by checking the bounds before indexing.

-                    if placeholder_position < args.len() {
-                        Transformed::yes(args[placeholder_position].clone())
-                    } else {
-                        match defaults[placeholder_position] {
-                            Some(ref default) => Transformed::yes(default.clone()),
-                            None => exec_err!(
-                                "Function argument {} not provided, argument missing!",
-                                placeholder.id
-                            )?,
-                        }
-                    }
+                    if placeholder_position < args.len() {
+                        Transformed::yes(args[placeholder_position].clone())
+                    } else if placeholder_position < defaults.len() {
+                        match defaults[placeholder_position] {
+                            Some(ref default) => Transformed::yes(default.clone()),
+                            None => exec_err!(
+                                "Function argument {} not provided, argument missing!",
+                                placeholder.id
+                            )?,
+                        }
+                    } else {
+                        exec_err!(
+                            "Function argument {} referenced in the body exceeds the declared parameter list ({} arguments)",
+                            placeholder.id,
+                            defaults.len()
+                        )?
+                    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Transformed::yes(args[placeholder_position].clone())
} else {
exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?
match defaults[placeholder_position] {
Some(ref default) => Transformed::yes(default.clone()),
None => exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?,
}
Transformed::yes(args[placeholder_position].clone())
} else if placeholder_position < defaults.len() {
match defaults[placeholder_position] {
Some(ref default) => Transformed::yes(default.clone()),
None => exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?,
}
} else {
exec_err!(
"Function argument {} referenced in the body exceeds the declared parameter list ({} arguments)",
placeholder.id,
defaults.len()
)?
}
🤖 Prompt for AI Agents
In datafusion/core/tests/user_defined/user_defined_scalar_functions.rs around
lines 997 to 1005, the code indexes defaults[placeholder_position] without
checking bounds which can panic when a SQL body references $N beyond the
declared parameter count; add a bounds check before accessing defaults: if
placeholder_position is within defaults.len() then use
defaults[placeholder_position] as before, otherwise return the same
exec_err!("Function argument {} not provided, argument missing!",
placeholder.id)? to preserve the previous graceful error path instead of
allowing a panic.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The CodeRabbit AI reviewer is correct that a index out of bounds error will be raised when the the SQL DDL statement uses a positional placeholder, like $5, when there are less than 5 function parameters. The finding prevents a runtime panic and abort of the application.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martin-augment, thank you for confirming this finding! You're absolutely right that this would cause a runtime panic when a SQL function body references a placeholder beyond the declared parameter count (e.g., $5 when only 3 parameters are defined). The bounds check is critical to ensure users get a clear error message instead of crashing the application.

The suggested fix properly handles all three cases:

  1. Provided argument: Use the actual argument value
  2. Missing argument with default: Use the default expression
  3. Missing argument beyond parameter list: Return a clear error message

This is especially important for SQL UDFs where the function body is user-provided SQL text and typos or logic errors could easily reference non-existent parameters.


✏️ Learnings added
Learnt from: martin-augment
Repo: martin-augment/datafusion PR: 5
File: datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:997-1005
Timestamp: 2025-11-06T13:37:45.993Z
Learning: In datafusion/core/tests/user_defined/user_defined_scalar_functions.rs, when implementing ScalarFunctionWrapper::replacement, always bounds-check placeholder_position against defaults.len() before indexing into the defaults array to prevent panics when SQL UDF bodies reference placeholders like $N where N exceeds the declared parameter count.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

}
}
_ => Transformed::no(e),
Expand Down Expand Up @@ -1021,6 +1030,32 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
type Error = DataFusionError;

fn try_from(definition: CreateFunction) -> std::result::Result<Self, Self::Error> {
let args = definition.args.unwrap_or_default();
let defaults: Vec<Option<Expr>> =
args.iter().map(|a| a.default_expr.clone()).collect();
let signature: Signature = match defaults.iter().position(|v| v.is_some()) {
Some(pos) => {
let mut type_signatures: Vec<TypeSignature> = vec![];
// Generate all valid signatures
for n in pos..defaults.len() + 1 {
if n == 0 {
type_signatures.push(TypeSignature::Nullary)
} else {
type_signatures.push(TypeSignature::Exact(
args.iter().take(n).map(|a| a.data_type.clone()).collect(),
))
}
}
Signature::one_of(
type_signatures,
definition.params.behavior.unwrap_or(Volatility::Volatile),
)
}
None => Signature::exact(
args.iter().map(|a| a.data_type.clone()).collect(),
definition.params.behavior.unwrap_or(Volatility::Volatile),
),
};
Ok(Self {
name: definition.name,
expr: definition
Expand All @@ -1030,15 +1065,8 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
return_type: definition
.return_type
.expect("Return type has to be defined!"),
signature: Signature::exact(
definition
.args
.unwrap_or_default()
.into_iter()
.map(|a| a.data_type)
.collect(),
definition.params.behavior.unwrap_or(Volatility::Volatile),
),
signature,
defaults,
})
}
}
Expand Down Expand Up @@ -1112,6 +1140,127 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> {
let function_factory = Arc::new(CustomFunctionFactory::default());
let ctx = SessionContext::new().with_function_factory(function_factory.clone());

let sql = r#"
CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
RETURNS DOUBLE
RETURN $a + $b
"#;

assert!(ctx.sql(sql).await.is_ok());

let result = ctx
.sql("select better_add(2.0, 2.0)")
.await?
.collect()
.await?;

assert_batches_eq!(
&[
"+-----------------------------------+",
"| better_add(Float64(2),Float64(2)) |",
"+-----------------------------------+",
"| 4.0 |",
"+-----------------------------------+",
],
&result
);

// cannot mix named and positional style
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
RETURNS DOUBLE
RETURN $1 + $b
"#;
let err = ctx
.sql(bad_expression_sql)
.await
.expect_err("cannot mix named and positional style");
let expected = "Error during planning: All function arguments must use either named or positional style.";
assert!(expected.starts_with(&err.strip_backtrace()));

Ok(())
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> {
let function_factory = Arc::new(CustomFunctionFactory::default());
let ctx = SessionContext::new().with_function_factory(function_factory.clone());

let sql = r#"
CREATE FUNCTION better_add(a DOUBLE DEFAULT 2.0, b DOUBLE DEFAULT 2.0)
RETURNS DOUBLE
RETURN $a + $b
"#;

assert!(ctx.sql(sql).await.is_ok());

// Check all function arity supported
let result = ctx.sql("select better_add()").await?.collect().await?;

assert_batches_eq!(
&[
"+--------------+",
"| better_add() |",
"+--------------+",
"| 4.0 |",
"+--------------+",
],
&result
);

let result = ctx.sql("select better_add(2.0)").await?.collect().await?;

assert_batches_eq!(
&[
"+------------------------+",
"| better_add(Float64(2)) |",
"+------------------------+",
"| 4.0 |",
"+------------------------+",
],
&result
);

let result = ctx
.sql("select better_add(2.0, 2.0)")
.await?
.collect()
.await?;

assert_batches_eq!(
&[
"+-----------------------------------+",
"| better_add(Float64(2),Float64(2)) |",
"+-----------------------------------+",
"| 4.0 |",
"+-----------------------------------+",
],
&result
);

assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err());

// non-default argument cannot follow default argument
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE)
RETURNS DOUBLE
RETURN $a + $b
"#;
let err = ctx
.sql(bad_expression_sql)
.await
.expect_err("non-default argument cannot follow default argument");
let expected =
"Error during planning: Non-default arguments cannot follow default arguments.";
assert!(expected.starts_with(&err.strip_backtrace()));
Ok(())
}

/// Saves whatever is passed to it as a scalar function
#[derive(Debug, Default)]
struct RecordingFunctionFactory {
Expand Down
21 changes: 16 additions & 5 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}

/// Create a placeholder expression
/// This is the same as Postgres's prepare statement syntax in which a placeholder starts with `$` sign and then
/// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on.
/// Both named (`$foo`) and positional (`$1`, `$2`, ...) placeholder styles are supported.
fn create_placeholder_expr(
param: String,
param_data_types: &[FieldRef],
) -> Result<Expr> {
// Parse the placeholder as a number because it is the only support from sqlparser and postgres
// Try to parse the placeholder as a number. If the placeholder does not have a valid
// positional value, assume we have a named placeholder.
let index = param[1..].parse::<usize>();
let idx = match index {
Ok(0) => {
Expand All @@ -123,8 +123,19 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
return if param_data_types.is_empty() {
Ok(Expr::Placeholder(Placeholder::new_with_field(param, None)))
} else {
// when PREPARE Statement, param_data_types length is always 0
plan_err!("Invalid placeholder, not a number: {param}")
// FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but
// only CREATE FUNCTION currently supports named params. For now, we rewrite
// these to positional params.
let named_param_pos = param_data_types
.iter()
.position(|v| v.name() == &param[1..]);
match named_param_pos {
Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field(
format!("${}", pos + 1),
param_data_types.get(pos).cloned(),
))),
None => plan_err!("Unknown placeholder: {param}"),
}
};
}
};
Expand Down
38 changes: 37 additions & 1 deletion datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,28 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
None => None,
};
// Validate default arguments
let first_default = match args.as_ref() {
Some(arg) => arg.iter().position(|t| t.default_expr.is_some()),
None => None,
};
let last_non_default = match args.as_ref() {
Some(arg) => arg
.iter()
.rev()
.position(|t| t.default_expr.is_none())
.map(|reverse_pos| arg.len() - reverse_pos - 1),
None => None,
};
if let (Some(pos_default), Some(pos_non_default)) =
(first_default, last_non_default)
{
if pos_non_default > pos_default {
return plan_err!(
"Non-default arguments cannot follow default arguments."
);
}
}
// At the moment functions can't be qualified `schema.name`
let name = match &name.0[..] {
[] => exec_err!("Function should have name")?,
Expand All @@ -1233,9 +1255,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
//
let arg_types = args.as_ref().map(|arg| {
arg.iter()
.map(|t| Arc::new(Field::new("", t.data_type.clone(), true)))
.map(|t| {
let name = match t.name.clone() {
Some(name) => name.value,
None => "".to_string(),
};
Arc::new(Field::new(name, t.data_type.clone(), true))
})
.collect::<Vec<_>>()
});
// Validate parameter style
if let Some(ref fields) = arg_types {
let count_positional =
fields.iter().filter(|f| f.name() == "").count();
if !(count_positional == 0 || count_positional == fields.len()) {
return plan_err!("All function arguments must use either named or positional style.");
}
}
let mut planner_context = PlannerContext::new()
.with_prepare_param_data_types(arg_types.unwrap_or_default());

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn test_prepare_statement_to_plan_panic_param_format() {
assert_snapshot!(
logical_plan(sql).unwrap_err().strip_backtrace(),
@r###"
Error during planning: Invalid placeholder, not a number: $foo
Error during planning: Unknown placeholder: $foo
"###
);
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/prepare.slt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ statement error DataFusion error: SQL error: ParserError
PREPARE AS SELECT id, age FROM person WHERE age = $foo;

# param following a non-number, $foo, not supported
statement error Invalid placeholder, not a number: \$foo
statement error Unknown placeholder: \$foo
PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo;

# not specify table hence cannot specify columns
Expand Down
Loading