diff --git a/Cargo.lock b/Cargo.lock index 171a391..23efdfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -992,6 +992,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "hashbrown" version = "0.12.3" @@ -1510,6 +1516,7 @@ dependencies = [ "tokio", "toml 0.8.23", "tracing", + "trybuild", "utoipa", "validator", "wasm-bindgen", @@ -2709,6 +2716,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20f34339676cdcab560c9a82300c4c2581f68b9369aedf0fae86f2ff9565ff3e" +[[package]] +name = "target-triple" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" + [[package]] name = "telegram-webapp-sdk" version = "0.1.1" @@ -2765,6 +2778,15 @@ dependencies = [ "uuid", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "2.0.16" @@ -2911,10 +2933,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75129e1dc5000bfbaa9fee9d1b21f974f9fbad9daec557a521ee6e080825f6e8" dependencies = [ + "indexmap 2.11.1", "serde", "serde_spanned 1.0.0", "toml_datetime 0.7.0", "toml_parser", + "toml_writer", "winnow", ] @@ -3048,6 +3072,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ded9fdb81f30a5708920310bfcd9ea7482ff9cba5f54601f7a19a877d5c2392" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml", +] + [[package]] name = "typeid" version = "1.0.3" @@ -3349,6 +3388,15 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "windows-core" version = "0.61.2" diff --git a/Cargo.toml b/Cargo.toml index 84afc47..2a4e0cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ tokio = { version = "1", features = [ "net", "time", ], default-features = false } +trybuild = "1" toml = "0.8" diff --git a/masterror-derive/src/lib.rs b/masterror-derive/src/lib.rs index 2d00622..7a7c3e8 100644 --- a/masterror-derive/src/lib.rs +++ b/masterror-derive/src/lib.rs @@ -4,8 +4,9 @@ //! Derive macro implementing [`std::error::Error`] with `Display` formatting. //! //! The macro mirrors the essential functionality relied upon by `masterror` and -//! consumers of the crate: display strings with named or positional fields and -//! a configurable error source via `#[source]` field attributes. +//! consumers of the crate: display strings with named or positional fields, +//! `#[from]` conversions for wrapper types, and a configurable error source via +//! `#[source]` field attributes. use std::collections::BTreeSet; @@ -20,8 +21,6 @@ use syn::{ /// Derive [`std::error::Error`] and [`core::fmt::Display`] for structs and /// enums. /// -/// ```ignore -/// use masterror::Error; /// /// #[derive(Debug, Error)] /// #[error("{code}: {message}")] @@ -35,9 +34,9 @@ use syn::{ /// message: "boom" /// }; /// assert_eq!(err.to_string(), "500: boom"); -/// assert!(err.source().is_none()); +/// assert!(std::error::Error::source(&err).is_none()); /// ``` -#[proc_macro_derive(Error, attributes(error, source))] +#[proc_macro_derive(Error, attributes(error, source, from))] pub fn derive_error(input: TokenStream) -> TokenStream { match derive_error_impl(syn::parse_macro_input!(input as DeriveInput)) { Ok(tokens) => tokens.into(), @@ -51,6 +50,7 @@ fn derive_error_impl(input: DeriveInput) -> syn::Result { let display_impl; let error_impl; + let from_impls; match input.data { Data::Struct(data) => { @@ -58,11 +58,13 @@ fn derive_error_impl(input: DeriveInput) -> syn::Result { let display_attr = parse_display_attr(&input.attrs)?; display_impl = build_struct_display(&ident, &generics, &fields, &display_attr)?; error_impl = build_struct_error(&ident, &generics, &fields)?; + from_impls = build_struct_from_impl(&ident, &generics, &fields)?; } Data::Enum(data) => { let variants = parse_enum(&data)?; display_impl = build_enum_display(&ident, &generics, &variants)?; error_impl = build_enum_error(&ident, &generics, &variants)?; + from_impls = build_enum_from_impls(&ident, &generics, &variants)?; } Data::Union(_) => { return Err(syn::Error::new( @@ -75,6 +77,7 @@ fn derive_error_impl(input: DeriveInput) -> syn::Result { Ok(quote! { #display_impl #error_impl + #from_impls }) } @@ -89,7 +92,45 @@ enum FieldsStyle { struct FieldSpec { member: Member, ident: Option, - binding: Ident + binding: Ident, + ty: Type, + attrs: FieldAttributes +} + +#[derive(Clone, Default)] +struct FieldAttributes { + from: Option, + source: Option +} + +impl FieldAttributes { + fn mark_from(&mut self, span: Span) -> syn::Result<()> { + if self.from.is_some() { + return Err(syn::Error::new(span, "duplicate #[from] attribute")); + } + self.from = Some(span); + Ok(()) + } + + fn mark_source(&mut self, span: Span) -> syn::Result<()> { + if self.source.is_some() { + return Err(syn::Error::new(span, "duplicate #[source] attribute")); + } + self.source = Some(span); + Ok(()) + } + + fn has_source(&self) -> bool { + self.source.is_some() + } + + fn span_of_from_attribute(&self) -> Option { + self.from + } + + fn source_span(&self) -> Option { + self.source + } } #[derive(Clone, Copy)] @@ -116,6 +157,11 @@ struct VariantInfo { display: LitStr } +struct FromFieldInfo<'a> { + field: &'a FieldSpec, + span: Span +} + struct RewriteResult { literal: LitStr, positional_indices: BTreeSet @@ -129,7 +175,10 @@ fn parse_enum(data: &DataEnum) -> syn::Result> { let mut variants = Vec::with_capacity(data.variants.len()); for variant in &data.variants { let display = parse_display_attr(&variant.attrs)?; - let fields = parse_fields_internal(&variant.fields)?; + let mut fields = parse_fields_internal(&variant.fields)?; + if let Some(span) = parse_variant_from_attr(&variant.attrs)? { + apply_variant_from_attr(&mut fields, span, &variant.ident)?; + } variants.push(VariantInfo { ident: variant.ident.clone(), fields, @@ -150,16 +199,17 @@ fn parse_fields_internal(fields: &Fields) -> syn::Result { let mut specs = Vec::with_capacity(named.named.len()); let mut source = None; for (index, field) in named.named.iter().enumerate() { + let attrs = parse_field_attributes(field)?; let ident = field.ident.clone().ok_or_else(|| { syn::Error::new(field.span(), "named field missing identifier") })?; let member = Member::Named(ident.clone()); let binding = ident.clone(); - if has_source_attr(field)? { + if attrs.has_source() { let kind = detect_source_kind(&field.ty)?; if source.is_some() { return Err(syn::Error::new( - field.span(), + attrs.source_span().unwrap_or_else(|| field.span()), "only a single #[source] field is supported" )); } @@ -171,7 +221,9 @@ fn parse_fields_internal(fields: &Fields) -> syn::Result { specs.push(FieldSpec { member, ident: Some(ident), - binding + binding, + ty: field.ty.clone(), + attrs }); } Ok(ParsedFields { @@ -184,13 +236,14 @@ fn parse_fields_internal(fields: &Fields) -> syn::Result { let mut specs = Vec::with_capacity(unnamed.unnamed.len()); let mut source = None; for (index, field) in unnamed.unnamed.iter().enumerate() { + let attrs = parse_field_attributes(field)?; let member = Member::Unnamed(index.into()); let binding = format_ident!("__masterror_{index}"); - if has_source_attr(field)? { + if attrs.has_source() { let kind = detect_source_kind(&field.ty)?; if source.is_some() { return Err(syn::Error::new( - field.span(), + attrs.source_span().unwrap_or_else(|| field.span()), "only a single #[source] field is supported" )); } @@ -202,7 +255,9 @@ fn parse_fields_internal(fields: &Fields) -> syn::Result { specs.push(FieldSpec { member, ident: None, - binding + binding, + ty: field.ty.clone(), + attrs }); } Ok(ParsedFields { @@ -240,20 +295,101 @@ fn parse_display_attr(attrs: &[Attribute]) -> syn::Result { .ok_or_else(|| syn::Error::new(Span::call_site(), r#"missing #[error("...")] attribute"#)) } -fn has_source_attr(field: &Field) -> syn::Result { - let mut found = false; +fn parse_field_attributes(field: &Field) -> syn::Result { + let mut attrs = FieldAttributes::default(); for attr in &field.attrs { if attr.path().is_ident("source") { - if found { + ensure_path_only(attr, "source")?; + attrs.mark_source(attr.span())?; + } else if attr.path().is_ident("from") { + ensure_path_only(attr, "from")?; + attrs.mark_from(attr.span())?; + } + } + Ok(attrs) +} + +fn ensure_path_only(attr: &Attribute, name: &str) -> syn::Result<()> { + if !matches!(&attr.meta, Meta::Path(_)) { + return Err(syn::Error::new( + attr.span(), + format!("#[{name}] attribute does not accept arguments") + )); + } + Ok(()) +} + +fn parse_variant_from_attr(attrs: &[Attribute]) -> syn::Result> { + let mut span = None; + for attr in attrs.iter().filter(|attr| attr.path().is_ident("from")) { + ensure_path_only(attr, "from")?; + if span.is_some() { + return Err(syn::Error::new(attr.span(), "duplicate #[from] attribute")); + } + span = Some(attr.span()); + } + Ok(span) +} + +fn apply_variant_from_attr( + fields: &mut ParsedFields, + span: Span, + variant_ident: &Ident +) -> syn::Result<()> { + if fields.fields.is_empty() { + return Err(syn::Error::new( + span, + format!( + "variant `{variant_ident}` marked with #[from] must contain exactly one field" + ) + )); + } + if fields.fields.len() != 1 { + return Err(syn::Error::new( + span, + format!( + "variant `{variant_ident}` marked with #[from] must contain exactly one field" + ) + )); + } + let field = fields + .fields + .get_mut(0) + .ok_or_else(|| syn::Error::new(span, "invalid #[from] field index"))?; + field.attrs.mark_from(span) +} + +fn find_from_field<'a>( + fields: &'a ParsedFields, + context: &str +) -> syn::Result>> { + let mut info = None; + for field in &fields.fields { + if let Some(span) = field.attrs.span_of_from_attribute() { + if info.is_some() { return Err(syn::Error::new( - attr.span(), - "duplicate #[source] attribute" + span, + format!( + "multiple #[from] attributes in {context}; only one field may use #[from]" + ) )); } - found = true; + info = Some(FromFieldInfo { + field, + span + }); } } - Ok(found) + let Some(info) = info else { + return Ok(None); + }; + if fields.fields.len() != 1 { + return Err(syn::Error::new( + info.span, + format!("using #[from] in {context} requires exactly one field") + )); + } + Ok(Some(info)) } fn detect_source_kind(ty: &Type) -> syn::Result { @@ -403,6 +539,52 @@ fn build_struct_error( }) } +fn build_struct_from_impl( + ident: &Ident, + generics: &Generics, + fields: &ParsedFields +) -> syn::Result { + let context = format!("struct `{ident}`"); + let Some(from_info) = find_from_field(fields, &context)? else { + return Ok(TokenStream2::new()); + }; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let field_ty = &from_info.field.ty; + let arg_ident = format_ident!("__masterror_from_value"); + + let construct = match fields.style { + FieldsStyle::Named => { + let field_ident = from_info.field.ident.clone().ok_or_else(|| { + syn::Error::new(from_info.span, "named field missing identifier") + })?; + quote! { Self { #field_ident: #arg_ident } } + } + FieldsStyle::Unnamed => { + if fields.fields.len() != 1 { + return Err(syn::Error::new( + from_info.span, + format!("using #[from] in {context} requires exactly one field") + )); + } + quote! { Self(#arg_ident) } + } + FieldsStyle::Unit => { + return Err(syn::Error::new( + from_info.span, + format!("using #[from] in {context} requires at least one field") + )); + } + }; + + Ok(quote! { + impl #impl_generics ::core::convert::From<#field_ty> for #ident #ty_generics #where_clause { + fn from(#arg_ident: #field_ty) -> Self { + #construct + } + } + }) +} + fn build_enum_display( ident: &Ident, generics: &Generics, @@ -598,6 +780,58 @@ fn build_enum_error( }) } +fn build_enum_from_impls( + ident: &Ident, + generics: &Generics, + variants: &[VariantInfo] +) -> syn::Result { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let mut impls = Vec::new(); + + for variant in variants { + let context = format!("variant `{}`", variant.ident); + if let Some(from_info) = find_from_field(&variant.fields, &context)? { + let field_ty = &from_info.field.ty; + let variant_ident = &variant.ident; + let arg_ident = format_ident!("__masterror_from_value"); + + let body = match variant.fields.style { + FieldsStyle::Named => { + let field_ident = from_info.field.ident.clone().ok_or_else(|| { + syn::Error::new(from_info.span, "named field missing identifier") + })?; + quote! { Self::#variant_ident { #field_ident: #arg_ident } } + } + FieldsStyle::Unnamed => { + if variant.fields.fields.len() != 1 { + return Err(syn::Error::new( + from_info.span, + format!("using #[from] in {context} requires exactly one field") + )); + } + quote! { Self::#variant_ident(#arg_ident) } + } + FieldsStyle::Unit => { + return Err(syn::Error::new( + from_info.span, + format!("{context} cannot be unit-like when using #[from]") + )); + } + }; + + impls.push(quote! { + impl #impl_generics ::core::convert::From<#field_ty> for #ident #ty_generics #where_clause { + fn from(#arg_ident: #field_ty) -> Self { + #body + } + } + }); + } + } + + Ok(quote! { #(#impls)* }) +} + fn rewrite_format_string(original: &LitStr, field_count: usize) -> syn::Result { let src = original.value(); let mut result = String::with_capacity(src.len()); diff --git a/tests/error_derive.rs b/tests/error_derive.rs index b1171d0..3df72ab 100644 --- a/tests/error_derive.rs +++ b/tests/error_derive.rs @@ -33,6 +33,48 @@ enum EnumError { Pair(String, #[source] LeafError) } +#[derive(Debug, Error)] +#[error("primary failure")] +struct PrimaryError; + +#[derive(Debug, Error)] +#[error("secondary failure")] +struct SecondaryError; + +#[derive(Debug, Error)] +#[error("tuple wrapper -> {0}")] +struct TupleWrapper( + #[from] + #[source] + LeafError +); + +#[derive(Debug, Error)] +#[error("message: {message}")] +struct MessageWrapper { + #[from] + message: String +} + +#[derive(Debug, Error)] +enum MixedFromError { + #[error("tuple variant {0}")] + Tuple( + #[from] + #[source] + LeafError + ), + #[error("variant attr {0}")] + #[from] + VariantAttr(#[source] PrimaryError), + #[error("named variant {source:?}")] + Named { + #[from] + #[source] + source: SecondaryError + } +} + #[test] fn named_struct_display_and_source() { let err = NamedError { @@ -77,3 +119,47 @@ fn tuple_variant_with_source() { assert!(err.to_string().starts_with("left")); assert_eq!(StdError::source(&err).unwrap().to_string(), "leaf failure"); } + +#[test] +fn tuple_struct_from_wraps_source() { + let err = TupleWrapper::from(LeafError); + assert_eq!(err.to_string(), "tuple wrapper -> leaf failure"); + let source = StdError::source(&err).expect("source present"); + assert_eq!(source.to_string(), "leaf failure"); +} + +#[test] +fn named_struct_from_without_source() { + let err = MessageWrapper::from(String::from("payload")); + assert_eq!(err.to_string(), "message: payload"); + assert!(StdError::source(&err).is_none()); +} + +#[test] +fn enum_from_variants_generate_impls() { + let tuple = MixedFromError::from(LeafError); + assert!(matches!(&tuple, MixedFromError::Tuple(_))); + assert_eq!( + StdError::source(&tuple).unwrap().to_string(), + "leaf failure" + ); + + let variant_attr = MixedFromError::from(PrimaryError); + assert!(matches!(&variant_attr, MixedFromError::VariantAttr(_))); + assert_eq!( + StdError::source(&variant_attr).unwrap().to_string(), + "primary failure" + ); + + let named = MixedFromError::from(SecondaryError); + assert!(matches!( + &named, + MixedFromError::Named { + source: SecondaryError + } + )); + assert_eq!( + StdError::source(&named).unwrap().to_string(), + "secondary failure" + ); +} diff --git a/tests/error_derive_from_trybuild.rs b/tests/error_derive_from_trybuild.rs new file mode 100644 index 0000000..a878025 --- /dev/null +++ b/tests/error_derive_from_trybuild.rs @@ -0,0 +1,7 @@ +use trybuild::TestCases; + +#[test] +fn from_attribute_compile_failures() { + let t = TestCases::new(); + t.compile_fail("tests/ui/from/*.rs"); +} diff --git a/tests/ui/from/struct_multiple_fields.rs b/tests/ui/from/struct_multiple_fields.rs new file mode 100644 index 0000000..a545ce7 --- /dev/null +++ b/tests/ui/from/struct_multiple_fields.rs @@ -0,0 +1,15 @@ +use masterror::Error; + +#[derive(Debug, Error)] +#[error("{left:?} - {right:?}")] +struct BadStruct { + #[from] + left: DummyError, + right: DummyError, +} + +#[derive(Debug, Error)] +#[error("dummy")] +struct DummyError; + +fn main() {} diff --git a/tests/ui/from/struct_multiple_fields.stderr b/tests/ui/from/struct_multiple_fields.stderr new file mode 100644 index 0000000..c529748 --- /dev/null +++ b/tests/ui/from/struct_multiple_fields.stderr @@ -0,0 +1,5 @@ +error: using #[from] in struct `BadStruct` requires exactly one field + --> tests/ui/from/struct_multiple_fields.rs:6:5 + | +6 | #[from] + | ^ diff --git a/tests/ui/from/variant_multiple_fields.rs b/tests/ui/from/variant_multiple_fields.rs new file mode 100644 index 0000000..3d35295 --- /dev/null +++ b/tests/ui/from/variant_multiple_fields.rs @@ -0,0 +1,14 @@ +use masterror::Error; + +#[derive(Debug, Error)] +enum BadEnum { + #[error("{0} - {1}")] + #[from] + Two(#[source] DummyError, DummyError), +} + +#[derive(Debug, Error)] +#[error("dummy")] +struct DummyError; + +fn main() {} diff --git a/tests/ui/from/variant_multiple_fields.stderr b/tests/ui/from/variant_multiple_fields.stderr new file mode 100644 index 0000000..0f07dd2 --- /dev/null +++ b/tests/ui/from/variant_multiple_fields.stderr @@ -0,0 +1,5 @@ +error: variant `Two` marked with #[from] must contain exactly one field + --> tests/ui/from/variant_multiple_fields.rs:6:5 + | +6 | #[from] + | ^