diff --git a/crates/wasmtime/src/func.rs b/crates/wasmtime/src/func.rs index f832bbb117..12d1d07ec5 100644 --- a/crates/wasmtime/src/func.rs +++ b/crates/wasmtime/src/func.rs @@ -272,8 +272,9 @@ macro_rules! generate_wrap_async_func { { assert!(store.as_context().async_support(), concat!("cannot use `wrap", $num, "_async` without enabling async support on the config")); Func::wrap(store, move |mut caller: Caller<'_, T>, $($args: $args),*| { - let async_cx = caller.store.as_context_mut().0.async_cx(); + let async_cx = caller.store.as_context_mut().0.async_cx().expect("Attempt to start async function on dying fiber"); let mut future = Pin::from(func(caller, $($args),*)); + match unsafe { async_cx.block_on(future.as_mut()) } { Ok(ret) => ret.into_fallible(), Err(e) => R::fallible_from_trap(e), @@ -439,7 +440,12 @@ impl Func { "cannot use `new_async` without enabling async support in the config" ); Func::new(store, ty, move |mut caller, params, results| { - let async_cx = caller.store.as_context_mut().0.async_cx(); + let async_cx = caller + .store + .as_context_mut() + .0 + .async_cx() + .expect("Attempt to spawn new action on dying fiber"); let mut future = Pin::from(func(caller, params, results)); match unsafe { async_cx.block_on(future.as_mut()) } { Ok(Ok(())) => Ok(()), diff --git a/crates/wasmtime/src/lib.rs b/crates/wasmtime/src/lib.rs index df110a2938..ac7dc96cb9 100644 --- a/crates/wasmtime/src/lib.rs +++ b/crates/wasmtime/src/lib.rs @@ -423,6 +423,8 @@ pub use crate::linker::*; pub use crate::memory::*; pub use crate::module::{FrameInfo, FrameSymbol, Module}; pub use crate::r#ref::ExternRef; +#[cfg(feature = "async")] +pub use crate::store::CallHookHandler; pub use crate::store::{AsContext, AsContextMut, CallHook, Store, StoreContext, StoreContextMut}; pub use crate::trap::*; pub use crate::types::*; diff --git a/crates/wasmtime/src/linker.rs b/crates/wasmtime/src/linker.rs index 0ea313dd07..d4df9d7332 100644 --- a/crates/wasmtime/src/linker.rs +++ b/crates/wasmtime/src/linker.rs @@ -146,7 +146,7 @@ macro_rules! generate_wrap_async_func { ), ); self.func_wrap(module, name, move |mut caller: Caller<'_, T>, $($args: $args),*| { - let async_cx = caller.store.as_context_mut().0.async_cx(); + let async_cx = caller.store.as_context_mut().0.async_cx().expect("Attempt to start async function on dying fiber"); let mut future = Pin::from(func(caller, $($args),*)); match unsafe { async_cx.block_on(future.as_mut()) } { Ok(ret) => ret.into_fallible(), @@ -360,7 +360,12 @@ impl Linker { "cannot use `func_new_async` without enabling async support in the config" ); self.func_new(module, name, ty, move |mut caller, params, results| { - let async_cx = caller.store.as_context_mut().0.async_cx(); + let async_cx = caller + .store + .as_context_mut() + .0 + .async_cx() + .expect("Attempt to spawn new function on dying fiber"); let mut future = Pin::from(func(caller, params, results)); match unsafe { async_cx.block_on(future.as_mut()) } { Ok(Ok(())) => Ok(()), diff --git a/crates/wasmtime/src/store.rs b/crates/wasmtime/src/store.rs index c6c5d3ceae..1357f703f7 100644 --- a/crates/wasmtime/src/store.rs +++ b/crates/wasmtime/src/store.rs @@ -200,7 +200,7 @@ pub struct StoreInner { inner: StoreOpaque, limiter: Option>, - call_hook: Option Result<(), crate::Trap> + Send + Sync>>, + call_hook: Option>, // for comments about `ManuallyDrop`, see `Store::into_data` data: ManuallyDrop, } @@ -211,6 +211,21 @@ enum ResourceLimiterInner { Async(Box &mut (dyn crate::ResourceLimiterAsync) + Send + Sync>), } +/// An object that can take callbacks when the runtime enters or exits hostcalls. +#[cfg(feature = "async")] +#[async_trait::async_trait] +pub trait CallHookHandler: Send { + /// A callback to run when wasmtime is about to enter a host call, or when about to + /// exit the hostcall. + async fn handle_call_event(&self, t: &mut T, ch: CallHook) -> Result<(), crate::Trap>; +} + +enum CallHookInner { + Sync(Box Result<(), crate::Trap> + Send + Sync>), + #[cfg(feature = "async")] + Async(Box + Send + Sync>), +} + // Forward methods on `StoreOpaque` to also being on `StoreInner` impl Deref for StoreInner { type Target = StoreOpaque; @@ -603,6 +618,27 @@ impl Store { inner.limiter = Some(ResourceLimiterInner::Async(Box::new(limiter))); } + #[cfg_attr(nightlydoc, doc(cfg(feature = "async")))] + /// Configures an async function that runs on calls and returns between + /// WebAssembly and host code. For the non-async equivalent of this method, + /// see [`Store::call_hook`]. + /// + /// The function is passed a [`CallHook`] argument, which indicates which + /// state transition the VM is making. + /// + /// This function's future may return a [`Trap`]. If a trap is returned + /// when an import was called, it is immediately raised as-if the host + /// import had returned the trap. If a trap is returned after wasm returns + /// to the host then the wasm function's result is ignored and this trap is + /// returned instead. + /// + /// After this function returns a trap, it may be called for subsequent + /// returns to host or wasm code as the trap propagates to the root call. + #[cfg(feature = "async")] + pub fn call_hook_async(&mut self, hook: impl CallHookHandler + Send + Sync + 'static) { + self.inner.call_hook = Some(CallHookInner::Async(Box::new(hook))); + } + /// Configure a function that runs on calls and returns between WebAssembly /// and host code. /// @@ -616,12 +652,12 @@ impl Store { /// instead. /// /// After this function returns a trap, it may be called for subsequent returns - /// to host or wasm code as the trap propogates to the root call. + /// to host or wasm code as the trap propagates to the root call. pub fn call_hook( &mut self, hook: impl FnMut(&mut T, CallHook) -> Result<(), Trap> + Send + Sync + 'static, ) { - self.inner.call_hook = Some(Box::new(hook)); + self.inner.call_hook = Some(CallHookInner::Sync(Box::new(hook))); } /// Returns the [`Engine`] that this store is associated with. @@ -956,10 +992,19 @@ impl StoreInner { } pub fn call_hook(&mut self, s: CallHook) -> Result<(), Trap> { - if let Some(hook) = &mut self.call_hook { - hook(&mut self.data, s) - } else { - Ok(()) + match &mut self.call_hook { + Some(CallHookInner::Sync(hook)) => hook(&mut self.data, s), + + #[cfg(feature = "async")] + Some(CallHookInner::Async(handler)) => unsafe { + Ok(self + .inner + .async_cx() + .ok_or(Trap::new("couldn't grab async_cx for call hook"))? + .block_on(handler.handle_call_event(&mut self.data, s).as_mut())??) + }, + + None => Ok(()), } } } @@ -1143,14 +1188,29 @@ impl StoreOpaque { panic!("trampoline missing") } + /// Yields the async context, assuming that we are executing on a fiber and + /// that fiber is not in the process of dying. This function will return + /// None in the latter case (the fiber is dying), and panic if + /// `async_support()` is false. #[cfg(feature = "async")] #[inline] - pub fn async_cx(&self) -> AsyncCx { + pub fn async_cx(&self) -> Option { debug_assert!(self.async_support()); - AsyncCx { - current_suspend: self.async_state.current_suspend.get(), - current_poll_cx: self.async_state.current_poll_cx.get(), + + let poll_cx_box_ptr = self.async_state.current_poll_cx.get(); + if poll_cx_box_ptr.is_null() { + return None; } + + let poll_cx_inner_ptr = unsafe { *poll_cx_box_ptr }; + if poll_cx_inner_ptr.is_null() { + return None; + } + + Some(AsyncCx { + current_suspend: self.async_state.current_suspend.get(), + current_poll_cx: poll_cx_box_ptr, + }) } pub fn fuel_consumed(&self) -> Option { @@ -1214,7 +1274,11 @@ impl StoreOpaque { // to clean up this fiber. Do so by raising a trap which will // abort all wasm and get caught on the other side to clean // things up. - unsafe { self.async_cx().block_on(Pin::new_unchecked(&mut future)) } + unsafe { + self.async_cx() + .expect("attempted to pull async context during shutdown") + .block_on(Pin::new_unchecked(&mut future)) + } } fn add_fuel(&mut self, fuel: u64) -> Result<()> { @@ -1649,22 +1713,15 @@ unsafe impl wasmtime_runtime::Store for StoreInner { desired: usize, maximum: Option, ) -> Result { - // Need to borrow async_cx before the mut borrow of the limiter. - // self.async_cx() panicks when used with a non-async store, so - // wrap this in an option. - #[cfg(feature = "async")] - let async_cx = if self.async_support() { - Some(self.async_cx()) - } else { - None - }; match self.limiter { Some(ResourceLimiterInner::Sync(ref mut limiter)) => { Ok(limiter(&mut self.data).memory_growing(current, desired, maximum)) } #[cfg(feature = "async")] Some(ResourceLimiterInner::Async(ref mut limiter)) => unsafe { - Ok(async_cx + Ok(self + .inner + .async_cx() .expect("ResourceLimiterAsync requires async Store") .block_on( limiter(&mut self.data) @@ -1700,7 +1757,7 @@ unsafe impl wasmtime_runtime::Store for StoreInner { // wrap this in an option. #[cfg(feature = "async")] let async_cx = if self.async_support() { - Some(self.async_cx()) + Some(self.async_cx().unwrap()) } else { None }; diff --git a/tests/all/call_hook.rs b/tests/all/call_hook.rs index d1bc78e8a1..6e06379077 100644 --- a/tests/all/call_hook.rs +++ b/tests/all/call_hook.rs @@ -1,4 +1,7 @@ use anyhow::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{self, Poll}; use wasmtime::*; // Crate a synchronous Func, call it directly: @@ -551,6 +554,275 @@ fn trapping() -> Result<(), Error> { Ok(()) } +#[tokio::test] +async fn basic_async_hook() -> Result<(), Error> { + struct HandlerR; + + #[async_trait::async_trait] + impl CallHookHandler for HandlerR { + async fn handle_call_event( + &self, + obj: &mut State, + ch: CallHook, + ) -> Result<(), wasmtime::Trap> { + State::call_hook(obj, ch) + } + } + let mut config = Config::new(); + config.async_support(true); + let engine = Engine::new(&config)?; + let mut store = Store::new(&engine, State::default()); + store.call_hook_async(HandlerR {}); + + assert_eq!(store.data().calls_into_host, 0); + assert_eq!(store.data().returns_from_host, 0); + assert_eq!(store.data().calls_into_wasm, 0); + assert_eq!(store.data().returns_from_wasm, 0); + + let mut linker = Linker::new(&engine); + + linker.func_wrap( + "host", + "f", + |caller: Caller, a: i32, b: i64, c: f32, d: f64| { + // Calling this func will switch context into wasm, then back to host: + assert_eq!(caller.data().context, vec![Context::Wasm, Context::Host]); + + assert_eq!( + caller.data().calls_into_host, + caller.data().returns_from_host + 1 + ); + assert_eq!( + caller.data().calls_into_wasm, + caller.data().returns_from_wasm + 1 + ); + + assert_eq!(a, 1); + assert_eq!(b, 2); + assert_eq!(c, 3.0); + assert_eq!(d, 4.0); + }, + )?; + + let wat = r#" + (module + (import "host" "f" + (func $f (param i32) (param i64) (param f32) (param f64))) + (func (export "export") + (call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0))) + ) + "#; + let module = Module::new(&engine, wat)?; + + let inst = linker.instantiate(&mut store, &module)?; + let export = inst + .get_export(&mut store, "export") + .expect("get export") + .into_func() + .expect("export is func"); + + export.call_async(&mut store, &[], &mut []).await?; + + // One switch from vm to host to call f, another in return from f. + assert_eq!(store.data().calls_into_host, 1); + assert_eq!(store.data().returns_from_host, 1); + assert_eq!(store.data().calls_into_wasm, 1); + assert_eq!(store.data().returns_from_wasm, 1); + + Ok(()) +} + +#[tokio::test] +async fn timeout_async_hook() -> Result<(), Error> { + struct HandlerR; + + #[async_trait::async_trait] + impl CallHookHandler for HandlerR { + async fn handle_call_event( + &self, + obj: &mut State, + ch: CallHook, + ) -> Result<(), wasmtime::Trap> { + if obj.calls_into_host > 200 { + return Err(wasmtime::Trap::new("timeout")); + } + + match ch { + CallHook::CallingHost => obj.calls_into_host += 1, + CallHook::CallingWasm => obj.calls_into_wasm += 1, + CallHook::ReturningFromHost => obj.returns_from_host += 1, + CallHook::ReturningFromWasm => obj.returns_from_wasm += 1, + } + + Ok(()) + } + } + + let mut config = Config::new(); + config.async_support(true); + let engine = Engine::new(&config)?; + let mut store = Store::new(&engine, State::default()); + store.call_hook_async(HandlerR {}); + + assert_eq!(store.data().calls_into_host, 0); + assert_eq!(store.data().returns_from_host, 0); + assert_eq!(store.data().calls_into_wasm, 0); + assert_eq!(store.data().returns_from_wasm, 0); + + let mut linker = Linker::new(&engine); + + linker.func_wrap( + "host", + "f", + |_caller: Caller, a: i32, b: i64, c: f32, d: f64| { + assert_eq!(a, 1); + assert_eq!(b, 2); + assert_eq!(c, 3.0); + assert_eq!(d, 4.0); + }, + )?; + + let wat = r#" + (module + (import "host" "f" + (func $f (param i32) (param i64) (param f32) (param f64))) + (func (export "export") + (loop $start + (call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0)) + (br $start))) + ) + "#; + let module = Module::new(&engine, wat)?; + + let inst = linker.instantiate(&mut store, &module)?; + let export = inst + .get_typed_func::<(), (), _>(&mut store, "export") + .expect("export is func"); + + store.set_epoch_deadline(1); + store.epoch_deadline_async_yield_and_update(1); + assert!(export.call_async(&mut store, ()).await.is_err()); + + // One switch from vm to host to call f, another in return from f. + assert!(store.data().calls_into_host > 1); + assert!(store.data().returns_from_host > 1); + assert_eq!(store.data().calls_into_wasm, 1); + assert_eq!(store.data().returns_from_wasm, 0); + + Ok(()) +} + +#[tokio::test] +async fn drop_suspended_async_hook() -> Result<(), Error> { + struct Handler; + + #[async_trait::async_trait] + impl CallHookHandler for Handler { + async fn handle_call_event( + &self, + state: &mut u32, + _ch: CallHook, + ) -> Result<(), wasmtime::Trap> { + assert_eq!(*state, 0); + *state += 1; + let _dec = Decrement(state); + + // Simulate some sort of event which takes a number of yields + for _ in 0..500 { + tokio::task::yield_now().await; + } + Ok(()) + } + } + + let mut config = Config::new(); + config.async_support(true); + let engine = Engine::new(&config)?; + let mut store = Store::new(&engine, 0); + store.call_hook_async(Handler); + + let mut linker = Linker::new(&engine); + + // Simulate a host function that has lots of yields with an infinite loop. + linker.func_wrap0_async("host", "f", |mut cx| { + Box::new(async move { + let state = cx.data_mut(); + assert_eq!(*state, 0); + *state += 1; + let _dec = Decrement(state); + loop { + tokio::task::yield_now().await; + } + }) + })?; + + let wat = r#" + (module + (import "host" "f" (func $f)) + (func (export "") call $f) + ) + "#; + let module = Module::new(&engine, wat)?; + + let inst = linker.instantiate(&mut store, &module)?; + assert_eq!(*store.data(), 0); + let export = inst + .get_typed_func::<(), (), _>(&mut store, "") + .expect("export is func"); + + // First test that if we drop in the middle of an async hook that everything + // is alright. + PollNTimes { + future: Box::pin(export.call_async(&mut store, ())), + times: 200, + } + .await; + assert_eq!(*store.data(), 0); // double-check user dtors ran + + // Next test that if we drop while in a host async function that everything + // is also alright. + PollNTimes { + future: Box::pin(export.call_async(&mut store, ())), + times: 1_000, + } + .await; + assert_eq!(*store.data(), 0); // double-check user dtors ran + + return Ok(()); + + // A helper struct to poll an inner `future` N `times` and then resolve. + // This is used above to test that when futures are dropped while they're + // pending everything works and is cleaned up on the Wasmtime side of + // things. + struct PollNTimes { + future: F, + times: u32, + } + + impl Future for PollNTimes { + type Output = (); + fn poll(mut self: Pin<&mut Self>, task: &mut task::Context<'_>) -> Poll<()> { + for _ in 0..self.times { + match Pin::new(&mut self.future).poll(task) { + Poll::Ready(_) => panic!("future should not be ready"), + Poll::Pending => {} + } + } + + Poll::Ready(()) + } + } + + // helper struct to decrement a counter on drop + struct Decrement<'a>(&'a mut u32); + + impl Drop for Decrement<'_> { + fn drop(&mut self) { + *self.0 -= 1; + } + } +} + #[derive(Debug, PartialEq, Eq)] enum Context { Host,