From f96b566d2399de9c384003ebe406c4b39aac59ef Mon Sep 17 00:00:00 2001 From: Albert Skalt Date: Sun, 8 Feb 2026 19:39:01 +0300 Subject: [PATCH] Provide session to the udtf call This patch adds the passing of the current session to the UDTF call. This helps implement session-dependent functions, for example, a function that returns the list of registered tables. --- datafusion-cli/src/functions.rs | 14 +- datafusion-examples/README.md | 21 +-- datafusion-examples/examples/udf/main.rs | 8 +- .../examples/udf/simple_udtf.rs | 5 +- .../examples/udf/table_list_udtf.rs | 128 ++++++++++++++++++ datafusion/catalog/src/table.rs | 37 ++++- .../core/src/execution/session_state.rs | 7 +- .../user_defined_table_functions.rs | 7 +- datafusion/ffi/src/udtf.rs | 98 +++++++++++++- .../functions-table/src/generate_series.rs | 13 +- .../functions/adding-udfs.md | 15 +- 11 files changed, 311 insertions(+), 42 deletions(-) create mode 100644 datafusion-examples/examples/udf/table_list_udtf.rs diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 67f3dc28269ef..bda5f1491be2c 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -31,7 +31,7 @@ use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion::catalog::{Session, TableFunctionImpl}; +use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl}; use datafusion::common::{Column, plan_err}; use datafusion::datasource::TableProvider; use datafusion::datasource::memory::MemorySourceConfig; @@ -326,7 +326,8 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; let filename = match exprs.first() { Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") @@ -517,7 +518,8 @@ impl MetadataCacheFunc { } impl TableFunctionImpl for MetadataCacheFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; if !exprs.is_empty() { return plan_err!("metadata_cache should have no arguments"); } @@ -635,7 +637,8 @@ impl StatisticsCacheFunc { } impl TableFunctionImpl for StatisticsCacheFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; if !exprs.is_empty() { return plan_err!("statistics_cache should have no arguments"); } @@ -770,7 +773,8 @@ impl ListFilesCacheFunc { } impl TableFunctionImpl for ListFilesCacheFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; if !exprs.is_empty() { return plan_err!("list_files_cache should have no arguments"); } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index ebb1bc5fa6fca..d9724de1e5baf 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -208,13 +208,14 @@ cargo run --example dataframe -- dataframe #### Category: Single Process -| Subcommand | File Path | Description | -| ---------- | ------------------------------------------------------- | ----------------------------------------------- | -| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) | -| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) | -| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) | -| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function | -| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example | -| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example | -| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example | -| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example | +| Subcommand | File Path | Description | +| --------------- | ----------------------------------------------------------- | ----------------------------------------------- | +| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) | +| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) | +| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) | +| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function | +| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example | +| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example | +| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example | +| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example | +| table_list_udtf | [`udf/table_list_udtf.rs`](examples/udf/table_list_udtf.rs) | Session-aware UDTF table list example | diff --git a/datafusion-examples/examples/udf/main.rs b/datafusion-examples/examples/udf/main.rs index e024e466ab07e..89f3fd801deec 100644 --- a/datafusion-examples/examples/udf/main.rs +++ b/datafusion-examples/examples/udf/main.rs @@ -21,7 +21,7 @@ //! //! ## Usage //! ```bash -//! cargo run --example udf -- [all|adv_udaf|adv_udf|adv_udwf|async_udf|udaf|udf|udtf|udwf] +//! cargo run --example udf -- [all|adv_udaf|adv_udf|adv_udwf|async_udf|udaf|udf|udtf|udwf|table_list_udtf] //! ``` //! //! Each subcommand runs a corresponding example: @@ -50,6 +50,9 @@ //! //! - `udwf` //! (file: simple_udwf.rs, desc: Simple UDWF example) +//! +//! - `table_list_udtf` +//! (file: table_list_udtf.rs, desc: Session-aware UDTF table list example) mod advanced_udaf; mod advanced_udf; @@ -59,6 +62,7 @@ mod simple_udaf; mod simple_udf; mod simple_udtf; mod simple_udwf; +mod table_list_udtf; use datafusion::error::{DataFusionError, Result}; use strum::{IntoEnumIterator, VariantNames}; @@ -76,6 +80,7 @@ enum ExampleKind { Udaf, Udwf, Udtf, + TableListUdtf, } impl ExampleKind { @@ -101,6 +106,7 @@ impl ExampleKind { ExampleKind::Udf => simple_udf::simple_udf().await?, ExampleKind::Udtf => simple_udtf::simple_udtf().await?, ExampleKind::Udwf => simple_udwf::simple_udwf().await?, + ExampleKind::TableListUdtf => table_list_udtf::table_list_udtf().await?, } Ok(()) diff --git a/datafusion-examples/examples/udf/simple_udtf.rs b/datafusion-examples/examples/udf/simple_udtf.rs index ee2615c4a5ac1..df91fbd34a011 100644 --- a/datafusion-examples/examples/udf/simple_udtf.rs +++ b/datafusion-examples/examples/udf/simple_udtf.rs @@ -27,7 +27,7 @@ use arrow::csv::reader::Format; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::{Session, TableFunctionImpl}; +use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl}; use datafusion::common::{ScalarValue, plan_err}; use datafusion::datasource::TableProvider; use datafusion::datasource::memory::MemorySourceConfig; @@ -135,7 +135,8 @@ impl TableProvider for LocalCsvTable { struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; let Some(Expr::Literal(ScalarValue::Utf8(Some(path)), _)) = exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; diff --git a/datafusion-examples/examples/udf/table_list_udtf.rs b/datafusion-examples/examples/udf/table_list_udtf.rs new file mode 100644 index 0000000000000..c14efb808a921 --- /dev/null +++ b/datafusion-examples/examples/udf/table_list_udtf.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::sync::{Arc, LazyLock}; + +use arrow::array::{RecordBatch, StringBuilder}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::{ + catalog::{MemTable, TableFunctionArgs, TableFunctionImpl, TableProvider}, + common::Result, + execution::SessionState, + prelude::SessionContext, +}; +use datafusion_common::{DataFusionError, plan_err}; +use tokio::{runtime::Handle, task::block_in_place}; + +const FUNCTION_NAME: &str = "table_list"; + +// The example shows, how to create UDTF that depends on the session state. +// Defines a `table_list` UDTF that returns a list of tables within the provided session. + +pub async fn table_list_udtf() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_udtf(FUNCTION_NAME, Arc::new(TableListUdtf)); + + // Register different kinds of tables. + ctx.sql("create view v as select 1") + .await? + .collect() + .await?; + ctx.sql("create table t(a int)").await?.collect().await?; + + // Print results. + ctx.sql("select * from table_list()").await?.show().await?; + + Ok(()) +} + +#[derive(Debug, Default)] +struct TableListUdtf; + +static SCHEMA: LazyLock = LazyLock::new(|| { + SchemaRef::new(Schema::new(vec![ + Field::new("catalog", DataType::Utf8, false), + Field::new("schema", DataType::Utf8, false), + Field::new("table", DataType::Utf8, false), + Field::new("type", DataType::Utf8, false), + ])) +}); + +impl TableFunctionImpl for TableListUdtf { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + if !args.args.is_empty() { + return plan_err!( + "{}: unexpected number of arguments: {}, expected: 0", + FUNCTION_NAME, + args.args.len() + ); + } + let state = args + .session + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("failed to downcast state".into()) + })?; + + let mut catalogs = StringBuilder::new(); + let mut schemas = StringBuilder::new(); + let mut tables = StringBuilder::new(); + let mut types = StringBuilder::new(); + + let catalog_list = state.catalog_list(); + for catalog_name in catalog_list.catalog_names() { + let Some(catalog) = catalog_list.catalog(&catalog_name) else { + continue; + }; + for schema_name in catalog.schema_names() { + let Some(schema) = catalog.schema(&schema_name) else { + continue; + }; + for table_name in schema.table_names() { + let Some(provider) = block_in_place(|| { + Handle::current().block_on(schema.table(&table_name)) + })? + else { + continue; + }; + catalogs.append_value(catalog_name.clone()); + schemas.append_value(schema_name.clone()); + tables.append_value(table_name.clone()); + types.append_value(provider.table_type().to_string()) + } + } + } + + let batch = RecordBatch::try_new( + Arc::clone(&SCHEMA), + vec![ + Arc::new(catalogs.finish()), + Arc::new(schemas.finish()), + Arc::new(tables.finish()), + Arc::new(types.finish()), + ], + )?; + + Ok(Arc::new(MemTable::try_new( + batch.schema(), + vec![vec![batch]], + )?)) + } +} diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 361589c5b6c1c..87a7b0c0fdc2a 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use crate::session::Session; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::Result; use datafusion_common::{Constraints, Statistics, not_impl_err}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::Expr; use datafusion_expr::dml::InsertOp; @@ -507,10 +507,30 @@ pub trait TableProviderFactory: Debug + Sync + Send { ) -> Result>; } +/// Describes arguments provided to the table function call. +pub struct TableFunctionArgs<'a> { + /// Call arguments. + pub args: &'a [Expr], + /// Session within which the function is called. + pub session: &'a dyn Session, +} + /// A trait for table function implementations pub trait TableFunctionImpl: Debug + Sync + Send + Any { /// Create a table provider - fn call(&self, args: &[Expr]) -> Result>; + #[deprecated( + since = "53.0.0", + note = "Implement `TableFunctionImpl::call_with_args` instead" + )] + fn call(&self, _args: &[Expr]) -> Result> { + internal_err!("unimplemented") + } + + /// Create a table provider + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + #[expect(deprecated)] + self.call(args.args) + } } /// A table that uses a function to generate data @@ -539,7 +559,20 @@ impl TableFunction { } /// Get the function implementation and generate a table + #[deprecated( + since = "53.0.0", + note = "Use `TableFunction::create_table_provider_with_args` instead" + )] pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + #[expect(deprecated)] self.fun.call(args) } + + /// Get the function implementation and generate a table + pub fn create_table_provider_with_args( + &self, + args: TableFunctionArgs, + ) -> Result> { + self.fun.call_with_args(args) + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 347c8eb3d25a4..be24e3a13a324 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1830,6 +1830,8 @@ impl ContextProvider for SessionContextProvider<'_> { name: &str, args: Vec, ) -> datafusion_common::Result> { + use datafusion_catalog::TableFunctionArgs; + let tbl_func = self .state .table_functions @@ -1852,7 +1854,10 @@ impl ContextProvider for SessionContextProvider<'_> { .and_then(|e| simplifier.simplify(e)) }) .collect::>>()?; - let provider = tbl_func.create_table_provider(&args)?; + let provider = tbl_func.create_table_provider_with_args(TableFunctionArgs { + args: &args, + session: self.state, + })?; Ok(provider_as_source(provider)) } diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 95694d00a6c30..70338142fecfb 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -33,8 +33,8 @@ use datafusion::error::Result; use datafusion::execution::TaskContext; use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::prelude::SessionContext; -use datafusion_catalog::Session; use datafusion_catalog::TableFunctionImpl; +use datafusion_catalog::{Session, TableFunctionArgs}; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; @@ -200,7 +200,8 @@ impl SimpleCsvTable { struct SimpleCsvTableFunc {} impl TableFunctionImpl for SimpleCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; let mut new_exprs = vec![]; let mut filepath = String::new(); for expr in exprs { @@ -231,7 +232,7 @@ async fn test_udtf_type_coercion() -> Result<()> { struct NoOpTableFunc; impl TableFunctionImpl for NoOpTableFunc { - fn call(&self, _: &[Expr]) -> Result> { + fn call_with_args(&self, _: TableFunctionArgs) -> Result> { let schema = Arc::new(arrow::datatypes::Schema::empty()); Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)) } diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index 35c13c1169c72..17c590e0632c8 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -21,21 +21,23 @@ use std::sync::Arc; use abi_stable::StableAbi; use abi_stable::std_types::{RResult, RVec}; -use datafusion_catalog::{TableFunctionImpl, TableProvider}; +use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider}; +use datafusion_common::DataFusionError; use datafusion_common::error::Result; use datafusion_execution::TaskContext; -use datafusion_expr::Expr; use datafusion_proto::logical_plan::from_proto::parse_exprs; use datafusion_proto::logical_plan::to_proto::serialize_exprs; use datafusion_proto::logical_plan::{ DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use datafusion_proto::protobuf::LogicalExprList; +use datafusion_session::Session; use prost::Message; use tokio::runtime::Handle; use crate::execution::FFI_TaskContextProvider; use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::session::{FFI_SessionRef, ForeignSession}; use crate::table_provider::FFI_TableProvider; use crate::util::FFIResult; use crate::{df_result, rresult_return}; @@ -44,11 +46,22 @@ use crate::{df_result, rresult_return}; #[repr(C)] #[derive(Debug, StableAbi)] pub struct FFI_TableFunction { - /// Equivalent to the `call` function of the TableFunctionImpl. + /// Equivalent to the [`TableFunctionImpl::call`]. /// The arguments are Expr passed as protobuf encoded bytes. + #[deprecated( + since = "53.0.0", + note = "See TableFunctionImpl::call deprecation note" + )] pub call: unsafe extern "C" fn(udtf: &Self, args: RVec) -> FFIResult, + /// Equivalent to the [`TableFunctionImpl::call_with_args`]. + call_with_args: unsafe extern "C" fn( + udtf: &Self, + args: RVec, + session: FFI_SessionRef, + ) -> FFIResult, + pub logical_codec: FFI_LogicalExtensionCodec, /// Used to create a clone on the provider of the udtf. This should @@ -107,6 +120,7 @@ unsafe extern "C" fn call_fn_wrapper( codec.as_ref() )); + #[expect(deprecated)] let table_provider = rresult_return!(udtf_inner.call(&args)); RResult::ROk(FFI_TableProvider::new_with_ffi_codec( table_provider, @@ -116,6 +130,48 @@ unsafe extern "C" fn call_fn_wrapper( )) } +unsafe extern "C" fn call_with_args_wrapper( + udtf: &FFI_TableFunction, + args: RVec, + session: FFI_SessionRef, +) -> FFIResult { + let runtime = udtf.runtime(); + let udtf_inner = udtf.inner(); + + let ctx: Arc = + rresult_return!((&udtf.logical_codec.task_ctx_provider).try_into()); + let codec: Arc = (&udtf.logical_codec).into(); + + let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref())); + + let args = rresult_return!(parse_exprs( + proto_filters.expr.iter(), + ctx.as_ref(), + codec.as_ref() + )); + + let mut foreign_session = None; + let session = rresult_return!( + session + .as_local() + .map(Ok::<&(dyn Session + Send + Sync), DataFusionError>) + .unwrap_or_else(|| { + foreign_session = Some(ForeignSession::try_from(&session)?); + Ok(foreign_session.as_ref().unwrap()) + }) + ); + let table_provider = rresult_return!(udtf_inner.call_with_args(TableFunctionArgs { + args: &args, + session + })); + RResult::ROk(FFI_TableProvider::new_with_ffi_codec( + table_provider, + false, + runtime, + udtf.logical_codec.clone(), + )) +} + unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { unsafe { debug_assert!(!udtf.private_data.is_null()); @@ -176,7 +232,9 @@ impl FFI_TableFunction { let private_data = Box::new(TableFunctionPrivateData { udtf, runtime }); Self { + #[expect(deprecated)] call: call_fn_wrapper, + call_with_args: call_with_args_wrapper, logical_codec, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -215,13 +273,32 @@ impl From for Arc { } impl TableFunctionImpl for ForeignTableFunction { - fn call(&self, args: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let session = + FFI_SessionRef::new(args.session, None, self.0.logical_codec.clone()); + let codec: Arc = (&self.0.logical_codec).into(); + let expr_list = LogicalExprList { + expr: serialize_exprs(args.args, codec.as_ref())?, + }; + let filters_serialized = expr_list.encode_to_vec().into(); + + let table_provider = + unsafe { (self.0.call_with_args)(&self.0, filters_serialized, session) }; + + let table_provider = df_result!(table_provider)?; + let table_provider: Arc = (&table_provider).into(); + + Ok(table_provider) + } + + fn call(&self, args: &[datafusion_expr::Expr]) -> Result> { let codec: Arc = (&self.0.logical_codec).into(); let expr_list = LogicalExprList { expr: serialize_exprs(args, codec.as_ref())?, }; let filters_serialized = expr_list.encode_to_vec().into(); + #[expect(deprecated)] let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) }; let table_provider = df_result!(table_provider)?; @@ -242,7 +319,9 @@ mod tests { use datafusion::logical_expr::ptr_eq::arc_ptr_eq; use datafusion::prelude::{SessionContext, lit}; use datafusion::scalar::ScalarValue; + use datafusion_catalog::TableFunctionArgs; use datafusion_execution::TaskContextProvider; + use datafusion_expr::Expr; use super::*; @@ -250,8 +329,12 @@ mod tests { struct TestUDTF {} impl TableFunctionImpl for TestUDTF { - fn call(&self, args: &[Expr]) -> Result> { + fn call_with_args( + &self, + args: TableFunctionArgs, + ) -> Result> { let args = args + .args .iter() .map(|arg| { if let Expr::Literal(scalar, _) = arg { @@ -341,7 +424,10 @@ mod tests { let foreign_udf: Arc = local_udtf.into(); - let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?; + let table = foreign_udf.call_with_args(TableFunctionArgs { + args: &[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)], + session: &ctx.state(), + })?; let _ = ctx.register_table("test-table", table)?; diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index 342269fbc2996..a53255c5d90ec 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -23,9 +23,9 @@ use arrow::datatypes::{ }; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion_catalog::Session; use datafusion_catalog::TableFunctionImpl; use datafusion_catalog::TableProvider; +use datafusion_catalog::{Session, TableFunctionArgs}; use datafusion_common::{Result, ScalarValue, plan_err}; use datafusion_expr::{Expr, TableType}; use datafusion_physical_plan::ExecutionPlan; @@ -479,7 +479,8 @@ struct GenerateSeriesFuncImpl { } impl TableFunctionImpl for GenerateSeriesFuncImpl { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; if exprs.is_empty() || exprs.len() > 3 { return plan_err!("{} function requires 1 to 3 arguments", self.name); } @@ -737,12 +738,12 @@ impl GenerateSeriesFuncImpl { pub struct GenerateSeriesFunc {} impl TableFunctionImpl for GenerateSeriesFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { let impl_func = GenerateSeriesFuncImpl { name: "generate_series", include_end: true, }; - impl_func.call(exprs) + impl_func.call_with_args(args) } } @@ -750,12 +751,12 @@ impl TableFunctionImpl for GenerateSeriesFunc { pub struct RangeFunc {} impl TableFunctionImpl for RangeFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { let impl_func = GenerateSeriesFuncImpl { name: "range", include_end: false, }; - impl_func.call(exprs) + impl_func.call_with_args(args) } } diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 15db731ca3746..bdaa8792360e0 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -1377,15 +1377,16 @@ in the CLI to read the metadata from a Parquet file. The simple UDTF used here takes a single `Int64` argument and returns a table with a single column with the value of the argument. To create a function in DataFusion, you need to implement the `TableFunctionImpl` trait. This trait has a -single method, `call`, that takes a slice of `Expr`s and returns a `Result>`. +single method, `call_with_args`, that takes a `TableFunctionArgs` struct and returns a `Result>`. +Passed struct includes function arguments as a slice of `Expr`s. -In the `call` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some +In the `call_with_args` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some validation of the input `Expr`s, e.g. checking that the number of arguments is correct. ```rust use std::sync::Arc; use datafusion::common::{plan_err, ScalarValue, Result}; -use datafusion::catalog::{TableFunctionImpl, TableProvider}; +use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider}; use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion::datasource::memory::MemTable; use arrow::record_batch::RecordBatch; @@ -1397,7 +1398,8 @@ use datafusion_expr::Expr; pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.args; let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { return plan_err!("First argument must be an integer"); }; @@ -1426,7 +1428,7 @@ With the UDTF implemented, you can register it with the `SessionContext`: ```rust # use std::sync::Arc; # use datafusion::common::{plan_err, ScalarValue, Result}; -# use datafusion::catalog::{TableFunctionImpl, TableProvider}; +# use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider}; # use datafusion::arrow::array::{ArrayRef, Int64Array}; # use datafusion::datasource::memory::MemTable; # use arrow::record_batch::RecordBatch; @@ -1438,7 +1440,8 @@ With the UDTF implemented, you can register it with the `SessionContext`: # pub struct EchoFunction {} # # impl TableFunctionImpl for EchoFunction { -# fn call(&self, exprs: &[Expr]) -> Result> { +# fn call_with_args(&self, args: TableFunctionArgs) -> Result> { +# let exprs = args.args; # let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { # return plan_err!("First argument must be an integer"); # };