Implement the post-return attribute (#4297)

This commit implements the `post-return` feature of the canonical ABI in
the component model. This attribute is an optionally-specified function
which is to be executed after the return value has been processed by the
caller to optionally clean-up the return value. This enables, for
example, returning an allocated string and the host then knows how to
clean it up to prevent memory leaks in the original module.

The API exposed in this PR changes the prior `TypedFunc::call` API in
behavior but not in its signature. Previously the `TypedFunc::call`
method would set the `may_enter` flag on the way out, but now that
operation is deferred until a new `TypedFunc::post_return` method is
called. This means that once a method on an instance is invoked then
nothing else can be done on the instance until the `post_return` method
is called. Note that the method must be called irrespective of whether
the `post-return` canonical ABI option was specified or not. Internally
wasm will be invoked if necessary.

This is a pretty wonky and unergonomic API to work with. For now I
couldn't think of a better alternative that improved on the ergonomics.
In the theory that the raw Wasmtime bindings for a component may not be
used all that heavily (instead `wit-bindgen` would largely be used) I'm
hoping that this isn't too much of an issue in the future.

cc #4185
This commit is contained in:
Alex Crichton
2022-06-23 14:36:21 -05:00
committed by GitHub
parent fa36e86f2c
commit 3339dd1f01
12 changed files with 787 additions and 112 deletions

View File

@@ -54,6 +54,7 @@ impl ComponentCompiler for Compiler {
let CanonicalOptions { let CanonicalOptions {
memory, memory,
realloc, realloc,
post_return,
string_encoding, string_encoding,
} = lowering.options; } = lowering.options;
@@ -94,6 +95,11 @@ impl ComponentCompiler for Compiler {
None => builder.ins().iconst(pointer_type, 0), None => builder.ins().iconst(pointer_type, 0),
}); });
// A post-return option is only valid on `canon.lift`'d functions so no
// valid component should have this specified for a lowering which this
// trampoline compiler is interested in.
assert!(post_return.is_none());
// string_encoding: StringEncoding // string_encoding: StringEncoding
host_sig.params.push(ir::AbiParam::new(ir::types::I8)); host_sig.params.push(ir::AbiParam::new(ir::types::I8));
callee_args.push( callee_args.push(

View File

@@ -129,6 +129,9 @@ pub struct Component {
/// `VMComponentContext`. /// `VMComponentContext`.
pub num_runtime_reallocs: u32, pub num_runtime_reallocs: u32,
/// Same as `num_runtime_reallocs`, but for post-return functions.
pub num_runtime_post_returns: u32,
/// The number of lowered host functions (maximum `LoweredIndex`) needed to /// The number of lowered host functions (maximum `LoweredIndex`) needed to
/// instantiate this component. /// instantiate this component.
pub num_lowerings: u32, pub num_lowerings: u32,
@@ -180,6 +183,10 @@ pub enum GlobalInitializer {
/// used as a `realloc` function. /// used as a `realloc` function.
ExtractRealloc(ExtractRealloc), ExtractRealloc(ExtractRealloc),
/// Same as `ExtractMemory`, except it's extracting a function pointer to be
/// used as a `post-return` function.
ExtractPostReturn(ExtractPostReturn),
/// The `module` specified is saved into the runtime state at the next /// The `module` specified is saved into the runtime state at the next
/// `RuntimeModuleIndex`, referred to later by `Export` definitions. /// `RuntimeModuleIndex`, referred to later by `Export` definitions.
SaveStaticModule(StaticModuleIndex), SaveStaticModule(StaticModuleIndex),
@@ -207,6 +214,15 @@ pub struct ExtractRealloc {
pub def: CoreDef, pub def: CoreDef,
} }
/// Same as `ExtractMemory` but for the `post-return` canonical option.
#[derive(Debug, Serialize, Deserialize)]
pub struct ExtractPostReturn {
/// The index of the post-return being defined.
pub index: RuntimePostReturnIndex,
/// Where this post-return is being extracted from.
pub def: CoreDef,
}
/// Different methods of instantiating a core wasm module. /// Different methods of instantiating a core wasm module.
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum InstantiateModule { pub enum InstantiateModule {
@@ -361,7 +377,9 @@ pub struct CanonicalOptions {
/// The realloc function used by these options, if specified. /// The realloc function used by these options, if specified.
pub realloc: Option<RuntimeReallocIndex>, pub realloc: Option<RuntimeReallocIndex>,
// TODO: need to represent post-return here as well
/// The post-return function used by these options, if specified.
pub post_return: Option<RuntimePostReturnIndex>,
} }
impl Default for CanonicalOptions { impl Default for CanonicalOptions {
@@ -370,6 +388,7 @@ impl Default for CanonicalOptions {
string_encoding: StringEncoding::Utf8, string_encoding: StringEncoding::Utf8,
memory: None, memory: None,
realloc: None, realloc: None,
post_return: None,
} }
} }
} }

View File

@@ -62,6 +62,7 @@ pub(super) fn run(
result: Component::default(), result: Component::default(),
import_path_interner: Default::default(), import_path_interner: Default::default(),
runtime_realloc_interner: Default::default(), runtime_realloc_interner: Default::default(),
runtime_post_return_interner: Default::default(),
runtime_memory_interner: Default::default(), runtime_memory_interner: Default::default(),
}; };
@@ -182,6 +183,7 @@ struct Inliner<'a> {
// runtime instead of multiple times. // runtime instead of multiple times.
import_path_interner: HashMap<ImportPath<'a>, RuntimeImportIndex>, import_path_interner: HashMap<ImportPath<'a>, RuntimeImportIndex>,
runtime_realloc_interner: HashMap<CoreDef, RuntimeReallocIndex>, runtime_realloc_interner: HashMap<CoreDef, RuntimeReallocIndex>,
runtime_post_return_interner: HashMap<CoreDef, RuntimePostReturnIndex>,
runtime_memory_interner: HashMap<CoreExport<MemoryIndex>, RuntimeMemoryIndex>, runtime_memory_interner: HashMap<CoreExport<MemoryIndex>, RuntimeMemoryIndex>,
} }
@@ -851,13 +853,29 @@ impl<'a> Inliner<'a> {
index index
}) })
}); });
if options.post_return.is_some() { let post_return = options.post_return.map(|i| {
unimplemented!("post-return handling"); let def = frame.funcs[i].clone();
} *self
.runtime_post_return_interner
.entry(def.clone())
.or_insert_with(|| {
let index =
RuntimePostReturnIndex::from_u32(self.result.num_runtime_post_returns);
self.result.num_runtime_post_returns += 1;
self.result
.initializers
.push(GlobalInitializer::ExtractPostReturn(ExtractPostReturn {
index,
def,
}));
index
})
});
CanonicalOptions { CanonicalOptions {
string_encoding: options.string_encoding, string_encoding: options.string_encoding,
memory, memory,
realloc, realloc,
post_return,
} }
} }
} }

View File

@@ -2,16 +2,18 @@
// //
// struct VMComponentContext { // struct VMComponentContext {
// magic: u32, // magic: u32,
// may_enter: u8, // flags: u8,
// may_leave: u8,
// store: *mut dyn Store, // store: *mut dyn Store,
// lowering_anyfuncs: [VMCallerCheckedAnyfunc; component.num_lowerings], // lowering_anyfuncs: [VMCallerCheckedAnyfunc; component.num_lowerings],
// lowerings: [VMLowering; component.num_lowerings], // lowerings: [VMLowering; component.num_lowerings],
// memories: [*mut VMMemoryDefinition; component.num_memories], // memories: [*mut VMMemoryDefinition; component.num_memories],
// reallocs: [*mut VMCallerCheckedAnyfunc; component.num_reallocs], // reallocs: [*mut VMCallerCheckedAnyfunc; component.num_reallocs],
// post_returns: [*mut VMCallerCheckedAnyfunc; component.num_post_returns],
// } // }
use crate::component::{Component, LoweredIndex, RuntimeMemoryIndex, RuntimeReallocIndex}; use crate::component::{
Component, LoweredIndex, RuntimeMemoryIndex, RuntimePostReturnIndex, RuntimeReallocIndex,
};
use crate::PtrSize; use crate::PtrSize;
/// Equivalent of `VMCONTEXT_MAGIC` except for components. /// Equivalent of `VMCONTEXT_MAGIC` except for components.
@@ -20,6 +22,18 @@ use crate::PtrSize;
/// double-checked on `VMComponentContext::from_opaque`. /// double-checked on `VMComponentContext::from_opaque`.
pub const VMCOMPONENT_MAGIC: u32 = u32::from_le_bytes(*b"comp"); pub const VMCOMPONENT_MAGIC: u32 = u32::from_le_bytes(*b"comp");
/// Flag for the `VMComponentContext::flags` field which corresponds to the
/// canonical ABI flag `may_leave`
pub const VMCOMPONENT_FLAG_MAY_LEAVE: u8 = 1 << 0;
/// Flag for the `VMComponentContext::flags` field which corresponds to the
/// canonical ABI flag `may_enter`
pub const VMCOMPONENT_FLAG_MAY_ENTER: u8 = 1 << 1;
/// Flag for the `VMComponentContext::flags` field which is set whenever a
/// function is called to indicate that `post_return` must be called next.
pub const VMCOMPONENT_FLAG_NEEDS_POST_RETURN: u8 = 1 << 2;
/// Runtime offsets within a `VMComponentContext` for a specific component. /// Runtime offsets within a `VMComponentContext` for a specific component.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct VMComponentOffsets<P> { pub struct VMComponentOffsets<P> {
@@ -32,16 +46,18 @@ pub struct VMComponentOffsets<P> {
pub num_runtime_memories: u32, pub num_runtime_memories: u32,
/// The number of reallocs which are recorded in this component for options. /// The number of reallocs which are recorded in this component for options.
pub num_runtime_reallocs: u32, pub num_runtime_reallocs: u32,
/// The number of post-returns which are recorded in this component for options.
pub num_runtime_post_returns: u32,
// precalculated offsets of various member fields // precalculated offsets of various member fields
magic: u32, magic: u32,
may_enter: u32, flags: u32,
may_leave: u32,
store: u32, store: u32,
lowering_anyfuncs: u32, lowering_anyfuncs: u32,
lowerings: u32, lowerings: u32,
memories: u32, memories: u32,
reallocs: u32, reallocs: u32,
post_returns: u32,
size: u32, size: u32,
} }
@@ -60,14 +76,15 @@ impl<P: PtrSize> VMComponentOffsets<P> {
num_lowerings: component.num_lowerings.try_into().unwrap(), num_lowerings: component.num_lowerings.try_into().unwrap(),
num_runtime_memories: component.num_runtime_memories.try_into().unwrap(), num_runtime_memories: component.num_runtime_memories.try_into().unwrap(),
num_runtime_reallocs: component.num_runtime_reallocs.try_into().unwrap(), num_runtime_reallocs: component.num_runtime_reallocs.try_into().unwrap(),
num_runtime_post_returns: component.num_runtime_post_returns.try_into().unwrap(),
magic: 0, magic: 0,
may_enter: 0, flags: 0,
may_leave: 0,
store: 0, store: 0,
lowering_anyfuncs: 0, lowering_anyfuncs: 0,
lowerings: 0, lowerings: 0,
memories: 0, memories: 0,
reallocs: 0, reallocs: 0,
post_returns: 0,
size: 0, size: 0,
}; };
@@ -97,14 +114,14 @@ impl<P: PtrSize> VMComponentOffsets<P> {
fields! { fields! {
size(magic) = 4u32, size(magic) = 4u32,
size(may_enter) = 1u32, size(flags) = 1u32,
size(may_leave) = 1u32,
align(u32::from(ret.ptr.size())), align(u32::from(ret.ptr.size())),
size(store) = cmul(2, ret.ptr.size()), size(store) = cmul(2, ret.ptr.size()),
size(lowering_anyfuncs) = cmul(ret.num_lowerings, ret.ptr.size_of_vmcaller_checked_anyfunc()), size(lowering_anyfuncs) = cmul(ret.num_lowerings, ret.ptr.size_of_vmcaller_checked_anyfunc()),
size(lowerings) = cmul(ret.num_lowerings, ret.ptr.size() * 2), size(lowerings) = cmul(ret.num_lowerings, ret.ptr.size() * 2),
size(memories) = cmul(ret.num_runtime_memories, ret.ptr.size()), size(memories) = cmul(ret.num_runtime_memories, ret.ptr.size()),
size(reallocs) = cmul(ret.num_runtime_reallocs, ret.ptr.size()), size(reallocs) = cmul(ret.num_runtime_reallocs, ret.ptr.size()),
size(post_returns) = cmul(ret.num_runtime_post_returns, ret.ptr.size()),
} }
ret.size = next_field_offset; ret.size = next_field_offset;
@@ -129,16 +146,10 @@ impl<P: PtrSize> VMComponentOffsets<P> {
self.magic self.magic
} }
/// The offset of the `may_leave` field. /// The offset of the `flags` field.
#[inline] #[inline]
pub fn may_leave(&self) -> u32 { pub fn flags(&self) -> u32 {
self.may_leave self.flags
}
/// The offset of the `may_enter` field.
#[inline]
pub fn may_enter(&self) -> u32 {
self.may_enter
} }
/// The offset of the `store` field. /// The offset of the `store` field.
@@ -232,6 +243,20 @@ impl<P: PtrSize> VMComponentOffsets<P> {
self.runtime_reallocs() + index.as_u32() * u32::from(self.ptr.size()) self.runtime_reallocs() + index.as_u32() * u32::from(self.ptr.size())
} }
/// The offset of the base of the `runtime_post_returns` field
#[inline]
pub fn runtime_post_returns(&self) -> u32 {
self.post_returns
}
/// The offset of the `*mut VMCallerCheckedAnyfunc` for the runtime index
/// provided.
#[inline]
pub fn runtime_post_return(&self, index: RuntimePostReturnIndex) -> u32 {
assert!(index.as_u32() < self.num_runtime_post_returns);
self.runtime_post_returns() + index.as_u32() * u32::from(self.ptr.size())
}
/// Return the size of the `VMComponentContext` allocation. /// Return the size of the `VMComponentContext` allocation.
#[inline] #[inline]
pub fn size_of_vmctx(&self) -> u32 { pub fn size_of_vmctx(&self) -> u32 {

View File

@@ -17,8 +17,9 @@ use std::mem;
use std::ops::Deref; use std::ops::Deref;
use std::ptr::{self, NonNull}; use std::ptr::{self, NonNull};
use wasmtime_environ::component::{ use wasmtime_environ::component::{
Component, LoweredIndex, RuntimeMemoryIndex, RuntimeReallocIndex, StringEncoding, Component, LoweredIndex, RuntimeMemoryIndex, RuntimePostReturnIndex, RuntimeReallocIndex,
VMComponentOffsets, VMCOMPONENT_MAGIC, StringEncoding, VMComponentOffsets, VMCOMPONENT_FLAG_MAY_ENTER, VMCOMPONENT_FLAG_MAY_LEAVE,
VMCOMPONENT_FLAG_NEEDS_POST_RETURN, VMCOMPONENT_MAGIC,
}; };
use wasmtime_environ::HostPtr; use wasmtime_environ::HostPtr;
@@ -63,6 +64,11 @@ pub struct ComponentInstance {
/// signature that this callee corresponds to. /// signature that this callee corresponds to.
/// * `nargs_and_results` - the size, in units of `ValRaw`, of /// * `nargs_and_results` - the size, in units of `ValRaw`, of
/// `args_and_results`. /// `args_and_results`.
//
// FIXME: 7 arguments is probably too many. The `data` through `string-encoding`
// parameters should probably get packaged up into the `VMComponentContext`.
// Needs benchmarking one way or another though to figure out what the best
// balance is here.
pub type VMLoweringCallee = extern "C" fn( pub type VMLoweringCallee = extern "C" fn(
vmctx: *mut VMOpaqueContext, vmctx: *mut VMOpaqueContext,
data: *mut u8, data: *mut u8,
@@ -104,6 +110,11 @@ pub struct VMComponentContext {
_marker: marker::PhantomPinned, _marker: marker::PhantomPinned,
} }
/// Flags stored in a `VMComponentContext` with values defined by
/// `VMCOMPONENT_FLAG_*`
#[repr(transparent)]
pub struct VMComponentFlags(u8);
impl ComponentInstance { impl ComponentInstance {
/// Returns the layout corresponding to what would be an allocation of a /// Returns the layout corresponding to what would be an allocation of a
/// `ComponentInstance` for the `offsets` provided. /// `ComponentInstance` for the `offsets` provided.
@@ -159,14 +170,8 @@ impl ComponentInstance {
/// Returns a pointer to the "may leave" flag for this instance specified /// Returns a pointer to the "may leave" flag for this instance specified
/// for canonical lowering and lifting operations. /// for canonical lowering and lifting operations.
pub fn may_leave(&self) -> *mut bool { pub fn flags(&self) -> *mut VMComponentFlags {
unsafe { self.vmctx_plus_offset(self.offsets.may_leave()) } unsafe { self.vmctx_plus_offset(self.offsets.flags()) }
}
/// Returns a pointer to the "may enter" flag for this instance specified
/// for canonical lowering and lifting operations.
pub fn may_enter(&self) -> *mut bool {
unsafe { self.vmctx_plus_offset(self.offsets.may_enter()) }
} }
/// Returns the store that this component was created with. /// Returns the store that this component was created with.
@@ -202,6 +207,22 @@ impl ComponentInstance {
ret ret
} }
} }
/// Returns the post-return pointer corresponding to the index provided.
///
/// This can only be called after `idx` has been initialized at runtime
/// during the instantiation process of a component.
pub fn runtime_post_return(
&self,
idx: RuntimePostReturnIndex,
) -> NonNull<VMCallerCheckedAnyfunc> {
unsafe {
let ret = *self.vmctx_plus_offset::<NonNull<_>>(self.offsets.runtime_post_return(idx));
debug_assert!(ret.as_ptr() as usize != INVALID_PTR);
ret
}
}
/// Returns the host information for the lowered function at the index /// Returns the host information for the lowered function at the index
/// specified. /// specified.
/// ///
@@ -264,6 +285,19 @@ impl ComponentInstance {
} }
} }
/// Same as `set_runtime_memory` but for post-return function pointers.
pub fn set_runtime_post_return(
&mut self,
idx: RuntimePostReturnIndex,
ptr: NonNull<VMCallerCheckedAnyfunc>,
) {
unsafe {
let storage = self.vmctx_plus_offset(self.offsets.runtime_post_return(idx));
debug_assert!(*storage as usize == INVALID_PTR);
*storage = ptr.as_ptr();
}
}
/// Configures a lowered host function with all the pieces necessary. /// Configures a lowered host function with all the pieces necessary.
/// ///
/// * `idx` - the index that's being configured /// * `idx` - the index that's being configured
@@ -304,8 +338,7 @@ impl ComponentInstance {
unsafe fn initialize_vmctx(&mut self, store: *mut dyn Store) { unsafe fn initialize_vmctx(&mut self, store: *mut dyn Store) {
*self.vmctx_plus_offset(self.offsets.magic()) = VMCOMPONENT_MAGIC; *self.vmctx_plus_offset(self.offsets.magic()) = VMCOMPONENT_MAGIC;
*self.may_leave() = true; *self.flags() = VMComponentFlags::new();
*self.may_enter() = true;
*self.vmctx_plus_offset(self.offsets.store()) = store; *self.vmctx_plus_offset(self.offsets.store()) = store;
// In debug mode set non-null bad values to all "pointer looking" bits // In debug mode set non-null bad values to all "pointer looking" bits
@@ -332,6 +365,11 @@ impl ComponentInstance {
let offset = self.offsets.runtime_realloc(i); let offset = self.offsets.runtime_realloc(i);
*self.vmctx_plus_offset(offset) = INVALID_PTR; *self.vmctx_plus_offset(offset) = INVALID_PTR;
} }
for i in 0..self.offsets.num_runtime_post_returns {
let i = RuntimePostReturnIndex::from_u32(i);
let offset = self.offsets.runtime_post_return(i);
*self.vmctx_plus_offset(offset) = INVALID_PTR;
}
} }
} }
} }
@@ -409,6 +447,15 @@ impl OwnedComponentInstance {
unsafe { self.instance_mut().set_runtime_realloc(idx, ptr) } unsafe { self.instance_mut().set_runtime_realloc(idx, ptr) }
} }
/// See `ComponentInstance::set_runtime_post_return`
pub fn set_runtime_post_return(
&mut self,
idx: RuntimePostReturnIndex,
ptr: NonNull<VMCallerCheckedAnyfunc>,
) {
unsafe { self.instance_mut().set_runtime_post_return(idx, ptr) }
}
/// See `ComponentInstance::set_lowering` /// See `ComponentInstance::set_lowering`
pub fn set_lowering( pub fn set_lowering(
&mut self, &mut self,
@@ -459,3 +506,52 @@ impl VMOpaqueContext {
ptr.cast() ptr.cast()
} }
} }
#[allow(missing_docs)]
impl VMComponentFlags {
fn new() -> VMComponentFlags {
VMComponentFlags(VMCOMPONENT_FLAG_MAY_LEAVE | VMCOMPONENT_FLAG_MAY_ENTER)
}
#[inline]
pub fn may_leave(&self) -> bool {
self.0 & VMCOMPONENT_FLAG_MAY_LEAVE != 0
}
#[inline]
pub fn set_may_leave(&mut self, val: bool) {
if val {
self.0 |= VMCOMPONENT_FLAG_MAY_LEAVE;
} else {
self.0 &= !VMCOMPONENT_FLAG_MAY_LEAVE;
}
}
#[inline]
pub fn may_enter(&self) -> bool {
self.0 & VMCOMPONENT_FLAG_MAY_ENTER != 0
}
#[inline]
pub fn set_may_enter(&mut self, val: bool) {
if val {
self.0 |= VMCOMPONENT_FLAG_MAY_ENTER;
} else {
self.0 &= !VMCOMPONENT_FLAG_MAY_ENTER;
}
}
#[inline]
pub fn needs_post_return(&self) -> bool {
self.0 & VMCOMPONENT_FLAG_NEEDS_POST_RETURN != 0
}
#[inline]
pub fn set_needs_post_return(&mut self, val: bool) {
if val {
self.0 |= VMCOMPONENT_FLAG_NEEDS_POST_RETURN;
} else {
self.0 &= !VMCOMPONENT_FLAG_NEEDS_POST_RETURN;
}
}
}

View File

@@ -1,6 +1,6 @@
use crate::component::instance::{Instance, InstanceData}; use crate::component::instance::{Instance, InstanceData};
use crate::store::{StoreOpaque, Stored}; use crate::store::{StoreOpaque, Stored};
use crate::AsContext; use crate::{AsContext, ValRaw};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::ptr::NonNull; use std::ptr::NonNull;
@@ -82,6 +82,8 @@ pub struct FuncData {
types: Arc<ComponentTypes>, types: Arc<ComponentTypes>,
options: Options, options: Options,
instance: Instance, instance: Instance,
post_return: Option<(ExportFunction, VMTrampoline)>,
post_return_arg: Option<ValRaw>,
} }
impl Func { impl Func {
@@ -102,6 +104,11 @@ impl Func {
.memory .memory
.map(|i| NonNull::new(data.instance().runtime_memory(i)).unwrap()); .map(|i| NonNull::new(data.instance().runtime_memory(i)).unwrap());
let realloc = options.realloc.map(|i| data.instance().runtime_realloc(i)); let realloc = options.realloc.map(|i| data.instance().runtime_realloc(i));
let post_return = options.post_return.map(|i| {
let anyfunc = data.instance().runtime_post_return(i);
let trampoline = store.lookup_trampoline(unsafe { anyfunc.as_ref() });
(ExportFunction { anyfunc }, trampoline)
});
let options = unsafe { Options::new(store.id(), memory, realloc, options.string_encoding) }; let options = unsafe { Options::new(store.id(), memory, realloc, options.string_encoding) };
Func(store.store_data_mut().insert(FuncData { Func(store.store_data_mut().insert(FuncData {
trampoline, trampoline,
@@ -110,6 +117,8 @@ impl Func {
ty, ty,
types: data.component_types().clone(), types: data.component_types().clone(),
instance: *instance, instance: *instance,
post_return,
post_return_arg: None,
})) }))
} }

View File

@@ -8,7 +8,9 @@ use std::panic::{self, AssertUnwindSafe};
use std::ptr::NonNull; use std::ptr::NonNull;
use std::sync::Arc; use std::sync::Arc;
use wasmtime_environ::component::{ComponentTypes, StringEncoding, TypeFuncIndex}; use wasmtime_environ::component::{ComponentTypes, StringEncoding, TypeFuncIndex};
use wasmtime_runtime::component::{VMComponentContext, VMLowering, VMLoweringCallee}; use wasmtime_runtime::component::{
VMComponentContext, VMComponentFlags, VMLowering, VMLoweringCallee,
};
use wasmtime_runtime::{VMCallerCheckedAnyfunc, VMMemoryDefinition, VMOpaqueContext}; use wasmtime_runtime::{VMCallerCheckedAnyfunc, VMMemoryDefinition, VMOpaqueContext};
/// Trait representing host-defined functions that can be imported into a wasm /// Trait representing host-defined functions that can be imported into a wasm
@@ -134,8 +136,7 @@ where
let cx = VMComponentContext::from_opaque(cx); let cx = VMComponentContext::from_opaque(cx);
let instance = (*cx).instance(); let instance = (*cx).instance();
let may_leave = (*instance).may_leave(); let flags = (*instance).flags();
let may_enter = (*instance).may_enter();
let mut cx = StoreContextMut::from_raw((*instance).store()); let mut cx = StoreContextMut::from_raw((*instance).store());
let options = Options::new( let options = Options::new(
@@ -148,13 +149,13 @@ where
// Perform a dynamic check that this instance can indeed be left. Exiting // Perform a dynamic check that this instance can indeed be left. Exiting
// the component is disallowed, for example, when the `realloc` function // the component is disallowed, for example, when the `realloc` function
// calls a canonical import. // calls a canonical import.
if !*may_leave { if !(*flags).may_leave() {
bail!("cannot leave component instance"); bail!("cannot leave component instance");
} }
// While we're lifting and lowering this instance cannot be reentered, so // While we're lifting and lowering this instance cannot be reentered, so
// unset the flag here. This is also reset back to `true` on exit. // unset the flag here. This is also reset back to `true` on exit.
let _reset_may_enter = unset_and_reset_on_drop(may_enter); let _reset_may_enter = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_enter);
// There's a 2x2 matrix of whether parameters and results are stored on the // There's a 2x2 matrix of whether parameters and results are stored on the
// stack or on the heap. Each of the 4 branches here have a different // stack or on the heap. Each of the 4 branches here have a different
@@ -172,7 +173,7 @@ where
let storage = cast_storage::<ReturnStack<Params::Lower, Return::Lower>>(storage); let storage = cast_storage::<ReturnStack<Params::Lower, Return::Lower>>(storage);
let params = Params::lift(cx.0, &options, &storage.assume_init_ref().args)?; let params = Params::lift(cx.0, &options, &storage.assume_init_ref().args)?;
let ret = closure(cx.as_context_mut(), params)?; let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave); reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?; ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else { } else {
let storage = cast_storage::<ReturnPointer<Params::Lower>>(storage).assume_init_ref(); let storage = cast_storage::<ReturnPointer<Params::Lower>>(storage).assume_init_ref();
@@ -180,7 +181,7 @@ where
let ret = closure(cx.as_context_mut(), params)?; let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options); let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?; let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave); reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.store(&mut memory, ptr)?; ret.store(&mut memory, ptr)?;
} }
} else { } else {
@@ -191,7 +192,7 @@ where
validate_inbounds::<Params>(memory.as_slice(), &storage.assume_init_ref().args)?; validate_inbounds::<Params>(memory.as_slice(), &storage.assume_init_ref().args)?;
let params = Params::load(&memory, &memory.as_slice()[ptr..][..Params::size()])?; let params = Params::load(&memory, &memory.as_slice()[ptr..][..Params::size()])?;
let ret = closure(cx.as_context_mut(), params)?; let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave); reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?; ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else { } else {
let storage = cast_storage::<ReturnPointer<ValRaw>>(storage).assume_init_ref(); let storage = cast_storage::<ReturnPointer<ValRaw>>(storage).assume_init_ref();
@@ -200,27 +201,28 @@ where
let ret = closure(cx.as_context_mut(), params)?; let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options); let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?; let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave); reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave);
ret.store(&mut memory, ptr)?; ret.store(&mut memory, ptr)?;
} }
} }
// TODO: need to call `post-return` before this `drop`
drop(reset_may_leave); drop(reset_may_leave);
return Ok(()); return Ok(());
unsafe fn unset_and_reset_on_drop(slot: *mut bool) -> impl Drop { unsafe fn unset_and_reset_on_drop(
debug_assert!(*slot); slot: *mut VMComponentFlags,
*slot = false; set: fn(&mut VMComponentFlags, bool),
return Reset(slot); ) -> impl Drop {
set(&mut *slot, false);
return Reset(slot, set);
struct Reset(*mut bool); struct Reset(*mut VMComponentFlags, fn(&mut VMComponentFlags, bool));
impl Drop for Reset { impl Drop for Reset {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
(*self.0) = true; (self.1)(&mut *self.0, true);
} }
} }
} }

View File

@@ -90,6 +90,31 @@ where
/// the `store` provided. The `params` are copied into WebAssembly memory /// the `store` provided. The `params` are copied into WebAssembly memory
/// as appropriate and a core wasm function is invoked. /// as appropriate and a core wasm function is invoked.
/// ///
/// # Post-return
///
/// In the component model each function can have a "post return" specified
/// which allows cleaning up the arguments returned to the host. For example
/// if WebAssembly returns a string to the host then it might be a uniquely
/// allocated string which, after the host finishes processing it, needs to
/// be deallocated in the wasm instance's own linear memory to prevent
/// memory leaks in wasm itself. The `post-return` canonical abi option is
/// used to configured this.
///
/// To accommodate this feature of the component model after invoking a
/// function via [`TypedFunc::call`] you must next invoke
/// [`TypedFunc::post_return`]. Note that the return value of the function
/// should be processed between these two function calls. The return value
/// continues to be usable from an embedder's perspective after
/// `post_return` is called, but after `post_return` is invoked it may no
/// longer retain the same value that the wasm module originally returned.
///
/// Also note that [`TypedFunc::post_return`] must be invoked irrespective
/// of whether the canonical ABI option `post-return` was configured or not.
/// This means that embedders must unconditionally call
/// [`TypedFunc::post_return`] when a function returns. If this function
/// call returns an error, however, then [`TypedFunc::post_return`] is not
/// required.
///
/// # Errors /// # Errors
/// ///
/// This function can return an error for a number of reasons: /// This function can return an error for a number of reasons:
@@ -99,6 +124,10 @@ where
/// * If the wasm provides bad allocation pointers when copying arguments /// * If the wasm provides bad allocation pointers when copying arguments
/// into memory. /// into memory.
/// * If the wasm returns a value which violates the canonical ABI. /// * If the wasm returns a value which violates the canonical ABI.
/// * If this function's instances cannot be entered, for example if the
/// instance is currently calling a host function.
/// * If a previous function call occurred and the corresponding
/// `post_return` hasn't been invoked yet.
/// ///
/// In general there are many ways that things could go wrong when copying /// In general there are many ways that things could go wrong when copying
/// types in and out of a wasm module with the canonical ABI, and certain /// types in and out of a wasm module with the canonical ABI, and certain
@@ -300,18 +329,17 @@ where
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align); assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
let instance = store.0[instance.0].as_ref().unwrap().instance(); let instance = store.0[instance.0].as_ref().unwrap().instance();
let may_enter = instance.may_enter(); let flags = instance.flags();
let may_leave = instance.may_leave();
unsafe { unsafe {
if !*may_enter { if !(*flags).may_enter() {
bail!("cannot reenter component instance"); bail!("cannot reenter component instance");
} }
debug_assert!(*may_leave); debug_assert!((*flags).may_leave());
*may_leave = false; (*flags).set_may_leave(false);
let result = lower(store, &options, params, map_maybe_uninit!(space.params)); let result = lower(store, &options, params, map_maybe_uninit!(space.params));
*may_leave = true; (*flags).set_may_leave(true);
result?; result?;
// This is unsafe as we are providing the guarantee that all the // This is unsafe as we are providing the guarantee that all the
@@ -336,18 +364,139 @@ where
// `[ValRaw]`, and additionally they should have the correct types // `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return // for the function we just called (which filled in the return
// values). // values).
*may_enter = false; let ret = map_maybe_uninit!(space.ret).assume_init_ref();
let result = lift(
store.0,
&options,
map_maybe_uninit!(space.ret).assume_init_ref(),
);
// TODO: this technically needs to happen only after the // Lift the result into the host while managing post-return state
// `post-return` is called. // here as well.
*may_enter = true; //
return result; // Initially the `may_enter` flag is set to `false` for this
// component instance and additionally we set a flag indicating that
// a post-return is required. This isn't specified by the component
// model itself but is used for our implementation of the API of
// `post_return` as a separate function call.
//
// FIXME(WebAssembly/component-model#55) it's not really clear what
// the semantics should be in the face of a lift error/trap. For now
// the flags are reset so the instance can continue to be reused in
// tests but that probably isn't what's desired.
//
// Otherwise though after a successful lift the return value of the
// function, which is currently required to be 0 or 1 values
// according to the canonical ABI, is saved within the `Store`'s
// `FuncData`. This'll later get used in post-return.
(*flags).set_may_enter(false);
(*flags).set_needs_post_return(true);
match lift(store.0, &options, ret) {
Ok(val) => {
let ret_slice = cast_storage(ret);
let data = &mut store.0[self.func.0];
assert!(data.post_return_arg.is_none());
match ret_slice.len() {
0 => data.post_return_arg = Some(ValRaw::i32(0)),
1 => data.post_return_arg = Some(ret_slice[0]),
_ => unreachable!(),
}
return Ok(val);
}
Err(err) => {
(*flags).set_may_enter(true);
(*flags).set_needs_post_return(false);
return Err(err);
}
}
} }
unsafe fn cast_storage<T>(storage: &T) -> &[ValRaw] {
assert!(std::mem::size_of_val(storage) % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<ValRaw>());
std::slice::from_raw_parts(
(storage as *const T).cast(),
mem::size_of_val(storage) / mem::size_of::<ValRaw>(),
)
}
}
/// Invokes the `post-return` canonical ABI option, if specified, after a
/// [`TypedFunc::call`] has finished.
///
/// For some more information on when to use this function see the
/// documentation for post-return in the [`TypedFunc::call`] method.
/// Otherwise though this function is a required method call after a
/// [`TypedFunc::call`] completes successfully. After the embedder has
/// finished processing the return value then this function must be invoked.
///
/// # Errors
///
/// This function will return an error in the case of a WebAssembly trap
/// happening during the execution of the `post-return` function, if
/// specified.
///
/// # Panics
///
/// This function will panic if it's not called under the correct
/// conditions. This can only be called after a previous invocation of
/// [`TypedFunc::call`] completes successfully, and this function can only
/// be called for the same [`TypedFunc`] that was `call`'d.
///
/// If this function is called when [`TypedFunc::call`] was not previously
/// called, then it will panic. If a different [`TypedFunc`] for the same
/// component instance was invoked then this function will also panic
/// because the `post-return` needs to happen for the other function.
pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> {
let mut store = store.as_context_mut();
let data = &mut store.0[self.func.0];
let instance = data.instance;
let post_return = data.post_return;
let post_return_arg = data.post_return_arg.take();
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags();
unsafe {
// First assert that the instance is in a "needs post return" state.
// This will ensure that the previous action on the instance was a
// function call above. This flag is only set after a component
// function returns so this also can't be called (as expected)
// during a host import for example.
//
// Note, though, that this assert is not sufficient because it just
// means some function on this instance needs its post-return
// called. We need a precise post-return for a particular function
// which is the second assert here (the `.expect`). That will assert
// that this function itself needs to have its post-return called.
//
// The theory at least is that these two asserts ensure component
// model semantics are upheld where the host properly calls
// `post_return` on the right function despite the call being a
// separate step in the API.
assert!(
(*flags).needs_post_return(),
"post_return can only be called after a function has previously been called",
);
let post_return_arg = post_return_arg.expect("calling post_return on wrong function");
// This is a sanity-check assert which shouldn't ever trip.
assert!(!(*flags).may_enter());
// With the state of the world validated these flags are updated to
// their component-model-defined states.
(*flags).set_may_enter(true);
(*flags).set_needs_post_return(false);
// And finally if the function actually had a `post-return`
// configured in its canonical options that's executed here.
let (func, trampoline) = match post_return {
Some(pair) => pair,
None => return Ok(()),
};
crate::Func::call_unchecked_raw(
&mut store,
func.anyfunc,
trampoline,
&post_return_arg as *const ValRaw as *mut ValRaw,
)?;
}
Ok(())
} }
} }

View File

@@ -7,9 +7,9 @@ use anyhow::{anyhow, Context, Result};
use std::marker; use std::marker;
use std::sync::Arc; use std::sync::Arc;
use wasmtime_environ::component::{ use wasmtime_environ::component::{
ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractRealloc, ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractPostReturn,
GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex, RuntimeInstanceIndex, ExtractRealloc, GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex,
RuntimeModuleIndex, RuntimeInstanceIndex, RuntimeModuleIndex,
}; };
use wasmtime_environ::{EntityIndex, PrimaryMap}; use wasmtime_environ::{EntityIndex, PrimaryMap};
use wasmtime_runtime::component::{ComponentInstance, OwnedComponentInstance}; use wasmtime_runtime::component::{ComponentInstance, OwnedComponentInstance};
@@ -278,6 +278,10 @@ impl<'a> Instantiator<'a> {
self.extract_realloc(store.0, realloc) self.extract_realloc(store.0, realloc)
} }
GlobalInitializer::ExtractPostReturn(post_return) => {
self.extract_post_return(store.0, post_return)
}
GlobalInitializer::SaveStaticModule(idx) => { GlobalInitializer::SaveStaticModule(idx) => {
self.data self.data
.exported_modules .exported_modules
@@ -338,6 +342,16 @@ impl<'a> Instantiator<'a> {
self.data.state.set_runtime_realloc(realloc.index, anyfunc); self.data.state.set_runtime_realloc(realloc.index, anyfunc);
} }
fn extract_post_return(&mut self, store: &mut StoreOpaque, post_return: &ExtractPostReturn) {
let anyfunc = match self.data.lookup_def(store, &post_return.def) {
wasmtime_runtime::Export::Function(f) => f.anyfunc,
_ => unreachable!(),
};
self.data
.state
.set_runtime_post_return(post_return.index, anyfunc);
}
fn build_imports<'b>( fn build_imports<'b>(
&mut self, &mut self,
store: &mut StoreOpaque, store: &mut StoreOpaque,

View File

@@ -5,6 +5,7 @@ use wasmtime::{Config, Engine};
mod func; mod func;
mod import; mod import;
mod nested; mod nested;
mod post_return;
// A simple bump allocator which can be used with modules // A simple bump allocator which can be used with modules
const REALLOC_AND_FREE: &str = r#" const REALLOC_AND_FREE: &str = r#"

View File

@@ -3,7 +3,7 @@ use anyhow::Result;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use wasmtime::component::*; use wasmtime::component::*;
use wasmtime::{Store, StoreContextMut, Trap, TrapCode}; use wasmtime::{AsContextMut, Store, StoreContextMut, Trap, TrapCode};
const CANON_32BIT_NAN: u32 = 0b01111111110000000000000000000000; const CANON_32BIT_NAN: u32 = 0b01111111110000000000000000000000;
const CANON_64BIT_NAN: u64 = 0b0111111111111000000000000000000000000000000000000000000000000000; const CANON_64BIT_NAN: u64 = 0b0111111111111000000000000000000000000000000000000000000000000000;
@@ -32,7 +32,7 @@ fn thunks() -> Result<()> {
let instance = Linker::new(&engine).instantiate(&mut store, &component)?; let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
instance instance
.get_typed_func::<(), (), _>(&mut store, "thunk")? .get_typed_func::<(), (), _>(&mut store, "thunk")?
.call(&mut store, ())?; .call_and_post_return(&mut store, ())?;
let err = instance let err = instance
.get_typed_func::<(), (), _>(&mut store, "thunk-trap")? .get_typed_func::<(), (), _>(&mut store, "thunk-trap")?
.call(&mut store, ()) .call(&mut store, ())
@@ -193,28 +193,28 @@ fn integers() -> Result<()> {
// Passing in 100 is valid for all primitives // Passing in 100 is valid for all primitives
instance instance
.get_typed_func::<(u8,), (), _>(&mut store, "take-u8")? .get_typed_func::<(u8,), (), _>(&mut store, "take-u8")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
instance instance
.get_typed_func::<(i8,), (), _>(&mut store, "take-s8")? .get_typed_func::<(i8,), (), _>(&mut store, "take-s8")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
instance instance
.get_typed_func::<(u16,), (), _>(&mut store, "take-u16")? .get_typed_func::<(u16,), (), _>(&mut store, "take-u16")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
instance instance
.get_typed_func::<(i16,), (), _>(&mut store, "take-s16")? .get_typed_func::<(i16,), (), _>(&mut store, "take-s16")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
instance instance
.get_typed_func::<(u32,), (), _>(&mut store, "take-u32")? .get_typed_func::<(u32,), (), _>(&mut store, "take-u32")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
instance instance
.get_typed_func::<(i32,), (), _>(&mut store, "take-s32")? .get_typed_func::<(i32,), (), _>(&mut store, "take-s32")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
instance instance
.get_typed_func::<(u64,), (), _>(&mut store, "take-u64")? .get_typed_func::<(u64,), (), _>(&mut store, "take-u64")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
instance instance
.get_typed_func::<(i64,), (), _>(&mut store, "take-s64")? .get_typed_func::<(i64,), (), _>(&mut store, "take-s64")?
.call(&mut store, (100,))?; .call_and_post_return(&mut store, (100,))?;
// This specific wasm instance traps if any value other than 100 is passed // This specific wasm instance traps if any value other than 100 is passed
instance instance
@@ -262,49 +262,49 @@ fn integers() -> Result<()> {
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u8, _>(&mut store, "ret-u8")? .get_typed_func::<(), u8, _>(&mut store, "ret-u8")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i8, _>(&mut store, "ret-s8")? .get_typed_func::<(), i8, _>(&mut store, "ret-s8")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u16, _>(&mut store, "ret-u16")? .get_typed_func::<(), u16, _>(&mut store, "ret-u16")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i16, _>(&mut store, "ret-s16")? .get_typed_func::<(), i16, _>(&mut store, "ret-s16")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u32, _>(&mut store, "ret-u32")? .get_typed_func::<(), u32, _>(&mut store, "ret-u32")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i32, _>(&mut store, "ret-s32")? .get_typed_func::<(), i32, _>(&mut store, "ret-s32")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u64, _>(&mut store, "ret-u64")? .get_typed_func::<(), u64, _>(&mut store, "ret-u64")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i64, _>(&mut store, "ret-s64")? .get_typed_func::<(), i64, _>(&mut store, "ret-s64")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0 0
); );
@@ -312,49 +312,49 @@ fn integers() -> Result<()> {
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u8, _>(&mut store, "retm1-u8")? .get_typed_func::<(), u8, _>(&mut store, "retm1-u8")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0xff 0xff
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i8, _>(&mut store, "retm1-s8")? .get_typed_func::<(), i8, _>(&mut store, "retm1-s8")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
-1 -1
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u16, _>(&mut store, "retm1-u16")? .get_typed_func::<(), u16, _>(&mut store, "retm1-u16")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0xffff 0xffff
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i16, _>(&mut store, "retm1-s16")? .get_typed_func::<(), i16, _>(&mut store, "retm1-s16")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
-1 -1
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u32, _>(&mut store, "retm1-u32")? .get_typed_func::<(), u32, _>(&mut store, "retm1-u32")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0xffffffff 0xffffffff
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i32, _>(&mut store, "retm1-s32")? .get_typed_func::<(), i32, _>(&mut store, "retm1-s32")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
-1 -1
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u64, _>(&mut store, "retm1-u64")? .get_typed_func::<(), u64, _>(&mut store, "retm1-u64")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
0xffffffff_ffffffff 0xffffffff_ffffffff
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i64, _>(&mut store, "retm1-s64")? .get_typed_func::<(), i64, _>(&mut store, "retm1-s64")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
-1 -1
); );
@@ -363,43 +363,59 @@ fn integers() -> Result<()> {
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u8, _>(&mut store, "retbig-u8")? .get_typed_func::<(), u8, _>(&mut store, "retbig-u8")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
ret as u8, ret as u8,
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i8, _>(&mut store, "retbig-s8")? .get_typed_func::<(), i8, _>(&mut store, "retbig-s8")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
ret as i8, ret as i8,
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u16, _>(&mut store, "retbig-u16")? .get_typed_func::<(), u16, _>(&mut store, "retbig-u16")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
ret as u16, ret as u16,
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i16, _>(&mut store, "retbig-s16")? .get_typed_func::<(), i16, _>(&mut store, "retbig-s16")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
ret as i16, ret as i16,
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), u32, _>(&mut store, "retbig-u32")? .get_typed_func::<(), u32, _>(&mut store, "retbig-u32")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
ret, ret,
); );
assert_eq!( assert_eq!(
instance instance
.get_typed_func::<(), i32, _>(&mut store, "retbig-s32")? .get_typed_func::<(), i32, _>(&mut store, "retbig-s32")?
.call(&mut store, ())?, .call_and_post_return(&mut store, ())?,
ret as i32, ret as i32,
); );
Ok(()) Ok(())
} }
trait TypedFuncExt<P, R> {
fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result<R>;
}
impl<P, R> TypedFuncExt<P, R> for TypedFunc<P, R>
where
P: ComponentParams + Lower,
R: Lift,
{
fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result<R> {
let result = self.call(&mut store, params)?;
self.post_return(&mut store)?;
Ok(result)
}
}
#[test] #[test]
fn type_layers() -> Result<()> { fn type_layers() -> Result<()> {
let component = r#" let component = r#"
@@ -425,19 +441,19 @@ fn type_layers() -> Result<()> {
instance instance
.get_typed_func::<(Box<u32>,), (), _>(&mut store, "take-u32")? .get_typed_func::<(Box<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Box::new(2),))?; .call_and_post_return(&mut store, (Box::new(2),))?;
instance instance
.get_typed_func::<(&u32,), (), _>(&mut store, "take-u32")? .get_typed_func::<(&u32,), (), _>(&mut store, "take-u32")?
.call(&mut store, (&2,))?; .call_and_post_return(&mut store, (&2,))?;
instance instance
.get_typed_func::<(Rc<u32>,), (), _>(&mut store, "take-u32")? .get_typed_func::<(Rc<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Rc::new(2),))?; .call_and_post_return(&mut store, (Rc::new(2),))?;
instance instance
.get_typed_func::<(Arc<u32>,), (), _>(&mut store, "take-u32")? .get_typed_func::<(Arc<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Arc::new(2),))?; .call_and_post_return(&mut store, (Arc::new(2),))?;
instance instance
.get_typed_func::<(&Box<Arc<Rc<u32>>>,), (), _>(&mut store, "take-u32")? .get_typed_func::<(&Box<Arc<Rc<u32>>>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?; .call_and_post_return(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?;
Ok(()) Ok(())
} }
@@ -491,9 +507,13 @@ fn floats() -> Result<()> {
let u64_to_f64 = instance.get_typed_func::<(u64,), f64, _>(&mut store, "u64-to-f64")?; let u64_to_f64 = instance.get_typed_func::<(u64,), f64, _>(&mut store, "u64-to-f64")?;
assert_eq!(f32_to_u32.call(&mut store, (1.0,))?, 1.0f32.to_bits()); assert_eq!(f32_to_u32.call(&mut store, (1.0,))?, 1.0f32.to_bits());
f32_to_u32.post_return(&mut store)?;
assert_eq!(f64_to_u64.call(&mut store, (2.0,))?, 2.0f64.to_bits()); assert_eq!(f64_to_u64.call(&mut store, (2.0,))?, 2.0f64.to_bits());
f64_to_u64.post_return(&mut store)?;
assert_eq!(u32_to_f32.call(&mut store, (3.0f32.to_bits(),))?, 3.0); assert_eq!(u32_to_f32.call(&mut store, (3.0f32.to_bits(),))?, 3.0);
u32_to_f32.post_return(&mut store)?;
assert_eq!(u64_to_f64.call(&mut store, (4.0f64.to_bits(),))?, 4.0); assert_eq!(u64_to_f64.call(&mut store, (4.0f64.to_bits(),))?, 4.0);
u64_to_f64.post_return(&mut store)?;
assert_eq!( assert_eq!(
u32_to_f32 u32_to_f32
@@ -501,21 +521,25 @@ fn floats() -> Result<()> {
.to_bits(), .to_bits(),
CANON_32BIT_NAN CANON_32BIT_NAN
); );
u32_to_f32.post_return(&mut store)?;
assert_eq!( assert_eq!(
u64_to_f64 u64_to_f64
.call(&mut store, (CANON_64BIT_NAN | 1,))? .call(&mut store, (CANON_64BIT_NAN | 1,))?
.to_bits(), .to_bits(),
CANON_64BIT_NAN CANON_64BIT_NAN
); );
u64_to_f64.post_return(&mut store)?;
assert_eq!( assert_eq!(
f32_to_u32.call(&mut store, (f32::from_bits(CANON_32BIT_NAN | 1),))?, f32_to_u32.call(&mut store, (f32::from_bits(CANON_32BIT_NAN | 1),))?,
CANON_32BIT_NAN CANON_32BIT_NAN
); );
f32_to_u32.post_return(&mut store)?;
assert_eq!( assert_eq!(
f64_to_u64.call(&mut store, (f64::from_bits(CANON_64BIT_NAN | 1),))?, f64_to_u64.call(&mut store, (f64::from_bits(CANON_64BIT_NAN | 1),))?,
CANON_64BIT_NAN CANON_64BIT_NAN
); );
f64_to_u64.post_return(&mut store)?;
Ok(()) Ok(())
} }
@@ -546,10 +570,15 @@ fn bools() -> Result<()> {
let bool_to_u32 = instance.get_typed_func::<(bool,), u32, _>(&mut store, "bool-to-u32")?; let bool_to_u32 = instance.get_typed_func::<(bool,), u32, _>(&mut store, "bool-to-u32")?;
assert_eq!(bool_to_u32.call(&mut store, (false,))?, 0); assert_eq!(bool_to_u32.call(&mut store, (false,))?, 0);
bool_to_u32.post_return(&mut store)?;
assert_eq!(bool_to_u32.call(&mut store, (true,))?, 1); assert_eq!(bool_to_u32.call(&mut store, (true,))?, 1);
bool_to_u32.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (0,))?, false); assert_eq!(u32_to_bool.call(&mut store, (0,))?, false);
u32_to_bool.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (1,))?, true); assert_eq!(u32_to_bool.call(&mut store, (1,))?, true);
u32_to_bool.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (2,))?, true); assert_eq!(u32_to_bool.call(&mut store, (2,))?, true);
u32_to_bool.post_return(&mut store)?;
Ok(()) Ok(())
} }
@@ -581,7 +610,9 @@ fn chars() -> Result<()> {
let mut roundtrip = |x: char| -> Result<()> { let mut roundtrip = |x: char| -> Result<()> {
assert_eq!(char_to_u32.call(&mut store, (x,))?, x as u32); assert_eq!(char_to_u32.call(&mut store, (x,))?, x as u32);
char_to_u32.post_return(&mut store)?;
assert_eq!(u32_to_char.call(&mut store, (x as u32,))?, x); assert_eq!(u32_to_char.call(&mut store, (x as u32,))?, x);
u32_to_char.post_return(&mut store)?;
Ok(()) Ok(())
}; };
@@ -644,7 +675,7 @@ fn tuple_result() -> Result<()> {
let input = (-1, 100, 3.0, 100.0); let input = (-1, 100, 3.0, 100.0);
let output = instance let output = instance
.get_typed_func::<(i8, u16, f32, f64), (i8, u16, f32, f64), _>(&mut store, "tuple")? .get_typed_func::<(i8, u16, f32, f64), (i8, u16, f32, f64), _>(&mut store, "tuple")?
.call(&mut store, input)?; .call_and_post_return(&mut store, input)?;
assert_eq!(input, output); assert_eq!(input, output);
let invalid_func = let invalid_func =
@@ -735,16 +766,20 @@ fn strings() -> Result<()> {
let mut roundtrip = |x: &str| -> Result<()> { let mut roundtrip = |x: &str| -> Result<()> {
let ret = list8_to_str.call(&mut store, (x.as_bytes(),))?; let ret = list8_to_str.call(&mut store, (x.as_bytes(),))?;
assert_eq!(ret.to_str(&store)?, x); assert_eq!(ret.to_str(&store)?, x);
list8_to_str.post_return(&mut store)?;
let utf16 = x.encode_utf16().collect::<Vec<_>>(); let utf16 = x.encode_utf16().collect::<Vec<_>>();
let ret = list16_to_str.call(&mut store, (&utf16[..],))?; let ret = list16_to_str.call(&mut store, (&utf16[..],))?;
assert_eq!(ret.to_str(&store)?, x); assert_eq!(ret.to_str(&store)?, x);
list16_to_str.post_return(&mut store)?;
let ret = str_to_list8.call(&mut store, (x,))?; let ret = str_to_list8.call(&mut store, (x,))?;
assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, x.as_bytes()); assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, x.as_bytes());
str_to_list8.post_return(&mut store)?;
let ret = str_to_list16.call(&mut store, (x,))?; let ret = str_to_list16.call(&mut store, (x,))?;
assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, utf16,); assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, utf16,);
str_to_list16.post_return(&mut store)?;
Ok(()) Ok(())
}; };
@@ -758,22 +793,27 @@ fn strings() -> Result<()> {
let ret = list8_to_str.call(&mut store, (b"\xff",))?; let ret = list8_to_str.call(&mut store, (b"\xff",))?;
let err = ret.to_str(&store).unwrap_err(); let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("invalid utf-8"), "{}", err); assert!(err.to_string().contains("invalid utf-8"), "{}", err);
list8_to_str.post_return(&mut store)?;
let ret = list8_to_str.call(&mut store, (b"hello there \xff invalid",))?; let ret = list8_to_str.call(&mut store, (b"hello there \xff invalid",))?;
let err = ret.to_str(&store).unwrap_err(); let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("invalid utf-8"), "{}", err); assert!(err.to_string().contains("invalid utf-8"), "{}", err);
list8_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xd800],))?; let ret = list16_to_str.call(&mut store, (&[0xd800],))?;
let err = ret.to_str(&store).unwrap_err(); let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err); assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xdfff],))?; let ret = list16_to_str.call(&mut store, (&[0xdfff],))?;
let err = ret.to_str(&store).unwrap_err(); let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err); assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xd800, 0xff00],))?; let ret = list16_to_str.call(&mut store, (&[0xd800, 0xff00],))?;
let err = ret.to_str(&store).unwrap_err(); let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err); assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
Ok(()) Ok(())
} }
@@ -1123,10 +1163,10 @@ fn some_traps() -> Result<()> {
instance instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[],))?; .call_and_post_return(&mut store, (&[],))?;
instance instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[1, 2, 3, 4],))?; .call_and_post_return(&mut store, (&[1, 2, 3, 4],))?;
let err = instance let err = instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[1, 2, 3, 4, 5],)) .call(&mut store, (&[1, 2, 3, 4, 5],))
@@ -1134,10 +1174,10 @@ fn some_traps() -> Result<()> {
assert_oob(&err); assert_oob(&err);
instance instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("",))?; .call_and_post_return(&mut store, ("",))?;
instance instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("abcd",))?; .call_and_post_return(&mut store, ("abcd",))?;
let err = instance let err = instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("abcde",)) .call(&mut store, ("abcde",))
@@ -1216,12 +1256,15 @@ fn char_bool_memory() -> Result<()> {
let ret = func.call(&mut store, (0, 'a' as u32))?; let ret = func.call(&mut store, (0, 'a' as u32))?;
assert_eq!(ret, (false, 'a')); assert_eq!(ret, (false, 'a'));
func.post_return(&mut store)?;
let ret = func.call(&mut store, (1, '🍰' as u32))?; let ret = func.call(&mut store, (1, '🍰' as u32))?;
assert_eq!(ret, (true, '🍰')); assert_eq!(ret, (true, '🍰'));
func.post_return(&mut store)?;
let ret = func.call(&mut store, (2, 'a' as u32))?; let ret = func.call(&mut store, (2, 'a' as u32))?;
assert_eq!(ret, (true, 'a')); assert_eq!(ret, (true, 'a'));
func.post_return(&mut store)?;
assert!(func.call(&mut store, (0, 0xd800)).is_err()); assert!(func.call(&mut store, (0, 0xd800)).is_err());
@@ -1437,22 +1480,30 @@ fn option() -> Result<()> {
let option_unit_to_u32 = let option_unit_to_u32 =
instance.get_typed_func::<(Option<()>,), u32, _>(&mut store, "option-unit-to-u32")?; instance.get_typed_func::<(Option<()>,), u32, _>(&mut store, "option-unit-to-u32")?;
assert_eq!(option_unit_to_u32.call(&mut store, (None,))?, 0); assert_eq!(option_unit_to_u32.call(&mut store, (None,))?, 0);
option_unit_to_u32.post_return(&mut store)?;
assert_eq!(option_unit_to_u32.call(&mut store, (Some(()),))?, 1); assert_eq!(option_unit_to_u32.call(&mut store, (Some(()),))?, 1);
option_unit_to_u32.post_return(&mut store)?;
let option_u8_to_tuple = instance let option_u8_to_tuple = instance
.get_typed_func::<(Option<u8>,), (u32, u32), _>(&mut store, "option-u8-to-tuple")?; .get_typed_func::<(Option<u8>,), (u32, u32), _>(&mut store, "option-u8-to-tuple")?;
assert_eq!(option_u8_to_tuple.call(&mut store, (None,))?, (0, 0)); assert_eq!(option_u8_to_tuple.call(&mut store, (None,))?, (0, 0));
option_u8_to_tuple.post_return(&mut store)?;
assert_eq!(option_u8_to_tuple.call(&mut store, (Some(0),))?, (1, 0)); assert_eq!(option_u8_to_tuple.call(&mut store, (Some(0),))?, (1, 0));
option_u8_to_tuple.post_return(&mut store)?;
assert_eq!(option_u8_to_tuple.call(&mut store, (Some(100),))?, (1, 100)); assert_eq!(option_u8_to_tuple.call(&mut store, (Some(100),))?, (1, 100));
option_u8_to_tuple.post_return(&mut store)?;
let option_u32_to_tuple = instance let option_u32_to_tuple = instance
.get_typed_func::<(Option<u32>,), (u32, u32), _>(&mut store, "option-u32-to-tuple")?; .get_typed_func::<(Option<u32>,), (u32, u32), _>(&mut store, "option-u32-to-tuple")?;
assert_eq!(option_u32_to_tuple.call(&mut store, (None,))?, (0, 0)); assert_eq!(option_u32_to_tuple.call(&mut store, (None,))?, (0, 0));
option_u32_to_tuple.post_return(&mut store)?;
assert_eq!(option_u32_to_tuple.call(&mut store, (Some(0),))?, (1, 0)); assert_eq!(option_u32_to_tuple.call(&mut store, (Some(0),))?, (1, 0));
option_u32_to_tuple.post_return(&mut store)?;
assert_eq!( assert_eq!(
option_u32_to_tuple.call(&mut store, (Some(100),))?, option_u32_to_tuple.call(&mut store, (Some(100),))?,
(1, 100) (1, 100)
); );
option_u32_to_tuple.post_return(&mut store)?;
let option_string_to_tuple = instance.get_typed_func::<(Option<&str>,), (u32, WasmStr), _>( let option_string_to_tuple = instance.get_typed_func::<(Option<&str>,), (u32, WasmStr), _>(
&mut store, &mut store,
@@ -1461,45 +1512,59 @@ fn option() -> Result<()> {
let (a, b) = option_string_to_tuple.call(&mut store, (None,))?; let (a, b) = option_string_to_tuple.call(&mut store, (None,))?;
assert_eq!(a, 0); assert_eq!(a, 0);
assert_eq!(b.to_str(&store)?, ""); assert_eq!(b.to_str(&store)?, "");
option_string_to_tuple.post_return(&mut store)?;
let (a, b) = option_string_to_tuple.call(&mut store, (Some(""),))?; let (a, b) = option_string_to_tuple.call(&mut store, (Some(""),))?;
assert_eq!(a, 1); assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, ""); assert_eq!(b.to_str(&store)?, "");
option_string_to_tuple.post_return(&mut store)?;
let (a, b) = option_string_to_tuple.call(&mut store, (Some("hello"),))?; let (a, b) = option_string_to_tuple.call(&mut store, (Some("hello"),))?;
assert_eq!(a, 1); assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "hello"); assert_eq!(b.to_str(&store)?, "hello");
option_string_to_tuple.post_return(&mut store)?;
let to_option_unit = let to_option_unit =
instance.get_typed_func::<(u32,), Option<()>, _>(&mut store, "to-option-unit")?; instance.get_typed_func::<(u32,), Option<()>, _>(&mut store, "to-option-unit")?;
assert_eq!(to_option_unit.call(&mut store, (0,))?, None); assert_eq!(to_option_unit.call(&mut store, (0,))?, None);
to_option_unit.post_return(&mut store)?;
assert_eq!(to_option_unit.call(&mut store, (1,))?, Some(())); assert_eq!(to_option_unit.call(&mut store, (1,))?, Some(()));
to_option_unit.post_return(&mut store)?;
let err = to_option_unit.call(&mut store, (2,)).unwrap_err(); let err = to_option_unit.call(&mut store, (2,)).unwrap_err();
assert!(err.to_string().contains("invalid option"), "{}", err); assert!(err.to_string().contains("invalid option"), "{}", err);
let to_option_u8 = let to_option_u8 =
instance.get_typed_func::<(u32, u32), Option<u8>, _>(&mut store, "to-option-u8")?; instance.get_typed_func::<(u32, u32), Option<u8>, _>(&mut store, "to-option-u8")?;
assert_eq!(to_option_u8.call(&mut store, (0x00_00, 0))?, None); assert_eq!(to_option_u8.call(&mut store, (0x00_00, 0))?, None);
to_option_u8.post_return(&mut store)?;
assert_eq!(to_option_u8.call(&mut store, (0x00_01, 0))?, Some(0)); assert_eq!(to_option_u8.call(&mut store, (0x00_01, 0))?, Some(0));
to_option_u8.post_return(&mut store)?;
assert_eq!(to_option_u8.call(&mut store, (0xfd_01, 0))?, Some(0xfd)); assert_eq!(to_option_u8.call(&mut store, (0xfd_01, 0))?, Some(0xfd));
to_option_u8.post_return(&mut store)?;
assert!(to_option_u8.call(&mut store, (0x00_02, 0)).is_err()); assert!(to_option_u8.call(&mut store, (0x00_02, 0)).is_err());
let to_option_u32 = let to_option_u32 =
instance.get_typed_func::<(u32, u32), Option<u32>, _>(&mut store, "to-option-u32")?; instance.get_typed_func::<(u32, u32), Option<u32>, _>(&mut store, "to-option-u32")?;
assert_eq!(to_option_u32.call(&mut store, (0, 0))?, None); assert_eq!(to_option_u32.call(&mut store, (0, 0))?, None);
to_option_u32.post_return(&mut store)?;
assert_eq!(to_option_u32.call(&mut store, (1, 0))?, Some(0)); assert_eq!(to_option_u32.call(&mut store, (1, 0))?, Some(0));
to_option_u32.post_return(&mut store)?;
assert_eq!( assert_eq!(
to_option_u32.call(&mut store, (1, 0x1234fead))?, to_option_u32.call(&mut store, (1, 0x1234fead))?,
Some(0x1234fead) Some(0x1234fead)
); );
to_option_u32.post_return(&mut store)?;
assert!(to_option_u32.call(&mut store, (2, 0)).is_err()); assert!(to_option_u32.call(&mut store, (2, 0)).is_err());
let to_option_string = instance let to_option_string = instance
.get_typed_func::<(u32, &str), Option<WasmStr>, _>(&mut store, "to-option-string")?; .get_typed_func::<(u32, &str), Option<WasmStr>, _>(&mut store, "to-option-string")?;
let ret = to_option_string.call(&mut store, (0, ""))?; let ret = to_option_string.call(&mut store, (0, ""))?;
assert!(ret.is_none()); assert!(ret.is_none());
to_option_string.post_return(&mut store)?;
let ret = to_option_string.call(&mut store, (1, ""))?; let ret = to_option_string.call(&mut store, (1, ""))?;
assert_eq!(ret.unwrap().to_str(&store)?, ""); assert_eq!(ret.unwrap().to_str(&store)?, "");
to_option_string.post_return(&mut store)?;
let ret = to_option_string.call(&mut store, (1, "cheesecake"))?; let ret = to_option_string.call(&mut store, (1, "cheesecake"))?;
assert_eq!(ret.unwrap().to_str(&store)?, "cheesecake"); assert_eq!(ret.unwrap().to_str(&store)?, "cheesecake");
to_option_string.post_return(&mut store)?;
assert!(to_option_string.call(&mut store, (2, "")).is_err()); assert!(to_option_string.call(&mut store, (2, "")).is_err());
Ok(()) Ok(())
@@ -1592,15 +1657,19 @@ fn expected() -> Result<()> {
let take_expected_unit = let take_expected_unit =
instance.get_typed_func::<(Result<(), ()>,), u32, _>(&mut store, "take-expected-unit")?; instance.get_typed_func::<(Result<(), ()>,), u32, _>(&mut store, "take-expected-unit")?;
assert_eq!(take_expected_unit.call(&mut store, (Ok(()),))?, 0); assert_eq!(take_expected_unit.call(&mut store, (Ok(()),))?, 0);
take_expected_unit.post_return(&mut store)?;
assert_eq!(take_expected_unit.call(&mut store, (Err(()),))?, 1); assert_eq!(take_expected_unit.call(&mut store, (Err(()),))?, 1);
take_expected_unit.post_return(&mut store)?;
let take_expected_u8_f32 = instance let take_expected_u8_f32 = instance
.get_typed_func::<(Result<u8, f32>,), (u32, u32), _>(&mut store, "take-expected-u8-f32")?; .get_typed_func::<(Result<u8, f32>,), (u32, u32), _>(&mut store, "take-expected-u8-f32")?;
assert_eq!(take_expected_u8_f32.call(&mut store, (Ok(1),))?, (0, 1)); assert_eq!(take_expected_u8_f32.call(&mut store, (Ok(1),))?, (0, 1));
take_expected_u8_f32.post_return(&mut store)?;
assert_eq!( assert_eq!(
take_expected_u8_f32.call(&mut store, (Err(2.0),))?, take_expected_u8_f32.call(&mut store, (Err(2.0),))?,
(1, 2.0f32.to_bits()) (1, 2.0f32.to_bits())
); );
take_expected_u8_f32.post_return(&mut store)?;
let take_expected_string = instance let take_expected_string = instance
.get_typed_func::<(Result<&str, &[u8]>,), (u32, WasmStr), _>( .get_typed_func::<(Result<&str, &[u8]>,), (u32, WasmStr), _>(
@@ -1610,27 +1679,35 @@ fn expected() -> Result<()> {
let (a, b) = take_expected_string.call(&mut store, (Ok("hello"),))?; let (a, b) = take_expected_string.call(&mut store, (Ok("hello"),))?;
assert_eq!(a, 0); assert_eq!(a, 0);
assert_eq!(b.to_str(&store)?, "hello"); assert_eq!(b.to_str(&store)?, "hello");
take_expected_string.post_return(&mut store)?;
let (a, b) = take_expected_string.call(&mut store, (Err(b"goodbye"),))?; let (a, b) = take_expected_string.call(&mut store, (Err(b"goodbye"),))?;
assert_eq!(a, 1); assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "goodbye"); assert_eq!(b.to_str(&store)?, "goodbye");
take_expected_string.post_return(&mut store)?;
let to_expected_unit = let to_expected_unit =
instance.get_typed_func::<(u32,), Result<(), ()>, _>(&mut store, "to-expected-unit")?; instance.get_typed_func::<(u32,), Result<(), ()>, _>(&mut store, "to-expected-unit")?;
assert_eq!(to_expected_unit.call(&mut store, (0,))?, Ok(())); assert_eq!(to_expected_unit.call(&mut store, (0,))?, Ok(()));
to_expected_unit.post_return(&mut store)?;
assert_eq!(to_expected_unit.call(&mut store, (1,))?, Err(())); assert_eq!(to_expected_unit.call(&mut store, (1,))?, Err(()));
to_expected_unit.post_return(&mut store)?;
let err = to_expected_unit.call(&mut store, (2,)).unwrap_err(); let err = to_expected_unit.call(&mut store, (2,)).unwrap_err();
assert!(err.to_string().contains("invalid expected"), "{}", err); assert!(err.to_string().contains("invalid expected"), "{}", err);
let to_expected_s16_f32 = instance let to_expected_s16_f32 = instance
.get_typed_func::<(u32, u32), Result<i16, f32>, _>(&mut store, "to-expected-s16-f32")?; .get_typed_func::<(u32, u32), Result<i16, f32>, _>(&mut store, "to-expected-s16-f32")?;
assert_eq!(to_expected_s16_f32.call(&mut store, (0, 0))?, Ok(0)); assert_eq!(to_expected_s16_f32.call(&mut store, (0, 0))?, Ok(0));
to_expected_s16_f32.post_return(&mut store)?;
assert_eq!(to_expected_s16_f32.call(&mut store, (0, 100))?, Ok(100)); assert_eq!(to_expected_s16_f32.call(&mut store, (0, 100))?, Ok(100));
to_expected_s16_f32.post_return(&mut store)?;
assert_eq!( assert_eq!(
to_expected_s16_f32.call(&mut store, (1, 1.0f32.to_bits()))?, to_expected_s16_f32.call(&mut store, (1, 1.0f32.to_bits()))?,
Err(1.0) Err(1.0)
); );
to_expected_s16_f32.post_return(&mut store)?;
let ret = to_expected_s16_f32.call(&mut store, (1, CANON_32BIT_NAN | 1))?; let ret = to_expected_s16_f32.call(&mut store, (1, CANON_32BIT_NAN | 1))?;
assert_eq!(ret.unwrap_err().to_bits(), CANON_32BIT_NAN); assert_eq!(ret.unwrap_err().to_bits(), CANON_32BIT_NAN);
to_expected_s16_f32.post_return(&mut store)?;
assert!(to_expected_s16_f32.call(&mut store, (2, 0)).is_err()); assert!(to_expected_s16_f32.call(&mut store, (2, 0)).is_err());
Ok(()) Ok(())

View File

@@ -0,0 +1,259 @@
use anyhow::Result;
use wasmtime::component::*;
use wasmtime::{Store, StoreContextMut};
#[test]
fn invalid_api() -> Result<()> {
let component = r#"
(component
(core module $m
(func (export "thunk1"))
(func (export "thunk2"))
)
(core instance $i (instantiate $m))
(func (export "thunk1")
(canon lift (core func $i "thunk1"))
)
(func (export "thunk2")
(canon lift (core func $i "thunk2"))
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, ());
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let thunk1 = instance.get_typed_func::<(), (), _>(&mut store, "thunk1")?;
let thunk2 = instance.get_typed_func::<(), (), _>(&mut store, "thunk2")?;
// Ensure that we can't call `post_return` before doing anything
let msg = "post_return can only be called after a function has previously been called";
assert_panics(|| drop(thunk1.post_return(&mut store)), msg);
assert_panics(|| drop(thunk2.post_return(&mut store)), msg);
// Schedule a "needs post return"
thunk1.call(&mut store, ())?;
// Ensure that we can't reenter the instance through either this function or
// another one.
let err = thunk1.call(&mut store, ()).unwrap_err();
assert!(
err.to_string()
.contains("cannot reenter component instance"),
"{}",
err
);
let err = thunk2.call(&mut store, ()).unwrap_err();
assert!(
err.to_string()
.contains("cannot reenter component instance"),
"{}",
err
);
// Calling post-return on the wrong function should panic
assert_panics(
|| drop(thunk2.post_return(&mut store)),
"calling post_return on wrong function",
);
// Actually execute the post-return
thunk1.post_return(&mut store)?;
// And now post-return should be invalid again.
assert_panics(|| drop(thunk1.post_return(&mut store)), msg);
assert_panics(|| drop(thunk2.post_return(&mut store)), msg);
Ok(())
}
#[track_caller]
fn assert_panics(f: impl FnOnce(), msg: &str) {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(()) => panic!("expected closure to panic"),
Err(e) => match e.downcast::<String>() {
Ok(s) => {
assert!(s.contains(msg), "bad panic: {}", s);
}
Err(e) => match e.downcast::<&'static str>() {
Ok(s) => assert!(s.contains(msg), "bad panic: {}", s),
Err(_) => panic!("bad panic"),
},
},
}
}
#[test]
fn invoke_post_return() -> Result<()> {
let component = r#"
(component
(import "f" (func $f))
(core func $f_lower
(canon lower (func $f))
)
(core module $m
(import "" "" (func $f))
(func (export "thunk"))
(func $post_return
call $f)
(export "post-return" (func $post_return))
)
(core instance $i (instantiate $m
(with "" (instance
(export "" (func $f_lower))
))
))
(func (export "thunk")
(canon lift
(core func $i "thunk")
(post-return (func $i "post-return"))
)
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let mut linker = Linker::new(&engine);
linker
.root()
.func_wrap("f", |mut store: StoreContextMut<'_, bool>| -> Result<()> {
assert!(!*store.data());
*store.data_mut() = true;
Ok(())
})?;
let instance = linker.instantiate(&mut store, &component)?;
let thunk = instance.get_typed_func::<(), (), _>(&mut store, "thunk")?;
assert!(!*store.data());
thunk.call(&mut store, ())?;
assert!(!*store.data());
thunk.post_return(&mut store)?;
assert!(*store.data());
Ok(())
}
#[test]
fn post_return_all_types() -> Result<()> {
let component = r#"
(component
(core module $m
(func (export "i32") (result i32)
i32.const 1)
(func (export "i64") (result i64)
i64.const 2)
(func (export "f32") (result f32)
f32.const 3)
(func (export "f64") (result f64)
f64.const 4)
(func (export "post-i32") (param i32)
local.get 0
i32.const 1
i32.ne
if unreachable end)
(func (export "post-i64") (param i64)
local.get 0
i64.const 2
i64.ne
if unreachable end)
(func (export "post-f32") (param f32)
local.get 0
f32.const 3
f32.ne
if unreachable end)
(func (export "post-f64") (param f64)
local.get 0
f64.const 4
f64.ne
if unreachable end)
)
(core instance $i (instantiate $m))
(func (export "i32") (result u32)
(canon lift (core func $i "i32") (post-return (func $i "post-i32")))
)
(func (export "i64") (result u64)
(canon lift (core func $i "i64") (post-return (func $i "post-i64")))
)
(func (export "f32") (result float32)
(canon lift (core func $i "f32") (post-return (func $i "post-f32")))
)
(func (export "f64") (result float64)
(canon lift (core func $i "f64") (post-return (func $i "post-f64")))
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let i32 = instance.get_typed_func::<(), u32, _>(&mut store, "i32")?;
let i64 = instance.get_typed_func::<(), u64, _>(&mut store, "i64")?;
let f32 = instance.get_typed_func::<(), f32, _>(&mut store, "f32")?;
let f64 = instance.get_typed_func::<(), f64, _>(&mut store, "f64")?;
assert_eq!(i32.call(&mut store, ())?, 1);
i32.post_return(&mut store)?;
assert_eq!(i64.call(&mut store, ())?, 2);
i64.post_return(&mut store)?;
assert_eq!(f32.call(&mut store, ())?, 3.);
f32.post_return(&mut store)?;
assert_eq!(f64.call(&mut store, ())?, 4.);
f64.post_return(&mut store)?;
Ok(())
}
#[test]
fn post_return_string() -> Result<()> {
let component = r#"
(component
(core module $m
(memory (export "memory") 1)
(func (export "get") (result i32)
(i32.store offset=0 (i32.const 8) (i32.const 100))
(i32.store offset=4 (i32.const 8) (i32.const 11))
i32.const 8
)
(func (export "post") (param i32)
local.get 0
i32.const 8
i32.ne
if unreachable end)
(data (i32.const 100) "hello world")
)
(core instance $i (instantiate $m))
(func (export "get") (result string)
(canon lift
(core func $i "get")
(post-return (func $i "post"))
(memory $i "memory")
)
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let get = instance.get_typed_func::<(), WasmStr, _>(&mut store, "get")?;
let s = get.call(&mut store, ())?;
assert_eq!(s.to_str(&store)?, "hello world");
get.post_return(&mut store)?;
Ok(())
}