Skip to content

Commit bd2af68

Browse files
kazantsev-maksimKazantsev Maksim
andauthored
Spark make_valid_utf8 function implementation (apache#20633)
## Which issue does this PR close? N/A ## Rationale for this change Add new spark function: https://spark.apache.org/docs/latest/api/sql/index.html#make_valid_utf8 ## What changes are included in this PR? - Implementation - SLT tests ## Are these changes tested? Yes, tests added as part of this PR. ## Are there any user-facing changes? No, these are new function. --------- Co-authored-by: Kazantsev Maksim <mn.kazantsev@gmail.com>
1 parent 9eefb7c commit bd2af68

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayRef, LargeStringArray, StringArray};
19+
use arrow::datatypes::{DataType, Field, FieldRef};
20+
use datafusion::logical_expr::{ColumnarValue, Signature, Volatility};
21+
use datafusion_common::cast::{
22+
as_binary_array, as_binary_view_array, as_large_binary_array,
23+
};
24+
use datafusion_common::utils::take_function_args;
25+
use datafusion_common::{Result, internal_err};
26+
use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
27+
use datafusion_functions::utils::make_scalar_function;
28+
use std::sync::Arc;
29+
30+
/// Spark-compatible `make_valid_utf8` expression
31+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#make_valid_utf8>
32+
#[derive(Debug, PartialEq, Eq, Hash)]
33+
pub struct SparkMakeValidUtf8 {
34+
signature: Signature,
35+
}
36+
37+
impl Default for SparkMakeValidUtf8 {
38+
fn default() -> Self {
39+
Self::new()
40+
}
41+
}
42+
43+
impl SparkMakeValidUtf8 {
44+
pub fn new() -> Self {
45+
Self {
46+
signature: Signature::uniform(
47+
1,
48+
vec![
49+
DataType::Utf8,
50+
DataType::LargeUtf8,
51+
DataType::Utf8View,
52+
DataType::Binary,
53+
DataType::BinaryView,
54+
DataType::LargeBinary,
55+
],
56+
Volatility::Immutable,
57+
),
58+
}
59+
}
60+
}
61+
62+
impl ScalarUDFImpl for SparkMakeValidUtf8 {
63+
fn name(&self) -> &str {
64+
"make_valid_utf8"
65+
}
66+
67+
fn signature(&self) -> &Signature {
68+
&self.signature
69+
}
70+
71+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
72+
internal_err!("return_field_from_args should be used instead")
73+
}
74+
75+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
76+
let [make_valid_utf8] = take_function_args(self.name(), args.arg_fields)?;
77+
let return_type = match make_valid_utf8.data_type() {
78+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
79+
Ok(make_valid_utf8.data_type().clone())
80+
}
81+
DataType::Binary | DataType::BinaryView => Ok(DataType::Utf8),
82+
DataType::LargeBinary => Ok(DataType::LargeUtf8),
83+
data_type => internal_err!("make_valid_utf8 does not support: {data_type}"),
84+
}?;
85+
Ok(Arc::new(Field::new(
86+
self.name(),
87+
return_type,
88+
make_valid_utf8.is_nullable(),
89+
)))
90+
}
91+
92+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
93+
make_scalar_function(spark_make_valid_utf8_inner, vec![])(&args.args)
94+
}
95+
}
96+
97+
fn spark_make_valid_utf8_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
98+
let array = &args[0];
99+
match &array.data_type() {
100+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(array.to_owned()),
101+
DataType::Binary => Ok(Arc::new(
102+
as_binary_array(&array)?
103+
.iter()
104+
.map(|x| x.map(String::from_utf8_lossy))
105+
.collect::<StringArray>(),
106+
)),
107+
DataType::BinaryView => Ok(Arc::new(
108+
as_binary_view_array(&array)?
109+
.iter()
110+
.map(|x| x.map(String::from_utf8_lossy))
111+
.collect::<StringArray>(),
112+
)),
113+
DataType::LargeBinary => Ok(Arc::new(
114+
as_large_binary_array(&array)?
115+
.iter()
116+
.map(|x| x.map(String::from_utf8_lossy))
117+
.collect::<LargeStringArray>(),
118+
)),
119+
data_type => {
120+
internal_err!("make_valid_utf8 does not support: {data_type}")
121+
}
122+
}
123+
}

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub mod ilike;
2525
pub mod length;
2626
pub mod like;
2727
pub mod luhn_check;
28+
pub mod make_valid_utf8;
2829
pub mod soundex;
2930
pub mod space;
3031
pub mod substring;
@@ -47,6 +48,7 @@ make_udf_function!(space::SparkSpace, space);
4748
make_udf_function!(substring::SparkSubstring, substring);
4849
make_udf_function!(base64::SparkUnBase64, unbase64);
4950
make_udf_function!(soundex::SparkSoundex, soundex);
51+
make_udf_function!(make_valid_utf8::SparkMakeValidUtf8, make_valid_utf8);
5052

5153
pub mod expr_fn {
5254
use datafusion_functions::export_functions;
@@ -113,6 +115,11 @@ pub mod expr_fn {
113115
str
114116
));
115117
export_functions!((soundex, "Returns Soundex code of the string.", str));
118+
export_functions!((
119+
make_valid_utf8,
120+
"Returns the original string if str is a valid UTF-8 string, otherwise returns a new string whose invalid UTF8 byte sequences are replaced using the UNICODE replacement character U+FFFD.",
121+
str
122+
));
116123
}
117124

118125
pub fn functions() -> Vec<Arc<ScalarUDF>> {
@@ -131,5 +138,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
131138
substring(),
132139
unbase64(),
133140
soundex(),
141+
make_valid_utf8(),
134142
]
135143
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
query T
19+
SELECT make_valid_utf8('Spark'::string);
20+
----
21+
Spark
22+
23+
query T
24+
SELECT make_valid_utf8(''::string);
25+
----
26+
(empty)
27+
28+
query T
29+
SELECT make_valid_utf8(NULL::string);
30+
----
31+
NULL
32+
33+
query T
34+
SELECT make_valid_utf8(arrow_cast(x'C3A9', 'Binary'));
35+
----
36+
é
37+
38+
query T
39+
SELECT make_valid_utf8(arrow_cast(x'F0908C80', 'Binary'));
40+
----
41+
𐌀
42+
43+
query T
44+
SELECT make_valid_utf8(arrow_cast(x'ED9FBF', 'Binary'));
45+
----
46+
47+
48+
query T
49+
SELECT make_valid_utf8(arrow_cast(x'FF', 'Binary'));
50+
----
51+
52+
53+
query T
54+
SELECT make_valid_utf8(arrow_cast(x'C0AF', 'Binary'));
55+
----
56+
��
57+
58+
query T
59+
SELECT make_valid_utf8(arrow_cast(x'F4808080', 'Binary'));
60+
----
61+
􀀀
62+
63+
query T
64+
SELECT make_valid_utf8(arrow_cast(x'EDA0BDEDB2A9', 'Binary'));
65+
----
66+
������
67+
68+
query T
69+
SELECT make_valid_utf8(arrow_cast(x'F0', 'Binary'));
70+
----
71+
72+
73+
query T
74+
SELECT make_valid_utf8(arrow_cast(x'E0', 'Binary'));
75+
----
76+
77+
78+
query T
79+
SELECT make_valid_utf8(arrow_cast(x'F0808080', 'Binary'));
80+
----
81+
����
82+
83+
query T
84+
SELECT make_valid_utf8(arrow_cast(x'61', 'Binary'));
85+
----
86+
a
87+
88+
query T
89+
SELECT make_valid_utf8(arrow_cast(x'61C262', 'Binary'));
90+
----
91+
a�b

0 commit comments

Comments
 (0)