diff --git a/src/generator.rs b/src/generator.rs index 1b57e1e..100d4d8 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::fs; use std::io::prelude::*; @@ -42,10 +42,12 @@ impl Generator { /// Generates Rust code from an Avro schema [`Source`](Source). /// Writes all generated types to the output. pub fn generate(&self, source: &Source, output: &mut impl Write) -> Result<()> { + let mut iso_datetime_fields: HashSet<(Name, String)> = HashSet::new(); + match source { Source::Schema(schema) => { let mut deps = deps_stack(schema, vec![]); - self.gen_in_order(&mut deps, output)?; + self.gen_in_order(&mut deps, iso_datetime_fields, output)?; } Source::Schemas(schemas) => { @@ -53,13 +55,16 @@ impl Generator { .iter() .fold(vec![], |deps, schema| deps_stack(schema, deps)); - self.gen_in_order(&mut deps, output)?; + self.gen_in_order(&mut deps, iso_datetime_fields, output)?; } Source::SchemaStr(raw_schema) => { + if self.templater.use_iso_datetime { + iso_datetime_fields = scan_iso_datetime_fields(raw_schema)?; + } let schema = Schema::parse_str(raw_schema)?; let mut deps = deps_stack(&schema, vec![]); - self.gen_in_order(&mut deps, output)?; + self.gen_in_order(&mut deps, iso_datetime_fields, output)?; } Source::GlobPattern(pattern) => { @@ -78,9 +83,18 @@ impl Generator { } } - let schemas = &raw_schemas.iter().map(|s| s.as_str()).collect::>(); - let schemas = Schema::parse_list(schemas)?; - self.generate(&Source::Schemas(&schemas), output)?; + if self.templater.use_iso_datetime { + for raw in &raw_schemas { + iso_datetime_fields.extend(scan_iso_datetime_fields(raw)?); + } + } + + let schemas_strs = &raw_schemas.iter().map(|s| s.as_str()).collect::>(); + let schemas = Schema::parse_list(schemas_strs)?; + let mut deps = schemas + .iter() + .fold(vec![], |deps, schema| deps_stack(schema, deps)); + self.gen_in_order(&mut deps, iso_datetime_fields, output)?; } } @@ -92,8 +106,15 @@ impl Generator { /// * Pops sub-schemas and generate appropriate Rust types /// * Keeps tracks of nested schema->name with `GenState` mapping /// * Appends generated Rust types to the output - fn gen_in_order(&self, deps: &mut Vec, output: &mut impl Write) -> Result<()> { - let mut gs = GenState::new(deps)?.with_chrono_dates(self.templater.use_chrono_dates); + fn gen_in_order( + &self, + deps: &mut Vec, + iso_datetime_fields: HashSet<(Name, String)>, + output: &mut impl Write, + ) -> Result<()> { + let mut gs = GenState::new(deps)? + .with_chrono_dates(self.templater.use_chrono_dates) + .with_iso_datetime_fields(iso_datetime_fields); if !self.templater.field_overrides.is_empty() { // This rechecks no_eq for all schemas, so only do it if there are actually overrides. @@ -308,6 +329,7 @@ pub struct GeneratorBuilder { nullable: bool, use_avro_rs_unions: bool, use_chrono_dates: bool, + use_iso_datetime: bool, derive_builders: bool, impl_schemas: ImplementAvroSchema, extra_derives: Vec, @@ -321,6 +343,7 @@ impl Default for GeneratorBuilder { nullable: false, use_avro_rs_unions: false, use_chrono_dates: false, + use_iso_datetime: false, derive_builders: false, impl_schemas: ImplementAvroSchema::None, extra_derives: vec![], @@ -436,6 +459,28 @@ impl GeneratorBuilder { self } + /// Map `string` fields with `logicalType: "iso_datetime"` to + /// `chrono::DateTime`. + /// + /// `"iso_datetime"` is a non-standard logical type that `apache-avro` + /// silently drops while parsing, so this option only takes effect for + /// sources that expose the raw JSON to `rsgen-avro`: + /// [`Source::SchemaStr`] and [`Source::GlobPattern`]. It is a no-op for + /// [`Source::Schema`] and [`Source::Schemas`] because the annotation has + /// already been lost by the time a parsed `Schema` reaches the generator. + /// + /// Detected fields are emitted as `chrono::DateTime` when + /// required, `Option>` when nullable (i.e. + /// wrapped in a `["null", ...]` union), `Vec>` + /// when used as array items, and `Option>` when a nullable + /// array. No `#[serde(with = ...)]` helper is needed: `chrono`'s default + /// `Serialize`/`Deserialize` impl for `DateTime` uses RFC3339, which + /// matches the Avro wire format (a plain string). + pub fn use_iso_datetime(mut self, use_iso_datetime: bool) -> GeneratorBuilder { + self.use_iso_datetime = use_iso_datetime; + self + } + /// Adds support to derive builders using the `rust-derive-builder` crate. /// /// Applies to record structs. @@ -496,6 +541,7 @@ impl GeneratorBuilder { templater.nullable = self.nullable; templater.use_avro_rs_unions = self.use_avro_rs_unions; templater.use_chrono_dates = self.use_chrono_dates; + templater.use_iso_datetime = self.use_iso_datetime; templater.derive_builders = self.derive_builders; templater.derive_schemas = self.impl_schemas == ImplementAvroSchema::Derive; templater.impl_schemas = self.impl_schemas == ImplementAvroSchema::CopyBuildSchema; diff --git a/src/templates.rs b/src/templates.rs index 47166cb..6c763fe 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -336,6 +336,148 @@ struct GenUnionVisitor { serde_visitor: Option, } +/// Walks the raw JSON schema text and returns the set of +/// `(record name, field name)` pairs whose field type is (or wraps) a +/// `string` with `logicalType == "iso_datetime"`. +/// +/// `apache-avro` drops unknown logical types while parsing, so the only way +/// to detect this custom annotation is to inspect the JSON before it is +/// handed to the parser. The returned [`Name`]s mirror the namespace +/// inheritance rules of `apache_avro::schema::Name::parse`, so they compare +/// equal to the [`Name`]s later produced by `Schema::parse_*`. +pub(crate) fn scan_iso_datetime_fields( + raw_json: &str, +) -> Result> { + let value: Value = serde_json::from_str(raw_json) + .map_err(|e| Error::Template(format!("Failed to parse JSON schema: {e}")))?; + let mut out = HashSet::new(); + scan_value_for_iso_datetime(&value, &None, &mut out); + Ok(out) +} + +fn scan_value_for_iso_datetime( + value: &Value, + enclosing_namespace: &Option, + out: &mut HashSet<(Name, String)>, +) { + match value { + Value::Object(obj) => { + if obj.get("type").and_then(|t| t.as_str()) == Some("record") { + let record_name = resolve_record_name(obj, enclosing_namespace); + let child_ns = record_name + .namespace + .clone() + .or_else(|| enclosing_namespace.clone()); + if let Some(Value::Array(fields)) = obj.get("fields") { + for field in fields { + if let Value::Object(field_obj) = field { + let field_name = field_obj + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or(""); + if let Some(field_type) = field_obj.get("type") { + if is_iso_datetime_type(field_type) { + out.insert(( + record_name.clone(), + field_name.to_string(), + )); + } + scan_value_for_iso_datetime(field_type, &child_ns, out); + } + } + } + } + } else { + for (_, v) in obj { + scan_value_for_iso_datetime(v, enclosing_namespace, out); + } + } + } + Value::Array(arr) => { + for v in arr { + scan_value_for_iso_datetime(v, enclosing_namespace, out); + } + } + _ => {} + } +} + +fn resolve_record_name( + obj: &serde_json::Map, + enclosing_namespace: &Option, +) -> Name { + let raw_name = obj.get("name").and_then(|n| n.as_str()).unwrap_or(""); + let (local_name, ns_from_name) = if let Some((ns, local)) = raw_name.rsplit_once('.') { + (local.to_string(), Some(ns.to_string())) + } else { + (raw_name.to_string(), None) + }; + let namespace = ns_from_name + .or_else(|| { + obj.get("namespace") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .or_else(|| enclosing_namespace.clone()) + .filter(|ns| !ns.is_empty()); + Name { + name: local_name, + namespace, + } +} + +/// Maps a parsed `Schema` (where the `iso_datetime` annotation has already +/// been stripped by `apache-avro`) to the corresponding Rust type string. +/// Returns `None` for shapes the generator does not handle (e.g. nested +/// unions or maps). +fn iso_datetime_rust_type(schema: &Schema) -> Option { + const T: &str = "chrono::DateTime"; + match schema { + Schema::String => Some(T.to_string()), + Schema::Array(ArraySchema { items, .. }) if matches!(items.as_ref(), Schema::String) => { + Some(format!("Vec<{T}>")) + } + Schema::Union(union) + if union.is_nullable() + && union.variants().len() == 2 + && matches!(union.variants()[1], Schema::String) => + { + Some(format!("Option<{T}>")) + } + Schema::Union(union) if union.is_nullable() && union.variants().len() == 2 => { + match &union.variants()[1] { + Schema::Array(ArraySchema { items, .. }) + if matches!(items.as_ref(), Schema::String) => + { + Some(format!("Option>")) + } + _ => None, + } + } + _ => None, + } +} + +fn is_iso_datetime_type(type_field: &Value) -> bool { + match type_field { + Value::Object(obj) => { + if obj.get("type").and_then(|t| t.as_str()) == Some("string") + && obj.get("logicalType").and_then(|l| l.as_str()) == Some("iso_datetime") + { + return true; + } + if obj.get("type").and_then(|t| t.as_str()) == Some("array") + && let Some(items) = obj.get("items") + { + return is_iso_datetime_type(items); + } + false + } + Value::Array(variants) => variants.iter().any(is_iso_datetime_type), + _ => false, + } +} + /// A helper struct for nested schema generation. /// /// Used to store inner schema String type so that outer schema String type can be created. @@ -345,6 +487,7 @@ pub struct GenState { schemata_by_name: HashMap, not_eq: HashSet, use_chrono_dates: bool, + iso_datetime_fields: HashSet<(Name, String)>, } impl GenState { @@ -364,6 +507,7 @@ impl GenState { schemata_by_name, not_eq, use_chrono_dates: false, + iso_datetime_fields: HashSet::new(), }) } @@ -381,6 +525,16 @@ impl GenState { self } + pub fn with_iso_datetime_fields(mut self, fields: HashSet<(Name, String)>) -> Self { + self.iso_datetime_fields = fields; + self + } + + pub(crate) fn is_iso_datetime_field(&self, record: &Name, field: &str) -> bool { + self.iso_datetime_fields + .contains(&(record.clone(), field.to_string())) + } + pub(crate) fn get_schema(&self, name: &Name) -> Option<&Schema> { self.schemata_by_name.get(name) } @@ -562,6 +716,7 @@ pub struct Templater { pub nullable: bool, pub use_avro_rs_unions: bool, pub use_chrono_dates: bool, + pub use_iso_datetime: bool, pub derive_builders: bool, pub derive_schemas: bool, pub impl_schemas: bool, @@ -586,6 +741,7 @@ impl Templater { nullable: false, use_avro_rs_unions: false, use_chrono_dates: false, + use_iso_datetime: false, derive_builders: false, derive_schemas: false, impl_schemas: false, @@ -736,6 +892,31 @@ impl Templater { schema }; + if gen_state.is_iso_datetime_field(full_name, name) { + let iso_type = iso_datetime_rust_type(schema); + if let Some(type_str) = iso_type { + f.push(name_std.clone()); + t.insert(name_std.clone(), type_str); + if let Some(default) = default { + // For a literal RFC3339 string default we need + // to emit a chrono parse call rather than the + // String default produced by `parse_default`. + // Null and empty-array defaults still flow + // through the existing helpers. + let default_str = match (schema, default) { + (Schema::String, Value::String(s)) => format!( + "chrono::DateTime::parse_from_rfc3339(\"{s}\").unwrap().with_timezone(&chrono::Utc)" + ), + _ => self.parse_default(schema, gen_state, default)?, + }; + d.insert(name_std.clone(), default_str); + } + continue; + } + // Fall through to the regular match if the schema shape is + // unexpected (e.g. nested unions). The flag is best-effort. + } + match schema { Schema::Ref { .. } => unreachable!("already resolved above"), Schema::Boolean => { diff --git a/tests/generation.rs b/tests/generation.rs index bdf1785..f080ece 100644 --- a/tests/generation.rs +++ b/tests/generation.rs @@ -241,6 +241,29 @@ fn gen_chrono_logical_dates() { ); } +#[test] +fn gen_iso_datetime() { + validate_generation( + "iso_datetime", + Generator::builder().use_iso_datetime(true).build().unwrap(), + ); +} + +#[test] +fn gen_iso_datetime_noop_without_flag() { + // Same schema, flag off → regular `String` output. Guards against + // accidental behaviour change for users not opting in. + validate_generation("iso_datetime_plain", Generator::new().unwrap()); +} + +#[test] +fn gen_order_created() { + validate_generation( + "order_created", + Generator::builder().use_iso_datetime(true).build().unwrap(), + ); +} + #[test] fn gen_record() { validate_generation("record", Generator::new().unwrap()); diff --git a/tests/schemas/iso_datetime.avsc b/tests/schemas/iso_datetime.avsc new file mode 100644 index 0000000..fd710ea --- /dev/null +++ b/tests/schemas/iso_datetime.avsc @@ -0,0 +1,29 @@ +{ + "type": "record", + "name": "Event", + "namespace": "demo", + "doc": "Event with iso_datetime fields", + "fields": [ + { + "name": "created_at", + "type": {"type": "string", "logicalType": "iso_datetime"} + }, + { + "name": "updated_at", + "type": ["null", {"type": "string", "logicalType": "iso_datetime"}], + "default": null + }, + { + "name": "history", + "type": { + "type": "array", + "items": {"type": "string", "logicalType": "iso_datetime"} + }, + "default": [] + }, + { + "name": "label", + "type": "string" + } + ] +} diff --git a/tests/schemas/iso_datetime.rs b/tests/schemas/iso_datetime.rs new file mode 100644 index 0000000..89687d8 --- /dev/null +++ b/tests/schemas/iso_datetime.rs @@ -0,0 +1,17 @@ + +/// Event with iso_datetime fields +#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize)] +pub struct Event { + pub created_at: chrono::DateTime, + #[serde(default = "default_event_updated_at")] + pub updated_at: Option>, + #[serde(default = "default_event_history")] + pub history: Vec>, + pub label: String, +} + +#[inline(always)] +fn default_event_updated_at() -> Option> { None } + +#[inline(always)] +fn default_event_history() -> Vec> { vec![] } diff --git a/tests/schemas/iso_datetime_plain.avsc b/tests/schemas/iso_datetime_plain.avsc new file mode 100644 index 0000000..fd710ea --- /dev/null +++ b/tests/schemas/iso_datetime_plain.avsc @@ -0,0 +1,29 @@ +{ + "type": "record", + "name": "Event", + "namespace": "demo", + "doc": "Event with iso_datetime fields", + "fields": [ + { + "name": "created_at", + "type": {"type": "string", "logicalType": "iso_datetime"} + }, + { + "name": "updated_at", + "type": ["null", {"type": "string", "logicalType": "iso_datetime"}], + "default": null + }, + { + "name": "history", + "type": { + "type": "array", + "items": {"type": "string", "logicalType": "iso_datetime"} + }, + "default": [] + }, + { + "name": "label", + "type": "string" + } + ] +} diff --git a/tests/schemas/iso_datetime_plain.rs b/tests/schemas/iso_datetime_plain.rs new file mode 100644 index 0000000..c2c33cd --- /dev/null +++ b/tests/schemas/iso_datetime_plain.rs @@ -0,0 +1,17 @@ + +/// Event with iso_datetime fields +#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize)] +pub struct Event { + pub created_at: String, + #[serde(default = "default_event_updated_at")] + pub updated_at: Option, + #[serde(default = "default_event_history")] + pub history: Vec, + pub label: String, +} + +#[inline(always)] +fn default_event_updated_at() -> Option { None } + +#[inline(always)] +fn default_event_history() -> Vec { vec![] } diff --git a/tests/schemas/mod.rs b/tests/schemas/mod.rs index 3759c76..a214e9a 100644 --- a/tests/schemas/mod.rs +++ b/tests/schemas/mod.rs @@ -2,6 +2,7 @@ #![allow(dead_code)] pub mod array_3d; +pub mod order_created; pub mod complex; pub mod decimals; pub mod enums; @@ -10,6 +11,8 @@ pub mod enums_multiline_doc; pub mod enums_sanitize; pub mod fixed; pub mod interop; +pub mod iso_datetime; +pub mod iso_datetime_plain; pub mod logical_dates; pub mod chrono_logical_dates; pub mod map_default; diff --git a/tests/schemas/order_created.avsc b/tests/schemas/order_created.avsc new file mode 100644 index 0000000..cf1809a --- /dev/null +++ b/tests/schemas/order_created.avsc @@ -0,0 +1,59 @@ +{ + "name": "OrderCreated", + "type": "record", + "doc": "An order has been placed", + "namespace": "com.example.orders", + "fields": [ + { + "doc": "Order ID", + "name": "order_id", + "type": {"type": "string", "logicalType": "uuid"} + }, + { + "doc": "Date and time the order was placed", + "name": "created_at", + "type": {"type": "string", "logicalType": "iso_datetime"} + }, + { + "doc": "Requested delivery date and time", + "name": "scheduled_at", + "type": {"type": "string", "logicalType": "iso_datetime"} + }, + { + "doc": "Customer who placed the order", + "name": "customer", + "type": { + "type": "record", + "name": "Customer", + "fields": [ + { + "doc": "Customer full name", + "name": "name", + "type": "string" + }, + { + "doc": "Customer email address", + "name": "email", + "type": ["null", "string"], + "default": null + } + ] + } + }, + { + "doc": "Current order status", + "name": "status", + "type": { + "type": "enum", + "name": "OrderStatus", + "symbols": ["PENDING", "CONFIRMED", "CANCELLED"] + } + }, + { + "doc": "Optional coupon code applied to the order", + "name": "coupon_code", + "type": ["null", "string"], + "default": null + } + ] +} diff --git a/tests/schemas/order_created.rs b/tests/schemas/order_created.rs new file mode 100644 index 0000000..700b6d6 --- /dev/null +++ b/tests/schemas/order_created.rs @@ -0,0 +1,43 @@ + +#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize)] +pub struct Customer { + /// Customer full name + pub name: String, + /// Customer email address + #[serde(default = "default_customer_email")] + pub email: Option, +} + +#[inline(always)] +fn default_customer_email() -> Option { None } + +#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy, serde::Deserialize, serde::Serialize)] +pub enum OrderStatus { + #[serde(rename = "PENDING")] + Pending, + #[serde(rename = "CONFIRMED")] + Confirmed, + #[serde(rename = "CANCELLED")] + Cancelled, +} + +/// An order has been placed +#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize)] +pub struct OrderCreated { + /// Order ID + pub order_id: apache_avro::Uuid, + /// Date and time the order was placed + pub created_at: chrono::DateTime, + /// Requested delivery date and time + pub scheduled_at: chrono::DateTime, + /// Customer who placed the order + pub customer: Customer, + /// Current order status + pub status: OrderStatus, + /// Optional coupon code applied to the order + #[serde(default = "default_ordercreated_coupon_code")] + pub coupon_code: Option, +} + +#[inline(always)] +fn default_ordercreated_coupon_code() -> Option { None } diff --git a/tests/serde.rs b/tests/serde.rs index 02fa332..466b1c0 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -3,8 +3,42 @@ mod schemas; use std::collections::HashMap; +use crate::schemas::iso_datetime::Event; use crate::schemas::multi_valued_union_with_avro_rs_unions::Contact; +#[test] +fn iso_datetime_roundtrip() { + use chrono::{TimeZone, Utc}; + + let original = Event { + created_at: Utc.with_ymd_and_hms(2024, 1, 2, 3, 4, 5).unwrap(), + updated_at: Some(Utc.with_ymd_and_hms(2024, 1, 2, 3, 4, 6).unwrap()), + history: vec![ + Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + Utc.with_ymd_and_hms(2024, 1, 1, 1, 0, 0).unwrap(), + ], + label: "open".into(), + }; + + // serde_json round-trip: confirms chrono's default RFC3339 serde impl. + let json = serde_json::to_string(&original).unwrap(); + let decoded: Event = serde_json::from_str(&json).unwrap(); + assert_eq!(original, decoded); + + // apache_avro round-trip: the wire format for the logical type is a + // plain `string`, and chrono reads/writes the same RFC3339 form. + let schema = + apache_avro::Schema::parse_str(include_str!("schemas/iso_datetime.avsc")).unwrap(); + let value = apache_avro::to_value(original.clone()).unwrap(); + let value = value.resolve(&schema).unwrap(); + let decoded: Event = apache_avro::from_value(&value).unwrap(); + assert_eq!(original, decoded); + + // Invalid datetime strings should fail deserialisation cleanly. + let bad_json = r#"{"created_at":"not-a-date","updated_at":null,"history":[],"label":"x"}"#; + assert!(serde_json::from_str::(bad_json).is_err()); +} + #[test] fn multi_valued_union_serde() { let expected = Contact {