Implement the post-return attribute (#4297)

This commit implements the `post-return` feature of the canonical ABI in
the component model. This attribute is an optionally-specified function
which is to be executed after the return value has been processed by the
caller to optionally clean-up the return value. This enables, for
example, returning an allocated string and the host then knows how to
clean it up to prevent memory leaks in the original module.

The API exposed in this PR changes the prior `TypedFunc::call` API in
behavior but not in its signature. Previously the `TypedFunc::call`
method would set the `may_enter` flag on the way out, but now that
operation is deferred until a new `TypedFunc::post_return` method is
called. This means that once a method on an instance is invoked then
nothing else can be done on the instance until the `post_return` method
is called. Note that the method must be called irrespective of whether
the `post-return` canonical ABI option was specified or not. Internally
wasm will be invoked if necessary.

This is a pretty wonky and unergonomic API to work with. For now I
couldn't think of a better alternative that improved on the ergonomics.
In the theory that the raw Wasmtime bindings for a component may not be
used all that heavily (instead `wit-bindgen` would largely be used) I'm
hoping that this isn't too much of an issue in the future.

cc #4185
This commit is contained in:
Alex Crichton
2022-06-23 14:36:21 -05:00
committed by GitHub
parent fa36e86f2c
commit 3339dd1f01
12 changed files with 787 additions and 112 deletions

View File

@@ -54,6 +54,7 @@ impl ComponentCompiler for Compiler {
let CanonicalOptions {
memory,
realloc,
post_return,
string_encoding,
} = lowering.options;
@@ -94,6 +95,11 @@ impl ComponentCompiler for Compiler {
None => builder.ins().iconst(pointer_type, 0),
});
// A post-return option is only valid on `canon.lift`'d functions so no
// valid component should have this specified for a lowering which this
// trampoline compiler is interested in.
assert!(post_return.is_none());
// string_encoding: StringEncoding
host_sig.params.push(ir::AbiParam::new(ir::types::I8));
callee_args.push(

View File

@@ -129,6 +129,9 @@ pub struct Component {
/// `VMComponentContext`.
pub num_runtime_reallocs: u32,
/// Same as `num_runtime_reallocs`, but for post-return functions.
pub num_runtime_post_returns: u32,
/// The number of lowered host functions (maximum `LoweredIndex`) needed to
/// instantiate this component.
pub num_lowerings: u32,
@@ -180,6 +183,10 @@ pub enum GlobalInitializer {
/// used as a `realloc` function.
ExtractRealloc(ExtractRealloc),
/// Same as `ExtractMemory`, except it's extracting a function pointer to be
/// used as a `post-return` function.
ExtractPostReturn(ExtractPostReturn),
/// The `module` specified is saved into the runtime state at the next
/// `RuntimeModuleIndex`, referred to later by `Export` definitions.
SaveStaticModule(StaticModuleIndex),
@@ -207,6 +214,15 @@ pub struct ExtractRealloc {
pub def: CoreDef,
}
/// Same as `ExtractMemory` but for the `post-return` canonical option.
#[derive(Debug, Serialize, Deserialize)]
pub struct ExtractPostReturn {
/// The index of the post-return being defined.
pub index: RuntimePostReturnIndex,
/// Where this post-return is being extracted from.
pub def: CoreDef,
}
/// Different methods of instantiating a core wasm module.
#[derive(Debug, Serialize, Deserialize)]
pub enum InstantiateModule {
@@ -361,7 +377,9 @@ pub struct CanonicalOptions {
/// The realloc function used by these options, if specified.
pub realloc: Option<RuntimeReallocIndex>,
// TODO: need to represent post-return here as well
/// The post-return function used by these options, if specified.
pub post_return: Option<RuntimePostReturnIndex>,
}
impl Default for CanonicalOptions {
@@ -370,6 +388,7 @@ impl Default for CanonicalOptions {
string_encoding: StringEncoding::Utf8,
memory: None,
realloc: None,
post_return: None,
}
}
}

View File

@@ -62,6 +62,7 @@ pub(super) fn run(
result: Component::default(),
import_path_interner: Default::default(),
runtime_realloc_interner: Default::default(),
runtime_post_return_interner: Default::default(),
runtime_memory_interner: Default::default(),
};
@@ -182,6 +183,7 @@ struct Inliner<'a> {
// runtime instead of multiple times.
import_path_interner: HashMap<ImportPath<'a>, RuntimeImportIndex>,
runtime_realloc_interner: HashMap<CoreDef, RuntimeReallocIndex>,
runtime_post_return_interner: HashMap<CoreDef, RuntimePostReturnIndex>,
runtime_memory_interner: HashMap<CoreExport<MemoryIndex>, RuntimeMemoryIndex>,
}
@@ -851,13 +853,29 @@ impl<'a> Inliner<'a> {
index
})
});
if options.post_return.is_some() {
unimplemented!("post-return handling");
}
let post_return = options.post_return.map(|i| {
let def = frame.funcs[i].clone();
*self
.runtime_post_return_interner
.entry(def.clone())
.or_insert_with(|| {
let index =
RuntimePostReturnIndex::from_u32(self.result.num_runtime_post_returns);
self.result.num_runtime_post_returns += 1;
self.result
.initializers
.push(GlobalInitializer::ExtractPostReturn(ExtractPostReturn {
index,
def,
}));
index
})
});
CanonicalOptions {
string_encoding: options.string_encoding,
memory,
realloc,
post_return,
}
}
}

View File

@@ -2,16 +2,18 @@
//
// struct VMComponentContext {
// magic: u32,
// may_enter: u8,
// may_leave: u8,
// flags: u8,
// store: *mut dyn Store,
// lowering_anyfuncs: [VMCallerCheckedAnyfunc; component.num_lowerings],
// lowerings: [VMLowering; component.num_lowerings],
// memories: [*mut VMMemoryDefinition; component.num_memories],
// reallocs: [*mut VMCallerCheckedAnyfunc; component.num_reallocs],
// post_returns: [*mut VMCallerCheckedAnyfunc; component.num_post_returns],
// }
use crate::component::{Component, LoweredIndex, RuntimeMemoryIndex, RuntimeReallocIndex};
use crate::component::{
Component, LoweredIndex, RuntimeMemoryIndex, RuntimePostReturnIndex, RuntimeReallocIndex,
};
use crate::PtrSize;
/// Equivalent of `VMCONTEXT_MAGIC` except for components.
@@ -20,6 +22,18 @@ use crate::PtrSize;
/// double-checked on `VMComponentContext::from_opaque`.
pub const VMCOMPONENT_MAGIC: u32 = u32::from_le_bytes(*b"comp");
/// Flag for the `VMComponentContext::flags` field which corresponds to the
/// canonical ABI flag `may_leave`
pub const VMCOMPONENT_FLAG_MAY_LEAVE: u8 = 1 << 0;
/// Flag for the `VMComponentContext::flags` field which corresponds to the
/// canonical ABI flag `may_enter`
pub const VMCOMPONENT_FLAG_MAY_ENTER: u8 = 1 << 1;
/// Flag for the `VMComponentContext::flags` field which is set whenever a
/// function is called to indicate that `post_return` must be called next.
pub const VMCOMPONENT_FLAG_NEEDS_POST_RETURN: u8 = 1 << 2;
/// Runtime offsets within a `VMComponentContext` for a specific component.
#[derive(Debug, Clone, Copy)]
pub struct VMComponentOffsets<P> {
@@ -32,16 +46,18 @@ pub struct VMComponentOffsets<P> {
pub num_runtime_memories: u32,
/// The number of reallocs which are recorded in this component for options.
pub num_runtime_reallocs: u32,
/// The number of post-returns which are recorded in this component for options.
pub num_runtime_post_returns: u32,
// precalculated offsets of various member fields
magic: u32,
may_enter: u32,
may_leave: u32,
flags: u32,
store: u32,
lowering_anyfuncs: u32,
lowerings: u32,
memories: u32,
reallocs: u32,
post_returns: u32,
size: u32,
}
@@ -60,14 +76,15 @@ impl<P: PtrSize> VMComponentOffsets<P> {
num_lowerings: component.num_lowerings.try_into().unwrap(),
num_runtime_memories: component.num_runtime_memories.try_into().unwrap(),
num_runtime_reallocs: component.num_runtime_reallocs.try_into().unwrap(),
num_runtime_post_returns: component.num_runtime_post_returns.try_into().unwrap(),
magic: 0,
may_enter: 0,
may_leave: 0,
flags: 0,
store: 0,
lowering_anyfuncs: 0,
lowerings: 0,
memories: 0,
reallocs: 0,
post_returns: 0,
size: 0,
};
@@ -97,14 +114,14 @@ impl<P: PtrSize> VMComponentOffsets<P> {
fields! {
size(magic) = 4u32,
size(may_enter) = 1u32,
size(may_leave) = 1u32,
size(flags) = 1u32,
align(u32::from(ret.ptr.size())),
size(store) = cmul(2, ret.ptr.size()),
size(lowering_anyfuncs) = cmul(ret.num_lowerings, ret.ptr.size_of_vmcaller_checked_anyfunc()),
size(lowerings) = cmul(ret.num_lowerings, ret.ptr.size() * 2),
size(memories) = cmul(ret.num_runtime_memories, ret.ptr.size()),
size(reallocs) = cmul(ret.num_runtime_reallocs, ret.ptr.size()),
size(post_returns) = cmul(ret.num_runtime_post_returns, ret.ptr.size()),
}
ret.size = next_field_offset;
@@ -129,16 +146,10 @@ impl<P: PtrSize> VMComponentOffsets<P> {
self.magic
}
/// The offset of the `may_leave` field.
/// The offset of the `flags` field.
#[inline]
pub fn may_leave(&self) -> u32 {
self.may_leave
}
/// The offset of the `may_enter` field.
#[inline]
pub fn may_enter(&self) -> u32 {
self.may_enter
pub fn flags(&self) -> u32 {
self.flags
}
/// The offset of the `store` field.
@@ -232,6 +243,20 @@ impl<P: PtrSize> VMComponentOffsets<P> {
self.runtime_reallocs() + index.as_u32() * u32::from(self.ptr.size())
}
/// The offset of the base of the `runtime_post_returns` field
#[inline]
pub fn runtime_post_returns(&self) -> u32 {
self.post_returns
}
/// The offset of the `*mut VMCallerCheckedAnyfunc` for the runtime index
/// provided.
#[inline]
pub fn runtime_post_return(&self, index: RuntimePostReturnIndex) -> u32 {
assert!(index.as_u32() < self.num_runtime_post_returns);
self.runtime_post_returns() + index.as_u32() * u32::from(self.ptr.size())
}
/// Return the size of the `VMComponentContext` allocation.
#[inline]
pub fn size_of_vmctx(&self) -> u32 {

View File

@@ -17,8 +17,9 @@ use std::mem;
use std::ops::Deref;
use std::ptr::{self, NonNull};
use wasmtime_environ::component::{
Component, LoweredIndex, RuntimeMemoryIndex, RuntimeReallocIndex, StringEncoding,
VMComponentOffsets, VMCOMPONENT_MAGIC,
Component, LoweredIndex, RuntimeMemoryIndex, RuntimePostReturnIndex, RuntimeReallocIndex,
StringEncoding, VMComponentOffsets, VMCOMPONENT_FLAG_MAY_ENTER, VMCOMPONENT_FLAG_MAY_LEAVE,
VMCOMPONENT_FLAG_NEEDS_POST_RETURN, VMCOMPONENT_MAGIC,
};
use wasmtime_environ::HostPtr;
@@ -63,6 +64,11 @@ pub struct ComponentInstance {
/// signature that this callee corresponds to.
/// * `nargs_and_results` - the size, in units of `ValRaw`, of
/// `args_and_results`.
//
// FIXME: 7 arguments is probably too many. The `data` through `string-encoding`
// parameters should probably get packaged up into the `VMComponentContext`.
// Needs benchmarking one way or another though to figure out what the best
// balance is here.
pub type VMLoweringCallee = extern "C" fn(
vmctx: *mut VMOpaqueContext,
data: *mut u8,
@@ -104,6 +110,11 @@ pub struct VMComponentContext {
_marker: marker::PhantomPinned,
}
/// Flags stored in a `VMComponentContext` with values defined by
/// `VMCOMPONENT_FLAG_*`
#[repr(transparent)]
pub struct VMComponentFlags(u8);
impl ComponentInstance {
/// Returns the layout corresponding to what would be an allocation of a
/// `ComponentInstance` for the `offsets` provided.
@@ -159,14 +170,8 @@ impl ComponentInstance {
/// Returns a pointer to the "may leave" flag for this instance specified
/// for canonical lowering and lifting operations.
pub fn may_leave(&self) -> *mut bool {
unsafe { self.vmctx_plus_offset(self.offsets.may_leave()) }
}
/// Returns a pointer to the "may enter" flag for this instance specified
/// for canonical lowering and lifting operations.
pub fn may_enter(&self) -> *mut bool {
unsafe { self.vmctx_plus_offset(self.offsets.may_enter()) }
pub fn flags(&self) -> *mut VMComponentFlags {
unsafe { self.vmctx_plus_offset(self.offsets.flags()) }
}
/// Returns the store that this component was created with.
@@ -202,6 +207,22 @@ impl ComponentInstance {
ret
}
}
/// Returns the post-return pointer corresponding to the index provided.
///
/// This can only be called after `idx` has been initialized at runtime
/// during the instantiation process of a component.
pub fn runtime_post_return(
&self,
idx: RuntimePostReturnIndex,
) -> NonNull<VMCallerCheckedAnyfunc> {
unsafe {
let ret = *self.vmctx_plus_offset::<NonNull<_>>(self.offsets.runtime_post_return(idx));
debug_assert!(ret.as_ptr() as usize != INVALID_PTR);
ret
}
}
/// Returns the host information for the lowered function at the index
/// specified.
///
@@ -264,6 +285,19 @@ impl ComponentInstance {
}
}
/// Same as `set_runtime_memory` but for post-return function pointers.
pub fn set_runtime_post_return(
&mut self,
idx: RuntimePostReturnIndex,
ptr: NonNull<VMCallerCheckedAnyfunc>,
) {
unsafe {
let storage = self.vmctx_plus_offset(self.offsets.runtime_post_return(idx));
debug_assert!(*storage as usize == INVALID_PTR);
*storage = ptr.as_ptr();
}
}
/// Configures a lowered host function with all the pieces necessary.
///
/// * `idx` - the index that's being configured
@@ -304,8 +338,7 @@ impl ComponentInstance {
unsafe fn initialize_vmctx(&mut self, store: *mut dyn Store) {
*self.vmctx_plus_offset(self.offsets.magic()) = VMCOMPONENT_MAGIC;
*self.may_leave() = true;
*self.may_enter() = true;
*self.flags() = VMComponentFlags::new();
*self.vmctx_plus_offset(self.offsets.store()) = store;
// In debug mode set non-null bad values to all "pointer looking" bits
@@ -332,6 +365,11 @@ impl ComponentInstance {
let offset = self.offsets.runtime_realloc(i);
*self.vmctx_plus_offset(offset) = INVALID_PTR;
}
for i in 0..self.offsets.num_runtime_post_returns {
let i = RuntimePostReturnIndex::from_u32(i);
let offset = self.offsets.runtime_post_return(i);
*self.vmctx_plus_offset(offset) = INVALID_PTR;
}
}
}
}
@@ -409,6 +447,15 @@ impl OwnedComponentInstance {
unsafe { self.instance_mut().set_runtime_realloc(idx, ptr) }
}
/// See `ComponentInstance::set_runtime_post_return`
pub fn set_runtime_post_return(
&mut self,
idx: RuntimePostReturnIndex,
ptr: NonNull<VMCallerCheckedAnyfunc>,
) {
unsafe { self.instance_mut().set_runtime_post_return(idx, ptr) }
}
/// See `ComponentInstance::set_lowering`
pub fn set_lowering(
&mut self,
@@ -459,3 +506,52 @@ impl VMOpaqueContext {
ptr.cast()
}
}
#[allow(missing_docs)]
impl VMComponentFlags {
fn new() -> VMComponentFlags {
VMComponentFlags(VMCOMPONENT_FLAG_MAY_LEAVE | VMCOMPONENT_FLAG_MAY_ENTER)
}
#[inline]
pub fn may_leave(&self) -> bool {
self.0 & VMCOMPONENT_FLAG_MAY_LEAVE != 0
}
#[inline]
pub fn set_may_leave(&mut self, val: bool) {
if val {
self.0 |= VMCOMPONENT_FLAG_MAY_LEAVE;
} else {
self.0 &= !VMCOMPONENT_FLAG_MAY_LEAVE;
}
}
#[inline]
pub fn may_enter(&self) -> bool {
self.0 & VMCOMPONENT_FLAG_MAY_ENTER != 0
}
#[inline]
pub fn set_may_enter(&mut self, val: bool) {
if val {
self.0 |= VMCOMPONENT_FLAG_MAY_ENTER;
} else {
self.0 &= !VMCOMPONENT_FLAG_MAY_ENTER;
}
}
#[inline]
pub fn needs_post_return(&self) -> bool {
self.0 & VMCOMPONENT_FLAG_NEEDS_POST_RETURN != 0
}
#[inline]
pub fn set_needs_post_return(&mut self, val: bool) {
if val {
self.0 |= VMCOMPONENT_FLAG_NEEDS_POST_RETURN;
} else {
self.0 &= !VMCOMPONENT_FLAG_NEEDS_POST_RETURN;
}
}
}

View File

@@ -1,6 +1,6 @@
use crate::component::instance::{Instance, InstanceData};
use crate::store::{StoreOpaque, Stored};
use crate::AsContext;
use crate::{AsContext, ValRaw};
use anyhow::{Context, Result};
use std::mem::MaybeUninit;
use std::ptr::NonNull;
@@ -82,6 +82,8 @@ pub struct FuncData {
types: Arc<ComponentTypes>,
options: Options,
instance: Instance,
post_return: Option<(ExportFunction, VMTrampoline)>,
post_return_arg: Option<ValRaw>,
}
impl Func {
@@ -102,6 +104,11 @@ impl Func {
.memory
.map(|i| NonNull::new(data.instance().runtime_memory(i)).unwrap());
let realloc = options.realloc.map(|i| data.instance().runtime_realloc(i));
let post_return = options.post_return.map(|i| {
let anyfunc = data.instance().runtime_post_return(i);
let trampoline = store.lookup_trampoline(unsafe { anyfunc.as_ref() });
(ExportFunction { anyfunc }, trampoline)
});
let options = unsafe { Options::new(store.id(), memory, realloc, options.string_encoding) };
Func(store.store_data_mut().insert(FuncData {
trampoline,
@@ -110,6 +117,8 @@ impl Func {
ty,
types: data.component_types().clone(),
instance: *instance,
post_return,
post_return_arg: None,
}))
}

View File

@@ -8,7 +8,9 @@ use std::panic::{self, AssertUnwindSafe};
use std::ptr::NonNull;
use std::sync::Arc;
use wasmtime_environ::component::{ComponentTypes, StringEncoding, TypeFuncIndex};
use wasmtime_runtime::component::{VMComponentContext, VMLowering, VMLoweringCallee};
use wasmtime_runtime::component::{
VMComponentContext, VMComponentFlags, VMLowering, VMLoweringCallee,
};
use wasmtime_runtime::{VMCallerCheckedAnyfunc, VMMemoryDefinition, VMOpaqueContext};
/// Trait representing host-defined functions that can be imported into a wasm
@@ -134,8 +136,7 @@ where
let cx = VMComponentContext::from_opaque(cx);
let instance = (*cx).instance();
let may_leave = (*instance).may_leave();
let may_enter = (*instance).may_enter();
let flags = (*instance).flags();
let mut cx = StoreContextMut::from_raw((*instance).store());
let options = Options::new(
@@ -148,13 +149,13 @@ where
// Perform a dynamic check that this instance can indeed be left. Exiting
// the component is disallowed, for example, when the `realloc` function
// calls a canonical import.
if !*may_leave {
if !(*flags).may_leave() {
bail!("cannot leave component instance");
}
// While we're lifting and lowering this instance cannot be reentered, so
// unset the flag here. This is also reset back to `true` on exit.
let _reset_may_enter = unset_and_reset_on_drop(may_enter);
let _reset_may_enter = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_enter);
// There's a 2x2 matrix of whether parameters and results are stored on the
// stack or on the heap. Each of the 4 branches here have a different
@@ -172,7 +173,7 @@ where
let storage = cast_storage::<ReturnStack<Params::Lower, Return::Lower>>(storage);
let params = Params::lift(cx.0, &options, &storage.assume_init_ref().args)?;
let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else {
let storage = cast_storage::<ReturnPointer<Params::Lower>>(storage).assume_init_ref();
@@ -180,7 +181,7 @@ where
let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.store(&mut memory, ptr)?;
}
} else {
@@ -191,7 +192,7 @@ where
validate_inbounds::<Params>(memory.as_slice(), &storage.assume_init_ref().args)?;
let params = Params::load(&memory, &memory.as_slice()[ptr..][..Params::size()])?;
let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else {
let storage = cast_storage::<ReturnPointer<ValRaw>>(storage).assume_init_ref();
@@ -200,27 +201,28 @@ where
let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.store(&mut memory, ptr)?;
}
}
// TODO: need to call `post-return` before this `drop`
drop(reset_may_leave);
return Ok(());
unsafe fn unset_and_reset_on_drop(slot: *mut bool) -> impl Drop {
debug_assert!(*slot);
*slot = false;
return Reset(slot);
unsafe fn unset_and_reset_on_drop(
slot: *mut VMComponentFlags,
set: fn(&mut VMComponentFlags, bool),
) -> impl Drop {
set(&mut *slot, false);
return Reset(slot, set);
struct Reset(*mut bool);
struct Reset(*mut VMComponentFlags, fn(&mut VMComponentFlags, bool));
impl Drop for Reset {
fn drop(&mut self) {
unsafe {
(*self.0) = true;
(self.1)(&mut *self.0, true);
}
}
}

View File

@@ -90,6 +90,31 @@ where
/// the `store` provided. The `params` are copied into WebAssembly memory
/// as appropriate and a core wasm function is invoked.
///
/// # Post-return
///
/// In the component model each function can have a "post return" specified
/// which allows cleaning up the arguments returned to the host. For example
/// if WebAssembly returns a string to the host then it might be a uniquely
/// allocated string which, after the host finishes processing it, needs to
/// be deallocated in the wasm instance's own linear memory to prevent
/// memory leaks in wasm itself. The `post-return` canonical abi option is
/// used to configured this.
///
/// To accommodate this feature of the component model after invoking a
/// function via [`TypedFunc::call`] you must next invoke
/// [`TypedFunc::post_return`]. Note that the return value of the function
/// should be processed between these two function calls. The return value
/// continues to be usable from an embedder's perspective after
/// `post_return` is called, but after `post_return` is invoked it may no
/// longer retain the same value that the wasm module originally returned.
///
/// Also note that [`TypedFunc::post_return`] must be invoked irrespective
/// of whether the canonical ABI option `post-return` was configured or not.
/// This means that embedders must unconditionally call
/// [`TypedFunc::post_return`] when a function returns. If this function
/// call returns an error, however, then [`TypedFunc::post_return`] is not
/// required.
///
/// # Errors
///
/// This function can return an error for a number of reasons:
@@ -99,6 +124,10 @@ where
/// * If the wasm provides bad allocation pointers when copying arguments
/// into memory.
/// * If the wasm returns a value which violates the canonical ABI.
/// * If this function's instances cannot be entered, for example if the
/// instance is currently calling a host function.
/// * If a previous function call occurred and the corresponding
/// `post_return` hasn't been invoked yet.
///
/// In general there are many ways that things could go wrong when copying
/// types in and out of a wasm module with the canonical ABI, and certain
@@ -300,18 +329,17 @@ where
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
let instance = store.0[instance.0].as_ref().unwrap().instance();
let may_enter = instance.may_enter();
let may_leave = instance.may_leave();
let flags = instance.flags();
unsafe {
if !*may_enter {
if !(*flags).may_enter() {
bail!("cannot reenter component instance");
}
debug_assert!(*may_leave);
debug_assert!((*flags).may_leave());
*may_leave = false;
(*flags).set_may_leave(false);
let result = lower(store, &options, params, map_maybe_uninit!(space.params));
*may_leave = true;
(*flags).set_may_leave(true);
result?;
// This is unsafe as we are providing the guarantee that all the
@@ -336,18 +364,139 @@ where
// `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return
// values).
*may_enter = false;
let result = lift(
store.0,
&options,
map_maybe_uninit!(space.ret).assume_init_ref(),
);
let ret = map_maybe_uninit!(space.ret).assume_init_ref();
// TODO: this technically needs to happen only after the
// `post-return` is called.
*may_enter = true;
return result;
// Lift the result into the host while managing post-return state
// here as well.
//
// Initially the `may_enter` flag is set to `false` for this
// component instance and additionally we set a flag indicating that
// a post-return is required. This isn't specified by the component
// model itself but is used for our implementation of the API of
// `post_return` as a separate function call.
//
// FIXME(WebAssembly/component-model#55) it's not really clear what
// the semantics should be in the face of a lift error/trap. For now
// the flags are reset so the instance can continue to be reused in
// tests but that probably isn't what's desired.
//
// Otherwise though after a successful lift the return value of the
// function, which is currently required to be 0 or 1 values
// according to the canonical ABI, is saved within the `Store`'s
// `FuncData`. This'll later get used in post-return.
(*flags).set_may_enter(false);
(*flags).set_needs_post_return(true);
match lift(store.0, &options, ret) {
Ok(val) => {
let ret_slice = cast_storage(ret);
let data = &mut store.0[self.func.0];
assert!(data.post_return_arg.is_none());
match ret_slice.len() {
0 => data.post_return_arg = Some(ValRaw::i32(0)),
1 => data.post_return_arg = Some(ret_slice[0]),
_ => unreachable!(),
}
return Ok(val);
}
Err(err) => {
(*flags).set_may_enter(true);
(*flags).set_needs_post_return(false);
return Err(err);
}
}
}
unsafe fn cast_storage<T>(storage: &T) -> &[ValRaw] {
assert!(std::mem::size_of_val(storage) % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<ValRaw>());
std::slice::from_raw_parts(
(storage as *const T).cast(),
mem::size_of_val(storage) / mem::size_of::<ValRaw>(),
)
}
}
/// Invokes the `post-return` canonical ABI option, if specified, after a
/// [`TypedFunc::call`] has finished.
///
/// For some more information on when to use this function see the
/// documentation for post-return in the [`TypedFunc::call`] method.
/// Otherwise though this function is a required method call after a
/// [`TypedFunc::call`] completes successfully. After the embedder has
/// finished processing the return value then this function must be invoked.
///
/// # Errors
///
/// This function will return an error in the case of a WebAssembly trap
/// happening during the execution of the `post-return` function, if
/// specified.
///
/// # Panics
///
/// This function will panic if it's not called under the correct
/// conditions. This can only be called after a previous invocation of
/// [`TypedFunc::call`] completes successfully, and this function can only
/// be called for the same [`TypedFunc`] that was `call`'d.
///
/// If this function is called when [`TypedFunc::call`] was not previously
/// called, then it will panic. If a different [`TypedFunc`] for the same
/// component instance was invoked then this function will also panic
/// because the `post-return` needs to happen for the other function.
pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> {
let mut store = store.as_context_mut();
let data = &mut store.0[self.func.0];
let instance = data.instance;
let post_return = data.post_return;
let post_return_arg = data.post_return_arg.take();
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags();
unsafe {
// First assert that the instance is in a "needs post return" state.
// This will ensure that the previous action on the instance was a
// function call above. This flag is only set after a component
// function returns so this also can't be called (as expected)
// during a host import for example.
//
// Note, though, that this assert is not sufficient because it just
// means some function on this instance needs its post-return
// called. We need a precise post-return for a particular function
// which is the second assert here (the `.expect`). That will assert
// that this function itself needs to have its post-return called.
//
// The theory at least is that these two asserts ensure component
// model semantics are upheld where the host properly calls
// `post_return` on the right function despite the call being a
// separate step in the API.
assert!(
(*flags).needs_post_return(),
"post_return can only be called after a function has previously been called",
);
let post_return_arg = post_return_arg.expect("calling post_return on wrong function");
// This is a sanity-check assert which shouldn't ever trip.
assert!(!(*flags).may_enter());
// With the state of the world validated these flags are updated to
// their component-model-defined states.
(*flags).set_may_enter(true);
(*flags).set_needs_post_return(false);
// And finally if the function actually had a `post-return`
// configured in its canonical options that's executed here.
let (func, trampoline) = match post_return {
Some(pair) => pair,
None => return Ok(()),
};
crate::Func::call_unchecked_raw(
&mut store,
func.anyfunc,
trampoline,
&post_return_arg as *const ValRaw as *mut ValRaw,
)?;
}
Ok(())
}
}

View File

@@ -7,9 +7,9 @@ use anyhow::{anyhow, Context, Result};
use std::marker;
use std::sync::Arc;
use wasmtime_environ::component::{
ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractRealloc,
GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex, RuntimeInstanceIndex,
RuntimeModuleIndex,
ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractPostReturn,
ExtractRealloc, GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex,
RuntimeInstanceIndex, RuntimeModuleIndex,
};
use wasmtime_environ::{EntityIndex, PrimaryMap};
use wasmtime_runtime::component::{ComponentInstance, OwnedComponentInstance};
@@ -278,6 +278,10 @@ impl<'a> Instantiator<'a> {
self.extract_realloc(store.0, realloc)
}
GlobalInitializer::ExtractPostReturn(post_return) => {
self.extract_post_return(store.0, post_return)
}
GlobalInitializer::SaveStaticModule(idx) => {
self.data
.exported_modules
@@ -338,6 +342,16 @@ impl<'a> Instantiator<'a> {
self.data.state.set_runtime_realloc(realloc.index, anyfunc);
}
fn extract_post_return(&mut self, store: &mut StoreOpaque, post_return: &ExtractPostReturn) {
let anyfunc = match self.data.lookup_def(store, &post_return.def) {
wasmtime_runtime::Export::Function(f) => f.anyfunc,
_ => unreachable!(),
};
self.data
.state
.set_runtime_post_return(post_return.index, anyfunc);
}
fn build_imports<'b>(
&mut self,
store: &mut StoreOpaque,

View File

@@ -5,6 +5,7 @@ use wasmtime::{Config, Engine};
mod func;
mod import;
mod nested;
mod post_return;
// A simple bump allocator which can be used with modules
const REALLOC_AND_FREE: &str = r#"

View File

@@ -3,7 +3,7 @@ use anyhow::Result;
use std::rc::Rc;
use std::sync::Arc;
use wasmtime::component::*;
use wasmtime::{Store, StoreContextMut, Trap, TrapCode};
use wasmtime::{AsContextMut, Store, StoreContextMut, Trap, TrapCode};
const CANON_32BIT_NAN: u32 = 0b01111111110000000000000000000000;
const CANON_64BIT_NAN: u64 = 0b0111111111111000000000000000000000000000000000000000000000000000;
@@ -32,7 +32,7 @@ fn thunks() -> Result<()> {
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
instance
.get_typed_func::<(), (), _>(&mut store, "thunk")?
.call(&mut store, ())?;
.call_and_post_return(&mut store, ())?;
let err = instance
.get_typed_func::<(), (), _>(&mut store, "thunk-trap")?
.call(&mut store, ())
@@ -193,28 +193,28 @@ fn integers() -> Result<()> {
// Passing in 100 is valid for all primitives
instance
.get_typed_func::<(u8,), (), _>(&mut store, "take-u8")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i8,), (), _>(&mut store, "take-s8")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(u16,), (), _>(&mut store, "take-u16")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i16,), (), _>(&mut store, "take-s16")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(u32,), (), _>(&mut store, "take-u32")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i32,), (), _>(&mut store, "take-s32")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(u64,), (), _>(&mut store, "take-u64")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i64,), (), _>(&mut store, "take-s64")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
// This specific wasm instance traps if any value other than 100 is passed
instance
@@ -262,49 +262,49 @@ fn integers() -> Result<()> {
assert_eq!(
instance
.get_typed_func::<(), u8, _>(&mut store, "ret-u8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i8, _>(&mut store, "ret-s8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), u16, _>(&mut store, "ret-u16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i16, _>(&mut store, "ret-s16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), u32, _>(&mut store, "ret-u32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i32, _>(&mut store, "ret-s32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), u64, _>(&mut store, "ret-u64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i64, _>(&mut store, "ret-s64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
@@ -312,49 +312,49 @@ fn integers() -> Result<()> {
assert_eq!(
instance
.get_typed_func::<(), u8, _>(&mut store, "retm1-u8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xff
);
assert_eq!(
instance
.get_typed_func::<(), i8, _>(&mut store, "retm1-s8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
assert_eq!(
instance
.get_typed_func::<(), u16, _>(&mut store, "retm1-u16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xffff
);
assert_eq!(
instance
.get_typed_func::<(), i16, _>(&mut store, "retm1-s16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
assert_eq!(
instance
.get_typed_func::<(), u32, _>(&mut store, "retm1-u32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xffffffff
);
assert_eq!(
instance
.get_typed_func::<(), i32, _>(&mut store, "retm1-s32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
assert_eq!(
instance
.get_typed_func::<(), u64, _>(&mut store, "retm1-u64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xffffffff_ffffffff
);
assert_eq!(
instance
.get_typed_func::<(), i64, _>(&mut store, "retm1-s64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
@@ -363,43 +363,59 @@ fn integers() -> Result<()> {
assert_eq!(
instance
.get_typed_func::<(), u8, _>(&mut store, "retbig-u8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as u8,
);
assert_eq!(
instance
.get_typed_func::<(), i8, _>(&mut store, "retbig-s8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as i8,
);
assert_eq!(
instance
.get_typed_func::<(), u16, _>(&mut store, "retbig-u16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as u16,
);
assert_eq!(
instance
.get_typed_func::<(), i16, _>(&mut store, "retbig-s16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as i16,
);
assert_eq!(
instance
.get_typed_func::<(), u32, _>(&mut store, "retbig-u32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret,
);
assert_eq!(
instance
.get_typed_func::<(), i32, _>(&mut store, "retbig-s32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as i32,
);
Ok(())
}
trait TypedFuncExt<P, R> {
fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result<R>;
}
impl<P, R> TypedFuncExt<P, R> for TypedFunc<P, R>
where
P: ComponentParams + Lower,
R: Lift,
{
fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result<R> {
let result = self.call(&mut store, params)?;
self.post_return(&mut store)?;
Ok(result)
}
}
#[test]
fn type_layers() -> Result<()> {
let component = r#"
@@ -425,19 +441,19 @@ fn type_layers() -> Result<()> {
instance
.get_typed_func::<(Box<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Box::new(2),))?;
.call_and_post_return(&mut store, (Box::new(2),))?;
instance
.get_typed_func::<(&u32,), (), _>(&mut store, "take-u32")?
.call(&mut store, (&2,))?;
.call_and_post_return(&mut store, (&2,))?;
instance
.get_typed_func::<(Rc<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Rc::new(2),))?;
.call_and_post_return(&mut store, (Rc::new(2),))?;
instance
.get_typed_func::<(Arc<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Arc::new(2),))?;
.call_and_post_return(&mut store, (Arc::new(2),))?;
instance
.get_typed_func::<(&Box<Arc<Rc<u32>>>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?;
.call_and_post_return(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?;
Ok(())
}
@@ -491,9 +507,13 @@ fn floats() -> Result<()> {
let u64_to_f64 = instance.get_typed_func::<(u64,), f64, _>(&mut store, "u64-to-f64")?;
assert_eq!(f32_to_u32.call(&mut store, (1.0,))?, 1.0f32.to_bits());
f32_to_u32.post_return(&mut store)?;
assert_eq!(f64_to_u64.call(&mut store, (2.0,))?, 2.0f64.to_bits());
f64_to_u64.post_return(&mut store)?;
assert_eq!(u32_to_f32.call(&mut store, (3.0f32.to_bits(),))?, 3.0);
u32_to_f32.post_return(&mut store)?;
assert_eq!(u64_to_f64.call(&mut store, (4.0f64.to_bits(),))?, 4.0);
u64_to_f64.post_return(&mut store)?;
assert_eq!(
u32_to_f32
@@ -501,21 +521,25 @@ fn floats() -> Result<()> {
.to_bits(),
CANON_32BIT_NAN
);
u32_to_f32.post_return(&mut store)?;
assert_eq!(
u64_to_f64
.call(&mut store, (CANON_64BIT_NAN | 1,))?
.to_bits(),
CANON_64BIT_NAN
);
u64_to_f64.post_return(&mut store)?;
assert_eq!(
f32_to_u32.call(&mut store, (f32::from_bits(CANON_32BIT_NAN | 1),))?,
CANON_32BIT_NAN
);
f32_to_u32.post_return(&mut store)?;
assert_eq!(
f64_to_u64.call(&mut store, (f64::from_bits(CANON_64BIT_NAN | 1),))?,
CANON_64BIT_NAN
);
f64_to_u64.post_return(&mut store)?;
Ok(())
}
@@ -546,10 +570,15 @@ fn bools() -> Result<()> {
let bool_to_u32 = instance.get_typed_func::<(bool,), u32, _>(&mut store, "bool-to-u32")?;
assert_eq!(bool_to_u32.call(&mut store, (false,))?, 0);
bool_to_u32.post_return(&mut store)?;
assert_eq!(bool_to_u32.call(&mut store, (true,))?, 1);
bool_to_u32.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (0,))?, false);
u32_to_bool.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (1,))?, true);
u32_to_bool.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (2,))?, true);
u32_to_bool.post_return(&mut store)?;
Ok(())
}
@@ -581,7 +610,9 @@ fn chars() -> Result<()> {
let mut roundtrip = |x: char| -> Result<()> {
assert_eq!(char_to_u32.call(&mut store, (x,))?, x as u32);
char_to_u32.post_return(&mut store)?;
assert_eq!(u32_to_char.call(&mut store, (x as u32,))?, x);
u32_to_char.post_return(&mut store)?;
Ok(())
};
@@ -644,7 +675,7 @@ fn tuple_result() -> Result<()> {
let input = (-1, 100, 3.0, 100.0);
let output = instance
.get_typed_func::<(i8, u16, f32, f64), (i8, u16, f32, f64), _>(&mut store, "tuple")?
.call(&mut store, input)?;
.call_and_post_return(&mut store, input)?;
assert_eq!(input, output);
let invalid_func =
@@ -735,16 +766,20 @@ fn strings() -> Result<()> {
let mut roundtrip = |x: &str| -> Result<()> {
let ret = list8_to_str.call(&mut store, (x.as_bytes(),))?;
assert_eq!(ret.to_str(&store)?, x);
list8_to_str.post_return(&mut store)?;
let utf16 = x.encode_utf16().collect::<Vec<_>>();
let ret = list16_to_str.call(&mut store, (&utf16[..],))?;
assert_eq!(ret.to_str(&store)?, x);
list16_to_str.post_return(&mut store)?;
let ret = str_to_list8.call(&mut store, (x,))?;
assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, x.as_bytes());
str_to_list8.post_return(&mut store)?;
let ret = str_to_list16.call(&mut store, (x,))?;
assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, utf16,);
str_to_list16.post_return(&mut store)?;
Ok(())
};
@@ -758,22 +793,27 @@ fn strings() -> Result<()> {
let ret = list8_to_str.call(&mut store, (b"\xff",))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("invalid utf-8"), "{}", err);
list8_to_str.post_return(&mut store)?;
let ret = list8_to_str.call(&mut store, (b"hello there \xff invalid",))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("invalid utf-8"), "{}", err);
list8_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xd800],))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xdfff],))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xd800, 0xff00],))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
Ok(())
}
@@ -1123,10 +1163,10 @@ fn some_traps() -> Result<()> {
instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[],))?;
.call_and_post_return(&mut store, (&[],))?;
instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[1, 2, 3, 4],))?;
.call_and_post_return(&mut store, (&[1, 2, 3, 4],))?;
let err = instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[1, 2, 3, 4, 5],))
@@ -1134,10 +1174,10 @@ fn some_traps() -> Result<()> {
assert_oob(&err);
instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("",))?;
.call_and_post_return(&mut store, ("",))?;
instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("abcd",))?;
.call_and_post_return(&mut store, ("abcd",))?;
let err = instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("abcde",))
@@ -1216,12 +1256,15 @@ fn char_bool_memory() -> Result<()> {
let ret = func.call(&mut store, (0, 'a' as u32))?;
assert_eq!(ret, (false, 'a'));
func.post_return(&mut store)?;
let ret = func.call(&mut store, (1, '🍰' as u32))?;
assert_eq!(ret, (true, '🍰'));
func.post_return(&mut store)?;
let ret = func.call(&mut store, (2, 'a' as u32))?;
assert_eq!(ret, (true, 'a'));
func.post_return(&mut store)?;
assert!(func.call(&mut store, (0, 0xd800)).is_err());
@@ -1437,22 +1480,30 @@ fn option() -> Result<()> {
let option_unit_to_u32 =
instance.get_typed_func::<(Option<()>,), u32, _>(&mut store, "option-unit-to-u32")?;
assert_eq!(option_unit_to_u32.call(&mut store, (None,))?, 0);
option_unit_to_u32.post_return(&mut store)?;
assert_eq!(option_unit_to_u32.call(&mut store, (Some(()),))?, 1);
option_unit_to_u32.post_return(&mut store)?;
let option_u8_to_tuple = instance
.get_typed_func::<(Option<u8>,), (u32, u32), _>(&mut store, "option-u8-to-tuple")?;
assert_eq!(option_u8_to_tuple.call(&mut store, (None,))?, (0, 0));
option_u8_to_tuple.post_return(&mut store)?;
assert_eq!(option_u8_to_tuple.call(&mut store, (Some(0),))?, (1, 0));
option_u8_to_tuple.post_return(&mut store)?;
assert_eq!(option_u8_to_tuple.call(&mut store, (Some(100),))?, (1, 100));
option_u8_to_tuple.post_return(&mut store)?;
let option_u32_to_tuple = instance
.get_typed_func::<(Option<u32>,), (u32, u32), _>(&mut store, "option-u32-to-tuple")?;
assert_eq!(option_u32_to_tuple.call(&mut store, (None,))?, (0, 0));
option_u32_to_tuple.post_return(&mut store)?;
assert_eq!(option_u32_to_tuple.call(&mut store, (Some(0),))?, (1, 0));
option_u32_to_tuple.post_return(&mut store)?;
assert_eq!(
option_u32_to_tuple.call(&mut store, (Some(100),))?,
(1, 100)
);
option_u32_to_tuple.post_return(&mut store)?;
let option_string_to_tuple = instance.get_typed_func::<(Option<&str>,), (u32, WasmStr), _>(
&mut store,
@@ -1461,45 +1512,59 @@ fn option() -> Result<()> {
let (a, b) = option_string_to_tuple.call(&mut store, (None,))?;
assert_eq!(a, 0);
assert_eq!(b.to_str(&store)?, "");
option_string_to_tuple.post_return(&mut store)?;
let (a, b) = option_string_to_tuple.call(&mut store, (Some(""),))?;
assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "");
option_string_to_tuple.post_return(&mut store)?;
let (a, b) = option_string_to_tuple.call(&mut store, (Some("hello"),))?;
assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "hello");
option_string_to_tuple.post_return(&mut store)?;
let to_option_unit =
instance.get_typed_func::<(u32,), Option<()>, _>(&mut store, "to-option-unit")?;
assert_eq!(to_option_unit.call(&mut store, (0,))?, None);
to_option_unit.post_return(&mut store)?;
assert_eq!(to_option_unit.call(&mut store, (1,))?, Some(()));
to_option_unit.post_return(&mut store)?;
let err = to_option_unit.call(&mut store, (2,)).unwrap_err();
assert!(err.to_string().contains("invalid option"), "{}", err);
let to_option_u8 =
instance.get_typed_func::<(u32, u32), Option<u8>, _>(&mut store, "to-option-u8")?;
assert_eq!(to_option_u8.call(&mut store, (0x00_00, 0))?, None);
to_option_u8.post_return(&mut store)?;
assert_eq!(to_option_u8.call(&mut store, (0x00_01, 0))?, Some(0));
to_option_u8.post_return(&mut store)?;
assert_eq!(to_option_u8.call(&mut store, (0xfd_01, 0))?, Some(0xfd));
to_option_u8.post_return(&mut store)?;
assert!(to_option_u8.call(&mut store, (0x00_02, 0)).is_err());
let to_option_u32 =
instance.get_typed_func::<(u32, u32), Option<u32>, _>(&mut store, "to-option-u32")?;
assert_eq!(to_option_u32.call(&mut store, (0, 0))?, None);
to_option_u32.post_return(&mut store)?;
assert_eq!(to_option_u32.call(&mut store, (1, 0))?, Some(0));
to_option_u32.post_return(&mut store)?;
assert_eq!(
to_option_u32.call(&mut store, (1, 0x1234fead))?,
Some(0x1234fead)
);
to_option_u32.post_return(&mut store)?;
assert!(to_option_u32.call(&mut store, (2, 0)).is_err());
let to_option_string = instance
.get_typed_func::<(u32, &str), Option<WasmStr>, _>(&mut store, "to-option-string")?;
let ret = to_option_string.call(&mut store, (0, ""))?;
assert!(ret.is_none());
to_option_string.post_return(&mut store)?;
let ret = to_option_string.call(&mut store, (1, ""))?;
assert_eq!(ret.unwrap().to_str(&store)?, "");
to_option_string.post_return(&mut store)?;
let ret = to_option_string.call(&mut store, (1, "cheesecake"))?;
assert_eq!(ret.unwrap().to_str(&store)?, "cheesecake");
to_option_string.post_return(&mut store)?;
assert!(to_option_string.call(&mut store, (2, "")).is_err());
Ok(())
@@ -1592,15 +1657,19 @@ fn expected() -> Result<()> {
let take_expected_unit =
instance.get_typed_func::<(Result<(), ()>,), u32, _>(&mut store, "take-expected-unit")?;
assert_eq!(take_expected_unit.call(&mut store, (Ok(()),))?, 0);
take_expected_unit.post_return(&mut store)?;
assert_eq!(take_expected_unit.call(&mut store, (Err(()),))?, 1);
take_expected_unit.post_return(&mut store)?;
let take_expected_u8_f32 = instance
.get_typed_func::<(Result<u8, f32>,), (u32, u32), _>(&mut store, "take-expected-u8-f32")?;
assert_eq!(take_expected_u8_f32.call(&mut store, (Ok(1),))?, (0, 1));
take_expected_u8_f32.post_return(&mut store)?;
assert_eq!(
take_expected_u8_f32.call(&mut store, (Err(2.0),))?,
(1, 2.0f32.to_bits())
);
take_expected_u8_f32.post_return(&mut store)?;
let take_expected_string = instance
.get_typed_func::<(Result<&str, &[u8]>,), (u32, WasmStr), _>(
@@ -1610,27 +1679,35 @@ fn expected() -> Result<()> {
let (a, b) = take_expected_string.call(&mut store, (Ok("hello"),))?;
assert_eq!(a, 0);
assert_eq!(b.to_str(&store)?, "hello");
take_expected_string.post_return(&mut store)?;
let (a, b) = take_expected_string.call(&mut store, (Err(b"goodbye"),))?;
assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "goodbye");
take_expected_string.post_return(&mut store)?;
let to_expected_unit =
instance.get_typed_func::<(u32,), Result<(), ()>, _>(&mut store, "to-expected-unit")?;
assert_eq!(to_expected_unit.call(&mut store, (0,))?, Ok(()));
to_expected_unit.post_return(&mut store)?;
assert_eq!(to_expected_unit.call(&mut store, (1,))?, Err(()));
to_expected_unit.post_return(&mut store)?;
let err = to_expected_unit.call(&mut store, (2,)).unwrap_err();
assert!(err.to_string().contains("invalid expected"), "{}", err);
let to_expected_s16_f32 = instance
.get_typed_func::<(u32, u32), Result<i16, f32>, _>(&mut store, "to-expected-s16-f32")?;
assert_eq!(to_expected_s16_f32.call(&mut store, (0, 0))?, Ok(0));
to_expected_s16_f32.post_return(&mut store)?;
assert_eq!(to_expected_s16_f32.call(&mut store, (0, 100))?, Ok(100));
to_expected_s16_f32.post_return(&mut store)?;
assert_eq!(
to_expected_s16_f32.call(&mut store, (1, 1.0f32.to_bits()))?,
Err(1.0)
);
to_expected_s16_f32.post_return(&mut store)?;
let ret = to_expected_s16_f32.call(&mut store, (1, CANON_32BIT_NAN | 1))?;
assert_eq!(ret.unwrap_err().to_bits(), CANON_32BIT_NAN);
to_expected_s16_f32.post_return(&mut store)?;
assert!(to_expected_s16_f32.call(&mut store, (2, 0)).is_err());
Ok(())

View File

@@ -0,0 +1,259 @@
use anyhow::Result;
use wasmtime::component::*;
use wasmtime::{Store, StoreContextMut};
#[test]
fn invalid_api() -> Result<()> {
let component = r#"
(component
(core module $m
(func (export "thunk1"))
(func (export "thunk2"))
)
(core instance $i (instantiate $m))
(func (export "thunk1")
(canon lift (core func $i "thunk1"))
)
(func (export "thunk2")
(canon lift (core func $i "thunk2"))
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, ());
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let thunk1 = instance.get_typed_func::<(), (), _>(&mut store, "thunk1")?;
let thunk2 = instance.get_typed_func::<(), (), _>(&mut store, "thunk2")?;
// Ensure that we can't call `post_return` before doing anything
let msg = "post_return can only be called after a function has previously been called";
assert_panics(|| drop(thunk1.post_return(&mut store)), msg);
assert_panics(|| drop(thunk2.post_return(&mut store)), msg);
// Schedule a "needs post return"
thunk1.call(&mut store, ())?;
// Ensure that we can't reenter the instance through either this function or
// another one.
let err = thunk1.call(&mut store, ()).unwrap_err();
assert!(
err.to_string()
.contains("cannot reenter component instance"),
"{}",
err
);
let err = thunk2.call(&mut store, ()).unwrap_err();
assert!(
err.to_string()
.contains("cannot reenter component instance"),
"{}",
err
);
// Calling post-return on the wrong function should panic
assert_panics(
|| drop(thunk2.post_return(&mut store)),
"calling post_return on wrong function",
);
// Actually execute the post-return
thunk1.post_return(&mut store)?;
// And now post-return should be invalid again.
assert_panics(|| drop(thunk1.post_return(&mut store)), msg);
assert_panics(|| drop(thunk2.post_return(&mut store)), msg);
Ok(())
}
#[track_caller]
fn assert_panics(f: impl FnOnce(), msg: &str) {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(()) => panic!("expected closure to panic"),
Err(e) => match e.downcast::<String>() {
Ok(s) => {
assert!(s.contains(msg), "bad panic: {}", s);
}
Err(e) => match e.downcast::<&'static str>() {
Ok(s) => assert!(s.contains(msg), "bad panic: {}", s),
Err(_) => panic!("bad panic"),
},
},
}
}
#[test]
fn invoke_post_return() -> Result<()> {
let component = r#"
(component
(import "f" (func $f))
(core func $f_lower
(canon lower (func $f))
)
(core module $m
(import "" "" (func $f))
(func (export "thunk"))
(func $post_return
call $f)
(export "post-return" (func $post_return))
)
(core instance $i (instantiate $m
(with "" (instance
(export "" (func $f_lower))
))
))
(func (export "thunk")
(canon lift
(core func $i "thunk")
(post-return (func $i "post-return"))
)
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let mut linker = Linker::new(&engine);
linker
.root()
.func_wrap("f", |mut store: StoreContextMut<'_, bool>| -> Result<()> {
assert!(!*store.data());
*store.data_mut() = true;
Ok(())
})?;
let instance = linker.instantiate(&mut store, &component)?;
let thunk = instance.get_typed_func::<(), (), _>(&mut store, "thunk")?;
assert!(!*store.data());
thunk.call(&mut store, ())?;
assert!(!*store.data());
thunk.post_return(&mut store)?;
assert!(*store.data());
Ok(())
}
#[test]
fn post_return_all_types() -> Result<()> {
let component = r#"
(component
(core module $m
(func (export "i32") (result i32)
i32.const 1)
(func (export "i64") (result i64)
i64.const 2)
(func (export "f32") (result f32)
f32.const 3)
(func (export "f64") (result f64)
f64.const 4)
(func (export "post-i32") (param i32)
local.get 0
i32.const 1
i32.ne
if unreachable end)
(func (export "post-i64") (param i64)
local.get 0
i64.const 2
i64.ne
if unreachable end)
(func (export "post-f32") (param f32)
local.get 0
f32.const 3
f32.ne
if unreachable end)
(func (export "post-f64") (param f64)
local.get 0
f64.const 4
f64.ne
if unreachable end)
)
(core instance $i (instantiate $m))
(func (export "i32") (result u32)
(canon lift (core func $i "i32") (post-return (func $i "post-i32")))
)
(func (export "i64") (result u64)
(canon lift (core func $i "i64") (post-return (func $i "post-i64")))
)
(func (export "f32") (result float32)
(canon lift (core func $i "f32") (post-return (func $i "post-f32")))
)
(func (export "f64") (result float64)
(canon lift (core func $i "f64") (post-return (func $i "post-f64")))
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let i32 = instance.get_typed_func::<(), u32, _>(&mut store, "i32")?;
let i64 = instance.get_typed_func::<(), u64, _>(&mut store, "i64")?;
let f32 = instance.get_typed_func::<(), f32, _>(&mut store, "f32")?;
let f64 = instance.get_typed_func::<(), f64, _>(&mut store, "f64")?;
assert_eq!(i32.call(&mut store, ())?, 1);
i32.post_return(&mut store)?;
assert_eq!(i64.call(&mut store, ())?, 2);
i64.post_return(&mut store)?;
assert_eq!(f32.call(&mut store, ())?, 3.);
f32.post_return(&mut store)?;
assert_eq!(f64.call(&mut store, ())?, 4.);
f64.post_return(&mut store)?;
Ok(())
}
#[test]
fn post_return_string() -> Result<()> {
let component = r#"
(component
(core module $m
(memory (export "memory") 1)
(func (export "get") (result i32)
(i32.store offset=0 (i32.const 8) (i32.const 100))
(i32.store offset=4 (i32.const 8) (i32.const 11))
i32.const 8
)
(func (export "post") (param i32)
local.get 0
i32.const 8
i32.ne
if unreachable end)
(data (i32.const 100) "hello world")
)
(core instance $i (instantiate $m))
(func (export "get") (result string)
(canon lift
(core func $i "get")
(post-return (func $i "post"))
(memory $i "memory")
)
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let get = instance.get_typed_func::<(), WasmStr, _>(&mut store, "get")?;
let s = get.call(&mut store, ())?;
assert_eq!(s.to_str(&store)?, "hello world");
get.post_return(&mut store)?;
Ok(())
}