diff --git a/crates/cranelift/src/compiler/component.rs b/crates/cranelift/src/compiler/component.rs index ffd3ed4abe..7d09427973 100644 --- a/crates/cranelift/src/compiler/component.rs +++ b/crates/cranelift/src/compiler/component.rs @@ -54,6 +54,7 @@ impl ComponentCompiler for Compiler { let CanonicalOptions { memory, realloc, + post_return, string_encoding, } = lowering.options; @@ -94,6 +95,11 @@ impl ComponentCompiler for Compiler { None => builder.ins().iconst(pointer_type, 0), }); + // A post-return option is only valid on `canon.lift`'d functions so no + // valid component should have this specified for a lowering which this + // trampoline compiler is interested in. + assert!(post_return.is_none()); + // string_encoding: StringEncoding host_sig.params.push(ir::AbiParam::new(ir::types::I8)); callee_args.push( diff --git a/crates/environ/src/component/info.rs b/crates/environ/src/component/info.rs index 1e54ffd7dd..1746518913 100644 --- a/crates/environ/src/component/info.rs +++ b/crates/environ/src/component/info.rs @@ -129,6 +129,9 @@ pub struct Component { /// `VMComponentContext`. pub num_runtime_reallocs: u32, + /// Same as `num_runtime_reallocs`, but for post-return functions. + pub num_runtime_post_returns: u32, + /// The number of lowered host functions (maximum `LoweredIndex`) needed to /// instantiate this component. pub num_lowerings: u32, @@ -180,6 +183,10 @@ pub enum GlobalInitializer { /// used as a `realloc` function. ExtractRealloc(ExtractRealloc), + /// Same as `ExtractMemory`, except it's extracting a function pointer to be + /// used as a `post-return` function. + ExtractPostReturn(ExtractPostReturn), + /// The `module` specified is saved into the runtime state at the next /// `RuntimeModuleIndex`, referred to later by `Export` definitions. SaveStaticModule(StaticModuleIndex), @@ -207,6 +214,15 @@ pub struct ExtractRealloc { pub def: CoreDef, } +/// Same as `ExtractMemory` but for the `post-return` canonical option. +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtractPostReturn { + /// The index of the post-return being defined. + pub index: RuntimePostReturnIndex, + /// Where this post-return is being extracted from. + pub def: CoreDef, +} + /// Different methods of instantiating a core wasm module. #[derive(Debug, Serialize, Deserialize)] pub enum InstantiateModule { @@ -361,7 +377,9 @@ pub struct CanonicalOptions { /// The realloc function used by these options, if specified. pub realloc: Option, - // TODO: need to represent post-return here as well + + /// The post-return function used by these options, if specified. + pub post_return: Option, } impl Default for CanonicalOptions { @@ -370,6 +388,7 @@ impl Default for CanonicalOptions { string_encoding: StringEncoding::Utf8, memory: None, realloc: None, + post_return: None, } } } diff --git a/crates/environ/src/component/translate/inline.rs b/crates/environ/src/component/translate/inline.rs index bac194ea5c..70b0a87337 100644 --- a/crates/environ/src/component/translate/inline.rs +++ b/crates/environ/src/component/translate/inline.rs @@ -62,6 +62,7 @@ pub(super) fn run( result: Component::default(), import_path_interner: Default::default(), runtime_realloc_interner: Default::default(), + runtime_post_return_interner: Default::default(), runtime_memory_interner: Default::default(), }; @@ -182,6 +183,7 @@ struct Inliner<'a> { // runtime instead of multiple times. import_path_interner: HashMap, RuntimeImportIndex>, runtime_realloc_interner: HashMap, + runtime_post_return_interner: HashMap, runtime_memory_interner: HashMap, RuntimeMemoryIndex>, } @@ -851,13 +853,29 @@ impl<'a> Inliner<'a> { index }) }); - if options.post_return.is_some() { - unimplemented!("post-return handling"); - } + let post_return = options.post_return.map(|i| { + let def = frame.funcs[i].clone(); + *self + .runtime_post_return_interner + .entry(def.clone()) + .or_insert_with(|| { + let index = + RuntimePostReturnIndex::from_u32(self.result.num_runtime_post_returns); + self.result.num_runtime_post_returns += 1; + self.result + .initializers + .push(GlobalInitializer::ExtractPostReturn(ExtractPostReturn { + index, + def, + })); + index + }) + }); CanonicalOptions { string_encoding: options.string_encoding, memory, realloc, + post_return, } } } diff --git a/crates/environ/src/component/vmcomponent_offsets.rs b/crates/environ/src/component/vmcomponent_offsets.rs index 02be59d3c6..47773bc628 100644 --- a/crates/environ/src/component/vmcomponent_offsets.rs +++ b/crates/environ/src/component/vmcomponent_offsets.rs @@ -2,16 +2,18 @@ // // struct VMComponentContext { // magic: u32, -// may_enter: u8, -// may_leave: u8, +// flags: u8, // store: *mut dyn Store, // lowering_anyfuncs: [VMCallerCheckedAnyfunc; component.num_lowerings], // lowerings: [VMLowering; component.num_lowerings], // memories: [*mut VMMemoryDefinition; component.num_memories], // reallocs: [*mut VMCallerCheckedAnyfunc; component.num_reallocs], +// post_returns: [*mut VMCallerCheckedAnyfunc; component.num_post_returns], // } -use crate::component::{Component, LoweredIndex, RuntimeMemoryIndex, RuntimeReallocIndex}; +use crate::component::{ + Component, LoweredIndex, RuntimeMemoryIndex, RuntimePostReturnIndex, RuntimeReallocIndex, +}; use crate::PtrSize; /// Equivalent of `VMCONTEXT_MAGIC` except for components. @@ -20,6 +22,18 @@ use crate::PtrSize; /// double-checked on `VMComponentContext::from_opaque`. pub const VMCOMPONENT_MAGIC: u32 = u32::from_le_bytes(*b"comp"); +/// Flag for the `VMComponentContext::flags` field which corresponds to the +/// canonical ABI flag `may_leave` +pub const VMCOMPONENT_FLAG_MAY_LEAVE: u8 = 1 << 0; + +/// Flag for the `VMComponentContext::flags` field which corresponds to the +/// canonical ABI flag `may_enter` +pub const VMCOMPONENT_FLAG_MAY_ENTER: u8 = 1 << 1; + +/// Flag for the `VMComponentContext::flags` field which is set whenever a +/// function is called to indicate that `post_return` must be called next. +pub const VMCOMPONENT_FLAG_NEEDS_POST_RETURN: u8 = 1 << 2; + /// Runtime offsets within a `VMComponentContext` for a specific component. #[derive(Debug, Clone, Copy)] pub struct VMComponentOffsets

{ @@ -32,16 +46,18 @@ pub struct VMComponentOffsets

{ pub num_runtime_memories: u32, /// The number of reallocs which are recorded in this component for options. pub num_runtime_reallocs: u32, + /// The number of post-returns which are recorded in this component for options. + pub num_runtime_post_returns: u32, // precalculated offsets of various member fields magic: u32, - may_enter: u32, - may_leave: u32, + flags: u32, store: u32, lowering_anyfuncs: u32, lowerings: u32, memories: u32, reallocs: u32, + post_returns: u32, size: u32, } @@ -60,14 +76,15 @@ impl VMComponentOffsets

{ num_lowerings: component.num_lowerings.try_into().unwrap(), num_runtime_memories: component.num_runtime_memories.try_into().unwrap(), num_runtime_reallocs: component.num_runtime_reallocs.try_into().unwrap(), + num_runtime_post_returns: component.num_runtime_post_returns.try_into().unwrap(), magic: 0, - may_enter: 0, - may_leave: 0, + flags: 0, store: 0, lowering_anyfuncs: 0, lowerings: 0, memories: 0, reallocs: 0, + post_returns: 0, size: 0, }; @@ -97,14 +114,14 @@ impl VMComponentOffsets

{ fields! { size(magic) = 4u32, - size(may_enter) = 1u32, - size(may_leave) = 1u32, + size(flags) = 1u32, align(u32::from(ret.ptr.size())), size(store) = cmul(2, ret.ptr.size()), size(lowering_anyfuncs) = cmul(ret.num_lowerings, ret.ptr.size_of_vmcaller_checked_anyfunc()), size(lowerings) = cmul(ret.num_lowerings, ret.ptr.size() * 2), size(memories) = cmul(ret.num_runtime_memories, ret.ptr.size()), size(reallocs) = cmul(ret.num_runtime_reallocs, ret.ptr.size()), + size(post_returns) = cmul(ret.num_runtime_post_returns, ret.ptr.size()), } ret.size = next_field_offset; @@ -129,16 +146,10 @@ impl VMComponentOffsets

{ self.magic } - /// The offset of the `may_leave` field. + /// The offset of the `flags` field. #[inline] - pub fn may_leave(&self) -> u32 { - self.may_leave - } - - /// The offset of the `may_enter` field. - #[inline] - pub fn may_enter(&self) -> u32 { - self.may_enter + pub fn flags(&self) -> u32 { + self.flags } /// The offset of the `store` field. @@ -232,6 +243,20 @@ impl VMComponentOffsets

{ self.runtime_reallocs() + index.as_u32() * u32::from(self.ptr.size()) } + /// The offset of the base of the `runtime_post_returns` field + #[inline] + pub fn runtime_post_returns(&self) -> u32 { + self.post_returns + } + + /// The offset of the `*mut VMCallerCheckedAnyfunc` for the runtime index + /// provided. + #[inline] + pub fn runtime_post_return(&self, index: RuntimePostReturnIndex) -> u32 { + assert!(index.as_u32() < self.num_runtime_post_returns); + self.runtime_post_returns() + index.as_u32() * u32::from(self.ptr.size()) + } + /// Return the size of the `VMComponentContext` allocation. #[inline] pub fn size_of_vmctx(&self) -> u32 { diff --git a/crates/runtime/src/component.rs b/crates/runtime/src/component.rs index baf1ad98a3..738a3a8a6c 100644 --- a/crates/runtime/src/component.rs +++ b/crates/runtime/src/component.rs @@ -17,8 +17,9 @@ use std::mem; use std::ops::Deref; use std::ptr::{self, NonNull}; use wasmtime_environ::component::{ - Component, LoweredIndex, RuntimeMemoryIndex, RuntimeReallocIndex, StringEncoding, - VMComponentOffsets, VMCOMPONENT_MAGIC, + Component, LoweredIndex, RuntimeMemoryIndex, RuntimePostReturnIndex, RuntimeReallocIndex, + StringEncoding, VMComponentOffsets, VMCOMPONENT_FLAG_MAY_ENTER, VMCOMPONENT_FLAG_MAY_LEAVE, + VMCOMPONENT_FLAG_NEEDS_POST_RETURN, VMCOMPONENT_MAGIC, }; use wasmtime_environ::HostPtr; @@ -63,6 +64,11 @@ pub struct ComponentInstance { /// signature that this callee corresponds to. /// * `nargs_and_results` - the size, in units of `ValRaw`, of /// `args_and_results`. +// +// FIXME: 7 arguments is probably too many. The `data` through `string-encoding` +// parameters should probably get packaged up into the `VMComponentContext`. +// Needs benchmarking one way or another though to figure out what the best +// balance is here. pub type VMLoweringCallee = extern "C" fn( vmctx: *mut VMOpaqueContext, data: *mut u8, @@ -104,6 +110,11 @@ pub struct VMComponentContext { _marker: marker::PhantomPinned, } +/// Flags stored in a `VMComponentContext` with values defined by +/// `VMCOMPONENT_FLAG_*` +#[repr(transparent)] +pub struct VMComponentFlags(u8); + impl ComponentInstance { /// Returns the layout corresponding to what would be an allocation of a /// `ComponentInstance` for the `offsets` provided. @@ -159,14 +170,8 @@ impl ComponentInstance { /// Returns a pointer to the "may leave" flag for this instance specified /// for canonical lowering and lifting operations. - pub fn may_leave(&self) -> *mut bool { - unsafe { self.vmctx_plus_offset(self.offsets.may_leave()) } - } - - /// Returns a pointer to the "may enter" flag for this instance specified - /// for canonical lowering and lifting operations. - pub fn may_enter(&self) -> *mut bool { - unsafe { self.vmctx_plus_offset(self.offsets.may_enter()) } + pub fn flags(&self) -> *mut VMComponentFlags { + unsafe { self.vmctx_plus_offset(self.offsets.flags()) } } /// Returns the store that this component was created with. @@ -202,6 +207,22 @@ impl ComponentInstance { ret } } + + /// Returns the post-return pointer corresponding to the index provided. + /// + /// This can only be called after `idx` has been initialized at runtime + /// during the instantiation process of a component. + pub fn runtime_post_return( + &self, + idx: RuntimePostReturnIndex, + ) -> NonNull { + unsafe { + let ret = *self.vmctx_plus_offset::>(self.offsets.runtime_post_return(idx)); + debug_assert!(ret.as_ptr() as usize != INVALID_PTR); + ret + } + } + /// Returns the host information for the lowered function at the index /// specified. /// @@ -264,6 +285,19 @@ impl ComponentInstance { } } + /// Same as `set_runtime_memory` but for post-return function pointers. + pub fn set_runtime_post_return( + &mut self, + idx: RuntimePostReturnIndex, + ptr: NonNull, + ) { + unsafe { + let storage = self.vmctx_plus_offset(self.offsets.runtime_post_return(idx)); + debug_assert!(*storage as usize == INVALID_PTR); + *storage = ptr.as_ptr(); + } + } + /// Configures a lowered host function with all the pieces necessary. /// /// * `idx` - the index that's being configured @@ -304,8 +338,7 @@ impl ComponentInstance { unsafe fn initialize_vmctx(&mut self, store: *mut dyn Store) { *self.vmctx_plus_offset(self.offsets.magic()) = VMCOMPONENT_MAGIC; - *self.may_leave() = true; - *self.may_enter() = true; + *self.flags() = VMComponentFlags::new(); *self.vmctx_plus_offset(self.offsets.store()) = store; // In debug mode set non-null bad values to all "pointer looking" bits @@ -332,6 +365,11 @@ impl ComponentInstance { let offset = self.offsets.runtime_realloc(i); *self.vmctx_plus_offset(offset) = INVALID_PTR; } + for i in 0..self.offsets.num_runtime_post_returns { + let i = RuntimePostReturnIndex::from_u32(i); + let offset = self.offsets.runtime_post_return(i); + *self.vmctx_plus_offset(offset) = INVALID_PTR; + } } } } @@ -409,6 +447,15 @@ impl OwnedComponentInstance { unsafe { self.instance_mut().set_runtime_realloc(idx, ptr) } } + /// See `ComponentInstance::set_runtime_post_return` + pub fn set_runtime_post_return( + &mut self, + idx: RuntimePostReturnIndex, + ptr: NonNull, + ) { + unsafe { self.instance_mut().set_runtime_post_return(idx, ptr) } + } + /// See `ComponentInstance::set_lowering` pub fn set_lowering( &mut self, @@ -459,3 +506,52 @@ impl VMOpaqueContext { ptr.cast() } } + +#[allow(missing_docs)] +impl VMComponentFlags { + fn new() -> VMComponentFlags { + VMComponentFlags(VMCOMPONENT_FLAG_MAY_LEAVE | VMCOMPONENT_FLAG_MAY_ENTER) + } + + #[inline] + pub fn may_leave(&self) -> bool { + self.0 & VMCOMPONENT_FLAG_MAY_LEAVE != 0 + } + + #[inline] + pub fn set_may_leave(&mut self, val: bool) { + if val { + self.0 |= VMCOMPONENT_FLAG_MAY_LEAVE; + } else { + self.0 &= !VMCOMPONENT_FLAG_MAY_LEAVE; + } + } + + #[inline] + pub fn may_enter(&self) -> bool { + self.0 & VMCOMPONENT_FLAG_MAY_ENTER != 0 + } + + #[inline] + pub fn set_may_enter(&mut self, val: bool) { + if val { + self.0 |= VMCOMPONENT_FLAG_MAY_ENTER; + } else { + self.0 &= !VMCOMPONENT_FLAG_MAY_ENTER; + } + } + + #[inline] + pub fn needs_post_return(&self) -> bool { + self.0 & VMCOMPONENT_FLAG_NEEDS_POST_RETURN != 0 + } + + #[inline] + pub fn set_needs_post_return(&mut self, val: bool) { + if val { + self.0 |= VMCOMPONENT_FLAG_NEEDS_POST_RETURN; + } else { + self.0 &= !VMCOMPONENT_FLAG_NEEDS_POST_RETURN; + } + } +} diff --git a/crates/wasmtime/src/component/func.rs b/crates/wasmtime/src/component/func.rs index ccf8fb996b..1dd152e339 100644 --- a/crates/wasmtime/src/component/func.rs +++ b/crates/wasmtime/src/component/func.rs @@ -1,6 +1,6 @@ use crate::component::instance::{Instance, InstanceData}; use crate::store::{StoreOpaque, Stored}; -use crate::AsContext; +use crate::{AsContext, ValRaw}; use anyhow::{Context, Result}; use std::mem::MaybeUninit; use std::ptr::NonNull; @@ -82,6 +82,8 @@ pub struct FuncData { types: Arc, options: Options, instance: Instance, + post_return: Option<(ExportFunction, VMTrampoline)>, + post_return_arg: Option, } impl Func { @@ -102,6 +104,11 @@ impl Func { .memory .map(|i| NonNull::new(data.instance().runtime_memory(i)).unwrap()); let realloc = options.realloc.map(|i| data.instance().runtime_realloc(i)); + let post_return = options.post_return.map(|i| { + let anyfunc = data.instance().runtime_post_return(i); + let trampoline = store.lookup_trampoline(unsafe { anyfunc.as_ref() }); + (ExportFunction { anyfunc }, trampoline) + }); let options = unsafe { Options::new(store.id(), memory, realloc, options.string_encoding) }; Func(store.store_data_mut().insert(FuncData { trampoline, @@ -110,6 +117,8 @@ impl Func { ty, types: data.component_types().clone(), instance: *instance, + post_return, + post_return_arg: None, })) } diff --git a/crates/wasmtime/src/component/func/host.rs b/crates/wasmtime/src/component/func/host.rs index 04c0c37b81..d4a68f2e72 100644 --- a/crates/wasmtime/src/component/func/host.rs +++ b/crates/wasmtime/src/component/func/host.rs @@ -8,7 +8,9 @@ use std::panic::{self, AssertUnwindSafe}; use std::ptr::NonNull; use std::sync::Arc; use wasmtime_environ::component::{ComponentTypes, StringEncoding, TypeFuncIndex}; -use wasmtime_runtime::component::{VMComponentContext, VMLowering, VMLoweringCallee}; +use wasmtime_runtime::component::{ + VMComponentContext, VMComponentFlags, VMLowering, VMLoweringCallee, +}; use wasmtime_runtime::{VMCallerCheckedAnyfunc, VMMemoryDefinition, VMOpaqueContext}; /// Trait representing host-defined functions that can be imported into a wasm @@ -134,8 +136,7 @@ where let cx = VMComponentContext::from_opaque(cx); let instance = (*cx).instance(); - let may_leave = (*instance).may_leave(); - let may_enter = (*instance).may_enter(); + let flags = (*instance).flags(); let mut cx = StoreContextMut::from_raw((*instance).store()); let options = Options::new( @@ -148,13 +149,13 @@ where // Perform a dynamic check that this instance can indeed be left. Exiting // the component is disallowed, for example, when the `realloc` function // calls a canonical import. - if !*may_leave { + if !(*flags).may_leave() { bail!("cannot leave component instance"); } // While we're lifting and lowering this instance cannot be reentered, so // unset the flag here. This is also reset back to `true` on exit. - let _reset_may_enter = unset_and_reset_on_drop(may_enter); + let _reset_may_enter = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_enter); // There's a 2x2 matrix of whether parameters and results are stored on the // stack or on the heap. Each of the 4 branches here have a different @@ -172,7 +173,7 @@ where let storage = cast_storage::>(storage); let params = Params::lift(cx.0, &options, &storage.assume_init_ref().args)?; let ret = closure(cx.as_context_mut(), params)?; - reset_may_leave = unset_and_reset_on_drop(may_leave); + reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave); ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?; } else { let storage = cast_storage::>(storage).assume_init_ref(); @@ -180,7 +181,7 @@ where let ret = closure(cx.as_context_mut(), params)?; let mut memory = MemoryMut::new(cx.as_context_mut(), &options); let ptr = validate_inbounds::(memory.as_slice_mut(), &storage.retptr)?; - reset_may_leave = unset_and_reset_on_drop(may_leave); + reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave); ret.store(&mut memory, ptr)?; } } else { @@ -191,7 +192,7 @@ where validate_inbounds::(memory.as_slice(), &storage.assume_init_ref().args)?; let params = Params::load(&memory, &memory.as_slice()[ptr..][..Params::size()])?; let ret = closure(cx.as_context_mut(), params)?; - reset_may_leave = unset_and_reset_on_drop(may_leave); + reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave); ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?; } else { let storage = cast_storage::>(storage).assume_init_ref(); @@ -200,27 +201,28 @@ where let ret = closure(cx.as_context_mut(), params)?; let mut memory = MemoryMut::new(cx.as_context_mut(), &options); let ptr = validate_inbounds::(memory.as_slice_mut(), &storage.retptr)?; - reset_may_leave = unset_and_reset_on_drop(may_leave); + reset_may_leave = unset_and_reset_on_drop(flags, VMComponentFlags::set_may_leave); ret.store(&mut memory, ptr)?; } } - // TODO: need to call `post-return` before this `drop` drop(reset_may_leave); return Ok(()); - unsafe fn unset_and_reset_on_drop(slot: *mut bool) -> impl Drop { - debug_assert!(*slot); - *slot = false; - return Reset(slot); + unsafe fn unset_and_reset_on_drop( + slot: *mut VMComponentFlags, + set: fn(&mut VMComponentFlags, bool), + ) -> impl Drop { + set(&mut *slot, false); + return Reset(slot, set); - struct Reset(*mut bool); + struct Reset(*mut VMComponentFlags, fn(&mut VMComponentFlags, bool)); impl Drop for Reset { fn drop(&mut self) { unsafe { - (*self.0) = true; + (self.1)(&mut *self.0, true); } } } diff --git a/crates/wasmtime/src/component/func/typed.rs b/crates/wasmtime/src/component/func/typed.rs index c25520b3b8..36064ba849 100644 --- a/crates/wasmtime/src/component/func/typed.rs +++ b/crates/wasmtime/src/component/func/typed.rs @@ -90,6 +90,31 @@ where /// the `store` provided. The `params` are copied into WebAssembly memory /// as appropriate and a core wasm function is invoked. /// + /// # Post-return + /// + /// In the component model each function can have a "post return" specified + /// which allows cleaning up the arguments returned to the host. For example + /// if WebAssembly returns a string to the host then it might be a uniquely + /// allocated string which, after the host finishes processing it, needs to + /// be deallocated in the wasm instance's own linear memory to prevent + /// memory leaks in wasm itself. The `post-return` canonical abi option is + /// used to configured this. + /// + /// To accommodate this feature of the component model after invoking a + /// function via [`TypedFunc::call`] you must next invoke + /// [`TypedFunc::post_return`]. Note that the return value of the function + /// should be processed between these two function calls. The return value + /// continues to be usable from an embedder's perspective after + /// `post_return` is called, but after `post_return` is invoked it may no + /// longer retain the same value that the wasm module originally returned. + /// + /// Also note that [`TypedFunc::post_return`] must be invoked irrespective + /// of whether the canonical ABI option `post-return` was configured or not. + /// This means that embedders must unconditionally call + /// [`TypedFunc::post_return`] when a function returns. If this function + /// call returns an error, however, then [`TypedFunc::post_return`] is not + /// required. + /// /// # Errors /// /// This function can return an error for a number of reasons: @@ -99,6 +124,10 @@ where /// * If the wasm provides bad allocation pointers when copying arguments /// into memory. /// * If the wasm returns a value which violates the canonical ABI. + /// * If this function's instances cannot be entered, for example if the + /// instance is currently calling a host function. + /// * If a previous function call occurred and the corresponding + /// `post_return` hasn't been invoked yet. /// /// In general there are many ways that things could go wrong when copying /// types in and out of a wasm module with the canonical ABI, and certain @@ -300,18 +329,17 @@ where assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align); let instance = store.0[instance.0].as_ref().unwrap().instance(); - let may_enter = instance.may_enter(); - let may_leave = instance.may_leave(); + let flags = instance.flags(); unsafe { - if !*may_enter { + if !(*flags).may_enter() { bail!("cannot reenter component instance"); } - debug_assert!(*may_leave); + debug_assert!((*flags).may_leave()); - *may_leave = false; + (*flags).set_may_leave(false); let result = lower(store, &options, params, map_maybe_uninit!(space.params)); - *may_leave = true; + (*flags).set_may_leave(true); result?; // This is unsafe as we are providing the guarantee that all the @@ -336,18 +364,139 @@ where // `[ValRaw]`, and additionally they should have the correct types // for the function we just called (which filled in the return // values). - *may_enter = false; - let result = lift( - store.0, - &options, - map_maybe_uninit!(space.ret).assume_init_ref(), - ); + let ret = map_maybe_uninit!(space.ret).assume_init_ref(); - // TODO: this technically needs to happen only after the - // `post-return` is called. - *may_enter = true; - return result; + // Lift the result into the host while managing post-return state + // here as well. + // + // Initially the `may_enter` flag is set to `false` for this + // component instance and additionally we set a flag indicating that + // a post-return is required. This isn't specified by the component + // model itself but is used for our implementation of the API of + // `post_return` as a separate function call. + // + // FIXME(WebAssembly/component-model#55) it's not really clear what + // the semantics should be in the face of a lift error/trap. For now + // the flags are reset so the instance can continue to be reused in + // tests but that probably isn't what's desired. + // + // Otherwise though after a successful lift the return value of the + // function, which is currently required to be 0 or 1 values + // according to the canonical ABI, is saved within the `Store`'s + // `FuncData`. This'll later get used in post-return. + (*flags).set_may_enter(false); + (*flags).set_needs_post_return(true); + match lift(store.0, &options, ret) { + Ok(val) => { + let ret_slice = cast_storage(ret); + let data = &mut store.0[self.func.0]; + assert!(data.post_return_arg.is_none()); + match ret_slice.len() { + 0 => data.post_return_arg = Some(ValRaw::i32(0)), + 1 => data.post_return_arg = Some(ret_slice[0]), + _ => unreachable!(), + } + return Ok(val); + } + Err(err) => { + (*flags).set_may_enter(true); + (*flags).set_needs_post_return(false); + return Err(err); + } + } } + + unsafe fn cast_storage(storage: &T) -> &[ValRaw] { + assert!(std::mem::size_of_val(storage) % std::mem::size_of::() == 0); + assert!(std::mem::align_of_val(storage) == std::mem::align_of::()); + + std::slice::from_raw_parts( + (storage as *const T).cast(), + mem::size_of_val(storage) / mem::size_of::(), + ) + } + } + + /// Invokes the `post-return` canonical ABI option, if specified, after a + /// [`TypedFunc::call`] has finished. + /// + /// For some more information on when to use this function see the + /// documentation for post-return in the [`TypedFunc::call`] method. + /// Otherwise though this function is a required method call after a + /// [`TypedFunc::call`] completes successfully. After the embedder has + /// finished processing the return value then this function must be invoked. + /// + /// # Errors + /// + /// This function will return an error in the case of a WebAssembly trap + /// happening during the execution of the `post-return` function, if + /// specified. + /// + /// # Panics + /// + /// This function will panic if it's not called under the correct + /// conditions. This can only be called after a previous invocation of + /// [`TypedFunc::call`] completes successfully, and this function can only + /// be called for the same [`TypedFunc`] that was `call`'d. + /// + /// If this function is called when [`TypedFunc::call`] was not previously + /// called, then it will panic. If a different [`TypedFunc`] for the same + /// component instance was invoked then this function will also panic + /// because the `post-return` needs to happen for the other function. + pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> { + let mut store = store.as_context_mut(); + let data = &mut store.0[self.func.0]; + let instance = data.instance; + let post_return = data.post_return; + let post_return_arg = data.post_return_arg.take(); + let instance = store.0[instance.0].as_ref().unwrap().instance(); + let flags = instance.flags(); + + unsafe { + // First assert that the instance is in a "needs post return" state. + // This will ensure that the previous action on the instance was a + // function call above. This flag is only set after a component + // function returns so this also can't be called (as expected) + // during a host import for example. + // + // Note, though, that this assert is not sufficient because it just + // means some function on this instance needs its post-return + // called. We need a precise post-return for a particular function + // which is the second assert here (the `.expect`). That will assert + // that this function itself needs to have its post-return called. + // + // The theory at least is that these two asserts ensure component + // model semantics are upheld where the host properly calls + // `post_return` on the right function despite the call being a + // separate step in the API. + assert!( + (*flags).needs_post_return(), + "post_return can only be called after a function has previously been called", + ); + let post_return_arg = post_return_arg.expect("calling post_return on wrong function"); + + // This is a sanity-check assert which shouldn't ever trip. + assert!(!(*flags).may_enter()); + + // With the state of the world validated these flags are updated to + // their component-model-defined states. + (*flags).set_may_enter(true); + (*flags).set_needs_post_return(false); + + // And finally if the function actually had a `post-return` + // configured in its canonical options that's executed here. + let (func, trampoline) = match post_return { + Some(pair) => pair, + None => return Ok(()), + }; + crate::Func::call_unchecked_raw( + &mut store, + func.anyfunc, + trampoline, + &post_return_arg as *const ValRaw as *mut ValRaw, + )?; + } + Ok(()) } } diff --git a/crates/wasmtime/src/component/instance.rs b/crates/wasmtime/src/component/instance.rs index 313b68fba9..8d83984898 100644 --- a/crates/wasmtime/src/component/instance.rs +++ b/crates/wasmtime/src/component/instance.rs @@ -7,9 +7,9 @@ use anyhow::{anyhow, Context, Result}; use std::marker; use std::sync::Arc; use wasmtime_environ::component::{ - ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractRealloc, - GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex, RuntimeInstanceIndex, - RuntimeModuleIndex, + ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractPostReturn, + ExtractRealloc, GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex, + RuntimeInstanceIndex, RuntimeModuleIndex, }; use wasmtime_environ::{EntityIndex, PrimaryMap}; use wasmtime_runtime::component::{ComponentInstance, OwnedComponentInstance}; @@ -278,6 +278,10 @@ impl<'a> Instantiator<'a> { self.extract_realloc(store.0, realloc) } + GlobalInitializer::ExtractPostReturn(post_return) => { + self.extract_post_return(store.0, post_return) + } + GlobalInitializer::SaveStaticModule(idx) => { self.data .exported_modules @@ -338,6 +342,16 @@ impl<'a> Instantiator<'a> { self.data.state.set_runtime_realloc(realloc.index, anyfunc); } + fn extract_post_return(&mut self, store: &mut StoreOpaque, post_return: &ExtractPostReturn) { + let anyfunc = match self.data.lookup_def(store, &post_return.def) { + wasmtime_runtime::Export::Function(f) => f.anyfunc, + _ => unreachable!(), + }; + self.data + .state + .set_runtime_post_return(post_return.index, anyfunc); + } + fn build_imports<'b>( &mut self, store: &mut StoreOpaque, diff --git a/tests/all/component_model.rs b/tests/all/component_model.rs index 3ea3021483..4049d9ddcd 100644 --- a/tests/all/component_model.rs +++ b/tests/all/component_model.rs @@ -5,6 +5,7 @@ use wasmtime::{Config, Engine}; mod func; mod import; mod nested; +mod post_return; // A simple bump allocator which can be used with modules const REALLOC_AND_FREE: &str = r#" diff --git a/tests/all/component_model/func.rs b/tests/all/component_model/func.rs index 23024a468a..6c89d8c591 100644 --- a/tests/all/component_model/func.rs +++ b/tests/all/component_model/func.rs @@ -3,7 +3,7 @@ use anyhow::Result; use std::rc::Rc; use std::sync::Arc; use wasmtime::component::*; -use wasmtime::{Store, StoreContextMut, Trap, TrapCode}; +use wasmtime::{AsContextMut, Store, StoreContextMut, Trap, TrapCode}; const CANON_32BIT_NAN: u32 = 0b01111111110000000000000000000000; const CANON_64BIT_NAN: u64 = 0b0111111111111000000000000000000000000000000000000000000000000000; @@ -32,7 +32,7 @@ fn thunks() -> Result<()> { let instance = Linker::new(&engine).instantiate(&mut store, &component)?; instance .get_typed_func::<(), (), _>(&mut store, "thunk")? - .call(&mut store, ())?; + .call_and_post_return(&mut store, ())?; let err = instance .get_typed_func::<(), (), _>(&mut store, "thunk-trap")? .call(&mut store, ()) @@ -193,28 +193,28 @@ fn integers() -> Result<()> { // Passing in 100 is valid for all primitives instance .get_typed_func::<(u8,), (), _>(&mut store, "take-u8")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; instance .get_typed_func::<(i8,), (), _>(&mut store, "take-s8")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; instance .get_typed_func::<(u16,), (), _>(&mut store, "take-u16")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; instance .get_typed_func::<(i16,), (), _>(&mut store, "take-s16")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; instance .get_typed_func::<(u32,), (), _>(&mut store, "take-u32")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; instance .get_typed_func::<(i32,), (), _>(&mut store, "take-s32")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; instance .get_typed_func::<(u64,), (), _>(&mut store, "take-u64")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; instance .get_typed_func::<(i64,), (), _>(&mut store, "take-s64")? - .call(&mut store, (100,))?; + .call_and_post_return(&mut store, (100,))?; // This specific wasm instance traps if any value other than 100 is passed instance @@ -262,49 +262,49 @@ fn integers() -> Result<()> { assert_eq!( instance .get_typed_func::<(), u8, _>(&mut store, "ret-u8")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); assert_eq!( instance .get_typed_func::<(), i8, _>(&mut store, "ret-s8")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); assert_eq!( instance .get_typed_func::<(), u16, _>(&mut store, "ret-u16")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); assert_eq!( instance .get_typed_func::<(), i16, _>(&mut store, "ret-s16")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); assert_eq!( instance .get_typed_func::<(), u32, _>(&mut store, "ret-u32")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); assert_eq!( instance .get_typed_func::<(), i32, _>(&mut store, "ret-s32")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); assert_eq!( instance .get_typed_func::<(), u64, _>(&mut store, "ret-u64")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); assert_eq!( instance .get_typed_func::<(), i64, _>(&mut store, "ret-s64")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0 ); @@ -312,49 +312,49 @@ fn integers() -> Result<()> { assert_eq!( instance .get_typed_func::<(), u8, _>(&mut store, "retm1-u8")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0xff ); assert_eq!( instance .get_typed_func::<(), i8, _>(&mut store, "retm1-s8")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, -1 ); assert_eq!( instance .get_typed_func::<(), u16, _>(&mut store, "retm1-u16")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0xffff ); assert_eq!( instance .get_typed_func::<(), i16, _>(&mut store, "retm1-s16")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, -1 ); assert_eq!( instance .get_typed_func::<(), u32, _>(&mut store, "retm1-u32")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0xffffffff ); assert_eq!( instance .get_typed_func::<(), i32, _>(&mut store, "retm1-s32")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, -1 ); assert_eq!( instance .get_typed_func::<(), u64, _>(&mut store, "retm1-u64")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, 0xffffffff_ffffffff ); assert_eq!( instance .get_typed_func::<(), i64, _>(&mut store, "retm1-s64")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, -1 ); @@ -363,43 +363,59 @@ fn integers() -> Result<()> { assert_eq!( instance .get_typed_func::<(), u8, _>(&mut store, "retbig-u8")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, ret as u8, ); assert_eq!( instance .get_typed_func::<(), i8, _>(&mut store, "retbig-s8")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, ret as i8, ); assert_eq!( instance .get_typed_func::<(), u16, _>(&mut store, "retbig-u16")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, ret as u16, ); assert_eq!( instance .get_typed_func::<(), i16, _>(&mut store, "retbig-s16")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, ret as i16, ); assert_eq!( instance .get_typed_func::<(), u32, _>(&mut store, "retbig-u32")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, ret, ); assert_eq!( instance .get_typed_func::<(), i32, _>(&mut store, "retbig-s32")? - .call(&mut store, ())?, + .call_and_post_return(&mut store, ())?, ret as i32, ); Ok(()) } +trait TypedFuncExt { + fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result; +} + +impl TypedFuncExt for TypedFunc +where + P: ComponentParams + Lower, + R: Lift, +{ + fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result { + let result = self.call(&mut store, params)?; + self.post_return(&mut store)?; + Ok(result) + } +} + #[test] fn type_layers() -> Result<()> { let component = r#" @@ -425,19 +441,19 @@ fn type_layers() -> Result<()> { instance .get_typed_func::<(Box,), (), _>(&mut store, "take-u32")? - .call(&mut store, (Box::new(2),))?; + .call_and_post_return(&mut store, (Box::new(2),))?; instance .get_typed_func::<(&u32,), (), _>(&mut store, "take-u32")? - .call(&mut store, (&2,))?; + .call_and_post_return(&mut store, (&2,))?; instance .get_typed_func::<(Rc,), (), _>(&mut store, "take-u32")? - .call(&mut store, (Rc::new(2),))?; + .call_and_post_return(&mut store, (Rc::new(2),))?; instance .get_typed_func::<(Arc,), (), _>(&mut store, "take-u32")? - .call(&mut store, (Arc::new(2),))?; + .call_and_post_return(&mut store, (Arc::new(2),))?; instance .get_typed_func::<(&Box>>,), (), _>(&mut store, "take-u32")? - .call(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?; + .call_and_post_return(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?; Ok(()) } @@ -491,9 +507,13 @@ fn floats() -> Result<()> { let u64_to_f64 = instance.get_typed_func::<(u64,), f64, _>(&mut store, "u64-to-f64")?; assert_eq!(f32_to_u32.call(&mut store, (1.0,))?, 1.0f32.to_bits()); + f32_to_u32.post_return(&mut store)?; assert_eq!(f64_to_u64.call(&mut store, (2.0,))?, 2.0f64.to_bits()); + f64_to_u64.post_return(&mut store)?; assert_eq!(u32_to_f32.call(&mut store, (3.0f32.to_bits(),))?, 3.0); + u32_to_f32.post_return(&mut store)?; assert_eq!(u64_to_f64.call(&mut store, (4.0f64.to_bits(),))?, 4.0); + u64_to_f64.post_return(&mut store)?; assert_eq!( u32_to_f32 @@ -501,21 +521,25 @@ fn floats() -> Result<()> { .to_bits(), CANON_32BIT_NAN ); + u32_to_f32.post_return(&mut store)?; assert_eq!( u64_to_f64 .call(&mut store, (CANON_64BIT_NAN | 1,))? .to_bits(), CANON_64BIT_NAN ); + u64_to_f64.post_return(&mut store)?; assert_eq!( f32_to_u32.call(&mut store, (f32::from_bits(CANON_32BIT_NAN | 1),))?, CANON_32BIT_NAN ); + f32_to_u32.post_return(&mut store)?; assert_eq!( f64_to_u64.call(&mut store, (f64::from_bits(CANON_64BIT_NAN | 1),))?, CANON_64BIT_NAN ); + f64_to_u64.post_return(&mut store)?; Ok(()) } @@ -546,10 +570,15 @@ fn bools() -> Result<()> { let bool_to_u32 = instance.get_typed_func::<(bool,), u32, _>(&mut store, "bool-to-u32")?; assert_eq!(bool_to_u32.call(&mut store, (false,))?, 0); + bool_to_u32.post_return(&mut store)?; assert_eq!(bool_to_u32.call(&mut store, (true,))?, 1); + bool_to_u32.post_return(&mut store)?; assert_eq!(u32_to_bool.call(&mut store, (0,))?, false); + u32_to_bool.post_return(&mut store)?; assert_eq!(u32_to_bool.call(&mut store, (1,))?, true); + u32_to_bool.post_return(&mut store)?; assert_eq!(u32_to_bool.call(&mut store, (2,))?, true); + u32_to_bool.post_return(&mut store)?; Ok(()) } @@ -581,7 +610,9 @@ fn chars() -> Result<()> { let mut roundtrip = |x: char| -> Result<()> { assert_eq!(char_to_u32.call(&mut store, (x,))?, x as u32); + char_to_u32.post_return(&mut store)?; assert_eq!(u32_to_char.call(&mut store, (x as u32,))?, x); + u32_to_char.post_return(&mut store)?; Ok(()) }; @@ -644,7 +675,7 @@ fn tuple_result() -> Result<()> { let input = (-1, 100, 3.0, 100.0); let output = instance .get_typed_func::<(i8, u16, f32, f64), (i8, u16, f32, f64), _>(&mut store, "tuple")? - .call(&mut store, input)?; + .call_and_post_return(&mut store, input)?; assert_eq!(input, output); let invalid_func = @@ -735,16 +766,20 @@ fn strings() -> Result<()> { let mut roundtrip = |x: &str| -> Result<()> { let ret = list8_to_str.call(&mut store, (x.as_bytes(),))?; assert_eq!(ret.to_str(&store)?, x); + list8_to_str.post_return(&mut store)?; let utf16 = x.encode_utf16().collect::>(); let ret = list16_to_str.call(&mut store, (&utf16[..],))?; assert_eq!(ret.to_str(&store)?, x); + list16_to_str.post_return(&mut store)?; let ret = str_to_list8.call(&mut store, (x,))?; assert_eq!(ret.iter(&store).collect::>>()?, x.as_bytes()); + str_to_list8.post_return(&mut store)?; let ret = str_to_list16.call(&mut store, (x,))?; assert_eq!(ret.iter(&store).collect::>>()?, utf16,); + str_to_list16.post_return(&mut store)?; Ok(()) }; @@ -758,22 +793,27 @@ fn strings() -> Result<()> { let ret = list8_to_str.call(&mut store, (b"\xff",))?; let err = ret.to_str(&store).unwrap_err(); assert!(err.to_string().contains("invalid utf-8"), "{}", err); + list8_to_str.post_return(&mut store)?; let ret = list8_to_str.call(&mut store, (b"hello there \xff invalid",))?; let err = ret.to_str(&store).unwrap_err(); assert!(err.to_string().contains("invalid utf-8"), "{}", err); + list8_to_str.post_return(&mut store)?; let ret = list16_to_str.call(&mut store, (&[0xd800],))?; let err = ret.to_str(&store).unwrap_err(); assert!(err.to_string().contains("unpaired surrogate"), "{}", err); + list16_to_str.post_return(&mut store)?; let ret = list16_to_str.call(&mut store, (&[0xdfff],))?; let err = ret.to_str(&store).unwrap_err(); assert!(err.to_string().contains("unpaired surrogate"), "{}", err); + list16_to_str.post_return(&mut store)?; let ret = list16_to_str.call(&mut store, (&[0xd800, 0xff00],))?; let err = ret.to_str(&store).unwrap_err(); assert!(err.to_string().contains("unpaired surrogate"), "{}", err); + list16_to_str.post_return(&mut store)?; Ok(()) } @@ -1123,10 +1163,10 @@ fn some_traps() -> Result<()> { instance .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")? - .call(&mut store, (&[],))?; + .call_and_post_return(&mut store, (&[],))?; instance .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")? - .call(&mut store, (&[1, 2, 3, 4],))?; + .call_and_post_return(&mut store, (&[1, 2, 3, 4],))?; let err = instance .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")? .call(&mut store, (&[1, 2, 3, 4, 5],)) @@ -1134,10 +1174,10 @@ fn some_traps() -> Result<()> { assert_oob(&err); instance .get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")? - .call(&mut store, ("",))?; + .call_and_post_return(&mut store, ("",))?; instance .get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")? - .call(&mut store, ("abcd",))?; + .call_and_post_return(&mut store, ("abcd",))?; let err = instance .get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")? .call(&mut store, ("abcde",)) @@ -1216,12 +1256,15 @@ fn char_bool_memory() -> Result<()> { let ret = func.call(&mut store, (0, 'a' as u32))?; assert_eq!(ret, (false, 'a')); + func.post_return(&mut store)?; let ret = func.call(&mut store, (1, '🍰' as u32))?; assert_eq!(ret, (true, '🍰')); + func.post_return(&mut store)?; let ret = func.call(&mut store, (2, 'a' as u32))?; assert_eq!(ret, (true, 'a')); + func.post_return(&mut store)?; assert!(func.call(&mut store, (0, 0xd800)).is_err()); @@ -1437,22 +1480,30 @@ fn option() -> Result<()> { let option_unit_to_u32 = instance.get_typed_func::<(Option<()>,), u32, _>(&mut store, "option-unit-to-u32")?; assert_eq!(option_unit_to_u32.call(&mut store, (None,))?, 0); + option_unit_to_u32.post_return(&mut store)?; assert_eq!(option_unit_to_u32.call(&mut store, (Some(()),))?, 1); + option_unit_to_u32.post_return(&mut store)?; let option_u8_to_tuple = instance .get_typed_func::<(Option,), (u32, u32), _>(&mut store, "option-u8-to-tuple")?; assert_eq!(option_u8_to_tuple.call(&mut store, (None,))?, (0, 0)); + option_u8_to_tuple.post_return(&mut store)?; assert_eq!(option_u8_to_tuple.call(&mut store, (Some(0),))?, (1, 0)); + option_u8_to_tuple.post_return(&mut store)?; assert_eq!(option_u8_to_tuple.call(&mut store, (Some(100),))?, (1, 100)); + option_u8_to_tuple.post_return(&mut store)?; let option_u32_to_tuple = instance .get_typed_func::<(Option,), (u32, u32), _>(&mut store, "option-u32-to-tuple")?; assert_eq!(option_u32_to_tuple.call(&mut store, (None,))?, (0, 0)); + option_u32_to_tuple.post_return(&mut store)?; assert_eq!(option_u32_to_tuple.call(&mut store, (Some(0),))?, (1, 0)); + option_u32_to_tuple.post_return(&mut store)?; assert_eq!( option_u32_to_tuple.call(&mut store, (Some(100),))?, (1, 100) ); + option_u32_to_tuple.post_return(&mut store)?; let option_string_to_tuple = instance.get_typed_func::<(Option<&str>,), (u32, WasmStr), _>( &mut store, @@ -1461,45 +1512,59 @@ fn option() -> Result<()> { let (a, b) = option_string_to_tuple.call(&mut store, (None,))?; assert_eq!(a, 0); assert_eq!(b.to_str(&store)?, ""); + option_string_to_tuple.post_return(&mut store)?; let (a, b) = option_string_to_tuple.call(&mut store, (Some(""),))?; assert_eq!(a, 1); assert_eq!(b.to_str(&store)?, ""); + option_string_to_tuple.post_return(&mut store)?; let (a, b) = option_string_to_tuple.call(&mut store, (Some("hello"),))?; assert_eq!(a, 1); assert_eq!(b.to_str(&store)?, "hello"); + option_string_to_tuple.post_return(&mut store)?; let to_option_unit = instance.get_typed_func::<(u32,), Option<()>, _>(&mut store, "to-option-unit")?; assert_eq!(to_option_unit.call(&mut store, (0,))?, None); + to_option_unit.post_return(&mut store)?; assert_eq!(to_option_unit.call(&mut store, (1,))?, Some(())); + to_option_unit.post_return(&mut store)?; let err = to_option_unit.call(&mut store, (2,)).unwrap_err(); assert!(err.to_string().contains("invalid option"), "{}", err); let to_option_u8 = instance.get_typed_func::<(u32, u32), Option, _>(&mut store, "to-option-u8")?; assert_eq!(to_option_u8.call(&mut store, (0x00_00, 0))?, None); + to_option_u8.post_return(&mut store)?; assert_eq!(to_option_u8.call(&mut store, (0x00_01, 0))?, Some(0)); + to_option_u8.post_return(&mut store)?; assert_eq!(to_option_u8.call(&mut store, (0xfd_01, 0))?, Some(0xfd)); + to_option_u8.post_return(&mut store)?; assert!(to_option_u8.call(&mut store, (0x00_02, 0)).is_err()); let to_option_u32 = instance.get_typed_func::<(u32, u32), Option, _>(&mut store, "to-option-u32")?; assert_eq!(to_option_u32.call(&mut store, (0, 0))?, None); + to_option_u32.post_return(&mut store)?; assert_eq!(to_option_u32.call(&mut store, (1, 0))?, Some(0)); + to_option_u32.post_return(&mut store)?; assert_eq!( to_option_u32.call(&mut store, (1, 0x1234fead))?, Some(0x1234fead) ); + to_option_u32.post_return(&mut store)?; assert!(to_option_u32.call(&mut store, (2, 0)).is_err()); let to_option_string = instance .get_typed_func::<(u32, &str), Option, _>(&mut store, "to-option-string")?; let ret = to_option_string.call(&mut store, (0, ""))?; assert!(ret.is_none()); + to_option_string.post_return(&mut store)?; let ret = to_option_string.call(&mut store, (1, ""))?; assert_eq!(ret.unwrap().to_str(&store)?, ""); + to_option_string.post_return(&mut store)?; let ret = to_option_string.call(&mut store, (1, "cheesecake"))?; assert_eq!(ret.unwrap().to_str(&store)?, "cheesecake"); + to_option_string.post_return(&mut store)?; assert!(to_option_string.call(&mut store, (2, "")).is_err()); Ok(()) @@ -1592,15 +1657,19 @@ fn expected() -> Result<()> { let take_expected_unit = instance.get_typed_func::<(Result<(), ()>,), u32, _>(&mut store, "take-expected-unit")?; assert_eq!(take_expected_unit.call(&mut store, (Ok(()),))?, 0); + take_expected_unit.post_return(&mut store)?; assert_eq!(take_expected_unit.call(&mut store, (Err(()),))?, 1); + take_expected_unit.post_return(&mut store)?; let take_expected_u8_f32 = instance .get_typed_func::<(Result,), (u32, u32), _>(&mut store, "take-expected-u8-f32")?; assert_eq!(take_expected_u8_f32.call(&mut store, (Ok(1),))?, (0, 1)); + take_expected_u8_f32.post_return(&mut store)?; assert_eq!( take_expected_u8_f32.call(&mut store, (Err(2.0),))?, (1, 2.0f32.to_bits()) ); + take_expected_u8_f32.post_return(&mut store)?; let take_expected_string = instance .get_typed_func::<(Result<&str, &[u8]>,), (u32, WasmStr), _>( @@ -1610,27 +1679,35 @@ fn expected() -> Result<()> { let (a, b) = take_expected_string.call(&mut store, (Ok("hello"),))?; assert_eq!(a, 0); assert_eq!(b.to_str(&store)?, "hello"); + take_expected_string.post_return(&mut store)?; let (a, b) = take_expected_string.call(&mut store, (Err(b"goodbye"),))?; assert_eq!(a, 1); assert_eq!(b.to_str(&store)?, "goodbye"); + take_expected_string.post_return(&mut store)?; let to_expected_unit = instance.get_typed_func::<(u32,), Result<(), ()>, _>(&mut store, "to-expected-unit")?; assert_eq!(to_expected_unit.call(&mut store, (0,))?, Ok(())); + to_expected_unit.post_return(&mut store)?; assert_eq!(to_expected_unit.call(&mut store, (1,))?, Err(())); + to_expected_unit.post_return(&mut store)?; let err = to_expected_unit.call(&mut store, (2,)).unwrap_err(); assert!(err.to_string().contains("invalid expected"), "{}", err); let to_expected_s16_f32 = instance .get_typed_func::<(u32, u32), Result, _>(&mut store, "to-expected-s16-f32")?; assert_eq!(to_expected_s16_f32.call(&mut store, (0, 0))?, Ok(0)); + to_expected_s16_f32.post_return(&mut store)?; assert_eq!(to_expected_s16_f32.call(&mut store, (0, 100))?, Ok(100)); + to_expected_s16_f32.post_return(&mut store)?; assert_eq!( to_expected_s16_f32.call(&mut store, (1, 1.0f32.to_bits()))?, Err(1.0) ); + to_expected_s16_f32.post_return(&mut store)?; let ret = to_expected_s16_f32.call(&mut store, (1, CANON_32BIT_NAN | 1))?; assert_eq!(ret.unwrap_err().to_bits(), CANON_32BIT_NAN); + to_expected_s16_f32.post_return(&mut store)?; assert!(to_expected_s16_f32.call(&mut store, (2, 0)).is_err()); Ok(()) diff --git a/tests/all/component_model/post_return.rs b/tests/all/component_model/post_return.rs new file mode 100644 index 0000000000..20e0de8973 --- /dev/null +++ b/tests/all/component_model/post_return.rs @@ -0,0 +1,259 @@ +use anyhow::Result; +use wasmtime::component::*; +use wasmtime::{Store, StoreContextMut}; + +#[test] +fn invalid_api() -> Result<()> { + let component = r#" + (component + (core module $m + (func (export "thunk1")) + (func (export "thunk2")) + ) + (core instance $i (instantiate $m)) + (func (export "thunk1") + (canon lift (core func $i "thunk1")) + ) + (func (export "thunk2") + (canon lift (core func $i "thunk2")) + ) + ) + "#; + + let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, ()); + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let thunk1 = instance.get_typed_func::<(), (), _>(&mut store, "thunk1")?; + let thunk2 = instance.get_typed_func::<(), (), _>(&mut store, "thunk2")?; + + // Ensure that we can't call `post_return` before doing anything + let msg = "post_return can only be called after a function has previously been called"; + assert_panics(|| drop(thunk1.post_return(&mut store)), msg); + assert_panics(|| drop(thunk2.post_return(&mut store)), msg); + + // Schedule a "needs post return" + thunk1.call(&mut store, ())?; + + // Ensure that we can't reenter the instance through either this function or + // another one. + let err = thunk1.call(&mut store, ()).unwrap_err(); + assert!( + err.to_string() + .contains("cannot reenter component instance"), + "{}", + err + ); + let err = thunk2.call(&mut store, ()).unwrap_err(); + assert!( + err.to_string() + .contains("cannot reenter component instance"), + "{}", + err + ); + + // Calling post-return on the wrong function should panic + assert_panics( + || drop(thunk2.post_return(&mut store)), + "calling post_return on wrong function", + ); + + // Actually execute the post-return + thunk1.post_return(&mut store)?; + + // And now post-return should be invalid again. + assert_panics(|| drop(thunk1.post_return(&mut store)), msg); + assert_panics(|| drop(thunk2.post_return(&mut store)), msg); + + Ok(()) +} + +#[track_caller] +fn assert_panics(f: impl FnOnce(), msg: &str) { + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { + Ok(()) => panic!("expected closure to panic"), + Err(e) => match e.downcast::() { + Ok(s) => { + assert!(s.contains(msg), "bad panic: {}", s); + } + Err(e) => match e.downcast::<&'static str>() { + Ok(s) => assert!(s.contains(msg), "bad panic: {}", s), + Err(_) => panic!("bad panic"), + }, + }, + } +} + +#[test] +fn invoke_post_return() -> Result<()> { + let component = r#" + (component + (import "f" (func $f)) + + (core func $f_lower + (canon lower (func $f)) + ) + (core module $m + (import "" "" (func $f)) + + (func (export "thunk")) + + (func $post_return + call $f) + (export "post-return" (func $post_return)) + ) + (core instance $i (instantiate $m + (with "" (instance + (export "" (func $f_lower)) + )) + )) + (func (export "thunk") + (canon lift + (core func $i "thunk") + (post-return (func $i "post-return")) + ) + ) + ) + "#; + + let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, false); + let mut linker = Linker::new(&engine); + linker + .root() + .func_wrap("f", |mut store: StoreContextMut<'_, bool>| -> Result<()> { + assert!(!*store.data()); + *store.data_mut() = true; + Ok(()) + })?; + + let instance = linker.instantiate(&mut store, &component)?; + let thunk = instance.get_typed_func::<(), (), _>(&mut store, "thunk")?; + + assert!(!*store.data()); + thunk.call(&mut store, ())?; + assert!(!*store.data()); + thunk.post_return(&mut store)?; + assert!(*store.data()); + + Ok(()) +} + +#[test] +fn post_return_all_types() -> Result<()> { + let component = r#" + (component + (core module $m + (func (export "i32") (result i32) + i32.const 1) + (func (export "i64") (result i64) + i64.const 2) + (func (export "f32") (result f32) + f32.const 3) + (func (export "f64") (result f64) + f64.const 4) + + (func (export "post-i32") (param i32) + local.get 0 + i32.const 1 + i32.ne + if unreachable end) + (func (export "post-i64") (param i64) + local.get 0 + i64.const 2 + i64.ne + if unreachable end) + (func (export "post-f32") (param f32) + local.get 0 + f32.const 3 + f32.ne + if unreachable end) + (func (export "post-f64") (param f64) + local.get 0 + f64.const 4 + f64.ne + if unreachable end) + ) + (core instance $i (instantiate $m)) + (func (export "i32") (result u32) + (canon lift (core func $i "i32") (post-return (func $i "post-i32"))) + ) + (func (export "i64") (result u64) + (canon lift (core func $i "i64") (post-return (func $i "post-i64"))) + ) + (func (export "f32") (result float32) + (canon lift (core func $i "f32") (post-return (func $i "post-f32"))) + ) + (func (export "f64") (result float64) + (canon lift (core func $i "f64") (post-return (func $i "post-f64"))) + ) + ) + "#; + + let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, false); + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let i32 = instance.get_typed_func::<(), u32, _>(&mut store, "i32")?; + let i64 = instance.get_typed_func::<(), u64, _>(&mut store, "i64")?; + let f32 = instance.get_typed_func::<(), f32, _>(&mut store, "f32")?; + let f64 = instance.get_typed_func::<(), f64, _>(&mut store, "f64")?; + + assert_eq!(i32.call(&mut store, ())?, 1); + i32.post_return(&mut store)?; + + assert_eq!(i64.call(&mut store, ())?, 2); + i64.post_return(&mut store)?; + + assert_eq!(f32.call(&mut store, ())?, 3.); + f32.post_return(&mut store)?; + + assert_eq!(f64.call(&mut store, ())?, 4.); + f64.post_return(&mut store)?; + + Ok(()) +} + +#[test] +fn post_return_string() -> Result<()> { + let component = r#" + (component + (core module $m + (memory (export "memory") 1) + (func (export "get") (result i32) + (i32.store offset=0 (i32.const 8) (i32.const 100)) + (i32.store offset=4 (i32.const 8) (i32.const 11)) + i32.const 8 + ) + + (func (export "post") (param i32) + local.get 0 + i32.const 8 + i32.ne + if unreachable end) + + (data (i32.const 100) "hello world") + ) + (core instance $i (instantiate $m)) + (func (export "get") (result string) + (canon lift + (core func $i "get") + (post-return (func $i "post")) + (memory $i "memory") + ) + ) + ) + "#; + + let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, false); + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let get = instance.get_typed_func::<(), WasmStr, _>(&mut store, "get")?; + let s = get.call(&mut store, ())?; + assert_eq!(s.to_str(&store)?, "hello world"); + get.post_return(&mut store)?; + + Ok(()) +}