Implement the post-return attribute (#4297)

This commit implements the `post-return` feature of the canonical ABI in
the component model. This attribute is an optionally-specified function
which is to be executed after the return value has been processed by the
caller to optionally clean-up the return value. This enables, for
example, returning an allocated string and the host then knows how to
clean it up to prevent memory leaks in the original module.

The API exposed in this PR changes the prior `TypedFunc::call` API in
behavior but not in its signature. Previously the `TypedFunc::call`
method would set the `may_enter` flag on the way out, but now that
operation is deferred until a new `TypedFunc::post_return` method is
called. This means that once a method on an instance is invoked then
nothing else can be done on the instance until the `post_return` method
is called. Note that the method must be called irrespective of whether
the `post-return` canonical ABI option was specified or not. Internally
wasm will be invoked if necessary.

This is a pretty wonky and unergonomic API to work with. For now I
couldn't think of a better alternative that improved on the ergonomics.
In the theory that the raw Wasmtime bindings for a component may not be
used all that heavily (instead `wit-bindgen` would largely be used) I'm
hoping that this isn't too much of an issue in the future.

cc #4185
This commit is contained in:
Alex Crichton
2022-06-23 14:36:21 -05:00
committed by GitHub
parent fa36e86f2c
commit 3339dd1f01
12 changed files with 787 additions and 112 deletions

View File

@@ -1,6 +1,6 @@
use crate::component::instance::{Instance, InstanceData};
use crate::store::{StoreOpaque, Stored};
use crate::AsContext;
use crate::{AsContext, ValRaw};
use anyhow::{Context, Result};
use std::mem::MaybeUninit;
use std::ptr::NonNull;
@@ -82,6 +82,8 @@ pub struct FuncData {
types: Arc<ComponentTypes>,
options: Options,
instance: Instance,
post_return: Option<(ExportFunction, VMTrampoline)>,
post_return_arg: Option<ValRaw>,
}
impl Func {
@@ -102,6 +104,11 @@ impl Func {
.memory
.map(|i| NonNull::new(data.instance().runtime_memory(i)).unwrap());
let realloc = options.realloc.map(|i| data.instance().runtime_realloc(i));
let post_return = options.post_return.map(|i| {
let anyfunc = data.instance().runtime_post_return(i);
let trampoline = store.lookup_trampoline(unsafe { anyfunc.as_ref() });
(ExportFunction { anyfunc }, trampoline)
});
let options = unsafe { Options::new(store.id(), memory, realloc, options.string_encoding) };
Func(store.store_data_mut().insert(FuncData {
trampoline,
@@ -110,6 +117,8 @@ impl Func {
ty,
types: data.component_types().clone(),
instance: *instance,
post_return,
post_return_arg: None,
}))
}

View File

@@ -8,7 +8,9 @@ use std::panic::{self, AssertUnwindSafe};
use std::ptr::NonNull;
use std::sync::Arc;
use wasmtime_environ::component::{ComponentTypes, StringEncoding, TypeFuncIndex};
use wasmtime_runtime::component::{VMComponentContext, VMLowering, VMLoweringCallee};
use wasmtime_runtime::component::{
VMComponentContext, VMComponentFlags, VMLowering, VMLoweringCallee,
};
use wasmtime_runtime::{VMCallerCheckedAnyfunc, VMMemoryDefinition, VMOpaqueContext};
/// Trait representing host-defined functions that can be imported into a wasm
@@ -134,8 +136,7 @@ where
let cx = VMComponentContext::from_opaque(cx);
let instance = (*cx).instance();
let may_leave = (*instance).may_leave();
let may_enter = (*instance).may_enter();
let flags = (*instance).flags();
let mut cx = StoreContextMut::from_raw((*instance).store());
let options = Options::new(
@@ -148,13 +149,13 @@ where
// Perform a dynamic check that this instance can indeed be left. Exiting
// the component is disallowed, for example, when the `realloc` function
// calls a canonical import.
if !*may_leave {
if !(*flags).may_leave() {
bail!("cannot leave component instance");
}
// While we're lifting and lowering this instance cannot be reentered, so
// unset the flag here. This is also reset back to `true` on exit.
let _reset_may_enter = unset_and_reset_on_drop(may_enter);
let _reset_may_enter = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_enter);
// There's a 2x2 matrix of whether parameters and results are stored on the
// stack or on the heap. Each of the 4 branches here have a different
@@ -172,7 +173,7 @@ where
let storage = cast_storage::<ReturnStack<Params::Lower, Return::Lower>>(storage);
let params = Params::lift(cx.0, &options, &storage.assume_init_ref().args)?;
let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else {
let storage = cast_storage::<ReturnPointer<Params::Lower>>(storage).assume_init_ref();
@@ -180,7 +181,7 @@ where
let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.store(&mut memory, ptr)?;
}
} else {
@@ -191,7 +192,7 @@ where
validate_inbounds::<Params>(memory.as_slice(), &storage.assume_init_ref().args)?;
let params = Params::load(&memory, &memory.as_slice()[ptr..][..Params::size()])?;
let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else {
let storage = cast_storage::<ReturnPointer<ValRaw>>(storage).assume_init_ref();
@@ -200,27 +201,28 @@ where
let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.store(&mut memory, ptr)?;
}
}
// TODO: need to call `post-return` before this `drop`
drop(reset_may_leave);
return Ok(());
unsafe fn unset_and_reset_on_drop(slot: *mut bool) -> impl Drop {
debug_assert!(*slot);
*slot = false;
return Reset(slot);
unsafe fn unset_and_reset_on_drop(
slot: *mut VMComponentFlags,
set: fn(&mut VMComponentFlags, bool),
) -> impl Drop {
set(&mut *slot, false);
return Reset(slot, set);
struct Reset(*mut bool);
struct Reset(*mut VMComponentFlags, fn(&mut VMComponentFlags, bool));
impl Drop for Reset {
fn drop(&mut self) {
unsafe {
(*self.0) = true;
(self.1)(&mut *self.0, true);
}
}
}

View File

@@ -90,6 +90,31 @@ where
/// the `store` provided. The `params` are copied into WebAssembly memory
/// as appropriate and a core wasm function is invoked.
///
/// # Post-return
///
/// In the component model each function can have a "post return" specified
/// which allows cleaning up the arguments returned to the host. For example
/// if WebAssembly returns a string to the host then it might be a uniquely
/// allocated string which, after the host finishes processing it, needs to
/// be deallocated in the wasm instance's own linear memory to prevent
/// memory leaks in wasm itself. The `post-return` canonical abi option is
/// used to configured this.
///
/// To accommodate this feature of the component model after invoking a
/// function via [`TypedFunc::call`] you must next invoke
/// [`TypedFunc::post_return`]. Note that the return value of the function
/// should be processed between these two function calls. The return value
/// continues to be usable from an embedder's perspective after
/// `post_return` is called, but after `post_return` is invoked it may no
/// longer retain the same value that the wasm module originally returned.
///
/// Also note that [`TypedFunc::post_return`] must be invoked irrespective
/// of whether the canonical ABI option `post-return` was configured or not.
/// This means that embedders must unconditionally call
/// [`TypedFunc::post_return`] when a function returns. If this function
/// call returns an error, however, then [`TypedFunc::post_return`] is not
/// required.
///
/// # Errors
///
/// This function can return an error for a number of reasons:
@@ -99,6 +124,10 @@ where
/// * If the wasm provides bad allocation pointers when copying arguments
/// into memory.
/// * If the wasm returns a value which violates the canonical ABI.
/// * If this function's instances cannot be entered, for example if the
/// instance is currently calling a host function.
/// * If a previous function call occurred and the corresponding
/// `post_return` hasn't been invoked yet.
///
/// In general there are many ways that things could go wrong when copying
/// types in and out of a wasm module with the canonical ABI, and certain
@@ -300,18 +329,17 @@ where
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
let instance = store.0[instance.0].as_ref().unwrap().instance();
let may_enter = instance.may_enter();
let may_leave = instance.may_leave();
let flags = instance.flags();
unsafe {
if !*may_enter {
if !(*flags).may_enter() {
bail!("cannot reenter component instance");
}
debug_assert!(*may_leave);
debug_assert!((*flags).may_leave());
*may_leave = false;
(*flags).set_may_leave(false);
let result = lower(store, &options, params, map_maybe_uninit!(space.params));
*may_leave = true;
(*flags).set_may_leave(true);
result?;
// This is unsafe as we are providing the guarantee that all the
@@ -336,18 +364,139 @@ where
// `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return
// values).
*may_enter = false;
let result = lift(
store.0,
&options,
map_maybe_uninit!(space.ret).assume_init_ref(),
);
let ret = map_maybe_uninit!(space.ret).assume_init_ref();
// TODO: this technically needs to happen only after the
// `post-return` is called.
*may_enter = true;
return result;
// Lift the result into the host while managing post-return state
// here as well.
//
// Initially the `may_enter` flag is set to `false` for this
// component instance and additionally we set a flag indicating that
// a post-return is required. This isn't specified by the component
// model itself but is used for our implementation of the API of
// `post_return` as a separate function call.
//
// FIXME(WebAssembly/component-model#55) it's not really clear what
// the semantics should be in the face of a lift error/trap. For now
// the flags are reset so the instance can continue to be reused in
// tests but that probably isn't what's desired.
//
// Otherwise though after a successful lift the return value of the
// function, which is currently required to be 0 or 1 values
// according to the canonical ABI, is saved within the `Store`'s
// `FuncData`. This'll later get used in post-return.
(*flags).set_may_enter(false);
(*flags).set_needs_post_return(true);
match lift(store.0, &options, ret) {
Ok(val) => {
let ret_slice = cast_storage(ret);
let data = &mut store.0[self.func.0];
assert!(data.post_return_arg.is_none());
match ret_slice.len() {
0 => data.post_return_arg = Some(ValRaw::i32(0)),
1 => data.post_return_arg = Some(ret_slice[0]),
_ => unreachable!(),
}
return Ok(val);
}
Err(err) => {
(*flags).set_may_enter(true);
(*flags).set_needs_post_return(false);
return Err(err);
}
}
}
unsafe fn cast_storage<T>(storage: &T) -> &[ValRaw] {
assert!(std::mem::size_of_val(storage) % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<ValRaw>());
std::slice::from_raw_parts(
(storage as *const T).cast(),
mem::size_of_val(storage) / mem::size_of::<ValRaw>(),
)
}
}
/// Invokes the `post-return` canonical ABI option, if specified, after a
/// [`TypedFunc::call`] has finished.
///
/// For some more information on when to use this function see the
/// documentation for post-return in the [`TypedFunc::call`] method.
/// Otherwise though this function is a required method call after a
/// [`TypedFunc::call`] completes successfully. After the embedder has
/// finished processing the return value then this function must be invoked.
///
/// # Errors
///
/// This function will return an error in the case of a WebAssembly trap
/// happening during the execution of the `post-return` function, if
/// specified.
///
/// # Panics
///
/// This function will panic if it's not called under the correct
/// conditions. This can only be called after a previous invocation of
/// [`TypedFunc::call`] completes successfully, and this function can only
/// be called for the same [`TypedFunc`] that was `call`'d.
///
/// If this function is called when [`TypedFunc::call`] was not previously
/// called, then it will panic. If a different [`TypedFunc`] for the same
/// component instance was invoked then this function will also panic
/// because the `post-return` needs to happen for the other function.
pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> {
let mut store = store.as_context_mut();
let data = &mut store.0[self.func.0];
let instance = data.instance;
let post_return = data.post_return;
let post_return_arg = data.post_return_arg.take();
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags();
unsafe {
// First assert that the instance is in a "needs post return" state.
// This will ensure that the previous action on the instance was a
// function call above. This flag is only set after a component
// function returns so this also can't be called (as expected)
// during a host import for example.
//
// Note, though, that this assert is not sufficient because it just
// means some function on this instance needs its post-return
// called. We need a precise post-return for a particular function
// which is the second assert here (the `.expect`). That will assert
// that this function itself needs to have its post-return called.
//
// The theory at least is that these two asserts ensure component
// model semantics are upheld where the host properly calls
// `post_return` on the right function despite the call being a
// separate step in the API.
assert!(
(*flags).needs_post_return(),
"post_return can only be called after a function has previously been called",
);
let post_return_arg = post_return_arg.expect("calling post_return on wrong function");
// This is a sanity-check assert which shouldn't ever trip.
assert!(!(*flags).may_enter());
// With the state of the world validated these flags are updated to
// their component-model-defined states.
(*flags).set_may_enter(true);
(*flags).set_needs_post_return(false);
// And finally if the function actually had a `post-return`
// configured in its canonical options that's executed here.
let (func, trampoline) = match post_return {
Some(pair) => pair,
None => return Ok(()),
};
crate::Func::call_unchecked_raw(
&mut store,
func.anyfunc,
trampoline,
&post_return_arg as *const ValRaw as *mut ValRaw,
)?;
}
Ok(())
}
}

View File

@@ -7,9 +7,9 @@ use anyhow::{anyhow, Context, Result};
use std::marker;
use std::sync::Arc;
use wasmtime_environ::component::{
ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractRealloc,
GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex, RuntimeInstanceIndex,
RuntimeModuleIndex,
ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractPostReturn,
ExtractRealloc, GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex,
RuntimeInstanceIndex, RuntimeModuleIndex,
};
use wasmtime_environ::{EntityIndex, PrimaryMap};
use wasmtime_runtime::component::{ComponentInstance, OwnedComponentInstance};
@@ -278,6 +278,10 @@ impl<'a> Instantiator<'a> {
self.extract_realloc(store.0, realloc)
}
GlobalInitializer::ExtractPostReturn(post_return) => {
self.extract_post_return(store.0, post_return)
}
GlobalInitializer::SaveStaticModule(idx) => {
self.data
.exported_modules
@@ -338,6 +342,16 @@ impl<'a> Instantiator<'a> {
self.data.state.set_runtime_realloc(realloc.index, anyfunc);
}
fn extract_post_return(&mut self, store: &mut StoreOpaque, post_return: &ExtractPostReturn) {
let anyfunc = match self.data.lookup_def(store, &post_return.def) {
wasmtime_runtime::Export::Function(f) => f.anyfunc,
_ => unreachable!(),
};
self.data
.state
.set_runtime_post_return(post_return.index, anyfunc);
}
fn build_imports<'b>(
&mut self,
store: &mut StoreOpaque,