Skip to content

Commit 4ed808a

Browse files
authored
feat(spark): add trunc, date_trunc and time_trunc functions (#19829)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #19828. - Part of #15914 ## Rationale for this change implement spark: - https://spark.apache.org/docs/latest/api/sql/index.html#trunc - https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc - https://spark.apache.org/docs/latest/api/sql/index.html#time_trunc ## What changes are included in this PR? Add spark compatible wrappers around datafusion date_trunc function to handle spark specificities. ## Are these changes tested? Yes in SLT ## Are there any user-facing changes? Yes
1 parent 936f959 commit 4ed808a

File tree

8 files changed

+782
-60
lines changed

8 files changed

+782
-60
lines changed

datafusion/functions/src/datetime/date_trunc.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ use arrow::array::types::{
3434
use arrow::array::{Array, ArrayRef, PrimitiveArray};
3535
use arrow::datatypes::DataType::{self, Time32, Time64, Timestamp};
3636
use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second};
37+
use arrow::datatypes::{Field, FieldRef};
3738
use datafusion_common::cast::as_primitive_array;
3839
use datafusion_common::types::{NativeType, logical_date, logical_string};
3940
use datafusion_common::{
4041
DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err,
4142
};
4243
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
4344
use datafusion_expr::{
44-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility,
45+
ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature,
46+
TypeSignature, Volatility,
4547
};
4648
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
4749
use datafusion_macros::user_doc;
@@ -221,6 +223,7 @@ impl ScalarUDFImpl for DateTruncFunc {
221223
&self.signature
222224
}
223225

226+
// keep return_type implementation for information schema generation
224227
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
225228
if arg_types[1].is_null() {
226229
Ok(Timestamp(Nanosecond, None))
@@ -229,6 +232,21 @@ impl ScalarUDFImpl for DateTruncFunc {
229232
}
230233
}
231234

235+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
236+
let data_types = args
237+
.arg_fields
238+
.iter()
239+
.map(|f| f.data_type())
240+
.cloned()
241+
.collect::<Vec<_>>();
242+
let return_type = self.return_type(&data_types)?;
243+
Ok(Arc::new(Field::new(
244+
self.name(),
245+
return_type,
246+
args.arg_fields[1].is_nullable(),
247+
)))
248+
}
249+
232250
fn invoke_with_args(
233251
&self,
234252
args: datafusion_expr::ScalarFunctionArgs,
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::sync::Arc;
20+
21+
use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit};
22+
use datafusion_common::types::{NativeType, logical_string};
23+
use datafusion_common::utils::take_function_args;
24+
use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
25+
use datafusion_expr::expr::ScalarFunction;
26+
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
27+
use datafusion_expr::{
28+
Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs,
29+
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
30+
};
31+
32+
/// Spark date_trunc supports extra format aliases.
33+
/// It also handles timestamps with timezones by converting to session timezone first.
34+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc>
35+
#[derive(Debug, PartialEq, Eq, Hash)]
36+
pub struct SparkDateTrunc {
37+
signature: Signature,
38+
}
39+
40+
impl Default for SparkDateTrunc {
41+
fn default() -> Self {
42+
Self::new()
43+
}
44+
}
45+
46+
impl SparkDateTrunc {
47+
pub fn new() -> Self {
48+
Self {
49+
signature: Signature::coercible(
50+
vec![
51+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
52+
Coercion::new_implicit(
53+
TypeSignatureClass::Timestamp,
54+
vec![TypeSignatureClass::Native(logical_string())],
55+
NativeType::Timestamp(TimeUnit::Microsecond, None),
56+
),
57+
],
58+
Volatility::Immutable,
59+
),
60+
}
61+
}
62+
}
63+
64+
impl ScalarUDFImpl for SparkDateTrunc {
65+
fn as_any(&self) -> &dyn Any {
66+
self
67+
}
68+
69+
fn name(&self) -> &str {
70+
"date_trunc"
71+
}
72+
73+
fn signature(&self) -> &Signature {
74+
&self.signature
75+
}
76+
77+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78+
internal_err!("return_field_from_args should be used instead")
79+
}
80+
81+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
82+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
83+
84+
Ok(Arc::new(Field::new(
85+
self.name(),
86+
args.arg_fields[1].data_type().clone(),
87+
nullable,
88+
)))
89+
}
90+
91+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92+
internal_err!(
93+
"spark date_trunc should have been simplified to standard date_trunc"
94+
)
95+
}
96+
97+
fn simplify(
98+
&self,
99+
args: Vec<Expr>,
100+
info: &SimplifyContext,
101+
) -> Result<ExprSimplifyResult> {
102+
let [fmt_expr, ts_expr] = take_function_args(self.name(), args)?;
103+
104+
let fmt = match fmt_expr.as_literal() {
105+
Some(ScalarValue::Utf8(Some(v)))
106+
| Some(ScalarValue::Utf8View(Some(v)))
107+
| Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(),
108+
_ => {
109+
return plan_err!(
110+
"First argument of `DATE_TRUNC` must be non-null scalar Utf8"
111+
);
112+
}
113+
};
114+
115+
// Map Spark-specific fmt aliases to datafusion ones
116+
let fmt = match fmt.as_str() {
117+
"yy" | "yyyy" => "year",
118+
"mm" | "mon" => "month",
119+
"dd" => "day",
120+
other => other,
121+
};
122+
123+
let session_tz = info.config_options().execution.time_zone.clone();
124+
let ts_type = ts_expr.get_type(info.schema())?;
125+
126+
// Spark interprets timestamps in the session timezone before truncating,
127+
// then returns a timestamp at microsecond precision.
128+
// See: https://github.com/apache/spark/blob/f310f4fcc95580a6824bc7d22b76006f79b8804a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala#L492
129+
//
130+
// For sub-second truncations (second, millisecond, microsecond), timezone
131+
// adjustment is unnecessary since timezone offsets are whole seconds.
132+
let ts_expr = match (&ts_type, fmt) {
133+
// Sub-second truncations don't need timezone adjustment
134+
(_, "second" | "millisecond" | "microsecond") => ts_expr,
135+
136+
// convert to session timezone, strip timezone and convert back to original timezone
137+
(DataType::Timestamp(unit, tz), _) => {
138+
let ts_expr = match &session_tz {
139+
Some(session_tz) => ts_expr.cast_to(
140+
&DataType::Timestamp(
141+
TimeUnit::Microsecond,
142+
Some(Arc::from(session_tz.as_str())),
143+
),
144+
info.schema(),
145+
)?,
146+
None => ts_expr,
147+
};
148+
Expr::ScalarFunction(ScalarFunction::new_udf(
149+
datafusion_functions::datetime::to_local_time(),
150+
vec![ts_expr],
151+
))
152+
.cast_to(&DataType::Timestamp(*unit, tz.clone()), info.schema())?
153+
}
154+
155+
_ => {
156+
return plan_err!(
157+
"Second argument of `DATE_TRUNC` must be Timestamp, got {}",
158+
ts_type
159+
);
160+
}
161+
};
162+
163+
let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None);
164+
165+
Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
166+
ScalarFunction::new_udf(
167+
datafusion_functions::datetime::date_trunc(),
168+
vec![fmt_expr, ts_expr],
169+
),
170+
)))
171+
}
172+
}

datafusion/spark/src/function/datetime/mod.rs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,32 @@
1818
pub mod date_add;
1919
pub mod date_part;
2020
pub mod date_sub;
21+
pub mod date_trunc;
2122
pub mod extract;
2223
pub mod last_day;
2324
pub mod make_dt_interval;
2425
pub mod make_interval;
2526
pub mod next_day;
27+
pub mod time_trunc;
28+
pub mod trunc;
2629

2730
use datafusion_expr::ScalarUDF;
2831
use datafusion_functions::make_udf_function;
2932
use std::sync::Arc;
3033

3134
make_udf_function!(date_add::SparkDateAdd, date_add);
35+
make_udf_function!(date_part::SparkDatePart, date_part);
3236
make_udf_function!(date_sub::SparkDateSub, date_sub);
37+
make_udf_function!(date_trunc::SparkDateTrunc, date_trunc);
3338
make_udf_function!(extract::SparkHour, hour);
3439
make_udf_function!(extract::SparkMinute, minute);
3540
make_udf_function!(extract::SparkSecond, second);
3641
make_udf_function!(last_day::SparkLastDay, last_day);
3742
make_udf_function!(make_dt_interval::SparkMakeDtInterval, make_dt_interval);
3843
make_udf_function!(make_interval::SparkMakeInterval, make_interval);
3944
make_udf_function!(next_day::SparkNextDay, next_day);
40-
make_udf_function!(date_part::SparkDatePart, date_part);
45+
make_udf_function!(time_trunc::SparkTimeTrunc, time_trunc);
46+
make_udf_function!(trunc::SparkTrunc, trunc);
4147

4248
pub mod expr_fn {
4349
use datafusion_functions::export_functions;
@@ -85,24 +91,43 @@ pub mod expr_fn {
8591
"Returns the first date which is later than start_date and named as indicated. The function returns NULL if at least one of the input parameters is NULL.",
8692
arg1 arg2
8793
));
94+
export_functions!((
95+
date_trunc,
96+
"Truncates a timestamp `ts` to the unit specified by the format `fmt`.",
97+
fmt ts
98+
));
99+
export_functions!((
100+
time_trunc,
101+
"Truncates a time `t` to the unit specified by the format `fmt`.",
102+
fmt t
103+
));
104+
export_functions!((
105+
trunc,
106+
"Truncates a date `dt` to the unit specified by the format `fmt`.",
107+
dt fmt
108+
));
88109
export_functions!((
89110
date_part,
90111
"Extracts a part of the date or time from a date, time, or timestamp expression.",
91112
arg1 arg2
113+
92114
));
93115
}
94116

95117
pub fn functions() -> Vec<Arc<ScalarUDF>> {
96118
vec![
97119
date_add(),
120+
date_part(),
98121
date_sub(),
122+
date_trunc(),
99123
hour(),
100-
minute(),
101-
second(),
102124
last_day(),
103125
make_dt_interval(),
104126
make_interval(),
127+
minute(),
105128
next_day(),
106-
date_part(),
129+
second(),
130+
time_trunc(),
131+
trunc(),
107132
]
108133
}

0 commit comments

Comments
 (0)