From 98319acf35572c6f51ba52d9e7376a070d1e4ee5 Mon Sep 17 00:00:00 2001 From: SimonThormeyer Date: Thu, 10 Apr 2025 18:12:43 +0200 Subject: [PATCH 1/3] feat: implement `#[call_generic]` --- macros/Cargo.toml | 2 +- macros/src/call_generic.rs | 114 +++++++++++++++++++++++++++++++++++++ macros/src/lib.rs | 3 + 3 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 macros/src/call_generic.rs diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 756dbf4..333402f 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -21,7 +21,7 @@ development = ["async-std", "async-trait"] [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "2.0", features = ["visit-mut", "full"] } +syn = { version = "2.0", features = ["visit", "visit-mut", "full"] } [dev-dependencies] async-std = { version = "1.0", features = ["attributes"] } diff --git a/macros/src/call_generic.rs b/macros/src/call_generic.rs new file mode 100644 index 0000000..9a95427 --- /dev/null +++ b/macros/src/call_generic.rs @@ -0,0 +1,114 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use syn::{ + parse_quote, + spanned::Spanned, + visit::{self, Visit}, + visit_mut::{self, VisitMut}, + Attribute, Error, Expr, ExprCall, ExprMethodCall, +}; + +const ATTRIBUTE_NAME: &str = "call_generic"; + +pub(super) fn parse_quote_call_generic(input: TokenStream) -> TokenStream { + let mut syntax_tree = syn::parse::(input.into()).unwrap(); + GenericCallVisitor.visit_file_mut(&mut syntax_tree); + let mut visitor = OrphanAttributeVisitor::default(); + visitor.visit_file(&syntax_tree); + if let Some(error_span) = visitor.error_span { + return Error::new( + error_span, + format!("{ATTRIBUTE_NAME} must be used on a method call or a function call."), + ) + .into_compile_error(); + } + quote! {#syntax_tree} +} + +/// Finds an orphan attribute that hasn't been handled by [GenericCallVisitor]. +#[derive(Default)] +struct OrphanAttributeVisitor { + error_span: Option, +} + +impl Visit<'_> for OrphanAttributeVisitor { + fn visit_attribute(&mut self, node: &Attribute) { + if node + .path() + .segments + .last() + .map(|path| path.ident.to_string()) + == Some(ATTRIBUTE_NAME.to_string()) + { + self.error_span = Some(node.span()); + } + visit::visit_attribute(self, node); + } +} + +/// Finds and replaces a method call or function call annotated with `#[call_generic]`. +/// This enables `#[call_generic]` to be used as a shorthand for +/// ```rust,ignore +/// if _sync { +/// do_stuff(); +/// } else { +/// do_stuff_async().await; +/// } +/// ``` +struct GenericCallVisitor; + +impl VisitMut for GenericCallVisitor { + fn visit_expr_mut(&mut self, node: &mut Expr) { + let async_call = match node { + Expr::MethodCall(expr) => construct_async_method_call(expr), + Expr::Call(expr) => construct_async_function_call(expr), + _ => None, + }; + + if let Some(async_call) = async_call { + let original_call = node.clone(); + *node = parse_quote! { + if _async { + #async_call.await + } else { + #original_call + } + }; + return; + } + + // Delegate to the default impl to visit nested expressions. + visit_mut::visit_expr_mut(self, node); + } +} + +fn construct_async_method_call(expr: &mut ExprMethodCall) -> Option { + find_and_remove_generic_call_attr(&mut expr.attrs)?; + let mut async_expr = expr.clone(); + async_expr.method = format_ident!("{}_async", async_expr.method); + Some(Expr::MethodCall(async_expr)) +} + +fn construct_async_function_call(expr: &mut ExprCall) -> Option { + find_and_remove_generic_call_attr(&mut expr.attrs)?; + let mut async_expr = expr.clone(); + let mut func = *async_expr.func; + let Expr::Path(ref mut path_expr) = &mut func else { + return None; + }; + let last_segment = path_expr.path.segments.last_mut()?; + last_segment.ident = format_ident!("{}_async", last_segment.ident); + async_expr.func = Box::new(func); + Some(Expr::Call(async_expr)) +} + +fn find_and_remove_generic_call_attr(attributes: &mut Vec) -> Option<()> { + let length_before_removal = attributes.len(); + attributes.retain(|attr| { + let Some(last_segment) = attr.path().segments.last() else { + return true; + }; + last_segment.ident != ATTRIBUTE_NAME + }); + (length_before_removal > attributes.len()).then_some(()) +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 305f1fd..51e77dc 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,6 +1,7 @@ #![deny(warnings)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg, doc_cfg_hide))] +use call_generic::parse_quote_call_generic; use proc_macro::{TokenStream, TokenTree}; use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree as TokenTree2}; use quote::quote; @@ -11,6 +12,7 @@ use syn::{ use crate::desugar_if_async::DesugarIfAsync; +mod call_generic; mod desugar_if_async; fn convert_sync_async( @@ -54,6 +56,7 @@ fn convert_sync_async( tokens }; + let tokens = parse_quote_call_generic(tokens); DesugarIfAsync { is_async }.desugar_if_async(tokens) } From eeddf4a3d2b97caa3d9860dbb8d23d1cfdfa09a0 Mon Sep 17 00:00:00 2001 From: SimonThormeyer Date: Mon, 14 Apr 2025 16:06:26 +0200 Subject: [PATCH 2/3] test: add tests for `#[call_generic]` --- tests/src/lib.rs | 2 + .../tests/fail/generic-call-invalid-usage.rs | 9 ++++ .../fail/generic-call-invalid-usage.stderr | 5 ++ tests/src/tests/pass/generic-call.rs | 47 +++++++++++++++++++ 4 files changed, 63 insertions(+) create mode 100644 tests/src/tests/fail/generic-call-invalid-usage.rs create mode 100644 tests/src/tests/fail/generic-call-invalid-usage.stderr create mode 100644 tests/src/tests/pass/generic-call.rs diff --git a/tests/src/lib.rs b/tests/src/lib.rs index bd0c596..bb90063 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -2,10 +2,12 @@ fn tests() { let t = trybuild::TestCases::new(); t.pass("src/tests/pass/fun-with-types.rs"); + t.pass("src/tests/pass/generic-call.rs"); t.pass("src/tests/pass/generic-fn.rs"); t.pass("src/tests/pass/generic-fn-with-visibility.rs"); t.pass("src/tests/pass/struct-method-generic.rs"); + t.compile_fail("src/tests/fail/generic-call-invalid-usage.rs"); t.compile_fail("src/tests/fail/misuse-of-underscore-async.rs"); t.compile_fail("src/tests/fail/no-async-fn.rs"); t.compile_fail("src/tests/fail/no-impl.rs"); diff --git a/tests/src/tests/fail/generic-call-invalid-usage.rs b/tests/src/tests/fail/generic-call-invalid-usage.rs new file mode 100644 index 0000000..388f9a9 --- /dev/null +++ b/tests/src/tests/fail/generic-call-invalid-usage.rs @@ -0,0 +1,9 @@ +use async_generic::async_generic; + +#[async_generic] +fn do_stuff() -> String { + #[call_generic] + "not a function call or method call" +} + +fn main() {} diff --git a/tests/src/tests/fail/generic-call-invalid-usage.stderr b/tests/src/tests/fail/generic-call-invalid-usage.stderr new file mode 100644 index 0000000..5e7e350 --- /dev/null +++ b/tests/src/tests/fail/generic-call-invalid-usage.stderr @@ -0,0 +1,5 @@ +error: call_generic must be used on a method call or a function call. + --> src/tests/fail/generic-call-invalid-usage.rs:5:5 + | +5 | #[call_generic] + | ^ diff --git a/tests/src/tests/pass/generic-call.rs b/tests/src/tests/pass/generic-call.rs new file mode 100644 index 0000000..6b27b36 --- /dev/null +++ b/tests/src/tests/pass/generic-call.rs @@ -0,0 +1,47 @@ +use async_generic::async_generic; + +#[async_generic] +fn do_stuff() -> String { + #[call_generic] + do_nested_stuff() +} + +#[async_generic] +fn do_nested_stuff() -> String { + if _async { + my_async_nested_stuff().await + } else { + "not async".to_owned() + } +} + +async fn my_async_nested_stuff() -> String { + "async".to_owned() +} + +struct Do; + +impl Do { + #[async_generic] + fn stuff(&self) -> String { + #[call_generic] + self.nested_stuff() + } + + #[async_generic] + fn nested_stuff(&self) -> String { + if _async { + my_async_nested_stuff().await + } else { + "not async".to_owned() + } + } +} + +#[async_std::main] +async fn main() { + println!("sync => {}", do_stuff()); + println!("async => {}", do_stuff_async().await); + println!("sync method => {}", Do.stuff()); + println!("async method => {}", Do.stuff_async().await); +} From b1acf66a1cc9375f713d907292795fad7f0fc06c Mon Sep 17 00:00:00 2001 From: SimonThormeyer Date: Thu, 17 Apr 2025 15:26:46 +0200 Subject: [PATCH 3/3] docs: add `#[call_generic]` examples --- README.md | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/README.md b/README.md index 6fe67b6..b53ebd2 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,58 @@ async fn main() { } ``` +Reduce boilerplate by annotating nested generic calls with `#[call_generic]`: + +```rust +use async_generic::async_generic; + +#[async_generic] +fn do_stuff() -> String { + #[call_generic] + do_nested_stuff() +} + +#[async_generic] +fn do_nested_stuff() -> String { + if _async { + my_async_nested_stuff().await + } else { + "not async".to_owned() + } +} + +async fn my_async_nested_stuff() -> String { + "async".to_owned() +} + +struct Do; + +impl Do { + #[async_generic] + fn stuff(&self) -> String { + #[call_generic] + self.nested_stuff() + } + + #[async_generic] + fn nested_stuff(&self) -> String { + if _async { + my_async_nested_stuff().await + } else { + "not async".to_owned() + } + } +} + +#[async_std::main] +async fn main() { + println!("sync => {}", do_stuff()); + println!("async => {}", do_stuff_async().await); + println!("sync method => {}", Do.stuff()); + println!("async method => {}", Do.stuff_async().await); +} +``` + ## Why not use `maybe-async`? This crate is loosely derived from the excellent work of the [`maybe-async`](https://crates.io/crates/maybe-async) crate, but is intended to solve a subtly different problem.