Skip to content

Commit 7abc92b

Browse files
committed
[Variant] Unify the CastOptions usage in parquet-variant-compute
1 parent f8796fd commit 7abc92b

File tree

4 files changed

+74
-39
lines changed

4 files changed

+74
-39
lines changed

parquet-variant-compute/src/arrow_to_variant.rs

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::type_conversion::CastOptions;
1918
use arrow::array::{
2019
Array, AsArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericListViewArray,
2120
GenericStringArray, OffsetSizeTrait, PrimitiveArray,
2221
};
23-
use arrow::compute::kernels::cast;
22+
use arrow::compute::{CastOptions, kernels::cast};
2423
use arrow::datatypes::{
2524
self as datatypes, ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType,
2625
DecimalType, RunEndIndexType,
@@ -367,7 +366,7 @@ macro_rules! define_row_builder {
367366
$(
368367
// NOTE: The `?` macro expansion fails without the type annotation.
369368
let Some(value): Option<$option_ty> = value else {
370-
if self.options.strict {
369+
if !self.options.safe {
371370
return Err(ArrowError::ComputeError(format!(
372371
"Failed to convert value at index {index}: conversion failed",
373372
)));
@@ -404,7 +403,7 @@ define_row_builder!(
404403
where
405404
V: VariantDecimalType<Native = A::Native>,
406405
{
407-
options: &'a CastOptions,
406+
options: &'a CastOptions<'a>,
408407
scale: i8,
409408
},
410409
|array| -> PrimitiveArray<A> { array.as_primitive() },
@@ -414,7 +413,7 @@ define_row_builder!(
414413
// Decimal256 needs a two-stage conversion via i128
415414
define_row_builder!(
416415
struct Decimal256ArrowToVariantBuilder<'a> {
417-
options: &'a CastOptions,
416+
options: &'a CastOptions<'a>,
418417
scale: i8,
419418
},
420419
|array| -> arrow::array::Decimal256Array { array.as_primitive() },
@@ -426,7 +425,7 @@ define_row_builder!(
426425

427426
define_row_builder!(
428427
struct TimestampArrowToVariantBuilder<'a, T: ArrowTimestampType> {
429-
options: &'a CastOptions,
428+
options: &'a CastOptions<'a>,
430429
has_time_zone: bool,
431430
},
432431
|array| -> PrimitiveArray<T> { array.as_primitive() },
@@ -450,7 +449,7 @@ define_row_builder!(
450449
where
451450
i64: From<T::Native>,
452451
{
453-
options: &'a CastOptions,
452+
options: &'a CastOptions<'a>,
454453
},
455454
|array| -> PrimitiveArray<T> { array.as_primitive() },
456455
|value| -> Option<_> {
@@ -464,7 +463,7 @@ define_row_builder!(
464463
where
465464
i64: From<T::Native>,
466465
{
467-
options: &'a CastOptions,
466+
options: &'a CastOptions<'a>,
468467
},
469468
|array| -> PrimitiveArray<T> { array.as_primitive() },
470469
|value| -> Option<_> {
@@ -899,7 +898,13 @@ mod tests {
899898

900899
/// Builds a VariantArray from an Arrow array using the row builder.
901900
fn execute_row_builder_test(array: &dyn Array) -> VariantArray {
902-
execute_row_builder_test_with_options(array, CastOptions::default())
901+
execute_row_builder_test_with_options(
902+
array,
903+
CastOptions {
904+
safe: false,
905+
..Default::default()
906+
},
907+
)
903908
}
904909

905910
/// Variant of `execute_row_builder_test` that allows specifying options
@@ -925,7 +930,14 @@ mod tests {
925930
/// Generic helper function to test row builders with basic assertion patterns.
926931
/// Uses execute_row_builder_test and adds simple value comparison assertions.
927932
fn test_row_builder_basic(array: &dyn Array, expected_values: Vec<Option<Variant>>) {
928-
test_row_builder_basic_with_options(array, expected_values, CastOptions::default());
933+
test_row_builder_basic_with_options(
934+
array,
935+
expected_values,
936+
CastOptions {
937+
safe: false,
938+
..Default::default()
939+
},
940+
);
929941
}
930942

931943
/// Variant of `test_row_builder_basic` that allows specifying options
@@ -1058,7 +1070,10 @@ mod tests {
10581070
let run_ends = Int32Array::from(vec![2, 5, 6]);
10591071
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
10601072

1061-
let options = CastOptions::default();
1073+
let options = CastOptions {
1074+
safe: false,
1075+
..Default::default()
1076+
};
10621077
let mut row_builder =
10631078
make_arrow_to_variant_row_builder(run_array.data_type(), &run_array, &options).unwrap();
10641079

@@ -1084,7 +1099,10 @@ mod tests {
10841099
let run_ends = Int32Array::from(vec![2, 4, 5]);
10851100
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
10861101

1087-
let options = CastOptions::default();
1102+
let options = CastOptions {
1103+
safe: false,
1104+
..Default::default()
1105+
};
10881106
let mut row_builder =
10891107
make_arrow_to_variant_row_builder(run_array.data_type(), &run_array, &options).unwrap();
10901108
let mut array_builder = VariantArrayBuilder::new(5);
@@ -1135,7 +1153,10 @@ mod tests {
11351153
let keys = Int32Array::from(vec![Some(0), None, Some(1), None, Some(2)]);
11361154
let dict_array = DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values)).unwrap();
11371155

1138-
let options = CastOptions::default();
1156+
let options = CastOptions {
1157+
safe: false,
1158+
..Default::default()
1159+
};
11391160
let mut row_builder =
11401161
make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options)
11411162
.unwrap();
@@ -1167,7 +1188,10 @@ mod tests {
11671188
let keys = Int32Array::from(vec![0, 1, 2, 0, 1, 2]);
11681189
let dict_array = DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values)).unwrap();
11691190

1170-
let options = CastOptions::default();
1191+
let options = CastOptions {
1192+
safe: false,
1193+
..Default::default()
1194+
};
11711195
let mut row_builder =
11721196
make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options)
11731197
.unwrap();
@@ -1207,7 +1231,10 @@ mod tests {
12071231
let dict_array =
12081232
DictionaryArray::<Int32Type>::try_new(keys, Arc::new(struct_array)).unwrap();
12091233

1210-
let options = CastOptions::default();
1234+
let options = CastOptions {
1235+
safe: false,
1236+
..Default::default()
1237+
};
12111238
let mut row_builder =
12121239
make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options)
12131240
.unwrap();
@@ -1302,7 +1329,10 @@ mod tests {
13021329
// Slice to get just the middle element: [[3, 4, 5]]
13031330
let sliced_array = list_array.slice(1, 1);
13041331

1305-
let options = CastOptions::default();
1332+
let options = CastOptions {
1333+
safe: false,
1334+
..Default::default()
1335+
};
13061336
let mut row_builder =
13071337
make_arrow_to_variant_row_builder(sliced_array.data_type(), &sliced_array, &options)
13081338
.unwrap();
@@ -1346,7 +1376,10 @@ mod tests {
13461376
Some(arrow::buffer::NullBuffer::from(vec![true, false])),
13471377
);
13481378

1349-
let options = CastOptions::default();
1379+
let options = CastOptions {
1380+
safe: false,
1381+
..Default::default()
1382+
};
13501383
let mut row_builder =
13511384
make_arrow_to_variant_row_builder(outer_list.data_type(), &outer_list, &options)
13521385
.unwrap();
@@ -1539,7 +1572,10 @@ mod tests {
15391572
.unwrap();
15401573

15411574
// Test the row builder
1542-
let options = CastOptions::default();
1575+
let options = CastOptions {
1576+
safe: false,
1577+
..Default::default()
1578+
};
15431579
let mut row_builder =
15441580
make_arrow_to_variant_row_builder(union_array.data_type(), &union_array, &options)
15451581
.unwrap();
@@ -1590,7 +1626,10 @@ mod tests {
15901626
.unwrap();
15911627

15921628
// Test the row builder
1593-
let options = CastOptions::default();
1629+
let options = CastOptions {
1630+
safe: false,
1631+
..Default::default()
1632+
};
15941633
let mut row_builder =
15951634
make_arrow_to_variant_row_builder(union_array.data_type(), &union_array, &options)
15961635
.unwrap();
@@ -1668,7 +1707,7 @@ mod tests {
16681707
Some(Variant::Null), // Overflow value becomes Variant::Null
16691708
Some(Variant::from(VariantDecimal16::try_new(123, 3).unwrap())),
16701709
],
1671-
CastOptions { strict: false },
1710+
CastOptions::default(),
16721711
);
16731712
}
16741713

parquet-variant-compute/src/cast_to_variant.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
// under the License.
1717

1818
use crate::arrow_to_variant::make_arrow_to_variant_row_builder;
19-
use crate::{CastOptions, VariantArray, VariantArrayBuilder};
19+
use crate::{VariantArray, VariantArrayBuilder};
2020
use arrow::array::Array;
21+
use arrow::compute::CastOptions;
2122
use arrow_schema::ArrowError;
2223

2324
/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you
@@ -75,9 +76,15 @@ pub fn cast_to_variant_with_options(
7576
/// failures).
7677
///
7778
/// This function provides backward compatibility. For non-strict behavior,
78-
/// use [`cast_to_variant_with_options`] with `CastOptions { strict: false }`.
79+
/// use [`cast_to_variant_with_options`] with `CastOptions { safe: true, ..Default::default() }`.
7980
pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
80-
cast_to_variant_with_options(input, &CastOptions::default())
81+
cast_to_variant_with_options(
82+
input,
83+
&CastOptions {
84+
safe: false,
85+
..Default::default()
86+
},
87+
)
8188
}
8289

8390
#[cfg(test)]
@@ -2261,14 +2268,17 @@ mod tests {
22612268
}
22622269

22632270
fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
2264-
run_test_with_options(values, expected, CastOptions { strict: false });
2271+
run_test_with_options(values, expected, CastOptions::default());
22652272
}
22662273

22672274
fn run_test_in_strict_mode(
22682275
values: ArrayRef,
22692276
expected: Result<Vec<Option<Variant>>, ArrowError>,
22702277
) {
2271-
let options = CastOptions { strict: true };
2278+
let options = CastOptions {
2279+
safe: false,
2280+
..Default::default()
2281+
};
22722282
match expected {
22732283
Ok(expected) => run_test_with_options(values, expected, options),
22742284
Err(_) => {

parquet-variant-compute/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,5 @@ pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options};
5858
pub use from_json::json_to_variant;
5959
pub use shred_variant::{IntoShreddingField, ShreddedSchemaBuilder, shred_variant};
6060
pub use to_json::variant_to_json;
61-
pub use type_conversion::CastOptions;
6261
pub use unshred_variant::unshred_variant;
6362
pub use variant_get::{GetOptions, variant_get};

parquet-variant-compute/src/type_conversion.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,6 @@ use arrow::datatypes::{
2525
use chrono::Timelike;
2626
use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16};
2727

28-
/// Options for controlling the behavior of `cast_to_variant_with_options`.
29-
#[derive(Debug, Clone, PartialEq, Eq)]
30-
pub struct CastOptions {
31-
/// If true, return error on conversion failure. If false, insert null for failed conversions.
32-
pub strict: bool,
33-
}
34-
35-
impl Default for CastOptions {
36-
fn default() -> Self {
37-
Self { strict: true }
38-
}
39-
}
40-
4128
/// Extension trait for Arrow primitive types that can extract their native value from a Variant
4229
pub(crate) trait PrimitiveFromVariant: ArrowPrimitiveType {
4330
fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;

0 commit comments

Comments
 (0)