Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn main() {
}

fn run() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rustc-check-cfg=cfg(error_generic_member_access)");
println!("cargo:rerun-if-changed=Cargo.toml");
println!("cargo:rerun-if-changed=README.template.md");
println!("cargo:rerun-if-changed=build/readme.rs");
Expand Down
130 changes: 130 additions & 0 deletions masterror-derive/src/error_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ pub fn expand(input: &ErrorInput) -> Result<TokenStream, Error> {

fn expand_struct(input: &ErrorInput, data: &StructData) -> Result<TokenStream, Error> {
let body = struct_source_body(&data.fields, &data.display);
let backtrace_method = struct_backtrace_method(&data.fields);
let has_backtrace = backtrace_method.is_some();
let backtrace_method = backtrace_method.unwrap_or_default();
let provide_method = if has_backtrace {
provide_method_tokens()
} else {
TokenStream::new()
};

let ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Expand All @@ -24,6 +32,8 @@ fn expand_struct(input: &ErrorInput, data: &StructData) -> Result<TokenStream, E
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
#body
}
#backtrace_method
#provide_method
}
})
}
Expand All @@ -34,6 +44,15 @@ fn expand_enum(input: &ErrorInput, variants: &[VariantData]) -> Result<TokenStre
arms.push(variant_source_arm(variant));
}

let backtrace_method = enum_backtrace_method(variants);
let has_backtrace = backtrace_method.is_some();
let backtrace_method = backtrace_method.unwrap_or_default();
let provide_method = if has_backtrace {
provide_method_tokens()
} else {
TokenStream::new()
};

let ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

Expand All @@ -44,6 +63,8 @@ fn expand_enum(input: &ErrorInput, variants: &[VariantData]) -> Result<TokenStre
#(#arms),*
}
}
#backtrace_method
#provide_method
}
})
}
Expand Down Expand Up @@ -179,6 +200,115 @@ fn field_source_expr(
}
}

fn struct_backtrace_method(fields: &Fields) -> Option<TokenStream> {
let field = fields.backtrace_field()?;
let member = &field.member;
let body = field_backtrace_expr(quote!(self.#member), quote!(&self.#member), &field.ty);
Some(quote! {
#[cfg(error_generic_member_access)]
fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
#body
}
})
}

fn enum_backtrace_method(variants: &[VariantData]) -> Option<TokenStream> {
let mut has_backtrace = false;
let mut arms = Vec::new();
for variant in variants {
if variant.fields.backtrace_field().is_some() {
has_backtrace = true;
}
arms.push(variant_backtrace_arm(variant));
}

if has_backtrace {
Some(quote! {
#[cfg(error_generic_member_access)]
fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
match self {
#(#arms),*
}
}
})
} else {
None
}
}

fn variant_backtrace_arm(variant: &VariantData) -> TokenStream {
let variant_ident = &variant.ident;
let backtrace_field = variant.fields.backtrace_field();

match (&variant.fields, backtrace_field) {
(Fields::Unit, _) => quote! { Self::#variant_ident => None },
(Fields::Named(fields), Some(field)) => {
let field_ident = field.ident.clone().expect("named field");
let binding = binding_ident(field);
let pattern = if fields.len() == 1 {
quote!(Self::#variant_ident { #field_ident: #binding })
} else {
quote!(Self::#variant_ident { #field_ident: #binding, .. })
};
let body = field_backtrace_expr(quote!(#binding), quote!(#binding), &field.ty);
quote! {
#pattern => { #body }
}
}
(Fields::Unnamed(fields), Some(field)) => {
let index = field.index;
let binding = binding_ident(field);
let pattern_elements: Vec<_> = fields
.iter()
.enumerate()
.map(|(idx, _)| {
if idx == index {
quote!(#binding)
} else {
quote!(_)
}
})
.collect();
let body = field_backtrace_expr(quote!(#binding), quote!(#binding), &field.ty);
quote! {
Self::#variant_ident(#(#pattern_elements),*) => { #body }
}
}
(Fields::Named(_), None) => quote! { Self::#variant_ident { .. } => None },
(Fields::Unnamed(fields), None) => {
if fields.is_empty() {
quote! { Self::#variant_ident() => None }
} else {
let placeholders = vec![quote!(_); fields.len()];
quote! { Self::#variant_ident(#(#placeholders),*) => None }
}
}
}
}

fn field_backtrace_expr(
owned_expr: TokenStream,
referenced_expr: TokenStream,
ty: &syn::Type
) -> TokenStream {
if is_option_type(ty) {
quote! { #owned_expr.as_ref() }
} else {
quote! { Some(#referenced_expr) }
}
}

fn provide_method_tokens() -> TokenStream {
quote! {
#[cfg(error_generic_member_access)]
fn provide<'a>(&'a self, request: &mut core::error::Request<'a>) {
if let Some(backtrace) = std::error::Error::backtrace(self) {
request.provide_ref::<std::backtrace::Backtrace>(backtrace);
}
}
}
}

fn binding_ident(field: &Field) -> Ident {
field
.ident
Expand Down
85 changes: 84 additions & 1 deletion masterror-derive/src/input.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use proc_macro2::Span;
use syn::{
Attribute, Data, DataEnum, DataStruct, DeriveInput, Error, Field as SynField,
Fields as SynFields, Ident, LitStr, spanned::Spanned
Fields as SynFields, GenericArgument, Ident, LitStr, spanned::Spanned
};

use crate::template_support::{DisplayTemplate, TemplateIdentifierSpec, parse_display_template};
Expand Down Expand Up @@ -94,6 +94,10 @@ impl Fields {
pub fn first_from_field(&self) -> Option<&Field> {
self.iter().find(|field| field.attrs.from.is_some())
}

pub fn backtrace_field(&self) -> Option<&Field> {
self.iter().find(|field| field.attrs.backtrace.is_some())
}
}

pub enum FieldIter<'a> {
Expand Down Expand Up @@ -250,6 +254,7 @@ fn parse_struct(
let fields = Fields::from_syn(&data.fields, errors);

validate_from_usage(&fields, &display, errors);
validate_backtrace_usage(&fields, errors);
validate_transparent(&fields, &display, errors, None);

Ok(ErrorData::Struct(Box::new(StructData {
Expand Down Expand Up @@ -295,6 +300,7 @@ fn parse_variant(variant: syn::Variant, errors: &mut Vec<Error>) -> Result<Varia
let fields = Fields::from_syn(&variant.fields, errors);

validate_from_usage(&fields, &display, errors);
validate_backtrace_usage(&fields, errors);
validate_transparent(&fields, &display, errors, Some(&variant));

Ok(VariantData {
Expand Down Expand Up @@ -433,6 +439,50 @@ fn validate_from_usage(fields: &Fields, display: &DisplaySpec, errors: &mut Vec<
}
}

fn validate_backtrace_usage(fields: &Fields, errors: &mut Vec<Error>) {
let backtrace_fields: Vec<_> = fields
.iter()
.filter(|field| field.attrs.backtrace.is_some())
.collect();

for field in &backtrace_fields {
validate_backtrace_field_type(field, errors);
}

if backtrace_fields.len() <= 1 {
return;
}

for field in backtrace_fields.iter().skip(1) {
if let Some(attr) = &field.attrs.backtrace {
errors.push(Error::new_spanned(
attr,
"multiple #[backtrace] fields are not supported"
));
}
}
}

fn validate_backtrace_field_type(field: &Field, errors: &mut Vec<Error>) {
let Some(attr) = &field.attrs.backtrace else {
return;
};

let ty = &field.ty;
if is_option_type(ty) {
if option_inner_type(ty).is_some_and(is_backtrace_type) {
return;
}
} else if is_backtrace_type(ty) {
return;
}

errors.push(Error::new_spanned(
attr,
"fields with #[backtrace] must be std::backtrace::Backtrace or Option<std::backtrace::Backtrace>"
));
}

fn validate_transparent(
fields: &Fields,
display: &DisplaySpec,
Expand Down Expand Up @@ -493,6 +543,39 @@ pub fn is_option_type(ty: &syn::Type) -> bool {
false
}

fn option_inner_type(ty: &syn::Type) -> Option<&syn::Type> {
let syn::Type::Path(path) = ty else {
return None;
};
if path.qself.is_some() {
return None;
}
let last = path.path.segments.last()?;
if last.ident != "Option" {
return None;
}
let syn::PathArguments::AngleBracketed(arguments) = &last.arguments else {
return None;
};
arguments.args.iter().find_map(|argument| match argument {
GenericArgument::Type(inner) => Some(inner),
_ => None
})
}

fn is_backtrace_type(ty: &syn::Type) -> bool {
let syn::Type::Path(path) = ty else {
return false;
};
if path.qself.is_some() {
return false;
}
let Some(last) = path.path.segments.last() else {
return false;
};
last.ident == "Backtrace" && matches!(last.arguments, syn::PathArguments::None)
}

pub fn placeholder_error(span: Span, identifier: &TemplateIdentifierSpec) -> Error {
match identifier {
TemplateIdentifierSpec::Named(name) => {
Expand Down
Loading
Loading