diff --git a/crates/misc/component-async-tests/http/src/lib.rs b/crates/misc/component-async-tests/http/src/lib.rs index 930fcded6c..3dbc04fdf6 100644 --- a/crates/misc/component-async-tests/http/src/lib.rs +++ b/crates/misc/component-async-tests/http/src/lib.rs @@ -29,9 +29,9 @@ use { wasi::http::types::{ErrorCode, HeaderError, Method, RequestOptionsError, Scheme}, wasmtime::{ component::{ - self, ErrorContext, FutureReader, Linker, Resource, ResourceTable, StreamReader, + Accessor, ErrorContext, FutureReader, Linker, Resource, ResourceTable, StreamReader, }, - AsContextMut, StoreContextMut, + AsContextMut, }, }; @@ -55,18 +55,9 @@ pub trait WasiHttpView: Send + Sized { fn table(&mut self) -> &mut ResourceTable; fn send_request( - store: StoreContextMut<'_, Self::Data>, + accessor: &mut Accessor, request: Resource, - ) -> impl Future< - Output = impl FnOnce( - StoreContextMut<'_, Self::Data>, - ) -> wasmtime::Result, ErrorCode>> - + Send - + Sync - + 'static, - > + Send - + Sync - + 'static; + ) -> impl Future, ErrorCode>>> + Send + Sync; } impl WasiHttpView for &mut T { @@ -77,19 +68,11 @@ impl WasiHttpView for &mut T { } fn send_request( - store: StoreContextMut<'_, Self::Data>, + accessor: &mut Accessor, request: Resource, - ) -> impl Future< - Output = impl FnOnce( - StoreContextMut<'_, Self::Data>, - ) -> wasmtime::Result, ErrorCode>> - + Send - + Sync - + 'static, - > + Send - + Sync - + 'static { - T::send_request(store, request) + ) -> impl Future, ErrorCode>>> + Send + Sync + { + T::send_request(accessor, request) } } @@ -103,19 +86,11 @@ impl WasiHttpView for WasiHttpImpl { } fn send_request( - store: StoreContextMut<'_, Self::Data>, + accessor: &mut Accessor, request: Resource, - ) -> impl Future< - Output = impl FnOnce( - StoreContextMut<'_, Self::Data>, - ) -> wasmtime::Result, ErrorCode>> - + Send - + Sync - + 'static, - > + Send - + Sync - + 'static { - T::send_request(store, request) + ) -> impl Future, ErrorCode>>> + Send + Sync + { + T::send_request(accessor, request) } } @@ -255,33 +230,22 @@ where Ok(Ok(stream)) } - fn finish( - mut store: StoreContextMut<'_, Self::BodyData>, + async fn finish( + accessor: &mut Accessor, this: Resource, - ) -> impl Future< - Output = impl FnOnce( - StoreContextMut<'_, Self::BodyData>, - ) - -> wasmtime::Result>, ErrorCode>> - + 'static, - > + Send - + Sync - + 'static { - let trailers = (|| { + ) -> wasmtime::Result>, ErrorCode>> { + let trailers = accessor.with(|mut store| { let trailers = store.data_mut().table().delete(this)?.trailers; trailers .map(|v| v.read(store.as_context_mut()).map(|v| v.into_future())) .transpose() - })(); - async move { - let trailers = match trailers { - Ok(Some(trailers)) => Ok(trailers.await), - Ok(None) => Ok(None), - Err(e) => Err(e), - }; + })?; - component::for_any(move |_| Ok(Ok(trailers?))) - } + Ok(Ok(if let Some(trailers) = trailers { + trailers.await + } else { + None + })) } fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { @@ -530,20 +494,11 @@ where impl wasi::http::handler::Host for WasiHttpImpl { type Data = T::Data; - fn handle( - store: StoreContextMut<'_, Self::Data>, + async fn handle( + accessor: &mut Accessor, request: Resource, - ) -> impl Future< - Output = impl FnOnce( - StoreContextMut<'_, Self::Data>, - ) -> wasmtime::Result, ErrorCode>> - + Send - + Sync - + 'static, - > + Send - + Sync - + 'static { - Self::send_request(store, request) + ) -> wasmtime::Result, ErrorCode>> { + Self::send_request(accessor, request).await } } diff --git a/crates/misc/component-async-tests/src/lib.rs b/crates/misc/component-async-tests/src/lib.rs index ecc16f8d44..cd3fed55e8 100644 --- a/crates/misc/component-async-tests/src/lib.rs +++ b/crates/misc/component-async-tests/src/lib.rs @@ -7,7 +7,6 @@ mod test { futures::future, round_trip_many::local::local::many::Stuff, std::{ - future::Future, iter, ops::DerefMut, sync::{Arc, Mutex, Once}, @@ -23,10 +22,10 @@ mod test { wasm_compose::composer::ComponentComposer, wasmtime::{ component::{ - self, Component, FutureReader, Instance, Linker, Promise, PromisesUnordered, - Resource, ResourceTable, StreamReader, StreamWriter, Val, + self, Accessor, Component, FutureReader, Instance, Linker, Promise, + PromisesUnordered, Resource, ResourceTable, StreamReader, StreamWriter, Val, }, - AsContextMut, Config, Engine, Store, StoreContextMut, + AsContextMut, Config, Engine, Store, }, wasmtime_wasi::{IoView, WasiCtx, WasiCtxBuilder, WasiView}, }; @@ -78,20 +77,9 @@ mod test { impl round_trip::local::local::baz::Host for Ctx { type Data = Ctx; - #[allow(clippy::manual_async_fn)] - fn foo( - _: StoreContextMut<'_, Self>, - s: String, - ) -> impl Future< - Output = impl FnOnce(StoreContextMut<'_, Self>) -> wasmtime::Result + 'static, - > + Send - + 'static { - async move { - tokio::time::sleep(Duration::from_millis(10)).await; - component::for_any(move |_: StoreContextMut<'_, Self>| { - Ok(format!("{s} - entered host - exited host")) - }) - } + async fn foo(_: &mut Accessor, s: String) -> wasmtime::Result { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(format!("{s} - entered host - exited host")) } } @@ -110,9 +98,8 @@ mod test { impl round_trip_many::local::local::many::Host for Ctx { type Data = Ctx; - #[allow(clippy::manual_async_fn)] - fn foo( - _: StoreContextMut<'_, Self>, + async fn foo( + _: &mut Accessor, a: String, b: u32, c: Vec, @@ -120,34 +107,25 @@ mod test { e: Stuff, f: Option, g: Result, - ) -> impl Future< - Output = impl FnOnce( - StoreContextMut<'_, Self>, - ) -> wasmtime::Result<( - String, - u32, - Vec, - (u64, u64), - Stuff, - Option, - Result, - )> + 'static, - > + Send - + 'static { - async move { - tokio::time::sleep(Duration::from_millis(10)).await; - component::for_any(move |_: StoreContextMut<'_, Self>| { - Ok(( - format!("{a} - entered host - exited host"), - b, - c, - d, - e, - f, - g, - )) - }) - } + ) -> wasmtime::Result<( + String, + u32, + Vec, + (u64, u64), + Stuff, + Option, + Result, + )> { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(( + format!("{a} - entered host - exited host"), + b, + c, + d, + e, + f, + g, + )) } } @@ -165,20 +143,9 @@ mod test { impl round_trip_direct::RoundTripDirectImports for Ctx { type Data = Ctx; - #[allow(clippy::manual_async_fn)] - fn foo( - _: StoreContextMut<'_, Self>, - s: String, - ) -> impl Future< - Output = impl FnOnce(StoreContextMut<'_, Self>) -> wasmtime::Result + 'static, - > + Send - + 'static { - async move { - tokio::time::sleep(Duration::from_millis(10)).await; - component::for_any(move |_: StoreContextMut<'_, Self>| { - Ok(format!("{s} - entered host - exited host")) - }) - } + async fn foo(_: &mut Accessor, s: String) -> wasmtime::Result { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(format!("{s} - entered host - exited host")) } } @@ -307,14 +274,12 @@ mod test { .instance("local:local/baz")? .func_new_concurrent("foo", |_, params| async move { tokio::time::sleep(Duration::from_millis(10)).await; - component::for_any(move |_: StoreContextMut<'_, Ctx>| { - let Some(Val::String(s)) = params.into_iter().next() else { - unreachable!() - }; - Ok(vec![Val::String(format!( - "{s} - entered host - exited host" - ))]) - }) + let Some(Val::String(s)) = params.into_iter().next() else { + unreachable!() + }; + Ok(vec![Val::String(format!( + "{s} - entered host - exited host" + ))]) })?; let mut store = make_store(); @@ -638,17 +603,15 @@ mod test { .instance("local:local/many")? .func_new_concurrent("foo", |_, params| async move { tokio::time::sleep(Duration::from_millis(10)).await; - component::for_any(move |_: StoreContextMut<'_, Ctx>| { - let mut params = params.into_iter(); - let Some(Val::String(s)) = params.next() else { - unreachable!() - }; - Ok(vec![Val::Tuple( - iter::once(Val::String(format!("{s} - entered host - exited host"))) - .chain(params) - .collect(), - )]) - }) + let mut params = params.into_iter(); + let Some(Val::String(s)) = params.next() else { + unreachable!() + }; + Ok(vec![Val::Tuple( + iter::once(Val::String(format!("{s} - entered host - exited host"))) + .chain(params) + .collect(), + )]) })?; let mut store = make_store(); @@ -961,14 +924,12 @@ mod test { .root() .func_new_concurrent("foo", |_, params| async move { tokio::time::sleep(Duration::from_millis(10)).await; - component::for_any(move |_: StoreContextMut<'_, Ctx>| { - let Some(Val::String(s)) = params.into_iter().next() else { - unreachable!() - }; - Ok(vec![Val::String(format!( - "{s} - entered host - exited host" - ))]) - }) + let Some(Val::String(s)) = params.into_iter().next() else { + unreachable!() + }; + Ok(vec![Val::String(format!( + "{s} - entered host - exited host" + ))]) })?; let mut store = make_store(); @@ -1058,22 +1019,18 @@ mod test { } } - fn when_ready( - store: StoreContextMut, - ) -> impl Future) + 'static> - + Send - + Sync - + 'static { - let wakers = store.data().wakers.clone(); + async fn when_ready(accessor: &mut Accessor) { + let wakers = accessor.with(|store| store.data().wakers.clone()); future::poll_fn(move |cx| { let mut wakers = wakers.lock().unwrap(); if let Some(wakers) = wakers.deref_mut() { wakers.push(cx.waker().clone()); Poll::Pending } else { - Poll::Ready(component::for_any(|_| ())) + Poll::Ready(()) } }) + .await } } @@ -1663,23 +1620,11 @@ mod test { &mut self.table } - #[allow(clippy::manual_async_fn)] - fn send_request( - _store: StoreContextMut<'_, Self::Data>, + async fn send_request( + _accessor: &mut Accessor, _request: Resource, - ) -> impl Future< - Output = impl FnOnce( - StoreContextMut<'_, Self::Data>, - ) - -> wasmtime::Result, ErrorCode>> - + 'static, - > + Send - + 'static { - async move { - move |_: StoreContextMut<'_, Self>| { - Err(anyhow!("no outbound request handler available")) - } - } + ) -> wasmtime::Result, ErrorCode>> { + Err(anyhow!("no outbound request handler available")) } } diff --git a/crates/wasmtime/src/runtime/component/concurrent.rs b/crates/wasmtime/src/runtime/component/concurrent.rs index 95b516a1f1..11575b3be3 100644 --- a/crates/wasmtime/src/runtime/component/concurrent.rs +++ b/crates/wasmtime/src/runtime/component/concurrent.rs @@ -148,6 +148,39 @@ impl PromisesUnordered { } } +/// Provides restricted mutable access to a `Store` in the context of a +/// concurrent host import function. +/// +/// This allows multiple host import futures to execute concurrently and access +/// the `Store` (and its data payload) between (but not across) `await` points. +pub struct Accessor { + store: *mut dyn VMStore, + _phantom: PhantomData)>, +} + +unsafe impl Send for Accessor {} +unsafe impl Sync for Accessor {} + +impl Accessor { + #[doc(hidden)] + pub unsafe fn new(store: *mut dyn VMStore) -> Self { + Self { + store, + _phantom: PhantomData, + } + } + + /// Run the specified closure, passing it a `StoreContextMut`. + /// + /// Note that the return value of the closure must be `'static`, meaning it + /// cannot borrow from the store or its data payload. If you need to return + /// a resource from the store, it must be cloned (using e.g. `Arc::clone` if + /// appropriate). + pub fn with(&mut self, fun: impl FnOnce(StoreContextMut<'_, T>) -> R) -> R { + fun(unsafe { StoreContextMut(&mut *self.store.cast()) }) + } +} + /// Trait representing component model ABI async intrinsics and fused adapter /// helper functions. pub unsafe trait VMComponentAsyncStore { @@ -1380,16 +1413,6 @@ fn dummy_waker() -> Waker { WAKER.clone().into() } -/// Provide a hint to Rust type inferencer that we're returning a compatible -/// closure from a `LinkerInstance::func_wrap_concurrent` future. -pub fn for_any(fun: F) -> F -where - F: FnOnce(StoreContextMut) -> R + 'static, - R: 'static, -{ - fun -} - fn for_any_lower< F: FnOnce(*mut dyn VMStore, &mut [MaybeUninit]) -> Result<()> + Send + Sync, >( @@ -1406,13 +1429,10 @@ fn for_any_lift< fun } -pub(crate) fn first_poll( +pub(crate) fn first_poll( instance: *mut ComponentInstance, mut store: StoreContextMut, - future: impl Future) -> Result + Send + Sync + 'static> - + Send - + Sync - + 'static, + future: impl Future> + Send + Sync + 'static, caller_instance: RuntimeComponentInstanceIndex, lower: impl FnOnce(StoreContextMut, R) -> Result<()> + Send + Sync + 'static, ) -> Result> { @@ -1422,13 +1442,12 @@ pub(crate) fn first_poll( .table .push_child(HostTask { caller_instance }, caller)?; log::trace!("new child of {}: {}", caller.rep(), task.rep()); - let mut future = Box::pin(future.map(move |fun| { + let mut future = Box::pin(future.map(move |result| { ( task.rep(), Box::new(move |store: *mut dyn VMStore| { - let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; - let result = fun(store.as_context_mut())?; - lower(store, result)?; + let store = unsafe { StoreContextMut(&mut *store.cast()) }; + lower(store, result?)?; Ok(HostTaskResult { event: Event::Done, param: 0u32, @@ -1463,18 +1482,12 @@ pub(crate) fn first_poll( pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>( mut store: StoreContextMut<'a, T>, - future: impl Future) -> Result + Send + Sync + 'static> - + Send - + Sync - + 'static, + future: impl Future> + Send + Sync + 'static, caller_instance: RuntimeComponentInstanceIndex, ) -> Result<(R, StoreContextMut<'a, T>)> { let Some(caller) = store.concurrent_state().guest_task else { return match pin!(future).poll(&mut Context::from_waker(&dummy_waker())) { - Poll::Ready(fun) => { - let result = fun(store.as_context_mut())?; - Ok((result, store)) - } + Poll::Ready(result) => Ok((result?, store)), Poll::Pending => { unreachable!() } @@ -1492,14 +1505,13 @@ pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>( .table .push_child(HostTask { caller_instance }, caller)?; log::trace!("new child of {}: {}", caller.rep(), task.rep()); - let mut future = Box::pin(future.map(move |fun| { + let mut future = Box::pin(future.map(move |result| { ( task.rep(), Box::new(move |store: *mut dyn VMStore| { - let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; - let result = fun(store.as_context_mut())?; + let mut store = unsafe { StoreContextMut::(&mut *store.cast()) }; store.concurrent_state().table.get_mut(caller)?.result = - Some(Box::new(result) as _); + Some(Box::new(result?) as _); Ok(HostTaskResult { event: Event::Done, param: 0u32, diff --git a/crates/wasmtime/src/runtime/component/func/host.rs b/crates/wasmtime/src/runtime/component/func/host.rs index bfc3e4007b..ebf9b4f54c 100644 --- a/crates/wasmtime/src/runtime/component/func/host.rs +++ b/crates/wasmtime/src/runtime/component/func/host.rs @@ -48,19 +48,18 @@ impl HostFunc { { Self::from_concurrent(move |store, params| { let result = func(store, params); - async move { concurrent::for_any(move |_| result) } + async move { result } }) } - pub(crate) fn from_concurrent(func: F) -> Arc + pub(crate) fn from_concurrent(func: F) -> Arc where - N: FnOnce(StoreContextMut) -> Result + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, P) -> FN + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + F: Fn(StoreContextMut, P) -> Fut + Send + Sync + 'static, P: ComponentNamedList + Lift + 'static, R: ComponentNamedList + Lower + Send + Sync + 'static, { - let entrypoint = Self::entrypoint::; + let entrypoint = Self::entrypoint::; Arc::new(HostFunc { entrypoint, typecheck: Box::new(typecheck::), @@ -68,7 +67,7 @@ impl HostFunc { }) } - extern "C" fn entrypoint( + extern "C" fn entrypoint( cx: NonNull, data: NonNull, ty: u32, @@ -82,9 +81,8 @@ impl HostFunc { storage_len: usize, ) -> bool where - N: FnOnce(StoreContextMut) -> Result + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, P) -> FN + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + F: Fn(StoreContextMut, P) -> Fut + Send + Sync + 'static, P: ComponentNamedList + Lift + 'static, R: ComponentNamedList + Lower + Send + Sync + 'static, { @@ -119,18 +117,17 @@ impl HostFunc { .collect::>(); let result = func(store, ¶ms, &mut results); let result = result.map(move |()| results); - async move { concurrent::for_any(move |_| result) } + async move { result } }) } - pub(crate) fn new_dynamic_concurrent(f: F) -> Arc + pub(crate) fn new_dynamic_concurrent(f: F) -> Arc where - N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, Vec, usize) -> FN + Send + Sync + 'static, + Fut: Future>> + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec, usize) -> Fut + Send + Sync + 'static, { Arc::new(HostFunc { - entrypoint: dynamic_entrypoint::, + entrypoint: dynamic_entrypoint::, // This function performs dynamic type checks and subsequently does // not need to perform up-front type checks. Instead everything is // dynamically managed at runtime. @@ -185,7 +182,7 @@ where /// This function is in general `unsafe` as the validity of all the parameters /// must be upheld. Generally that's done by ensuring this is only called from /// the select few places it's intended to be called from. -unsafe fn call_host( +unsafe fn call_host( instance: *mut ComponentInstance, types: &Arc, mut cx: StoreContextMut<'_, T>, @@ -200,9 +197,8 @@ unsafe fn call_host( closure: F, ) -> Result<()> where - N: FnOnce(StoreContextMut) -> Result + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, Params) -> FN + 'static, + Fut: Future> + Send + Sync + 'static, + F: Fn(StoreContextMut, Params) -> Fut + 'static, Params: Lift, Return: Lower + Send + Sync + 'static, { @@ -419,7 +415,7 @@ unsafe fn call_host_and_handle_result( }) } -unsafe fn call_host_dynamic( +unsafe fn call_host_dynamic( instance: *mut ComponentInstance, types: &Arc, mut store: StoreContextMut<'_, T>, @@ -434,9 +430,8 @@ unsafe fn call_host_dynamic( closure: F, ) -> Result<()> where - N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, Vec, usize) -> FN + 'static, + Fut: Future>> + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec, usize) -> Fut + 'static, { let options = Options::new( store.0.id(), @@ -621,7 +616,7 @@ pub(crate) fn validate_inbounds_dynamic( Ok(ptr) } -extern "C" fn dynamic_entrypoint( +extern "C" fn dynamic_entrypoint( cx: NonNull, data: NonNull, ty: u32, @@ -635,9 +630,8 @@ extern "C" fn dynamic_entrypoint( storage_len: usize, ) -> bool where - N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, Vec, usize) -> FN + Send + Sync + 'static, + Fut: Future>> + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec, usize) -> Fut + Send + Sync + 'static, { let data = Ptr(data.as_ptr() as *const F); unsafe { diff --git a/crates/wasmtime/src/runtime/component/linker.rs b/crates/wasmtime/src/runtime/component/linker.rs index 1960f50b6a..9255eb3c7e 100644 --- a/crates/wasmtime/src/runtime/component/linker.rs +++ b/crates/wasmtime/src/runtime/component/linker.rs @@ -457,22 +457,11 @@ impl LinkerInstance<'_, T> { /// method because it takes a function which returns a future that owns a /// unique reference to the Store, meaning the Store can't be used for /// anything else until the future resolves. - /// - /// Ideally, we'd have a way to thread a `StoreContextMut` through an - /// arbitrary `Future` such that it has access to the `Store` only while - /// being polled (i.e. between, but not across, await points). However, - /// there's currently no way to express that in async Rust, so we make do - /// with a more awkward scheme: each function registered using - /// `func_wrap_concurrent` gets access to the `Store` twice: once before - /// doing any concurrent operations (i.e. before awaiting) and once - /// afterward. This allows multiple calls to proceed concurrently without - /// any one of them monopolizing the store. #[cfg(feature = "component-model-async")] - pub fn func_wrap_concurrent(&mut self, name: &str, f: F) -> Result<()> + pub fn func_wrap_concurrent(&mut self, name: &str, f: F) -> Result<()> where - N: FnOnce(StoreContextMut) -> Result + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, Params) -> FN + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + F: Fn(StoreContextMut, Params) -> Fut + Send + Sync + 'static, Params: ComponentNamedList + Lift + 'static, Return: ComponentNamedList + Lower + Send + Sync + 'static, { @@ -648,11 +637,10 @@ impl LinkerInstance<'_, T> { /// afterward. This allows multiple calls to proceed concurrently without /// any one of them monopolizing the store. #[cfg(feature = "component-model-async")] - pub fn func_new_concurrent(&mut self, name: &str, f: F) -> Result<()> + pub fn func_new_concurrent(&mut self, name: &str, f: F) -> Result<()> where - N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, - FN: Future + Send + Sync + 'static, - F: Fn(StoreContextMut, Vec) -> FN + Send + Sync + 'static, + Fut: Future>> + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec) -> Fut + Send + Sync + 'static, { assert!( self.engine.config().async_support, diff --git a/crates/wasmtime/src/runtime/component/mod.rs b/crates/wasmtime/src/runtime/component/mod.rs index 0e0db75501..4236eb3951 100644 --- a/crates/wasmtime/src/runtime/component/mod.rs +++ b/crates/wasmtime/src/runtime/component/mod.rs @@ -116,7 +116,7 @@ mod values; pub use self::component::{Component, ComponentExportIndex}; #[cfg(feature = "component-model-async")] pub use self::concurrent::{ - for_any, future, stream, ErrorContext, FutureReader, FutureWriter, Promise, PromisesUnordered, + future, stream, Accessor, ErrorContext, FutureReader, FutureWriter, Promise, PromisesUnordered, StreamReader, StreamWriter, VMComponentAsyncStore, }; pub use self::func::{ diff --git a/crates/wasmtime/src/runtime/store/context.rs b/crates/wasmtime/src/runtime/store/context.rs index 45d58c30ca..ac26509e9d 100644 --- a/crates/wasmtime/src/runtime/store/context.rs +++ b/crates/wasmtime/src/runtime/store/context.rs @@ -18,6 +18,13 @@ pub struct StoreContext<'a, T>(pub(crate) &'a StoreInner); #[repr(transparent)] pub struct StoreContextMut<'a, T>(pub(crate) &'a mut StoreInner); +impl<'a, T> StoreContextMut<'a, T> { + #[doc(hidden)] + pub fn traitobj(&self) -> std::ptr::NonNull { + self.0.traitobj() + } +} + /// A trait used to get shared access to a [`Store`] in Wasmtime. /// /// This trait is used as a bound on the first argument of many methods within diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index 7c4698607f..72b96c6fd7 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -82,7 +82,7 @@ struct Wasmtime { struct ImportFunction { func: Function, add_to_linker: String, - sig: Option, + sig: Option<(String, String)>, } #[derive(Default)] @@ -278,7 +278,7 @@ impl Opts { { anyhow::bail!( "must enable `component-model-async` feature when using WIT files \ - containing future, stream, or error types" + containing future, stream, or error-context types" ); } @@ -471,7 +471,7 @@ impl Wasmtime { bail!( "failed to locate a WIT error type corresponding to the \ - `trappable_error_type` name `{}` provided", + `trappable_error_type` name `{}` provided", te.wit_path ) } @@ -511,8 +511,11 @@ impl Wasmtime { // resource-related functions get their trait signatures // during `type_resource`. let sig = if let FunctionKind::Freestanding = func.kind { - generator.generate_function_trait_sig(func, "Data"); - Some(mem::take(&mut generator.src).into()) + generator.generate_function_trait_sig(func, "Data", false); + let without_sugar = mem::take(&mut generator.src).into(); + generator.generate_function_trait_sig(func, "Data", true); + let with_sugar = mem::take(&mut generator.src).into(); + Some((without_sugar, with_sugar)) } else { None }; @@ -876,8 +879,9 @@ fn _new( let world_name = &resolve.worlds[world].name; let camel = to_rust_upper_camel_case(&world_name); let (async_, async__, where_clause, await_) = match self.opts.call_style() { - CallStyle::Async => ("async", "_async", "where _T: Send", ".await"), - CallStyle::Concurrent => ("async", "_async", "where _T: Send + 'static", ".await"), + CallStyle::Async | CallStyle::Concurrent => { + ("async", "_async", "where _T: Send", ".await") + } CallStyle::Sync => ("", "", "", ""), }; uwriteln!( @@ -1431,15 +1435,25 @@ impl Wasmtime { let wt = self.wasmtime_path(); let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name); - if let CallStyle::Async = self.opts.call_style() { + + let has_concurrent_function = self.import_functions.iter().any(|func| { + matches!(func.func.kind, FunctionKind::Freestanding) + && matches!( + self.opts.import_call_style(None, &func.func.name), + CallStyle::Concurrent + ) + }); + + if let CallStyle::Async | CallStyle::Concurrent = self.opts.call_style() { uwriteln!( self.src, - "#[{wt}::component::__internal::trait_variant_make(::core::marker::Send)]" - ) + "#[{wt}::component::__internal::trait_variant_make(::core::marker::Send)]", + ); } + uwrite!(self.src, "pub trait {world_camel}Imports"); let mut supertraits = vec![]; - if let CallStyle::Async = self.opts.call_style() { + if let CallStyle::Async | CallStyle::Concurrent = self.opts.call_style() { supertraits.push("Send".to_string()); } for (_, name) in get_world_resources(resolve, world) { @@ -1450,20 +1464,12 @@ impl Wasmtime { } uwriteln!(self.src, " {{"); - let has_concurrent_function = self.import_functions.iter().any(|func| { - matches!(func.func.kind, FunctionKind::Freestanding) - && matches!( - self.opts.import_call_style(None, &func.func.name), - CallStyle::Concurrent - ) - }); - if has_concurrent_function { self.src.push_str("type Data;\n"); } for f in self.import_functions.iter() { - if let Some(sig) = &f.sig { + if let Some((sig, _)) = &f.sig { self.src.push_str(sig); self.src.push_str(";\n"); } @@ -1502,7 +1508,7 @@ impl Wasmtime { ); // Generate impl WorldImports for &mut WorldImports - let maybe_send = if let CallStyle::Async = self.opts.call_style() { + let maybe_send = if let CallStyle::Async | CallStyle::Concurrent = self.opts.call_style() { "+ Send" } else { "" @@ -1515,7 +1521,7 @@ impl Wasmtime { }; uwriteln!( self.src, - "impl<_T: {world_camel}Imports {maybe_maybe_sized} {maybe_send}> {world_camel}Imports for &mut _T {{" + "impl<_T: {world_camel}Imports {maybe_maybe_sized} {maybe_send}> {world_camel}Imports for &mut _T {{" ); let has_concurrent_function = self.import_functions.iter().any(|f| { matches!( @@ -1529,13 +1535,13 @@ impl Wasmtime { } // Forward each method call to &mut T for f in self.import_functions.iter() { - if let Some(sig) = &f.sig { + if let Some((_, sig)) = &f.sig { self.src.push_str(sig); let call_style = self.opts.import_call_style(None, &f.func.name); if let CallStyle::Concurrent = &call_style { uwrite!( self.src, - "{{ <_T as {world_camel}Imports>::{}(store,", + "{{ <_T as {world_camel}Imports>::{}(accessor,", rust_function_name(&f.func) ); } else { @@ -1549,7 +1555,7 @@ impl Wasmtime { uwrite!(self.src, "{},", to_rust_ident(name)); } uwrite!(self.src, ")"); - if let CallStyle::Async = &call_style { + if let CallStyle::Async | CallStyle::Concurrent = &call_style { uwrite!(self.src, ".await"); } uwriteln!(self.src, "}}"); @@ -1589,7 +1595,7 @@ impl Wasmtime { let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name); traits.push(format!("{world_camel}Imports")); } - if let CallStyle::Async = self.opts.call_style() { + if let CallStyle::Async | CallStyle::Concurrent = self.opts.call_style() { traits.push("Send".to_string()); } traits @@ -1618,6 +1624,7 @@ impl Wasmtime { } else { "" }; + let wt = self.wasmtime_path(); if has_world_imports_trait { let host_bounds = if let CallStyle::Concurrent = self.opts.call_style() { @@ -1775,7 +1782,7 @@ impl Wasmtime { ) { let gate = FeatureGate::open(src, stability); let camel = name.to_upper_camel_case(); - if let CallStyle::Async = opts.drop_call_style(qualifier, name) { + if let CallStyle::Async | CallStyle::Concurrent = opts.drop_call_style(qualifier, name) { uwriteln!( src, "{inst}.resource_async( @@ -1895,14 +1902,6 @@ impl<'a> InterfaceGenerator<'a> { } // Generate resource trait - if let CallStyle::Async = self.generator.opts.call_style() { - uwriteln!( - self.src, - "#[{wt}::component::__internal::trait_variant_make(::core::marker::Send)]" - ) - } - - uwriteln!(self.src, "pub trait Host{camel}: Sized {{"); let mut functions = match resource.owner { TypeOwner::World(id) => self.resolve.worlds[id] @@ -1938,16 +1937,25 @@ impl<'a> InterfaceGenerator<'a> { ) }); + if let CallStyle::Async | CallStyle::Concurrent = self.generator.opts.call_style() { + uwriteln!( + self.src, + "#[{wt}::component::__internal::trait_variant_make(::core::marker::Send)]", + ); + } + + uwriteln!(self.src, "pub trait Host{camel}: Sized {{"); + if has_concurrent_function { uwriteln!(self.src, "type {camel}Data;"); } for func in &functions { - self.generate_function_trait_sig(func, &format!("{camel}Data")); + self.generate_function_trait_sig(func, &format!("{camel}Data"), false); self.push_str(";\n"); } - if let CallStyle::Async = self + if let CallStyle::Async | CallStyle::Concurrent = self .generator .opts .drop_call_style(self.qualifier().as_deref(), name) @@ -1985,11 +1993,11 @@ impl<'a> InterfaceGenerator<'a> { .generator .opts .import_call_style(self.qualifier().as_deref(), &func.name); - self.generate_function_trait_sig(func, &format!("{camel}Data")); + self.generate_function_trait_sig(func, &format!("{camel}Data"), true); if let CallStyle::Concurrent = call_style { uwrite!( self.src, - "{{ <_T as Host{camel}>::{}(store,", + "{{ <_T as Host{camel}>::{}(accessor,", rust_function_name(func) ); } else { @@ -2003,12 +2011,12 @@ impl<'a> InterfaceGenerator<'a> { uwrite!(self.src, "{},", to_rust_ident(name)); } uwrite!(self.src, ")"); - if let CallStyle::Async = call_style { + if let CallStyle::Async | CallStyle::Concurrent = call_style { uwrite!(self.src, ".await"); } uwriteln!(self.src, "}}"); } - if let CallStyle::Async = self + if let CallStyle::Async | CallStyle::Concurrent = self .generator .opts .drop_call_style(self.qualifier().as_deref(), name) @@ -2592,24 +2600,6 @@ impl<'a> InterfaceGenerator<'a> { } } - fn print_result_ty_tuple(&mut self, results: &Results, mode: TypeMode) { - self.push_str("("); - match results { - Results::Named(rs) if rs.is_empty() => self.push_str(")"), - Results::Named(rs) => { - for (_, ty) in rs { - self.print_ty(ty, mode); - self.push_str(", "); - } - self.push_str(")"); - } - Results::Anon(ty) => { - self.print_ty(ty, mode); - self.push_str(",)"); - } - } - } - fn special_case_trappable_error( &mut self, func: &Function, @@ -2620,7 +2610,7 @@ impl<'a> InterfaceGenerator<'a> { .used_trappable_imports_opts .insert(func.name.clone()); - // We fillin a special trappable error type in the case when a function has just one + // We fill in a special trappable error type in the case when a function has just one // result, which is itself a `result`, and the `e` is *not* a primitive // (i.e. defined in std) type, and matches the typename given by the user. let mut i = results.iter_types(); @@ -2652,13 +2642,27 @@ impl<'a> InterfaceGenerator<'a> { let owner = TypeOwner::Interface(id); let wt = self.generator.wasmtime_path(); - let is_maybe_async = matches!(self.generator.opts.call_style(), CallStyle::Async); - if is_maybe_async { - uwriteln!( - self.src, - "#[{wt}::component::__internal::trait_variant_make(::core::marker::Send)]" - ) - } + let has_concurrent_function = iface.functions.iter().any(|(_, func)| { + matches!(func.kind, FunctionKind::Freestanding) + && matches!( + self.generator + .opts + .import_call_style(self.qualifier().as_deref(), &func.name), + CallStyle::Concurrent + ) + }); + + let is_maybe_async = + if let CallStyle::Async | CallStyle::Concurrent = self.generator.opts.call_style() { + uwriteln!( + self.src, + "#[{wt}::component::__internal::trait_variant_make(::core::marker::Send)]", + ); + true + } else { + false + }; + // Generate the `pub trait` which represents the host functionality for // this import which additionally inherits from all resource traits // for this interface defined by `type_resource`. @@ -2681,16 +2685,6 @@ impl<'a> InterfaceGenerator<'a> { } uwriteln!(self.src, " {{"); - let has_concurrent_function = iface.functions.iter().any(|(_, func)| { - matches!(func.kind, FunctionKind::Freestanding) - && matches!( - self.generator - .opts - .import_call_style(self.qualifier().as_deref(), &func.name), - CallStyle::Concurrent - ) - }); - if has_concurrent_function { self.push_str("type Data;\n"); } @@ -2700,7 +2694,7 @@ impl<'a> InterfaceGenerator<'a> { FunctionKind::Freestanding => {} _ => continue, } - self.generate_function_trait_sig(func, "Data"); + self.generate_function_trait_sig(func, "Data", false); self.push_str(";\n"); } @@ -2881,11 +2875,11 @@ impl<'a> InterfaceGenerator<'a> { .generator .opts .import_call_style(self.qualifier().as_deref(), &func.name); - self.generate_function_trait_sig(func, "Data"); + self.generate_function_trait_sig(func, "Data", true); if let CallStyle::Concurrent = call_style { uwrite!( self.src, - "{{ <_T as Host>::{}(store,", + "{{ <_T as Host>::{}(accessor,", rust_function_name(func) ); } else { @@ -2895,7 +2889,7 @@ impl<'a> InterfaceGenerator<'a> { uwrite!(self.src, "{},", to_rust_ident(name)); } uwrite!(self.src, ")"); - if let CallStyle::Async = call_style { + if let CallStyle::Async | CallStyle::Concurrent = call_style { uwrite!(self.src, ".await"); } uwriteln!(self.src, "}}"); @@ -2969,7 +2963,7 @@ impl<'a> InterfaceGenerator<'a> { .import_call_style(self.qualifier().as_deref(), &func.name); if self.generator.opts.tracing { - if let CallStyle::Async = style { + if let CallStyle::Async | CallStyle::Concurrent = style { self.src.push_str("use tracing::Instrument;\n"); } @@ -2995,17 +2989,29 @@ impl<'a> InterfaceGenerator<'a> { ); } - if let CallStyle::Async = &style { - uwriteln!( - self.src, - " {wt}::component::__internal::Box::new(async move {{ " - ); - } else { - // Only directly enter the span if the function is sync. Otherwise - // we use tracing::Instrument to ensure that the span is not entered - // across an await point. - if self.generator.opts.tracing { - self.push_str("let _enter = span.enter();\n"); + match &style { + CallStyle::Async => { + uwriteln!( + self.src, + "{wt}::component::__internal::Box::new(async move {{" + ); + } + CallStyle::Concurrent => { + uwriteln!( + self.src, + "let mut accessor = unsafe {{ + {wt}::component::Accessor::new(caller.traitobj().as_ptr()) + }}; + {wt}::component::__internal::Box::pin(async move {{" + ); + } + CallStyle::Sync => { + // Only directly enter the span if the function is sync. Otherwise + // we use tracing::Instrument to ensure that the span is not entered + // across an await point. + if self.generator.opts.tracing { + self.push_str("let _enter = span.enter();\n"); + } } } @@ -3027,11 +3033,10 @@ impl<'a> InterfaceGenerator<'a> { ); } - self.src.push_str(if let CallStyle::Concurrent = &style { - "let host = caller;\n" - } else { - "let host = &mut host_getter(caller.data_mut());\n" - }); + if !matches!(style, CallStyle::Concurrent) { + self.src + .push_str("let host = &mut host_getter(caller.data_mut());\n"); + } let func_name = rust_function_name(func); let host_trait = match func.kind { FunctionKind::Freestanding => match owner { @@ -3054,7 +3059,7 @@ impl<'a> InterfaceGenerator<'a> { if let CallStyle::Concurrent = &style { uwrite!( self.src, - "let r = ::{func_name}(host, " + "let r = ::{func_name}(&mut accessor, " ); } else { uwrite!(self.src, "let r = {host_trait}::{func_name}(host, "); @@ -3065,20 +3070,10 @@ impl<'a> InterfaceGenerator<'a> { } self.src.push_str(match &style { - CallStyle::Sync | CallStyle::Concurrent => ");\n", - CallStyle::Async => ").await;\n", + CallStyle::Sync => ");\n", + CallStyle::Async | CallStyle::Concurrent => ").await;\n", }); - if let CallStyle::Concurrent = &style { - self.src.push_str( - "Box::pin(async move { - let fun = r.await; - Box::new(move |mut caller: wasmtime::StoreContextMut<'_, T>| { - let r = fun(caller); - ", - ); - } - if self.generator.opts.tracing { uwrite!( self.src, @@ -3120,34 +3115,18 @@ impl<'a> InterfaceGenerator<'a> { match &style { CallStyle::Sync => (), - CallStyle::Async => { + CallStyle::Async | CallStyle::Concurrent => { if self.generator.opts.tracing { self.src.push_str("}.instrument(span))\n"); } else { self.src.push_str("})\n"); } } - CallStyle::Concurrent => { - let old_source = mem::take(&mut self.src); - self.print_result_ty_tuple(&func.results, TypeMode::Owned); - let result_type = String::from(mem::replace(&mut self.src, old_source)); - let box_fn = format!( - "Box) -> \ - wasmtime::Result<{result_type}> + Send + Sync>" - ); - uwriteln!( - self.src, - " }}) as {box_fn} - }}) as ::core::pin::Pin \ - + Send + Sync + 'static>> - " - ); - } } self.src.push_str("}\n"); } - fn generate_function_trait_sig(&mut self, func: &Function, data: &str) { + fn generate_function_trait_sig(&mut self, func: &Function, data: &str, async_sugar: bool) { let wt = self.generator.wasmtime_path(); self.rustdoc(&func.docs); @@ -3155,13 +3134,13 @@ impl<'a> InterfaceGenerator<'a> { .generator .opts .import_call_style(self.qualifier().as_deref(), &func.name); - if let CallStyle::Async = &style { + if let (CallStyle::Async, _) | (CallStyle::Concurrent, true) = (&style, async_sugar) { self.push_str("async "); } self.push_str("fn "); self.push_str(&rust_function_name(func)); self.push_str(&if let CallStyle::Concurrent = &style { - format!("(store: wasmtime::StoreContextMut<'_, Self::{data}>, ") + format!("(accessor: &mut {wt}::component::Accessor, ") } else { "(&mut self, ".to_string() }); @@ -3175,8 +3154,8 @@ impl<'a> InterfaceGenerator<'a> { self.push_str(")"); self.push_str(" -> "); - if let CallStyle::Concurrent = &style { - uwrite!(self.src, "impl ::core::future::Future) -> "); + if let (CallStyle::Concurrent, false) = (&style, async_sugar) { + uwrite!(self.src, "impl ::core::future::Future InterfaceGenerator<'a> { self.push_str(">"); } - if let CallStyle::Concurrent = &style { - self.push_str(" + Send + Sync + 'static> + Send + Sync + 'static where Self: Sized"); + if let (CallStyle::Concurrent, false) = (&style, async_sugar) { + self.push_str("> + Send + Sync where Self: Sized"); } } @@ -3277,16 +3256,11 @@ impl<'a> InterfaceGenerator<'a> { uwrite!(self.src, ">"); } - let maybe_static = if concurrent { " + 'static" } else { "" }; - - uwrite!( - self.src, - "> where ::Data: Send{maybe_static} {{\n" - ); + uwrite!(self.src, "> where ::Data: Send {{\n"); // TODO: support tracing concurrent calls if self.generator.opts.tracing && !concurrent { - if let CallStyle::Async = &style { + if let CallStyle::Async | CallStyle::Concurrent = &style { self.src.push_str("use tracing::Instrument;\n"); } @@ -3306,7 +3280,7 @@ impl<'a> InterfaceGenerator<'a> { func.name, )); - if !matches!(&style, CallStyle::Async) { + if !matches!(&style, CallStyle::Async | CallStyle::Concurrent) { self.src.push_str( " let _enter = span.enter(); @@ -3365,14 +3339,18 @@ impl<'a> InterfaceGenerator<'a> { uwrite!(self.src, "arg{}, ", i); } - let instrument = if matches!(&style, CallStyle::Async) && self.generator.opts.tracing { + let instrument = if matches!(&style, CallStyle::Async | CallStyle::Concurrent) + && self.generator.opts.tracing + { ".instrument(span.clone())" } else { "" }; uwriteln!(self.src, ")){instrument}{await_}?;"); - let instrument = if matches!(&style, CallStyle::Async) && self.generator.opts.tracing { + let instrument = if matches!(&style, CallStyle::Async | CallStyle::Concurrent) + && self.generator.opts.tracing + { ".instrument(span)" } else { "" diff --git a/tests/all/component_model/bindgen.rs b/tests/all/component_model/bindgen.rs index 344105a8f0..cb81b1bcc8 100644 --- a/tests/all/component_model/bindgen.rs +++ b/tests/all/component_model/bindgen.rs @@ -188,11 +188,7 @@ mod one_import { } mod one_import_concurrent { - use { - super::*, - std::future::Future, - wasmtime::{component, StoreContextMut}, - }; + use {super::*, wasmtime::component::Accessor}; wasmtime::component::bindgen!({ inline: " @@ -269,14 +265,8 @@ mod one_import_concurrent { impl foo::Host for MyImports { type Data = MyImports; - fn foo( - mut store: StoreContextMut<'_, Self::Data>, - ) -> impl Future) + 'static> - + Send - + Sync - + 'static { - store.data_mut().hit = true; - async { component::for_any(|_| ()) } + async fn foo(accessor: &mut Accessor) { + accessor.with(|mut store| store.data_mut().hit = true); } } diff --git a/tests/all/component_model/import.rs b/tests/all/component_model/import.rs index 285e173ff5..45409a6bdf 100644 --- a/tests/all/component_model/import.rs +++ b/tests/all/component_model/import.rs @@ -3,7 +3,6 @@ use super::REALLOC_AND_FREE; use anyhow::Result; use std::ops::Deref; -use wasmtime::component; use wasmtime::component::*; use wasmtime::{Config, Engine, Store, StoreContextMut, Trap, WasmBacktrace}; @@ -737,7 +736,7 @@ async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { .root() .func_wrap_concurrent("f1", |_, (x,): (u32,)| { assert_eq!(x, 1); - async { component::for_any(|_| Ok((2u32,))) } + async { Ok((2u32,)) } })?; linker.root().func_wrap_concurrent( "f2", @@ -754,14 +753,14 @@ async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { WasmStr, ),)| { assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); - async { component::for_any(|_| Ok((3u32,))) } + async { Ok((3u32,)) } }, )?; linker .root() .func_wrap_concurrent("f3", |_, (arg,): (u32,)| { assert_eq!(arg, 8); - async { component::for_any(|_| Ok(("xyz".to_string(),))) } + async { Ok(("xyz".to_string(),)) } })?; linker.root().func_wrap_concurrent( "f4", @@ -778,7 +777,7 @@ async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { WasmStr, ),)| { assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); - async { component::for_any(|_| Ok(("xyz".to_string(),))) } + async { Ok(("xyz".to_string(),)) } }, )?; } else { @@ -851,7 +850,7 @@ async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { linker.root().func_new_concurrent("f1", |_, args| { if let Val::U32(x) = &args[0] { assert_eq!(*x, 1); - async { component::for_any(|_| Ok(vec![Val::U32(2)])) } + async { Ok(vec![Val::U32(2)]) } } else { panic!() } @@ -860,7 +859,7 @@ async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { if let Val::Tuple(tuple) = &args[0] { if let Val::String(s) = &tuple[0] { assert_eq!(s.deref(), "abc"); - async { component::for_any(|_| Ok(vec![Val::U32(3)])) } + async { Ok(vec![Val::U32(3)]) } } else { panic!() } @@ -871,7 +870,7 @@ async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { linker.root().func_new_concurrent("f3", |_, args| { if let Val::U32(x) = &args[0] { assert_eq!(*x, 8); - async { component::for_any(|_| Ok(vec![Val::String("xyz".into())])) } + async { Ok(vec![Val::String("xyz".into())]) } } else { panic!(); } @@ -880,7 +879,7 @@ async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { if let Val::Tuple(tuple) = &args[0] { if let Val::String(s) = &tuple[0] { assert_eq!(s.deref(), "abc"); - async { component::for_any(|_| Ok(vec![Val::String("xyz".into())])) } + async { Ok(vec![Val::String("xyz".into())]) } } else { panic!() }