Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 69 additions & 56 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#![deny(warnings)]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg, doc_cfg_hide))]

use proc_macro::{TokenStream, TokenTree};
use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree as TokenTree2};
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{
parenthesized,
parse::{Parse, ParseStream, Result},
parse_macro_input, Attribute, Error, ItemFn, Token,
};
Expand All @@ -16,7 +17,7 @@ mod desugar_if_async;
fn convert_sync_async(
input: &mut Item,
is_async: bool,
alt_sig: Option<TokenStream>,
async_signature: Option<Args>,
) -> TokenStream2 {
let item = &mut input.0;

Expand All @@ -25,68 +26,29 @@ fn convert_sync_async(
item.sig.ident = Ident::new(&format!("{}_async", item.sig.ident), Span::call_site());
}

let tokens = quote!(#item);

let tokens = if let Some(alt_sig) = alt_sig {
let mut found_fn = false;
let mut found_args = false;
if let Some(async_signature) = async_signature {
item.sig.inputs = async_signature.inputs;

let old_tokens = tokens.into_iter().map(|token| match &token {
TokenTree2::Ident(i) => {
found_fn = found_fn || &i.to_string() == "fn";
token
}
TokenTree2::Group(g) => {
if found_fn && !found_args && g.delimiter() == proc_macro2::Delimiter::Parenthesis {
found_args = true;
return TokenTree2::Group(proc_macro2::Group::new(
proc_macro2::Delimiter::Parenthesis,
alt_sig.clone().into(),
));
}
token
}
_ => token,
});
if let Some(generics) = async_signature.generics {
item.sig.generics = generics;
}

TokenStream2::from_iter(old_tokens)
} else {
tokens
if let Some(output) = async_signature.output {
item.sig.output = output;
}
};

let tokens = quote!(#item);

DesugarIfAsync { is_async }.desugar_if_async(tokens)
}

#[proc_macro_attribute]
pub fn async_generic(args: TokenStream, input: TokenStream) -> TokenStream {
let mut async_signature: Option<TokenStream> = None;

if !args.to_string().is_empty() {
let mut atokens = args.into_iter();
loop {
if let Some(TokenTree::Ident(i)) = atokens.next() {
if i.to_string() != *"async_signature" {
break;
}
} else {
break;
}

if let Some(TokenTree::Group(g)) = atokens.next() {
if atokens.next().is_none() && g.delimiter() == proc_macro::Delimiter::Parenthesis {
async_signature = Some(g.stream());
}
}
}

if async_signature.is_none() {
return syn::Error::new(
Span::call_site(),
"async_generic can only take a async_signature argument",
)
.to_compile_error()
.into();
}
let async_signature = if args.is_empty() {
None
} else {
Some(parse_macro_input!(args as Args))
};

let input_clone = input.clone();
Expand All @@ -101,6 +63,57 @@ pub fn async_generic(args: TokenStream, input: TokenStream) -> TokenStream {
tokens.into()
}

struct Args {
generics: Option<syn::Generics>,
inputs: syn::punctuated::Punctuated<syn::FnArg, Token![,]>,
output: Option<syn::ReturnType>,
}

impl Parse for Args {
fn parse(input: ParseStream) -> Result<Self> {
let async_signature: Ident = input.parse()?;
if async_signature != "async_signature" {
return Err(Error::new(
Span::call_site(),
"async_generic can only take a async_signature argument",
));
}

let mut generics: Option<syn::Generics> = if input.peek(Token![<]) {
Some(input.parse()?)
} else {
None
};

let args;
let _paren: syn::token::Paren = parenthesized!(args in input);
let inputs = args.parse_terminated(syn::FnArg::parse, Token![,])?;

let output = if input.peek(Token![->]) {
Some(input.parse()?)
} else {
None
};

if input.peek(Token![where]) {
if let Some(generics) = &mut generics {
generics.where_clause = Some(input.parse()?);
} else {
generics = Some(syn::Generics {
where_clause: Some(input.parse()?),
..Default::default()
});
}
}

Ok(Self {
generics,
inputs,
output,
})
}
}

struct Item(ItemFn);

impl Parse for Item {
Expand Down
65 changes: 65 additions & 0 deletions tests/src/tests/pass/fun-with-types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use async_generic::async_generic;
use std::future::Future;

#[async_generic(async_signature(thing: &AsyncThing))]
fn do_stuff(thing: &SyncThing) -> String {
Expand All @@ -9,12 +10,64 @@ fn do_stuff(thing: &SyncThing) -> String {
}
}

#[async_generic(async_signature(thing: &AsyncThing) -> i64)]
fn do_other_stuff(thing: &SyncThing) -> i32 {
if _async {
thing.do_other_stuff().await
} else {
thing.do_other_stuff()
}
}

#[async_generic(async_signature<G: AsyncGeneric>(thing: &AsyncThing, g: G) -> G::AsyncOutput)]
fn do_generic_stuff<G: SyncGeneric>(thing: &SyncThing, g: G) -> G::SyncOutput {
if _async {
thing.do_generic_stuff(g).await
} else {
thing.do_generic_stuff(g)
}
}

trait SyncGeneric {
type SyncOutput;

fn get(&self) -> Self::SyncOutput;
}

trait AsyncGeneric {
type AsyncOutput;

fn get(&self) -> impl Future<Output = Self::AsyncOutput>;
}

impl SyncGeneric for i32 {
type SyncOutput = u32;

fn get(&self) -> Self::SyncOutput {
*self as u32
}
}

impl AsyncGeneric for i64 {
type AsyncOutput = u64;

fn get(&self) -> impl Future<Output = Self::AsyncOutput> {
async { *self as u64 }
}
}

struct SyncThing {}

impl SyncThing {
fn do_stuff(&self) -> String {
"sync".to_owned()
}
fn do_other_stuff(&self) -> i32 {
42
}
fn do_generic_stuff<G: SyncGeneric>(&self, g: G) -> G::SyncOutput {
g.get()
}
}

struct AsyncThing {}
Expand All @@ -23,6 +76,12 @@ impl AsyncThing {
async fn do_stuff(&self) -> String {
"async".to_owned()
}
async fn do_other_stuff(&self) -> i64 {
24
}
async fn do_generic_stuff<G: AsyncGeneric>(&self, g: G) -> G::AsyncOutput {
g.get().await
}
}

#[async_std::main]
Expand All @@ -32,4 +91,10 @@ async fn main() {

println!("sync => {}", do_stuff(&st));
println!("async => {}", do_stuff_async(&at).await);

let _s: i32 = do_other_stuff(&st);
let _a: i64 = do_other_stuff_async(&at).await;

let _sg: u32 = do_generic_stuff(&st, 42_i32);
let _ag: u64 = do_generic_stuff_async(&at, 24_i64).await;
}