Skip to content

Commit d2a8ae2

Browse files
committed
arrow-cast: Bring back in-order casting
1 parent c2bd7d9 commit d2a8ae2

File tree

1 file changed

+39
-13
lines changed

1 file changed

+39
-13
lines changed

arrow-cast/src/cast/mod.rs

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
254254
}
255255

256256
// slow path, we match the fields by name
257-
to_fields.iter().all(|to_field| {
257+
if to_fields.iter().all(|to_field| {
258258
from_fields
259259
.iter()
260260
.find(|from_field| from_field.name() == to_field.name())
@@ -263,7 +263,15 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
263263
// cast kernel will return error.
264264
can_cast_types(from_field.data_type(), to_field.data_type())
265265
})
266-
})
266+
}) {
267+
return true;
268+
}
269+
270+
// if we couldn't match by name, we try to see if they can be matched by position
271+
from_fields
272+
.iter()
273+
.zip(to_fields.iter())
274+
.all(|(f1, f2)| can_cast_types(f1.data_type(), f2.data_type()))
267275
}
268276
(Struct(_), _) => false,
269277
(_, Struct(_)) => false,
@@ -1240,7 +1248,7 @@ pub fn cast_with_options(
12401248
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
12411249
} else {
12421250
// Slow path: match fields by name and reorder
1243-
to_fields
1251+
match to_fields
12441252
.iter()
12451253
.map(|to_field| {
12461254
let from_field_idx = from_fields
@@ -1255,7 +1263,25 @@ pub fn cast_with_options(
12551263
let column = array.column(from_field_idx);
12561264
cast_with_options(column, to_field.data_type(), cast_options)
12571265
})
1258-
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1266+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()
1267+
{
1268+
Ok(casted_fields) => casted_fields,
1269+
Err(e) => {
1270+
// If it's Field not found, we cast field by field
1271+
if !e.to_string().starts_with("Field '")
1272+
&& !e.to_string().ends_with("' not found in source struct")
1273+
{
1274+
return Err(e);
1275+
}
1276+
1277+
array
1278+
.columns()
1279+
.iter()
1280+
.zip(to_fields.iter())
1281+
.map(|(l, field)| cast_with_options(l, field.data_type(), cast_options))
1282+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1283+
}
1284+
}
12591285
};
12601286

12611287
let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
@@ -10917,11 +10943,11 @@ mod tests {
1091710943
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1091810944
let struct_array = StructArray::from(vec![
1091910945
(
10920-
Arc::new(Field::new("a", DataType::Boolean, false)),
10946+
Arc::new(Field::new("b", DataType::Boolean, false)),
1092110947
boolean.clone() as ArrayRef,
1092210948
),
1092310949
(
10924-
Arc::new(Field::new("b", DataType::Int32, false)),
10950+
Arc::new(Field::new("c", DataType::Int32, false)),
1092510951
int.clone() as ArrayRef,
1092610952
),
1092710953
]);
@@ -10965,11 +10991,11 @@ mod tests {
1096510991
let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None]));
1096610992
let struct_array = StructArray::from(vec![
1096710993
(
10968-
Arc::new(Field::new("a", DataType::Boolean, false)),
10994+
Arc::new(Field::new("b", DataType::Boolean, false)),
1096910995
boolean.clone() as ArrayRef,
1097010996
),
1097110997
(
10972-
Arc::new(Field::new("b", DataType::Int32, true)),
10998+
Arc::new(Field::new("c", DataType::Int32, true)),
1097310999
int.clone() as ArrayRef,
1097411000
),
1097511001
]);
@@ -10999,11 +11025,11 @@ mod tests {
1099911025
let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100]));
1100011026
let struct_array = StructArray::from(vec![
1100111027
(
11002-
Arc::new(Field::new("a", DataType::Boolean, false)),
11028+
Arc::new(Field::new("b", DataType::Boolean, false)),
1100311029
boolean.clone() as ArrayRef,
1100411030
),
1100511031
(
11006-
Arc::new(Field::new("b", DataType::Int32, false)),
11032+
Arc::new(Field::new("c", DataType::Int32, false)),
1100711033
int.clone() as ArrayRef,
1100811034
),
1100911035
]);
@@ -11139,7 +11165,7 @@ mod tests {
1113911165
assert!(result.is_err());
1114011166
assert_eq!(
1114111167
result.unwrap_err().to_string(),
11142-
"Cast error: Field 'b' not found in source struct"
11168+
"Invalid argument error: Incorrect number of arrays for StructArray fields, expected 2 got 1"
1114311169
);
1114411170
}
1114511171

@@ -11196,7 +11222,7 @@ mod tests {
1119611222
}
1119711223

1119811224
#[test]
11199-
fn test_can_cast_struct_with_missing_field() {
11225+
fn test_can_cast_struct_rename_field() {
1120011226
// Test that can_cast_types returns false when target has a field not in source
1120111227
let from_type = DataType::Struct(
1120211228
vec![
@@ -11214,7 +11240,7 @@ mod tests {
1121411240
.into(),
1121511241
);
1121611242

11217-
assert!(!can_cast_types(&from_type, &to_type));
11243+
assert!(can_cast_types(&from_type, &to_type));
1121811244
}
1121911245

1122011246
fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) {

0 commit comments

Comments
 (0)