diff --git a/avro/src/error.rs b/avro/src/error.rs index 50a09afc..12bee1eb 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -300,8 +300,18 @@ pub enum Details { #[error("Unions may not directly contain a union")] GetNestedUnion, - #[error("Unions cannot contain duplicate types")] - GetUnionDuplicate, + #[error( + "Found two different maps while building Union: Schema::Map({0:?}), Schema::Map({1:?})" + )] + GetUnionDuplicateMap(Schema, Schema), + + #[error( + "Found two different arrays while building Union: Schema::Array({0:?}), Schema::Array({1:?})" + )] + GetUnionDuplicateArray(Schema, Schema), + + #[error("Unions cannot contain duplicate types, found at least two {0:?}")] + GetUnionDuplicate(SchemaKind), #[error("Unions cannot contain more than one named schema with the same name: {0}")] GetUnionDuplicateNamedSchemas(String), diff --git a/avro/src/schema/mod.rs b/avro/src/schema/mod.rs index 5f5d81c1..70d30663 100644 --- a/avro/src/schema/mod.rs +++ b/avro/src/schema/mod.rs @@ -33,7 +33,7 @@ pub use crate::schema::{ RecordField, RecordFieldBuilder, RecordFieldOrder, RecordSchema, RecordSchemaBuilder, }, resolve::ResolvedSchema, - union::UnionSchema, + union::{UnionSchema, UnionSchemaBuilder}, }; use crate::{ AvroResult, diff --git a/avro/src/schema/name.rs b/avro/src/schema/name.rs index b551584c..22af2e75 100644 --- a/avro/src/schema/name.rs +++ b/avro/src/schema/name.rs @@ -38,7 +38,7 @@ use crate::{ /// /// More information about schema names can be found in the /// [Avro specification](https://avro.apache.org/docs/++version++/specification/#names) -#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct Name { pub name: String, pub namespace: Namespace, diff --git a/avro/src/schema/union.rs b/avro/src/schema/union.rs index bca79aa8..c1595cd1 100644 --- a/avro/src/schema/union.rs +++ b/avro/src/schema/union.rs @@ -15,24 +15,32 @@ // specific language governing permissions and limitations // under the License. -use crate::AvroResult; use crate::error::Details; -use crate::schema::{Name, Namespace, ResolvedSchema, Schema, SchemaKind}; +use crate::schema::{ + DecimalSchema, InnerDecimalSchema, Name, Namespace, Schema, SchemaKind, UuidSchema, +}; use crate::types; +use crate::{AvroResult, Error}; use std::borrow::Borrow; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap}; use std::fmt::{Debug, Formatter}; +use strum::IntoDiscriminant; /// A description of a Union schema #[derive(Clone)] pub struct UnionSchema { /// The schemas that make up this union pub(crate) schemas: Vec, - // Used to ensure uniqueness of schema inputs, and provide constant time finding of the - // schema index given a value. - // **NOTE** that this approach does not work for named types, and will have to be modified - // to support that. A simple solution is to also keep a mapping of the names used. + /// The indexes of unnamed types. + /// + /// Logical types have been reduced to their inner type. + /// Used to provide constant time finding of the + /// schema index given an unnamed type. Must only contain unnamed types. variant_index: BTreeMap, + /// The indexes of named types. + /// + /// The names self aren't saved as they aren't used. + named_index: Vec, } impl Debug for UnionSchema { @@ -51,25 +59,16 @@ impl UnionSchema { /// Will return an error if `schemas` has duplicate unnamed schemas or if `schemas` /// contains a union. pub fn new(schemas: Vec) -> AvroResult { - let mut named_schemas: HashSet<&Name> = HashSet::default(); - let mut vindex = BTreeMap::new(); - for (i, schema) in schemas.iter().enumerate() { - if let Schema::Union(_) = schema { - return Err(Details::GetNestedUnion.into()); - } else if !schema.is_named() && vindex.insert(SchemaKind::from(schema), i).is_some() { - return Err(Details::GetUnionDuplicate.into()); - } else if schema.is_named() { - let name = schema.name().unwrap(); - if !named_schemas.insert(name) { - return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); - } - vindex.insert(SchemaKind::from(schema), i); - } + let mut builder = Self::builder(); + for schema in schemas { + builder.variant(schema)?; } - Ok(UnionSchema { - schemas, - variant_index: vindex, - }) + Ok(builder.build()) + } + + /// Build a `UnionSchema` piece-by-piece. + pub fn builder() -> UnionSchemaBuilder { + UnionSchemaBuilder::new() } /// Returns a slice to all variants of this schema. @@ -79,7 +78,7 @@ impl UnionSchema { /// Returns true if the any of the variants of this `UnionSchema` is `Null`. pub fn is_nullable(&self) -> bool { - self.schemas.iter().any(|x| matches!(x, Schema::Null)) + self.variant_index.contains_key(&SchemaKind::Null) } /// Optionally returns a reference to the schema matched by this value, as well as its position @@ -93,45 +92,126 @@ impl UnionSchema { known_schemata: Option<&HashMap>, enclosing_namespace: &Namespace, ) -> Option<(usize, &Schema)> { - let schema_kind = SchemaKind::from(value); - if let Some(&i) = self.variant_index.get(&schema_kind) { - // fast path - Some((i, &self.schemas[i])) - } else { - // slow path (required for matching logical or named types) - - // first collect what schemas we already know - let mut collected_names: HashMap = known_schemata - .map(|names| { - names - .iter() - .map(|(name, schema)| (name.clone(), schema.borrow())) - .collect() + let ValueSchemaKind { unnamed, named } = Self::value_to_base_schemakind(value); + // Unnamed schema types can be looked up directly using the variant_index + let unnamed = unnamed + .and_then(|kind| self.variant_index.get(&kind).copied()) + .map(|index| (index, &self.schemas[index])) + .and_then(|(index, schema)| { + let kind = schema.discriminant(); + // Maps and arrays need to be checked if they actually match the value + if kind == SchemaKind::Map || kind == SchemaKind::Array { + let known_schemata_if_none = HashMap::new(); + let known_schemata = known_schemata.unwrap_or(&known_schemata_if_none); + let namespace = if schema.namespace().is_some() { + &schema.namespace() + } else { + enclosing_namespace + }; + + // TODO: Do this without the clone + value + .clone() + .resolve_internal(schema, known_schemata, namespace, &None) + .ok() + .map(|_| (index, schema)) + } else { + Some((index, schema)) + } + }); + let named = named.and_then(|kind| { + // Every named type needs to be checked against a value until one matches + + let known_schemata_if_none = HashMap::new(); + let known_schemata = known_schemata.unwrap_or(&known_schemata_if_none); + + self.named_index + .iter() + .copied() + .map(|i| (i, &self.schemas[i])) + .filter(|(_i, s)| s.discriminant() == kind || s.discriminant() == SchemaKind::Ref) + .find(|(_i, schema)| { + let namespace = if schema.namespace().is_some() { + &schema.namespace() + } else { + enclosing_namespace + }; + + // TODO: Do this without the clone + value + .clone() + .resolve_internal(schema, known_schemata, namespace, &None) + .is_ok() }) - .unwrap_or_default(); - - self.schemas.iter().enumerate().find(|(_, schema)| { - let resolved_schema = ResolvedSchema::new_with_known_schemata( - vec![*schema], - enclosing_namespace, - &collected_names, - ) - .expect("Schema didn't successfully parse"); - let resolved_names = resolved_schema.names_ref; - - // extend known schemas with just resolved names - collected_names.extend(resolved_names); - let namespace = &schema.namespace().or_else(|| enclosing_namespace.clone()); - - value - .clone() - .resolve_internal(schema, &collected_names, namespace, &None) - .is_ok() - }) + }); + + match (unnamed, named) { + (Some((u_i, _)), Some((n_i, _))) if u_i < n_i => unnamed, + (Some(_), Some(_)) => named, + (Some(_), None) => unnamed, + (None, Some(_)) => named, + (None, None) => None, + } + } + + /// Convert a value to a [`SchemaKind`] stripping logical types to their base type. + fn value_to_base_schemakind(value: &types::Value) -> ValueSchemaKind { + let schemakind = SchemaKind::from(value); + match schemakind { + SchemaKind::Decimal => ValueSchemaKind { + unnamed: Some(SchemaKind::Bytes), + named: Some(SchemaKind::Fixed), + }, + SchemaKind::BigDecimal => ValueSchemaKind { + unnamed: Some(SchemaKind::Bytes), + named: None, + }, + SchemaKind::Uuid => ValueSchemaKind { + unnamed: Some(SchemaKind::String), + named: Some(SchemaKind::Fixed), + }, + SchemaKind::Date | SchemaKind::TimeMillis => ValueSchemaKind { + unnamed: Some(SchemaKind::Int), + named: None, + }, + SchemaKind::TimeMicros + | SchemaKind::TimestampMillis + | SchemaKind::TimestampMicros + | SchemaKind::TimestampNanos + | SchemaKind::LocalTimestampMillis + | SchemaKind::LocalTimestampMicros + | SchemaKind::LocalTimestampNanos => ValueSchemaKind { + unnamed: Some(SchemaKind::Long), + named: None, + }, + SchemaKind::Duration => ValueSchemaKind { + unnamed: None, + named: Some(SchemaKind::Fixed), + }, + SchemaKind::Record | SchemaKind::Enum | SchemaKind::Fixed => ValueSchemaKind { + unnamed: None, + named: Some(schemakind), + }, + // When a `serde_json::Value` is converted to a `types::Value` a object will always become a map + // so a `types::Value::Map` can also be a record. + SchemaKind::Map => ValueSchemaKind { + unnamed: Some(SchemaKind::Map), + named: Some(SchemaKind::Record), + }, + _ => ValueSchemaKind { + unnamed: Some(schemakind), + named: None, + }, } } } +/// The schema kinds matching a specific value. +struct ValueSchemaKind { + unnamed: Option, + named: Option, +} + // No need to compare variant_index, it is derivative of schemas. impl PartialEq for UnionSchema { fn eq(&self, other: &UnionSchema) -> bool { @@ -139,11 +219,174 @@ impl PartialEq for UnionSchema { } } +/// A builder for [`UnionSchema`] +#[derive(Default, Debug)] +pub struct UnionSchemaBuilder { + schemas: Vec, + names: BTreeMap, + variant_index: BTreeMap, +} + +impl UnionSchemaBuilder { + /// Create a builder. + /// + /// See also [`UnionSchema::builder`]. + pub fn new() -> Self { + Self::default() + } + + #[doc(hidden)] + /// This is not a public API, it should only be used by `avro_derive` + /// + /// Add a variant to this union, if it already exists ignore it. + /// + /// # Errors + /// Will return a [`Details::GetUnionDuplicateMap`] or [`Details::GetUnionDuplicateArray`] if + /// duplicate maps or arrays are encountered with different subtypes. + pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut Self, Error> { + if let Some(name) = schema.name() { + if let Some(current) = self.names.get(name).copied() { + if self.schemas[current] != schema { + return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); + } + } else { + self.names.insert(name.clone(), self.schemas.len()); + self.schemas.push(schema); + } + } else if let Schema::Map(_) = &schema { + if let Some(index) = self.variant_index.get(&SchemaKind::Map).copied() { + if self.schemas[index] != schema { + return Err( + Details::GetUnionDuplicateMap(self.schemas.remove(index), schema).into(), + ); + } + } else { + self.variant_index + .insert(SchemaKind::Map, self.schemas.len()); + self.schemas.push(schema); + } + } else if let Schema::Array(_) = &schema { + if let Some(index) = self.variant_index.get(&SchemaKind::Array).copied() { + if self.schemas[index] != schema { + return Err( + Details::GetUnionDuplicateMap(self.schemas.remove(index), schema).into(), + ); + } + } else { + self.variant_index + .insert(SchemaKind::Array, self.schemas.len()); + self.schemas.push(schema); + } + } else { + let discriminant = Self::schema_to_base_schemakind(&schema); + if discriminant == SchemaKind::Union { + return Err(Details::GetNestedUnion.into()); + } + if !self.variant_index.contains_key(&discriminant) { + self.variant_index.insert(discriminant, self.schemas.len()); + self.schemas.push(schema); + } + } + Ok(self) + } + + /// Add a variant to this union. + /// + /// # Errors + /// Will return a [`Details::GetUnionDuplicateNamedSchemas`] or [`Details::GetUnionDuplicate`] if + /// duplicate names or schema kinds are found. + pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> { + if let Some(name) = schema.name() { + if self.names.contains_key(name) { + return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); + } else { + self.names.insert(name.clone(), self.schemas.len()); + self.schemas.push(schema); + } + } else { + let discriminant = Self::schema_to_base_schemakind(&schema); + if discriminant == SchemaKind::Union { + return Err(Details::GetNestedUnion.into()); + } + if self.variant_index.contains_key(&discriminant) { + return Err(Details::GetUnionDuplicate(discriminant).into()); + } else { + self.variant_index.insert(discriminant, self.schemas.len()); + self.schemas.push(schema); + } + } + Ok(self) + } + + /// Check if a schema already exists in this union. + pub fn contains(&self, schema: &Schema) -> bool { + if let Some(name) = schema.name() { + if let Some(current) = self.names.get(name).copied() { + &self.schemas[current] == schema + } else { + false + } + } else { + let discriminant = Self::schema_to_base_schemakind(schema); + if let Some(index) = self.variant_index.get(&discriminant).copied() { + &self.schemas[index] == schema + } else { + false + } + } + } + + /// Create the `UnionSchema`. + pub fn build(mut self) -> UnionSchema { + self.schemas.shrink_to_fit(); + UnionSchema { + variant_index: self.variant_index, + named_index: self.names.into_values().collect(), + schemas: self.schemas, + } + } + + /// Get the [`SchemaKind`] of a [`Schema`] converting logical types to their base type. + fn schema_to_base_schemakind(schema: &Schema) -> SchemaKind { + let kind = schema.discriminant(); + match kind { + SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int, + SchemaKind::TimeMicros + | SchemaKind::TimestampMillis + | SchemaKind::TimestampMicros + | SchemaKind::TimestampNanos + | SchemaKind::LocalTimestampMillis + | SchemaKind::LocalTimestampMicros + | SchemaKind::LocalTimestampNanos => SchemaKind::Long, + SchemaKind::Uuid => match schema { + Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes, + Schema::Uuid(UuidSchema::String) => SchemaKind::String, + Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::Decimal => match schema { + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Bytes, + .. + }) => SchemaKind::Bytes, + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Fixed(_), + .. + }) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::Duration => SchemaKind::Fixed, + _ => kind, + } + } +} + #[cfg(test)] mod tests { use super::*; use crate::error::{Details, Error}; use crate::schema::RecordSchema; + use crate::types::Value; use apache_avro_test_helper::TestResult; #[test] @@ -174,4 +417,151 @@ mod tests { Ok(()) } + + #[test] + fn avro_rs_489_union_schema_builder_primitive_type() -> TestResult { + let mut builder = UnionSchema::builder(); + builder.variant(Schema::Null)?; + assert!(builder.variant(Schema::Null).is_err()); + builder.variant_ignore_duplicates(Schema::Null)?; + builder.variant(Schema::Int)?; + assert!(builder.variant(Schema::Int).is_err()); + builder.variant_ignore_duplicates(Schema::Int)?; + builder.variant(Schema::Long)?; + assert!(builder.variant(Schema::Long).is_err()); + builder.variant_ignore_duplicates(Schema::Long)?; + + let union = builder.build(); + assert_eq!(union.schemas, &[Schema::Null, Schema::Int, Schema::Long]); + + Ok(()) + } + + #[test] + fn avro_rs_489_union_schema_builder_complex_types() -> TestResult { + let enum_abc = Schema::parse_str( + r#"{ + "type": "enum", + "name": "ABC", + "symbols": ["A", "B", "C"] + }"#, + )?; + let enum_abc_with_extra_symbol = Schema::parse_str( + r#"{ + "type": "enum", + "name": "ABC", + "symbols": ["A", "B", "C", "D"] + }"#, + )?; + let enum_def = Schema::parse_str( + r#"{ + "type": "enum", + "name": "DEF", + "symbols": ["D", "E", "F"] + }"#, + )?; + let fixed_abc = Schema::parse_str( + r#"{ + "type": "fixed", + "name": "ABC", + "size": 1 + }"#, + )?; + let fixed_foo = Schema::parse_str( + r#"{ + "type": "fixed", + "name": "Foo", + "size": 1 + }"#, + )?; + + let mut builder = UnionSchema::builder(); + builder.variant(enum_abc.clone())?; + assert!(builder.variant(enum_abc.clone()).is_err()); + builder.variant_ignore_duplicates(enum_abc.clone())?; + // Name is the same but different schemas, so should always fail + assert!(builder.variant(fixed_abc.clone()).is_err()); + assert!( + builder + .variant_ignore_duplicates(fixed_abc.clone()) + .is_err() + ); + // Name and schema type are the same but symbols are different + assert!(builder.variant(enum_abc_with_extra_symbol.clone()).is_err()); + assert!( + builder + .variant_ignore_duplicates(enum_abc_with_extra_symbol.clone()) + .is_err() + ); + builder.variant(enum_def.clone())?; + assert!(builder.variant(enum_def.clone()).is_err()); + builder.variant_ignore_duplicates(enum_def.clone())?; + builder.variant(fixed_foo.clone())?; + assert!(builder.variant(fixed_foo.clone()).is_err()); + builder.variant_ignore_duplicates(fixed_foo.clone())?; + + let union = builder.build(); + assert_eq!(union.variants(), &[enum_abc, enum_def, fixed_foo]); + + Ok(()) + } + + #[test] + fn avro_rs_489_union_schema_builder_logical_types() -> TestResult { + let fixed_uuid = Schema::parse_str( + r#"{ + "type": "fixed", + "name": "Uuid", + "size": 16 + }"#, + )?; + let uuid = Schema::parse_str( + r#"{ + "type": "fixed", + "logicalType": "uuid", + "name": "Uuid", + "size": 16 + }"#, + )?; + + let mut builder = UnionSchema::builder(); + + builder.variant(Schema::Date)?; + assert!(builder.variant(Schema::Date).is_err()); + builder.variant_ignore_duplicates(Schema::Date)?; + assert!(builder.variant(Schema::Int).is_err()); + builder.variant_ignore_duplicates(Schema::Int)?; + builder.variant(uuid.clone())?; + assert!(builder.variant(uuid.clone()).is_err()); + builder.variant_ignore_duplicates(uuid.clone())?; + assert!(builder.variant(fixed_uuid.clone()).is_err()); + assert!( + builder + .variant_ignore_duplicates(fixed_uuid.clone()) + .is_err() + ); + + let union = builder.build(); + assert_eq!(union.schemas, &[Schema::Date, uuid]); + + Ok(()) + } + + #[test] + fn avro_rs_489_find_schema_with_known_schemata_wrong_map() -> TestResult { + let union = UnionSchema::new(vec![Schema::map(Schema::Int).build(), Schema::Null])?; + let value = Value::Map( + [("key".to_string(), Value::String("value".to_string()))] + .into_iter() + .collect(), + ); + + assert!( + union + .find_schema_with_known_schemata(&value, None::<&HashMap>, &None) + .is_none() + ); + + Ok(()) + } } diff --git a/avro/src/serde/derive.rs b/avro/src/serde/derive.rs index 7e3fb9e3..cada6e46 100644 --- a/avro/src/serde/derive.rs +++ b/avro/src/serde/derive.rs @@ -637,14 +637,17 @@ where named_schemas: &mut HashSet, enclosing_namespace: &Namespace, ) -> Schema { - let variants = vec![ - Schema::Null, - T::get_schema_in_ctxt(named_schemas, enclosing_namespace), - ]; + let schema = T::get_schema_in_ctxt(named_schemas, enclosing_namespace); + if let Schema::Null = schema { + Schema::Union(UnionSchema::new(vec![Schema::Null]).expect("This is a valid schema")) + } else { + let variants = vec![Schema::Null, schema]; - Schema::Union( - UnionSchema::new(variants).expect("Option must produce a valid (non-nested) union"), - ) + Schema::Union( + UnionSchema::new(variants) + .expect("Option must produce a valid (non-nested) union"), + ) + } } fn get_record_fields_in_ctxt( diff --git a/avro/src/types.rs b/avro/src/types.rs index b3b20358..8a9bdf6d 100644 --- a/avro/src/types.rs +++ b/avro/src/types.rs @@ -3140,15 +3140,10 @@ Field with name '"b"' is not a member of the map items"#, let main_schema = schemata.last().unwrap(); let other_schemata: Vec<&Schema> = schemata.iter().take(2).collect(); - let resolve_result = avro_value.resolve_schemata(main_schema, other_schemata); + let resolve_result = avro_value.resolve_schemata(main_schema, other_schemata)?; assert!( - resolve_result.is_ok(), - "result of resolving with schemata should be ok, got: {resolve_result:?}" - ); - - assert!( - resolve_result?.validate_schemata(schemata.iter().collect()), + resolve_result.validate_schemata(schemata.iter().collect()), "result of validation with schemata should be true" );