From a82d4e6e2aaa9be0ea446f32110efed408ed2955 Mon Sep 17 00:00:00 2001 From: Luis Morgenstern Date: Fri, 7 Nov 2025 17:33:15 +0100 Subject: [PATCH] Allow async-specific generic parameters and return type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow `async_signature` to have: - Return type: `async_signature(...) -> ReturnType` - Generic parameters: `async_signature(...) where U: TraitB` If omitted, the synchronous implementation’s generic parameters and return type are used. --- macros/src/lib.rs | 125 ++++++++++++++----------- tests/src/tests/pass/fun-with-types.rs | 65 +++++++++++++ 2 files changed, 134 insertions(+), 56 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 305f1fd..f637441 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,10 +1,11 @@ #![deny(warnings)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg, doc_cfg_hide))] -use proc_macro::{TokenStream, TokenTree}; -use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree as TokenTree2}; +use proc_macro::TokenStream; +use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; use quote::quote; use syn::{ + parenthesized, parse::{Parse, ParseStream, Result}, parse_macro_input, Attribute, Error, ItemFn, Token, }; @@ -16,7 +17,7 @@ mod desugar_if_async; fn convert_sync_async( input: &mut Item, is_async: bool, - alt_sig: Option, + async_signature: Option, ) -> TokenStream2 { let item = &mut input.0; @@ -25,68 +26,29 @@ fn convert_sync_async( item.sig.ident = Ident::new(&format!("{}_async", item.sig.ident), Span::call_site()); } - let tokens = quote!(#item); - - let tokens = if let Some(alt_sig) = alt_sig { - let mut found_fn = false; - let mut found_args = false; + if let Some(async_signature) = async_signature { + item.sig.inputs = async_signature.inputs; - let old_tokens = tokens.into_iter().map(|token| match &token { - TokenTree2::Ident(i) => { - found_fn = found_fn || &i.to_string() == "fn"; - token - } - TokenTree2::Group(g) => { - if found_fn && !found_args && g.delimiter() == proc_macro2::Delimiter::Parenthesis { - found_args = true; - return TokenTree2::Group(proc_macro2::Group::new( - proc_macro2::Delimiter::Parenthesis, - alt_sig.clone().into(), - )); - } - token - } - _ => token, - }); + if let Some(generics) = async_signature.generics { + item.sig.generics = generics; + } - TokenStream2::from_iter(old_tokens) - } else { - tokens + if let Some(output) = async_signature.output { + item.sig.output = output; + } }; + let tokens = quote!(#item); + DesugarIfAsync { is_async }.desugar_if_async(tokens) } #[proc_macro_attribute] pub fn async_generic(args: TokenStream, input: TokenStream) -> TokenStream { - let mut async_signature: Option = None; - - if !args.to_string().is_empty() { - let mut atokens = args.into_iter(); - loop { - if let Some(TokenTree::Ident(i)) = atokens.next() { - if i.to_string() != *"async_signature" { - break; - } - } else { - break; - } - - if let Some(TokenTree::Group(g)) = atokens.next() { - if atokens.next().is_none() && g.delimiter() == proc_macro::Delimiter::Parenthesis { - async_signature = Some(g.stream()); - } - } - } - - if async_signature.is_none() { - return syn::Error::new( - Span::call_site(), - "async_generic can only take a async_signature argument", - ) - .to_compile_error() - .into(); - } + let async_signature = if args.is_empty() { + None + } else { + Some(parse_macro_input!(args as Args)) }; let input_clone = input.clone(); @@ -101,6 +63,57 @@ pub fn async_generic(args: TokenStream, input: TokenStream) -> TokenStream { tokens.into() } +struct Args { + generics: Option, + inputs: syn::punctuated::Punctuated, + output: Option, +} + +impl Parse for Args { + fn parse(input: ParseStream) -> Result { + let async_signature: Ident = input.parse()?; + if async_signature != "async_signature" { + return Err(Error::new( + Span::call_site(), + "async_generic can only take a async_signature argument", + )); + } + + let mut generics: Option = if input.peek(Token![<]) { + Some(input.parse()?) + } else { + None + }; + + let args; + let _paren: syn::token::Paren = parenthesized!(args in input); + let inputs = args.parse_terminated(syn::FnArg::parse, Token![,])?; + + let output = if input.peek(Token![->]) { + Some(input.parse()?) + } else { + None + }; + + if input.peek(Token![where]) { + if let Some(generics) = &mut generics { + generics.where_clause = Some(input.parse()?); + } else { + generics = Some(syn::Generics { + where_clause: Some(input.parse()?), + ..Default::default() + }); + } + } + + Ok(Self { + generics, + inputs, + output, + }) + } +} + struct Item(ItemFn); impl Parse for Item { diff --git a/tests/src/tests/pass/fun-with-types.rs b/tests/src/tests/pass/fun-with-types.rs index a9b84ec..5f0145d 100644 --- a/tests/src/tests/pass/fun-with-types.rs +++ b/tests/src/tests/pass/fun-with-types.rs @@ -1,4 +1,5 @@ use async_generic::async_generic; +use std::future::Future; #[async_generic(async_signature(thing: &AsyncThing))] fn do_stuff(thing: &SyncThing) -> String { @@ -9,12 +10,64 @@ fn do_stuff(thing: &SyncThing) -> String { } } +#[async_generic(async_signature(thing: &AsyncThing) -> i64)] +fn do_other_stuff(thing: &SyncThing) -> i32 { + if _async { + thing.do_other_stuff().await + } else { + thing.do_other_stuff() + } +} + +#[async_generic(async_signature(thing: &AsyncThing, g: G) -> G::AsyncOutput)] +fn do_generic_stuff(thing: &SyncThing, g: G) -> G::SyncOutput { + if _async { + thing.do_generic_stuff(g).await + } else { + thing.do_generic_stuff(g) + } +} + +trait SyncGeneric { + type SyncOutput; + + fn get(&self) -> Self::SyncOutput; +} + +trait AsyncGeneric { + type AsyncOutput; + + fn get(&self) -> impl Future; +} + +impl SyncGeneric for i32 { + type SyncOutput = u32; + + fn get(&self) -> Self::SyncOutput { + *self as u32 + } +} + +impl AsyncGeneric for i64 { + type AsyncOutput = u64; + + fn get(&self) -> impl Future { + async { *self as u64 } + } +} + struct SyncThing {} impl SyncThing { fn do_stuff(&self) -> String { "sync".to_owned() } + fn do_other_stuff(&self) -> i32 { + 42 + } + fn do_generic_stuff(&self, g: G) -> G::SyncOutput { + g.get() + } } struct AsyncThing {} @@ -23,6 +76,12 @@ impl AsyncThing { async fn do_stuff(&self) -> String { "async".to_owned() } + async fn do_other_stuff(&self) -> i64 { + 24 + } + async fn do_generic_stuff(&self, g: G) -> G::AsyncOutput { + g.get().await + } } #[async_std::main] @@ -32,4 +91,10 @@ async fn main() { println!("sync => {}", do_stuff(&st)); println!("async => {}", do_stuff_async(&at).await); + + let _s: i32 = do_other_stuff(&st); + let _a: i64 = do_other_stuff_async(&at).await; + + let _sg: u32 = do_generic_stuff(&st, 42_i32); + let _ag: u64 = do_generic_stuff_async(&at, 24_i64).await; }