Skip to content

Commit 1a0ef18

Browse files
committed
arrow-cast: Bring back in-order casting
1 parent c2bd7d9 commit 1a0ef18

File tree

1 file changed

+47
-28
lines changed

1 file changed

+47
-28
lines changed

arrow-cast/src/cast/mod.rs

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
254254
}
255255

256256
// slow path, we match the fields by name
257-
to_fields.iter().all(|to_field| {
257+
if to_fields.iter().all(|to_field| {
258258
from_fields
259259
.iter()
260260
.find(|from_field| from_field.name() == to_field.name())
@@ -263,7 +263,15 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
263263
// cast kernel will return error.
264264
can_cast_types(from_field.data_type(), to_field.data_type())
265265
})
266-
})
266+
}) {
267+
return true;
268+
}
269+
270+
// if we couldn't match by name, we try to see if they can be matched by position
271+
from_fields
272+
.iter()
273+
.zip(to_fields.iter())
274+
.all(|(f1, f2)| can_cast_types(f1.data_type(), f2.data_type()))
267275
}
268276
(Struct(_), _) => false,
269277
(_, Struct(_)) => false,
@@ -1239,23 +1247,34 @@ pub fn cast_with_options(
12391247
})
12401248
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
12411249
} else {
1242-
// Slow path: match fields by name and reorder
1243-
to_fields
1244-
.iter()
1245-
.map(|to_field| {
1246-
let from_field_idx = from_fields
1247-
.iter()
1248-
.position(|from_field| from_field.name() == to_field.name())
1249-
.ok_or_else(|| {
1250-
ArrowError::CastError(format!(
1251-
"Field '{}' not found in source struct",
1252-
to_field.name()
1253-
))
1254-
})?;
1255-
let column = array.column(from_field_idx);
1256-
cast_with_options(column, to_field.data_type(), cast_options)
1257-
})
1258-
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1250+
let all_fields_match_by_name = to_fields.iter().all(|to_field| {
1251+
from_fields
1252+
.iter()
1253+
.any(|from_field| from_field.name() == to_field.name())
1254+
});
1255+
1256+
if all_fields_match_by_name {
1257+
// Slow path: match fields by name and reorder
1258+
to_fields
1259+
.iter()
1260+
.map(|to_field| {
1261+
let from_field_idx = from_fields
1262+
.iter()
1263+
.position(|from_field| from_field.name() == to_field.name())
1264+
.unwrap(); // safe because we checked above
1265+
let column = array.column(from_field_idx);
1266+
cast_with_options(column, to_field.data_type(), cast_options)
1267+
})
1268+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1269+
} else {
1270+
// Fallback: cast field by field in order
1271+
array
1272+
.columns()
1273+
.iter()
1274+
.zip(to_fields.iter())
1275+
.map(|(l, field)| cast_with_options(l, field.data_type(), cast_options))
1276+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1277+
}
12591278
};
12601279

12611280
let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
@@ -10917,11 +10936,11 @@ mod tests {
1091710936
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1091810937
let struct_array = StructArray::from(vec![
1091910938
(
10920-
Arc::new(Field::new("a", DataType::Boolean, false)),
10939+
Arc::new(Field::new("b", DataType::Boolean, false)),
1092110940
boolean.clone() as ArrayRef,
1092210941
),
1092310942
(
10924-
Arc::new(Field::new("b", DataType::Int32, false)),
10943+
Arc::new(Field::new("c", DataType::Int32, false)),
1092510944
int.clone() as ArrayRef,
1092610945
),
1092710946
]);
@@ -10965,11 +10984,11 @@ mod tests {
1096510984
let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None]));
1096610985
let struct_array = StructArray::from(vec![
1096710986
(
10968-
Arc::new(Field::new("a", DataType::Boolean, false)),
10987+
Arc::new(Field::new("b", DataType::Boolean, false)),
1096910988
boolean.clone() as ArrayRef,
1097010989
),
1097110990
(
10972-
Arc::new(Field::new("b", DataType::Int32, true)),
10991+
Arc::new(Field::new("c", DataType::Int32, true)),
1097310992
int.clone() as ArrayRef,
1097410993
),
1097510994
]);
@@ -10999,11 +11018,11 @@ mod tests {
1099911018
let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100]));
1100011019
let struct_array = StructArray::from(vec![
1100111020
(
11002-
Arc::new(Field::new("a", DataType::Boolean, false)),
11021+
Arc::new(Field::new("b", DataType::Boolean, false)),
1100311022
boolean.clone() as ArrayRef,
1100411023
),
1100511024
(
11006-
Arc::new(Field::new("b", DataType::Int32, false)),
11025+
Arc::new(Field::new("c", DataType::Int32, false)),
1100711026
int.clone() as ArrayRef,
1100811027
),
1100911028
]);
@@ -11139,7 +11158,7 @@ mod tests {
1113911158
assert!(result.is_err());
1114011159
assert_eq!(
1114111160
result.unwrap_err().to_string(),
11142-
"Cast error: Field 'b' not found in source struct"
11161+
"Invalid argument error: Incorrect number of arrays for StructArray fields, expected 2 got 1"
1114311162
);
1114411163
}
1114511164

@@ -11196,7 +11215,7 @@ mod tests {
1119611215
}
1119711216

1119811217
#[test]
11199-
fn test_can_cast_struct_with_missing_field() {
11218+
fn test_can_cast_struct_rename_field() {
1120011219
// Test that can_cast_types returns false when target has a field not in source
1120111220
let from_type = DataType::Struct(
1120211221
vec![
@@ -11214,7 +11233,7 @@ mod tests {
1121411233
.into(),
1121511234
);
1121611235

11217-
assert!(!can_cast_types(&from_type, &to_type));
11236+
assert!(can_cast_types(&from_type, &to_type));
1121811237
}
1121911238

1122011239
fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) {

0 commit comments

Comments
 (0)