Add support for async call hooks (#3876)

* Instead of simply panicking, return an error when we attempt to resume on a dying fiber.

This situation should never occur in the existing code base, but can be
triggered if support for running outside async code in a call hook.

* Shift `async_cx()` to return an `Option`, reflecting if the fiber is dying.

This should never happen in the existing code base, but is a nice
forward-looking guard. The current implementations simply lift the
trap that would eventually be produced by such an operation into
a `Trap` (or similar) at the invocation of `async_cx()`.

* Add support for using `async` call hooks.

This retains the ability to do non-async hooks. Hooks end up being
implemented as an async trait with a handler call, to get around some
issues passing around async closures. This change requires some of
the prior changes to handle picking up blocked tasks during fiber
shutdown, to avoid some panics during timeouts and other such events.

* More fully specify a doc link, to avoid a doc-building error.

* Revert the use of catchable traps on cancellation of a fiber; turn them into expect()/unwrap().

The justification for this revert is that (a) these events shouldn't
happen, and (b) they wouldn't be catchable by wasm anyways.

* Replace a duplicated check in `async` hook evaluation with a single check.

This also moves the checks inside of their respective Async variants,
meaning that if you're using an async-enabled version of wasmtime but
using the synchronous versions of the callbacks, you won't pay any
penalty for validating the async context.

* Use `match &mut ...` insead of `ref mut`.

* Add some documentation on why/when `async_cx` can return None.

* Add two simple test cases for async call hooks.

* Fix async_cx() to check both the box and the value for current_poll_cx.

In the prior version, we only checked that the box had not been cleared,
but had not ensured that there was an actual context for us to use. This
updates the check to validate both, returning None if the inner context
is missing. This allows us to skip a validation check inside `block_on`,
since all callers will have run through the `async_cx` check prior to
arrival.

* Tweak the timeout test to address PR suggestions.

* Add a test about dropping async hooks while suspended

Should help exercise that the check for `None` is properly handled in a
few more locations.

Co-authored-by: Alex Crichton <alex@alexcrichton.com>
This commit is contained in:
Adam Wick
2022-03-23 08:43:34 -07:00
committed by GitHub
parent 923faaff4f
commit 6a60e8363f
5 changed files with 369 additions and 27 deletions

View File

@@ -272,8 +272,9 @@ macro_rules! generate_wrap_async_func {
{ {
assert!(store.as_context().async_support(), concat!("cannot use `wrap", $num, "_async` without enabling async support on the config")); assert!(store.as_context().async_support(), concat!("cannot use `wrap", $num, "_async` without enabling async support on the config"));
Func::wrap(store, move |mut caller: Caller<'_, T>, $($args: $args),*| { Func::wrap(store, move |mut caller: Caller<'_, T>, $($args: $args),*| {
let async_cx = caller.store.as_context_mut().0.async_cx(); let async_cx = caller.store.as_context_mut().0.async_cx().expect("Attempt to start async function on dying fiber");
let mut future = Pin::from(func(caller, $($args),*)); let mut future = Pin::from(func(caller, $($args),*));
match unsafe { async_cx.block_on(future.as_mut()) } { match unsafe { async_cx.block_on(future.as_mut()) } {
Ok(ret) => ret.into_fallible(), Ok(ret) => ret.into_fallible(),
Err(e) => R::fallible_from_trap(e), Err(e) => R::fallible_from_trap(e),
@@ -439,7 +440,12 @@ impl Func {
"cannot use `new_async` without enabling async support in the config" "cannot use `new_async` without enabling async support in the config"
); );
Func::new(store, ty, move |mut caller, params, results| { Func::new(store, ty, move |mut caller, params, results| {
let async_cx = caller.store.as_context_mut().0.async_cx(); let async_cx = caller
.store
.as_context_mut()
.0
.async_cx()
.expect("Attempt to spawn new action on dying fiber");
let mut future = Pin::from(func(caller, params, results)); let mut future = Pin::from(func(caller, params, results));
match unsafe { async_cx.block_on(future.as_mut()) } { match unsafe { async_cx.block_on(future.as_mut()) } {
Ok(Ok(())) => Ok(()), Ok(Ok(())) => Ok(()),

View File

@@ -423,6 +423,8 @@ pub use crate::linker::*;
pub use crate::memory::*; pub use crate::memory::*;
pub use crate::module::{FrameInfo, FrameSymbol, Module}; pub use crate::module::{FrameInfo, FrameSymbol, Module};
pub use crate::r#ref::ExternRef; pub use crate::r#ref::ExternRef;
#[cfg(feature = "async")]
pub use crate::store::CallHookHandler;
pub use crate::store::{AsContext, AsContextMut, CallHook, Store, StoreContext, StoreContextMut}; pub use crate::store::{AsContext, AsContextMut, CallHook, Store, StoreContext, StoreContextMut};
pub use crate::trap::*; pub use crate::trap::*;
pub use crate::types::*; pub use crate::types::*;

View File

@@ -146,7 +146,7 @@ macro_rules! generate_wrap_async_func {
), ),
); );
self.func_wrap(module, name, move |mut caller: Caller<'_, T>, $($args: $args),*| { self.func_wrap(module, name, move |mut caller: Caller<'_, T>, $($args: $args),*| {
let async_cx = caller.store.as_context_mut().0.async_cx(); let async_cx = caller.store.as_context_mut().0.async_cx().expect("Attempt to start async function on dying fiber");
let mut future = Pin::from(func(caller, $($args),*)); let mut future = Pin::from(func(caller, $($args),*));
match unsafe { async_cx.block_on(future.as_mut()) } { match unsafe { async_cx.block_on(future.as_mut()) } {
Ok(ret) => ret.into_fallible(), Ok(ret) => ret.into_fallible(),
@@ -360,7 +360,12 @@ impl<T> Linker<T> {
"cannot use `func_new_async` without enabling async support in the config" "cannot use `func_new_async` without enabling async support in the config"
); );
self.func_new(module, name, ty, move |mut caller, params, results| { self.func_new(module, name, ty, move |mut caller, params, results| {
let async_cx = caller.store.as_context_mut().0.async_cx(); let async_cx = caller
.store
.as_context_mut()
.0
.async_cx()
.expect("Attempt to spawn new function on dying fiber");
let mut future = Pin::from(func(caller, params, results)); let mut future = Pin::from(func(caller, params, results));
match unsafe { async_cx.block_on(future.as_mut()) } { match unsafe { async_cx.block_on(future.as_mut()) } {
Ok(Ok(())) => Ok(()), Ok(Ok(())) => Ok(()),

View File

@@ -200,7 +200,7 @@ pub struct StoreInner<T> {
inner: StoreOpaque, inner: StoreOpaque,
limiter: Option<ResourceLimiterInner<T>>, limiter: Option<ResourceLimiterInner<T>>,
call_hook: Option<Box<dyn FnMut(&mut T, CallHook) -> Result<(), crate::Trap> + Send + Sync>>, call_hook: Option<CallHookInner<T>>,
// for comments about `ManuallyDrop`, see `Store::into_data` // for comments about `ManuallyDrop`, see `Store::into_data`
data: ManuallyDrop<T>, data: ManuallyDrop<T>,
} }
@@ -211,6 +211,21 @@ enum ResourceLimiterInner<T> {
Async(Box<dyn FnMut(&mut T) -> &mut (dyn crate::ResourceLimiterAsync) + Send + Sync>), Async(Box<dyn FnMut(&mut T) -> &mut (dyn crate::ResourceLimiterAsync) + Send + Sync>),
} }
/// An object that can take callbacks when the runtime enters or exits hostcalls.
#[cfg(feature = "async")]
#[async_trait::async_trait]
pub trait CallHookHandler<T>: Send {
/// A callback to run when wasmtime is about to enter a host call, or when about to
/// exit the hostcall.
async fn handle_call_event(&self, t: &mut T, ch: CallHook) -> Result<(), crate::Trap>;
}
enum CallHookInner<T> {
Sync(Box<dyn FnMut(&mut T, CallHook) -> Result<(), crate::Trap> + Send + Sync>),
#[cfg(feature = "async")]
Async(Box<dyn CallHookHandler<T> + Send + Sync>),
}
// Forward methods on `StoreOpaque` to also being on `StoreInner<T>` // Forward methods on `StoreOpaque` to also being on `StoreInner<T>`
impl<T> Deref for StoreInner<T> { impl<T> Deref for StoreInner<T> {
type Target = StoreOpaque; type Target = StoreOpaque;
@@ -603,6 +618,27 @@ impl<T> Store<T> {
inner.limiter = Some(ResourceLimiterInner::Async(Box::new(limiter))); inner.limiter = Some(ResourceLimiterInner::Async(Box::new(limiter)));
} }
#[cfg_attr(nightlydoc, doc(cfg(feature = "async")))]
/// Configures an async function that runs on calls and returns between
/// WebAssembly and host code. For the non-async equivalent of this method,
/// see [`Store::call_hook`].
///
/// The function is passed a [`CallHook`] argument, which indicates which
/// state transition the VM is making.
///
/// This function's future may return a [`Trap`]. If a trap is returned
/// when an import was called, it is immediately raised as-if the host
/// import had returned the trap. If a trap is returned after wasm returns
/// to the host then the wasm function's result is ignored and this trap is
/// returned instead.
///
/// After this function returns a trap, it may be called for subsequent
/// returns to host or wasm code as the trap propagates to the root call.
#[cfg(feature = "async")]
pub fn call_hook_async(&mut self, hook: impl CallHookHandler<T> + Send + Sync + 'static) {
self.inner.call_hook = Some(CallHookInner::Async(Box::new(hook)));
}
/// Configure a function that runs on calls and returns between WebAssembly /// Configure a function that runs on calls and returns between WebAssembly
/// and host code. /// and host code.
/// ///
@@ -616,12 +652,12 @@ impl<T> Store<T> {
/// instead. /// instead.
/// ///
/// After this function returns a trap, it may be called for subsequent returns /// After this function returns a trap, it may be called for subsequent returns
/// to host or wasm code as the trap propogates to the root call. /// to host or wasm code as the trap propagates to the root call.
pub fn call_hook( pub fn call_hook(
&mut self, &mut self,
hook: impl FnMut(&mut T, CallHook) -> Result<(), Trap> + Send + Sync + 'static, hook: impl FnMut(&mut T, CallHook) -> Result<(), Trap> + Send + Sync + 'static,
) { ) {
self.inner.call_hook = Some(Box::new(hook)); self.inner.call_hook = Some(CallHookInner::Sync(Box::new(hook)));
} }
/// Returns the [`Engine`] that this store is associated with. /// Returns the [`Engine`] that this store is associated with.
@@ -956,10 +992,19 @@ impl<T> StoreInner<T> {
} }
pub fn call_hook(&mut self, s: CallHook) -> Result<(), Trap> { pub fn call_hook(&mut self, s: CallHook) -> Result<(), Trap> {
if let Some(hook) = &mut self.call_hook { match &mut self.call_hook {
hook(&mut self.data, s) Some(CallHookInner::Sync(hook)) => hook(&mut self.data, s),
} else {
Ok(()) #[cfg(feature = "async")]
Some(CallHookInner::Async(handler)) => unsafe {
Ok(self
.inner
.async_cx()
.ok_or(Trap::new("couldn't grab async_cx for call hook"))?
.block_on(handler.handle_call_event(&mut self.data, s).as_mut())??)
},
None => Ok(()),
} }
} }
} }
@@ -1143,14 +1188,29 @@ impl StoreOpaque {
panic!("trampoline missing") panic!("trampoline missing")
} }
/// Yields the async context, assuming that we are executing on a fiber and
/// that fiber is not in the process of dying. This function will return
/// None in the latter case (the fiber is dying), and panic if
/// `async_support()` is false.
#[cfg(feature = "async")] #[cfg(feature = "async")]
#[inline] #[inline]
pub fn async_cx(&self) -> AsyncCx { pub fn async_cx(&self) -> Option<AsyncCx> {
debug_assert!(self.async_support()); debug_assert!(self.async_support());
AsyncCx {
current_suspend: self.async_state.current_suspend.get(), let poll_cx_box_ptr = self.async_state.current_poll_cx.get();
current_poll_cx: self.async_state.current_poll_cx.get(), if poll_cx_box_ptr.is_null() {
return None;
} }
let poll_cx_inner_ptr = unsafe { *poll_cx_box_ptr };
if poll_cx_inner_ptr.is_null() {
return None;
}
Some(AsyncCx {
current_suspend: self.async_state.current_suspend.get(),
current_poll_cx: poll_cx_box_ptr,
})
} }
pub fn fuel_consumed(&self) -> Option<u64> { pub fn fuel_consumed(&self) -> Option<u64> {
@@ -1214,7 +1274,11 @@ impl StoreOpaque {
// to clean up this fiber. Do so by raising a trap which will // to clean up this fiber. Do so by raising a trap which will
// abort all wasm and get caught on the other side to clean // abort all wasm and get caught on the other side to clean
// things up. // things up.
unsafe { self.async_cx().block_on(Pin::new_unchecked(&mut future)) } unsafe {
self.async_cx()
.expect("attempted to pull async context during shutdown")
.block_on(Pin::new_unchecked(&mut future))
}
} }
fn add_fuel(&mut self, fuel: u64) -> Result<()> { fn add_fuel(&mut self, fuel: u64) -> Result<()> {
@@ -1649,22 +1713,15 @@ unsafe impl<T> wasmtime_runtime::Store for StoreInner<T> {
desired: usize, desired: usize,
maximum: Option<usize>, maximum: Option<usize>,
) -> Result<bool, anyhow::Error> { ) -> Result<bool, anyhow::Error> {
// Need to borrow async_cx before the mut borrow of the limiter.
// self.async_cx() panicks when used with a non-async store, so
// wrap this in an option.
#[cfg(feature = "async")]
let async_cx = if self.async_support() {
Some(self.async_cx())
} else {
None
};
match self.limiter { match self.limiter {
Some(ResourceLimiterInner::Sync(ref mut limiter)) => { Some(ResourceLimiterInner::Sync(ref mut limiter)) => {
Ok(limiter(&mut self.data).memory_growing(current, desired, maximum)) Ok(limiter(&mut self.data).memory_growing(current, desired, maximum))
} }
#[cfg(feature = "async")] #[cfg(feature = "async")]
Some(ResourceLimiterInner::Async(ref mut limiter)) => unsafe { Some(ResourceLimiterInner::Async(ref mut limiter)) => unsafe {
Ok(async_cx Ok(self
.inner
.async_cx()
.expect("ResourceLimiterAsync requires async Store") .expect("ResourceLimiterAsync requires async Store")
.block_on( .block_on(
limiter(&mut self.data) limiter(&mut self.data)
@@ -1700,7 +1757,7 @@ unsafe impl<T> wasmtime_runtime::Store for StoreInner<T> {
// wrap this in an option. // wrap this in an option.
#[cfg(feature = "async")] #[cfg(feature = "async")]
let async_cx = if self.async_support() { let async_cx = if self.async_support() {
Some(self.async_cx()) Some(self.async_cx().unwrap())
} else { } else {
None None
}; };

View File

@@ -1,4 +1,7 @@
use anyhow::Error; use anyhow::Error;
use std::future::Future;
use std::pin::Pin;
use std::task::{self, Poll};
use wasmtime::*; use wasmtime::*;
// Crate a synchronous Func, call it directly: // Crate a synchronous Func, call it directly:
@@ -551,6 +554,275 @@ fn trapping() -> Result<(), Error> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn basic_async_hook() -> Result<(), Error> {
struct HandlerR;
#[async_trait::async_trait]
impl CallHookHandler<State> for HandlerR {
async fn handle_call_event(
&self,
obj: &mut State,
ch: CallHook,
) -> Result<(), wasmtime::Trap> {
State::call_hook(obj, ch)
}
}
let mut config = Config::new();
config.async_support(true);
let engine = Engine::new(&config)?;
let mut store = Store::new(&engine, State::default());
store.call_hook_async(HandlerR {});
assert_eq!(store.data().calls_into_host, 0);
assert_eq!(store.data().returns_from_host, 0);
assert_eq!(store.data().calls_into_wasm, 0);
assert_eq!(store.data().returns_from_wasm, 0);
let mut linker = Linker::new(&engine);
linker.func_wrap(
"host",
"f",
|caller: Caller<State>, a: i32, b: i64, c: f32, d: f64| {
// Calling this func will switch context into wasm, then back to host:
assert_eq!(caller.data().context, vec![Context::Wasm, Context::Host]);
assert_eq!(
caller.data().calls_into_host,
caller.data().returns_from_host + 1
);
assert_eq!(
caller.data().calls_into_wasm,
caller.data().returns_from_wasm + 1
);
assert_eq!(a, 1);
assert_eq!(b, 2);
assert_eq!(c, 3.0);
assert_eq!(d, 4.0);
},
)?;
let wat = r#"
(module
(import "host" "f"
(func $f (param i32) (param i64) (param f32) (param f64)))
(func (export "export")
(call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0)))
)
"#;
let module = Module::new(&engine, wat)?;
let inst = linker.instantiate(&mut store, &module)?;
let export = inst
.get_export(&mut store, "export")
.expect("get export")
.into_func()
.expect("export is func");
export.call_async(&mut store, &[], &mut []).await?;
// One switch from vm to host to call f, another in return from f.
assert_eq!(store.data().calls_into_host, 1);
assert_eq!(store.data().returns_from_host, 1);
assert_eq!(store.data().calls_into_wasm, 1);
assert_eq!(store.data().returns_from_wasm, 1);
Ok(())
}
#[tokio::test]
async fn timeout_async_hook() -> Result<(), Error> {
struct HandlerR;
#[async_trait::async_trait]
impl CallHookHandler<State> for HandlerR {
async fn handle_call_event(
&self,
obj: &mut State,
ch: CallHook,
) -> Result<(), wasmtime::Trap> {
if obj.calls_into_host > 200 {
return Err(wasmtime::Trap::new("timeout"));
}
match ch {
CallHook::CallingHost => obj.calls_into_host += 1,
CallHook::CallingWasm => obj.calls_into_wasm += 1,
CallHook::ReturningFromHost => obj.returns_from_host += 1,
CallHook::ReturningFromWasm => obj.returns_from_wasm += 1,
}
Ok(())
}
}
let mut config = Config::new();
config.async_support(true);
let engine = Engine::new(&config)?;
let mut store = Store::new(&engine, State::default());
store.call_hook_async(HandlerR {});
assert_eq!(store.data().calls_into_host, 0);
assert_eq!(store.data().returns_from_host, 0);
assert_eq!(store.data().calls_into_wasm, 0);
assert_eq!(store.data().returns_from_wasm, 0);
let mut linker = Linker::new(&engine);
linker.func_wrap(
"host",
"f",
|_caller: Caller<State>, a: i32, b: i64, c: f32, d: f64| {
assert_eq!(a, 1);
assert_eq!(b, 2);
assert_eq!(c, 3.0);
assert_eq!(d, 4.0);
},
)?;
let wat = r#"
(module
(import "host" "f"
(func $f (param i32) (param i64) (param f32) (param f64)))
(func (export "export")
(loop $start
(call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0))
(br $start)))
)
"#;
let module = Module::new(&engine, wat)?;
let inst = linker.instantiate(&mut store, &module)?;
let export = inst
.get_typed_func::<(), (), _>(&mut store, "export")
.expect("export is func");
store.set_epoch_deadline(1);
store.epoch_deadline_async_yield_and_update(1);
assert!(export.call_async(&mut store, ()).await.is_err());
// One switch from vm to host to call f, another in return from f.
assert!(store.data().calls_into_host > 1);
assert!(store.data().returns_from_host > 1);
assert_eq!(store.data().calls_into_wasm, 1);
assert_eq!(store.data().returns_from_wasm, 0);
Ok(())
}
#[tokio::test]
async fn drop_suspended_async_hook() -> Result<(), Error> {
struct Handler;
#[async_trait::async_trait]
impl CallHookHandler<u32> for Handler {
async fn handle_call_event(
&self,
state: &mut u32,
_ch: CallHook,
) -> Result<(), wasmtime::Trap> {
assert_eq!(*state, 0);
*state += 1;
let _dec = Decrement(state);
// Simulate some sort of event which takes a number of yields
for _ in 0..500 {
tokio::task::yield_now().await;
}
Ok(())
}
}
let mut config = Config::new();
config.async_support(true);
let engine = Engine::new(&config)?;
let mut store = Store::new(&engine, 0);
store.call_hook_async(Handler);
let mut linker = Linker::new(&engine);
// Simulate a host function that has lots of yields with an infinite loop.
linker.func_wrap0_async("host", "f", |mut cx| {
Box::new(async move {
let state = cx.data_mut();
assert_eq!(*state, 0);
*state += 1;
let _dec = Decrement(state);
loop {
tokio::task::yield_now().await;
}
})
})?;
let wat = r#"
(module
(import "host" "f" (func $f))
(func (export "") call $f)
)
"#;
let module = Module::new(&engine, wat)?;
let inst = linker.instantiate(&mut store, &module)?;
assert_eq!(*store.data(), 0);
let export = inst
.get_typed_func::<(), (), _>(&mut store, "")
.expect("export is func");
// First test that if we drop in the middle of an async hook that everything
// is alright.
PollNTimes {
future: Box::pin(export.call_async(&mut store, ())),
times: 200,
}
.await;
assert_eq!(*store.data(), 0); // double-check user dtors ran
// Next test that if we drop while in a host async function that everything
// is also alright.
PollNTimes {
future: Box::pin(export.call_async(&mut store, ())),
times: 1_000,
}
.await;
assert_eq!(*store.data(), 0); // double-check user dtors ran
return Ok(());
// A helper struct to poll an inner `future` N `times` and then resolve.
// This is used above to test that when futures are dropped while they're
// pending everything works and is cleaned up on the Wasmtime side of
// things.
struct PollNTimes<F> {
future: F,
times: u32,
}
impl<F: Future + Unpin> Future for PollNTimes<F> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, task: &mut task::Context<'_>) -> Poll<()> {
for _ in 0..self.times {
match Pin::new(&mut self.future).poll(task) {
Poll::Ready(_) => panic!("future should not be ready"),
Poll::Pending => {}
}
}
Poll::Ready(())
}
}
// helper struct to decrement a counter on drop
struct Decrement<'a>(&'a mut u32);
impl Drop for Decrement<'_> {
fn drop(&mut self) {
*self.0 -= 1;
}
}
}
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
enum Context { enum Context {
Host, Host,