diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index c94c8217b8ff..95482bc54044 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -68,6 +68,10 @@ impl ExtensionType for Bool8 { fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { Self.supports_data_type(data_type).map(|_| Self) } + + fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> { + Self.supports_data_type(data_type) + } } #[cfg(test)] diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index d2a54b9189b7..99b78fd6ef52 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -173,6 +173,10 @@ impl ExtensionType for Json { json.supports_data_type(data_type)?; Ok(json) } + + fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> { + Self::default().supports_data_type(data_type) + } } #[cfg(test)] diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index acfc1331a670..5a5eb1b66818 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -257,6 +257,10 @@ impl ExtensionType for Opaque { fn try_new(_data_type: &DataType, metadata: Self::Metadata) -> Result { Ok(Self::from(metadata)) } + + fn validate(_data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> { + Ok(()) + } } #[cfg(test)] diff --git a/arrow-schema/src/extension/canonical/timestamp_with_offset.rs b/arrow-schema/src/extension/canonical/timestamp_with_offset.rs index 20df20bad922..06ea98b3a0a1 100644 --- a/arrow-schema/src/extension/canonical/timestamp_with_offset.rs +++ b/arrow-schema/src/extension/canonical/timestamp_with_offset.rs @@ -139,6 +139,10 @@ impl ExtensionType for TimestampWithOffset { fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { Self.supports_data_type(data_type).map(|_| Self) } + + fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> { + Self.supports_data_type(data_type) + } } #[cfg(test)] diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs index 3e897f47318d..affea04acb3c 100644 --- a/arrow-schema/src/extension/canonical/uuid.rs +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -73,6 +73,10 @@ impl ExtensionType for Uuid { fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { Self.supports_data_type(data_type).map(|_| Self) } + + fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> { + Self.supports_data_type(data_type) + } } #[cfg(test)] diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index aed560029db8..3dd1f8d3547f 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -257,6 +257,15 @@ pub trait ExtensionType: Sized { /// this extension type. fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result; + /// Validate this extension type for a field with the given data type and + /// metadata. + /// + /// The default implementation delegates to [`Self::try_new`]. Extension + /// types may override this to validate without constructing `Self`. + fn validate(data_type: &DataType, metadata: Self::Metadata) -> Result<(), ArrowError> { + Self::try_new(data_type, metadata).map(|_| ()) + } + /// Construct this extension type from field metadata and data type. /// /// This is a provided method that extracts extension type information from diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index 1f2b57564ded..0d12728ca23b 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -504,13 +504,39 @@ impl Field { .map(String::as_ref) } + /// Returns `true` if this [`Field`] has the given [`ExtensionType`] name + /// and can be successfully validated as that extension type. + /// + /// This first checks the extension type name and only calls + /// [`ExtensionType::validate`] when the name matches. + /// + /// This is useful when you only need a boolean validity check and do not + /// need to retrieve the extension type instance. + #[inline] + pub fn has_valid_extension_type(&self) -> bool { + if self.extension_type_name() != Some(E::NAME) { + return false; + } + + let ext_metadata = self + .metadata() + .get(EXTENSION_TYPE_METADATA_KEY) + .map(|s| s.as_str()); + + E::deserialize_metadata(ext_metadata) + .and_then(|metadata| E::validate(self.data_type(), metadata)) + .is_ok() + } + /// Returns an instance of the given [`ExtensionType`] of this [`Field`], /// if set in the [`Field::metadata`]. /// /// Note that using `try_extension_type` with an extension type that does /// not match the name in the metadata will return an `ArrowError` which can /// be slow due to string allocations. If you only want to check if a - /// [`Field`] has a specific [`ExtensionType`], see the example below. + /// [`Field`] has a specific [`ExtensionType`], first check + /// [`Field::extension_type_name`], or use [`Field::has_valid_extension_type`] + /// to also validate metadata and data type. /// /// # Errors /// @@ -524,7 +550,7 @@ impl Field { /// fail (for example when the [`Field::data_type`] is not supported by /// the extension type ([`ExtensionType::supports_data_type`])) /// - /// # Examples: Check and retrieve an extension type + /// # Example: Check and retrieve an extension type /// You can use this to check if a [`Field`] has a specific /// [`ExtensionType`] and retrieve it: /// ``` @@ -546,34 +572,6 @@ impl Field { /// // do something with extension_type /// } /// ``` - /// - /// # Example: Checking if a field has a specific extension type first - /// - /// Since `try_extension_type` returns an error, it is more - /// efficient to first check if the name matches before calling - /// `try_extension_type`: - /// ``` - /// # use arrow_schema::{DataType, Field, ArrowError}; - /// # use arrow_schema::extension::ExtensionType; - /// # struct MyExtensionType; - /// # impl ExtensionType for MyExtensionType { - /// # const NAME: &'static str = "my_extension"; - /// # type Metadata = String; - /// # fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { Ok(()) } - /// # fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result { Ok(Self) } - /// # fn serialize_metadata(&self) -> Option { unimplemented!() } - /// # fn deserialize_metadata(s: Option<&str>) -> Result { unimplemented!() } - /// # fn metadata(&self) -> &::Metadata { todo!() } - /// # } - /// # fn get_field() -> Field { Field::new("field", DataType::Null, false) } - /// let field = get_field(); - /// // First check if the name matches before calling the potentially expensive `try_extension_type` - /// if field.extension_type_name() == Some(MyExtensionType::NAME) { - /// if let Ok(extension_type) = field.try_extension_type::() { - /// // do something with extension_type - /// } - /// } - /// ``` pub fn try_extension_type(&self) -> Result { E::try_new_from_field_metadata(self.data_type(), self.metadata()) } @@ -1013,6 +1011,80 @@ mod test { use super::*; use std::collections::hash_map::DefaultHasher; + #[derive(Debug, Clone, Copy)] + struct TestExtensionType; + + impl ExtensionType for TestExtensionType { + const NAME: &'static str = "test.extension"; + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata(metadata: Option<&str>) -> Result { + metadata.map_or(Ok(()), |_| { + Err(ArrowError::InvalidArgumentError( + "TestExtensionType expects no metadata".to_owned(), + )) + }) + } + + fn supports_data_type(&self, _data_type: &DataType) -> Result<(), ArrowError> { + Ok(()) + } + + fn try_new(_data_type: &DataType, _metadata: Self::Metadata) -> Result { + Ok(Self) + } + } + + #[test] + fn test_has_valid_extension_type() { + let no_extension = Field::new("f", DataType::Null, false); + assert!(!no_extension.has_valid_extension_type::()); + + let matching_name = Field::new("f", DataType::Null, false).with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + TestExtensionType::NAME.to_owned(), + )] + .into_iter() + .collect(), + ); + assert!(matching_name.has_valid_extension_type::()); + + let matching_name_with_invalid_metadata = Field::new("f", DataType::Null, false) + .with_metadata( + [ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + TestExtensionType::NAME.to_owned(), + ), + (EXTENSION_TYPE_METADATA_KEY.to_owned(), "invalid".to_owned()), + ] + .into_iter() + .collect(), + ); + assert!( + !matching_name_with_invalid_metadata.has_valid_extension_type::() + ); + + let different_name = Field::new("f", DataType::Null, false).with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + "some.other_extension".to_owned(), + )] + .into_iter() + .collect(), + ); + assert!(!different_name.has_valid_extension_type::()); + } + #[test] fn test_new_with_string() { // Fields should allow owned Strings to support reuse diff --git a/parquet-variant-compute/src/variant_array.rs b/parquet-variant-compute/src/variant_array.rs index 145de5edfb70..1aad3145a894 100644 --- a/parquet-variant-compute/src/variant_array.rs +++ b/parquet-variant-compute/src/variant_array.rs @@ -80,6 +80,10 @@ impl ExtensionType for VariantType { Self.supports_data_type(data_type)?; Ok(Self) } + + fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<()> { + Self.supports_data_type(data_type) + } } /// An array of Parquet [`Variant`] values @@ -131,9 +135,9 @@ impl ExtensionType for VariantType { /// let schema = get_schema(); /// assert_eq!(schema.fields().len(), 2); /// // first field is not a Variant -/// assert!(schema.field(0).try_extension_type::().is_err()); +/// assert!(!schema.field(0).has_valid_extension_type::()); /// // second field is a Variant -/// assert!(schema.field(1).try_extension_type::().is_ok()); +/// assert!(schema.field(1).has_valid_extension_type::()); /// ``` /// /// # Example: Constructing the correct [`Field`] for a [`VariantArray`] diff --git a/parquet/src/arrow/schema/extension.rs b/parquet/src/arrow/schema/extension.rs index 119332ff6bcd..a2e9c32dee8c 100644 --- a/parquet/src/arrow/schema/extension.rs +++ b/parquet/src/arrow/schema/extension.rs @@ -110,19 +110,13 @@ pub(crate) fn has_extension_type(parquet_type: &Type) -> bool { /// Return the Parquet logical type to use for the specified Arrow Struct field, if any. #[cfg(feature = "variant_experimental")] pub(crate) fn logical_type_for_struct(field: &Field) -> Option { - use arrow_schema::extension::ExtensionType; use parquet_variant_compute::VariantType; - // Check the name (= quick and cheap) and only try_extension_type if the name matches - // to avoid unnecessary String allocations in ArrowError - if field.extension_type_name()? != VariantType::NAME { - return None; - } - match field.try_extension_type::() { - Ok(VariantType) => Some(LogicalType::Variant { + if field.has_valid_extension_type::() { + Some(LogicalType::Variant { specification_version: None, - }), - // Given check above, this should not error, but if it does ignore - Err(_e) => None, + }) + } else { + None } } @@ -137,9 +131,8 @@ pub(crate) fn logical_type_for_fixed_size_binary(field: &Field) -> Option() - .ok() - .map(|_| LogicalType::Uuid) + .has_valid_extension_type::() + .then_some(LogicalType::Uuid) } #[cfg(not(feature = "arrow_canonical_extension_types"))] @@ -153,9 +146,11 @@ pub(crate) fn logical_type_for_string(field: &Field) -> Option { use arrow_schema::extension::Json; // Use the Json logical type if the canonical Json // extension type is set on this field. - field - .try_extension_type::() - .map_or(Some(LogicalType::String), |_| Some(LogicalType::Json)) + Some(if field.has_valid_extension_type::() { + LogicalType::Json + } else { + LogicalType::String + }) } #[cfg(not(feature = "arrow_canonical_extension_types"))] diff --git a/parquet/src/arrow/schema/virtual_type.rs b/parquet/src/arrow/schema/virtual_type.rs index 657a76b73229..f66352ae1981 100644 --- a/parquet/src/arrow/schema/virtual_type.rs +++ b/parquet/src/arrow/schema/virtual_type.rs @@ -69,6 +69,10 @@ impl ExtensionType for RowGroupIndex { fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { Self.supports_data_type(data_type).map(|_| Self) } + + fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> { + Self.supports_data_type(data_type) + } } /// The extension type for row numbers. @@ -113,6 +117,10 @@ impl ExtensionType for RowNumber { fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result { Self.supports_data_type(data_type).map(|_| Self) } + + fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> { + Self.supports_data_type(data_type) + } } /// Returns `true` if the field is a virtual column. diff --git a/parquet/src/variant.rs b/parquet/src/variant.rs index b7dde6b3c8a4..cdbdd849683d 100644 --- a/parquet/src/variant.rs +++ b/parquet/src/variant.rs @@ -123,7 +123,7 @@ //! // the VariantType extension type //! let schema = reader.schema(); //! let field = schema.field_with_name("var")?; -//! assert!(field.try_extension_type::().is_ok()); +//! assert!(field.has_valid_extension_type::()); //! //! // The reader will yield RecordBatches with a StructArray //! // to convert them to VariantArray, use VariantArray::try_new @@ -285,9 +285,7 @@ mod tests { assert_eq!(metadata_value, "arrow.parquet.variant"); // verify that `VariantType` also correctly finds the metadata - field - .try_extension_type::() - .expect("VariantExtensionType should be readable"); + assert!(field.has_valid_extension_type::()); } /// Read the specified test case filename from parquet-testing