Enable passing host functions to components (#4219)

* Enable passing host functions to components

This commit implements the ability to pass a host function into a
component. The `wasmtime::component::Linker` type now has a `func_wrap`
method allowing it to take a host function which is exposed internally
to the component and available for lowering.

This is currently mostly a "let's get at least the bare minimum working"
implementation. That involves plumbing around lots of various bits of
the canonical ABI and getting all the previous PRs to line up in this
one to get a test where we call a function where the host takes a
string. This PR also additionally starts reading and using the
`may_{enter,leave}` flags since this is the first time they're actually
relevant.

Overall while this is the bare bones of working this is not a final spot
we should end up at. One of the major downsides is that host functions
are represented as:

    F: Fn(StoreContextMut<'_, T>, Arg1, Arg2, ...) -> Result<Return>

while this naively seems reasonable this critically doesn't allow
`Return` to actually close over any of its arguments. This means that if
you want to return a string to wasm then it has to be `String` or
`Rc<str>` or some other owned type. In the case of `String` this means
that to return a string to wasm you first have to copy it from the host
to a temporary `String` allocation, then to wasm. This extra copy for
all strings/lists is expected to be prohibitive. Unfortuantely I don't
think Rust is able to solve this, at least on stable, today.

Nevertheless I wanted to at least post this to get some feedback on it
since it's the final step in implementing host imports to see how others
feel about it.

* Fix a typo in an assertion

* Fix some typos

* Review comments
This commit is contained in:
Alex Crichton
2022-06-07 09:39:02 -05:00
committed by GitHub
parent 3f152273d3
commit 20f510671d
8 changed files with 726 additions and 24 deletions

View File

@@ -1,4 +1,4 @@
use crate::component::instance::InstanceData; use crate::component::instance::{Instance, InstanceData};
use crate::store::{StoreOpaque, Stored}; use crate::store::{StoreOpaque, Stored};
use crate::AsContext; use crate::AsContext;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
@@ -61,8 +61,10 @@ impl<T> MaybeUninitExt<T> for MaybeUninit<T> {
} }
} }
mod host;
mod options; mod options;
mod typed; mod typed;
pub use self::host::*;
pub use self::options::*; pub use self::options::*;
pub use self::typed::*; pub use self::typed::*;
@@ -79,11 +81,13 @@ pub struct FuncData {
ty: FuncTypeIndex, ty: FuncTypeIndex,
types: Arc<ComponentTypes>, types: Arc<ComponentTypes>,
options: Options, options: Options,
instance: Instance,
} }
impl Func { impl Func {
pub(crate) fn from_lifted_func( pub(crate) fn from_lifted_func(
store: &mut StoreOpaque, store: &mut StoreOpaque,
instance: &Instance,
data: &InstanceData, data: &InstanceData,
ty: FuncTypeIndex, ty: FuncTypeIndex,
func: &CoreExport<FuncIndex>, func: &CoreExport<FuncIndex>,
@@ -105,6 +109,7 @@ impl Func {
options, options,
ty, ty,
types: data.component_types().clone(), types: data.component_types().clone(),
instance: *instance,
})) }))
} }

View File

@@ -0,0 +1,343 @@
use crate::component::func::{MAX_STACK_PARAMS, MAX_STACK_RESULTS};
use crate::component::{ComponentParams, ComponentValue, Memory, MemoryMut, Op, Options};
use crate::{AsContextMut, StoreContextMut, ValRaw};
use anyhow::{bail, Context, Result};
use std::any::Any;
use std::mem::MaybeUninit;
use std::panic::{self, AssertUnwindSafe};
use std::ptr::NonNull;
use std::sync::Arc;
use wasmtime_environ::component::{ComponentTypes, FuncTypeIndex, StringEncoding};
use wasmtime_runtime::component::{VMComponentContext, VMLowering, VMLoweringCallee};
use wasmtime_runtime::{VMCallerCheckedAnyfunc, VMMemoryDefinition, VMOpaqueContext};
/// Trait representing host-defined functions that can be imported into a wasm
/// component.
///
/// For more information see the
/// [`Linker::func_wrap`](crate::component::Linker::func_wrap) documentation.
pub trait IntoComponentFunc<T, Params, Return> {
/// Host entrypoint from a cranelift-generated trampoline.
///
/// This function has type `VMLoweringCallee` and delegates to the shared
/// `call_host` function below.
#[doc(hidden)]
extern "C" fn entrypoint(
cx: *mut VMOpaqueContext,
data: *mut u8,
memory: *mut VMMemoryDefinition,
realloc: *mut VMCallerCheckedAnyfunc,
string_encoding: StringEncoding,
storage: *mut ValRaw,
storage_len: usize,
);
#[doc(hidden)]
fn into_host_func(self) -> Arc<HostFunc>;
}
pub struct HostFunc {
entrypoint: VMLoweringCallee,
typecheck: fn(FuncTypeIndex, &ComponentTypes) -> Result<()>,
func: Box<dyn Any + Send + Sync>,
}
impl HostFunc {
fn new<F, P, R>(func: F, entrypoint: VMLoweringCallee) -> Arc<HostFunc>
where
F: Send + Sync + 'static,
P: ComponentParams,
R: ComponentValue,
{
Arc::new(HostFunc {
entrypoint,
typecheck: typecheck::<P, R>,
func: Box::new(func),
})
}
pub fn typecheck(&self, ty: FuncTypeIndex, types: &ComponentTypes) -> Result<()> {
(self.typecheck)(ty, types)
}
pub fn lowering(&self) -> VMLowering {
let data = &*self.func as *const (dyn Any + Send + Sync) as *mut u8;
VMLowering {
callee: self.entrypoint,
data,
}
}
}
fn typecheck<P, R>(ty: FuncTypeIndex, types: &ComponentTypes) -> Result<()>
where
P: ComponentParams,
R: ComponentValue,
{
let ty = &types[ty];
P::typecheck(&ty.params, types, Op::Lift).context("type mismatch with parameters")?;
R::typecheck(&ty.result, types, Op::Lower).context("type mismatch with result")?;
Ok(())
}
/// The "meat" of calling a host function from wasm.
///
/// This function is delegated to from implementations of `IntoComponentFunc`
/// generated in the macro below. Most of the arguments from the `entrypoint`
/// are forwarded here except for the `data` pointer which is encapsulated in
/// the `closure` argument here.
///
/// This function is parameterized over:
///
/// * `T` - the type of store this function works with (an unsafe assertion)
/// * `Params` - the parameters to the host function, viewed as a tuple
/// * `Return` - the result of the host function
/// * `F` - the `closure` to actually receive the `Params` and return the
/// `Return`
///
/// It's expected that `F` will "un-tuple" the arguments to pass to a host
/// closure.
///
/// This function is in general `unsafe` as the validity of all the parameters
/// must be upheld. Generally that's done by ensuring this is only called from
/// the select few places it's intended to be called from.
unsafe fn call_host<T, Params, Return, F>(
cx: *mut VMOpaqueContext,
memory: *mut VMMemoryDefinition,
realloc: *mut VMCallerCheckedAnyfunc,
string_encoding: StringEncoding,
storage: &mut [ValRaw],
closure: F,
) -> Result<()>
where
Params: ComponentValue,
Return: ComponentValue,
F: FnOnce(StoreContextMut<'_, T>, Params) -> Result<Return>,
{
/// Representation of arguments to this function when a return pointer is in
/// use, namely the argument list is followed by a single value which is the
/// return pointer.
#[repr(C)]
struct ReturnPointer<T> {
args: T,
retptr: ValRaw,
}
/// Representation of arguments to this function when the return value is
/// returned directly, namely the arguments and return value all start from
/// the beginning (aka this is a `union`, not a `struct`).
#[repr(C)]
union ReturnStack<T: Copy, U: Copy> {
args: T,
ret: U,
}
let cx = VMComponentContext::from_opaque(cx);
let instance = (*cx).instance();
let may_leave = (*instance).may_leave();
let may_enter = (*instance).may_enter();
let mut cx = StoreContextMut::from_raw((*instance).store());
let options = Options::new(
cx.0.id(),
NonNull::new(memory),
NonNull::new(realloc),
string_encoding,
);
// Perform a dynamic check that this instance can indeed be left. Exiting
// the component is disallowed, for example, when the `realloc` function
// calls a canonical import.
if !*may_leave {
bail!("cannot leave component instance");
}
// While we're lifting and lowering this instance cannot be reentered, so
// unset the flag here. This is also reset back to `true` on exit.
let _reset_may_enter = unset_and_reset_on_drop(may_enter);
// There's a 2x2 matrix of whether parameters and results are stored on the
// stack or on the heap. Each of the 4 branches here have a different
// representation of the storage of arguments/returns which is represented
// by the type parameter that we pass to `cast_storage`.
//
// Also note that while four branches are listed here only one is taken for
// any particular `Params` and `Return` combination. This should be
// trivially DCE'd by LLVM. Perhaps one day with enough const programming in
// Rust we can make monomorphizations of this function codegen only one
// branch, but today is not that day.
let reset_may_leave;
if Params::flatten_count() <= MAX_STACK_PARAMS {
if Return::flatten_count() <= MAX_STACK_RESULTS {
let storage = cast_storage::<ReturnStack<Params::Lower, Return::Lower>>(storage);
let params = Params::lift(cx.0, &options, &storage.assume_init_ref().args)?;
let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else {
let storage = cast_storage::<ReturnPointer<Params::Lower>>(storage).assume_init_ref();
let params = Params::lift(cx.0, &options, &storage.args)?;
let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
ret.store(&mut memory, ptr)?;
}
} else {
let memory = Memory::new(cx.0, &options);
if Return::flatten_count() <= MAX_STACK_RESULTS {
let storage = cast_storage::<ReturnStack<ValRaw, Return::Lower>>(storage);
let ptr =
validate_inbounds::<Params>(memory.as_slice(), &storage.assume_init_ref().args)?;
let params = Params::load(&memory, &memory.as_slice()[ptr..][..Params::size()])?;
let ret = closure(cx.as_context_mut(), params)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
ret.lower(&mut cx, &options, map_maybe_uninit!(storage.ret))?;
} else {
let storage = cast_storage::<ReturnPointer<ValRaw>>(storage).assume_init_ref();
let ptr = validate_inbounds::<Params>(memory.as_slice(), &storage.args)?;
let params = Params::load(&memory, &memory.as_slice()[ptr..][..Params::size()])?;
let ret = closure(cx.as_context_mut(), params)?;
let mut memory = MemoryMut::new(cx.as_context_mut(), &options);
let ptr = validate_inbounds::<Return>(memory.as_slice_mut(), &storage.retptr)?;
reset_may_leave = unset_and_reset_on_drop(may_leave);
ret.store(&mut memory, ptr)?;
}
}
// TODO: need to call `post-return` before this `drop`
drop(reset_may_leave);
return Ok(());
unsafe fn unset_and_reset_on_drop(slot: *mut bool) -> impl Drop {
debug_assert!(*slot);
*slot = false;
return Reset(slot);
struct Reset(*mut bool);
impl Drop for Reset {
fn drop(&mut self) {
unsafe {
(*self.0) = true;
}
}
}
}
}
fn validate_inbounds<T: ComponentValue>(memory: &[u8], ptr: &ValRaw) -> Result<usize> {
// FIXME: needs memory64 support
let ptr = usize::try_from(ptr.get_u32())?;
let end = match ptr.checked_add(T::size()) {
Some(n) => n,
None => bail!("return pointer size overflow"),
};
if end > memory.len() {
bail!("return pointer out of bounds")
}
Ok(ptr)
}
unsafe fn cast_storage<T>(storage: &mut [ValRaw]) -> &mut MaybeUninit<T> {
// Assertions that LLVM can easily optimize away but are sanity checks here
assert!(std::mem::size_of::<T>() % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of::<T>() == std::mem::align_of::<ValRaw>());
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<T>());
// This is an actual runtime assertion which if performance calls for we may
// need to relax to a debug assertion. This notably tries to ensure that we
// stay within the bounds of the number of actual values given rather than
// reading past the end of an array. This shouldn't actually trip unless
// there's a bug in Wasmtime though.
assert!(std::mem::size_of_val(storage) >= std::mem::size_of::<T>());
&mut *storage.as_mut_ptr().cast()
}
unsafe fn handle_result(func: impl FnOnce() -> Result<()>) {
match panic::catch_unwind(AssertUnwindSafe(func)) {
Ok(Ok(())) => {}
Ok(Err(e)) => wasmtime_runtime::raise_user_trap(e),
Err(e) => wasmtime_runtime::resume_panic(e),
}
}
macro_rules! impl_into_component_func {
($num:tt $($args:ident)*) => {
// Implement for functions without a leading `StoreContextMut` parameter
#[allow(non_snake_case)]
impl<T, F, $($args,)* R> IntoComponentFunc<T, ($($args,)*), R> for F
where
F: Fn($($args),*) -> Result<R> + Send + Sync + 'static,
($($args,)*): ComponentParams + ComponentValue,
R: ComponentValue,
{
extern "C" fn entrypoint(
cx: *mut VMOpaqueContext,
data: *mut u8,
memory: *mut VMMemoryDefinition,
realloc: *mut VMCallerCheckedAnyfunc,
string_encoding: StringEncoding,
storage: *mut ValRaw,
storage_len: usize,
) {
let data = data as *const Self;
unsafe {
handle_result(|| call_host::<T, _, _, _>(
cx,
memory,
realloc,
string_encoding,
std::slice::from_raw_parts_mut(storage, storage_len),
|_, ($($args,)*)| (*data)($($args),*),
))
}
}
fn into_host_func(self) -> Arc<HostFunc> {
let entrypoint = <Self as IntoComponentFunc<T, ($($args,)*), R>>::entrypoint;
HostFunc::new::<_, ($($args,)*), R>(self, entrypoint)
}
}
// Implement for functions with a leading `StoreContextMut` parameter
#[allow(non_snake_case)]
impl<T, F, $($args,)* R> IntoComponentFunc<T, (StoreContextMut<'_, T>, $($args,)*), R> for F
where
F: Fn(StoreContextMut<'_, T>, $($args),*) -> Result<R> + Send + Sync + 'static,
($($args,)*): ComponentParams + ComponentValue,
R: ComponentValue,
{
extern "C" fn entrypoint(
cx: *mut VMOpaqueContext,
data: *mut u8,
memory: *mut VMMemoryDefinition,
realloc: *mut VMCallerCheckedAnyfunc,
string_encoding: StringEncoding,
storage: *mut ValRaw,
storage_len: usize,
) {
let data = data as *const Self;
unsafe {
handle_result(|| call_host::<T, _, _, _>(
cx,
memory,
realloc,
string_encoding,
std::slice::from_raw_parts_mut(storage, storage_len),
|store, ($($args,)*)| (*data)(store, $($args),*),
))
}
}
fn into_host_func(self) -> Arc<HostFunc> {
let entrypoint = <Self as IntoComponentFunc<T, (StoreContextMut<'_, T>, $($args,)*), R>>::entrypoint;
HostFunc::new::<_, ($($args,)*), R>(self, entrypoint)
}
}
}
}
for_each_function_signature!(impl_into_component_func);

View File

@@ -274,6 +274,7 @@ where
trampoline, trampoline,
export, export,
options, options,
instance,
.. ..
} = store.0[self.func.0]; } = store.0[self.func.0];
@@ -294,9 +295,21 @@ where
assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align); assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align); assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
lower(store, &options, params, map_maybe_uninit!(space.params))?; let instance = store.0[instance.0].as_ref().unwrap().instance();
let may_enter = instance.may_enter();
let may_leave = instance.may_leave();
unsafe { unsafe {
if !*may_enter {
bail!("cannot reenter component instance");
}
debug_assert!(*may_leave);
*may_leave = false;
let result = lower(store, &options, params, map_maybe_uninit!(space.params));
*may_leave = true;
result?;
// This is unsafe as we are providing the guarantee that all the // This is unsafe as we are providing the guarantee that all the
// inputs are valid. The various pointers passed in for the function // inputs are valid. The various pointers passed in for the function
// are all valid since they're coming from our store, and the // are all valid since they're coming from our store, and the
@@ -319,11 +332,17 @@ where
// `[ValRaw]`, and additionally they should have the correct types // `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return // for the function we just called (which filled in the return
// values). // values).
lift( *may_enter = false;
let result = lift(
store.0, store.0,
&options, &options,
map_maybe_uninit!(space.ret).assume_init_ref(), map_maybe_uninit!(space.ret).assume_init_ref(),
) );
// TODO: this technically needs to happen only after the
// `post-return` is called.
*may_enter = true;
return result;
} }
} }
} }

View File

@@ -1,3 +1,4 @@
use crate::component::func::HostFunc;
use crate::component::{Component, ComponentParams, ComponentValue, Func, TypedFunc}; use crate::component::{Component, ComponentParams, ComponentValue, Func, TypedFunc};
use crate::instance::OwnedImports; use crate::instance::OwnedImports;
use crate::store::{StoreOpaque, Stored}; use crate::store::{StoreOpaque, Stored};
@@ -21,17 +22,24 @@ use wasmtime_runtime::component::{ComponentInstance, OwnedComponentInstance};
// //
// FIXME: need to write more docs here. // FIXME: need to write more docs here.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct Instance(Stored<Option<Box<InstanceData>>>); pub struct Instance(pub(crate) Stored<Option<Box<InstanceData>>>);
pub(crate) struct InstanceData { pub(crate) struct InstanceData {
instances: PrimaryMap<RuntimeInstanceIndex, crate::Instance>, instances: PrimaryMap<RuntimeInstanceIndex, crate::Instance>,
// FIXME: shouldn't store the entire component here which keeps upvars // FIXME: shouldn't store the entire component here which keeps upvars
// alive and things like that, instead only the bare minimum necessary // alive and things like that, instead only the bare minimum necessary
// should be kept alive here (mostly just `wasmtime_environ::Component`. // should be kept alive here (mostly just `wasmtime_environ::Component`).
component: Component, component: Component,
exported_modules: PrimaryMap<RuntimeModuleIndex, Module>, exported_modules: PrimaryMap<RuntimeModuleIndex, Module>,
state: OwnedComponentInstance, state: OwnedComponentInstance,
/// Functions that this instance used during instantiation.
///
/// Strong references are stored to these functions since pointers are saved
/// into the functions within the `OwnedComponentInstance` but it's our job
/// to keep them alive.
funcs: Vec<Arc<HostFunc>>,
} }
impl Instance { impl Instance {
@@ -56,7 +64,7 @@ impl Instance {
// By moving it out we appease the borrow-checker but take a runtime // By moving it out we appease the borrow-checker but take a runtime
// hit. // hit.
let data = store[self.0].take().unwrap(); let data = store[self.0].take().unwrap();
let result = data.get_func(store, name); let result = data.get_func(store, self, name);
store[self.0] = Some(data); store[self.0] = Some(data);
return result; return result;
} }
@@ -126,18 +134,22 @@ impl Instance {
} }
impl InstanceData { impl InstanceData {
fn get_func(&self, store: &mut StoreOpaque, name: &str) -> Option<Func> { fn get_func(&self, store: &mut StoreOpaque, instance: &Instance, name: &str) -> Option<Func> {
match self.component.env_component().exports.get(name)? { match self.component.env_component().exports.get(name)? {
Export::LiftedFunction { ty, func, options } => { Export::LiftedFunction { ty, func, options } => Some(Func::from_lifted_func(
Some(Func::from_lifted_func(store, self, *ty, func, options)) store, instance, self, *ty, func, options,
} )),
Export::Module(_) => None, Export::Module(_) => None,
} }
} }
fn lookup_def(&self, store: &mut StoreOpaque, item: &CoreDef) -> wasmtime_runtime::Export { fn lookup_def(&self, store: &mut StoreOpaque, def: &CoreDef) -> wasmtime_runtime::Export {
match item { match def {
CoreDef::Lowered(_) => unimplemented!(), CoreDef::Lowered(idx) => {
wasmtime_runtime::Export::Function(wasmtime_runtime::ExportFunction {
anyfunc: self.state.lowering_anyfunc(*idx),
})
}
CoreDef::Export(e) => self.lookup_export(store, e), CoreDef::Export(e) => self.lookup_export(store, e),
} }
} }
@@ -177,6 +189,7 @@ struct Instantiator<'a> {
} }
pub enum RuntimeImport { pub enum RuntimeImport {
Func(Arc<HostFunc>),
Module(Module), Module(Module),
} }
@@ -198,6 +211,7 @@ impl<'a> Instantiator<'a> {
env_component.num_runtime_modules as usize, env_component.num_runtime_modules as usize,
), ),
state: OwnedComponentInstance::new(env_component, store.traitobj()), state: OwnedComponentInstance::new(env_component, store.traitobj()),
funcs: Vec::new(),
}, },
} }
} }
@@ -215,6 +229,7 @@ impl<'a> Instantiator<'a> {
module = self.component.upvar(*idx); module = self.component.upvar(*idx);
self.build_imports(store.0, module, args.iter()) self.build_imports(store.0, module, args.iter())
} }
// With imports, unlike upvars, we need to do runtime // With imports, unlike upvars, we need to do runtime
// lookups with strings to determine the order of the // lookups with strings to determine the order of the
// imports since it's whatever the actual module // imports since it's whatever the actual module
@@ -222,6 +237,7 @@ impl<'a> Instantiator<'a> {
InstantiateModule::Import(idx, args) => { InstantiateModule::Import(idx, args) => {
module = match &self.imports[*idx] { module = match &self.imports[*idx] {
RuntimeImport::Module(m) => m, RuntimeImport::Module(m) => m,
_ => unreachable!(),
}; };
let args = module let args = module
.imports() .imports()
@@ -238,6 +254,7 @@ impl<'a> Instantiator<'a> {
unsafe { crate::Instance::new_started(store, module, imports.as_ref())? }; unsafe { crate::Instance::new_started(store, module, imports.as_ref())? };
self.data.instances.push(i); self.data.instances.push(i);
} }
Initializer::LowerImport(import) => self.lower_import(import), Initializer::LowerImport(import) => self.lower_import(import),
Initializer::ExtractMemory { index, export } => { Initializer::ExtractMemory { index, export } => {
@@ -257,6 +274,7 @@ impl<'a> Instantiator<'a> {
Initializer::SaveModuleImport(idx) => { Initializer::SaveModuleImport(idx) => {
self.data.exported_modules.push(match &self.imports[*idx] { self.data.exported_modules.push(match &self.imports[*idx] {
RuntimeImport::Module(m) => m.clone(), RuntimeImport::Module(m) => m.clone(),
_ => unreachable!(),
}); });
} }
} }
@@ -265,14 +283,28 @@ impl<'a> Instantiator<'a> {
} }
fn lower_import(&mut self, import: &LowerImport) { fn lower_import(&mut self, import: &LowerImport) {
drop(self.component.trampoline_ptr(import.index)); let func = match &self.imports[import.import] {
drop( RuntimeImport::Func(func) => func,
_ => unreachable!(),
};
self.data.state.set_lowering(
import.index,
func.lowering(),
self.component.trampoline_ptr(import.index),
self.component self.component
.signatures() .signatures()
.shared_signature(import.canonical_abi) .shared_signature(import.canonical_abi)
.unwrap(), .expect("found unregistered signature"),
); );
unimplemented!()
// The `func` provided here must be retained within the `Store` itself
// after instantiation. Otherwise it might be possible to drop the
// `Arc<HostFunc>` and possibly result in a use-after-free. This comes
// about because the `.lowering()` method returns a structure that
// points to an interior pointer within the `func`. By saving the list
// of host functions used we can ensure that the function lives long
// enough for the whole duration of this instance.
self.data.funcs.push(func.clone());
} }
fn extract_memory( fn extract_memory(

View File

@@ -1,6 +1,7 @@
use crate::component::func::HostFunc;
use crate::component::instance::RuntimeImport; use crate::component::instance::RuntimeImport;
use crate::component::matching::TypeChecker; use crate::component::matching::TypeChecker;
use crate::component::{Component, Instance, InstancePre}; use crate::component::{Component, Instance, InstancePre, IntoComponentFunc};
use crate::{AsContextMut, Engine, Module}; use crate::{AsContextMut, Engine, Module};
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use std::collections::hash_map::{Entry, HashMap}; use std::collections::hash_map::{Entry, HashMap};
@@ -45,6 +46,7 @@ pub type NameMap = HashMap<usize, Definition>;
#[derive(Clone)] #[derive(Clone)]
pub enum Definition { pub enum Definition {
Instance(NameMap), Instance(NameMap),
Func(Arc<HostFunc>),
Module(Module), Module(Module),
} }
@@ -155,6 +157,7 @@ impl<T> Linker<T> {
} }
let import = match cur { let import = match cur {
Definition::Module(m) => RuntimeImport::Module(m.clone()), Definition::Module(m) => RuntimeImport::Module(m.clone()),
Definition::Func(f) => RuntimeImport::Func(f.clone()),
// This is guaranteed by the compilation process that "leaf" // This is guaranteed by the compilation process that "leaf"
// runtime imports are never instances. // runtime imports are never instances.
@@ -197,6 +200,36 @@ impl<T> LinkerInstance<'_, T> {
} }
} }
/// Defines a new host-provided function into this [`Linker`].
///
/// This method is used to give host functions to wasm components. The
/// `func` provided will be callable from linked components with the type
/// signature dictated by `Params` and `Return`. The `Params` is a tuple of
/// types that will come from wasm and `Return` is a value coming from the
/// host going back to wasm.
///
/// The [`IntoComponentFunc`] trait is implemented for functions whose
/// arguments and return values implement the
/// [`ComponentValue`](crate::component::ComponentValue) trait. Additionally
/// the `func` may take a [`StoreContextMut`](crate::StoreContextMut) as its
/// first parameter.
///
/// Note that `func` must be an `Fn` and must also be `Send + Sync +
/// 'static`. Shared state within a func is typically accesed with the `T`
/// type parameter from [`Store<T>`](crate::Store) which is accessible
/// through the leading [`StoreContextMut<'_, T>`](crate::StoreContextMut)
/// argument which can be provided to the `func` given here.
//
// TODO: needs more words and examples
pub fn func_wrap<Params, Return>(
&mut self,
name: &str,
func: impl IntoComponentFunc<T, Params, Return>,
) -> Result<()> {
let name = self.strings.intern(name);
self.insert(name, Definition::Func(func.into_host_func()))
}
/// Defines a [`Module`] within this instance. /// Defines a [`Module`] within this instance.
/// ///
/// This can be used to provide a core wasm [`Module`] as an import to a /// This can be used to provide a core wasm [`Module`] as an import to a

View File

@@ -1,8 +1,11 @@
use crate::component::func::HostFunc;
use crate::component::linker::{Definition, NameMap, Strings}; use crate::component::linker::{Definition, NameMap, Strings};
use crate::types::matching; use crate::types::matching;
use crate::Module; use crate::Module;
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use wasmtime_environ::component::{ComponentInstanceType, ComponentTypes, ModuleType, TypeDef}; use wasmtime_environ::component::{
ComponentInstanceType, ComponentTypes, FuncTypeIndex, ModuleType, TypeDef,
};
pub struct TypeChecker<'a> { pub struct TypeChecker<'a> {
pub types: &'a ComponentTypes, pub types: &'a ComponentTypes,
@@ -20,7 +23,10 @@ impl TypeChecker<'_> {
Definition::Instance(actual) => self.instance(&self.types[t], actual), Definition::Instance(actual) => self.instance(&self.types[t], actual),
_ => bail!("expected instance found {}", actual.desc()), _ => bail!("expected instance found {}", actual.desc()),
}, },
TypeDef::Func(_) => bail!("expected func found {}", actual.desc()), TypeDef::Func(t) => match actual {
Definition::Func(actual) => self.func(t, actual),
_ => bail!("expected func found {}", actual.desc()),
},
TypeDef::Component(_) => bail!("expected component found {}", actual.desc()), TypeDef::Component(_) => bail!("expected component found {}", actual.desc()),
TypeDef::Interface(_) => bail!("expected type found {}", actual.desc()), TypeDef::Interface(_) => bail!("expected type found {}", actual.desc()),
} }
@@ -72,12 +78,17 @@ impl TypeChecker<'_> {
} }
Ok(()) Ok(())
} }
fn func(&self, expected: FuncTypeIndex, actual: &HostFunc) -> Result<()> {
actual.typecheck(expected, self.types)
}
} }
impl Definition { impl Definition {
fn desc(&self) -> &'static str { fn desc(&self) -> &'static str {
match self { match self {
Definition::Module(_) => "module", Definition::Module(_) => "module",
Definition::Func(_) => "func",
Definition::Instance(_) => "instance", Definition::Instance(_) => "instance",
} }
} }

View File

@@ -10,7 +10,9 @@ mod linker;
mod matching; mod matching;
mod store; mod store;
pub use self::component::Component; pub use self::component::Component;
pub use self::func::{ComponentParams, ComponentValue, Func, Op, TypedFunc, WasmList, WasmStr}; pub use self::func::{
ComponentParams, ComponentValue, Func, IntoComponentFunc, Op, TypedFunc, WasmList, WasmStr,
};
pub use self::instance::{Instance, InstancePre}; pub use self::instance::{Instance, InstancePre};
pub use self::linker::Linker; pub use self::linker::Linker;
@@ -18,6 +20,9 @@ pub use self::linker::Linker;
// `#[derive(ComponentValue)]`, they are not part of Wasmtime's API stability // `#[derive(ComponentValue)]`, they are not part of Wasmtime's API stability
// guarantees // guarantees
#[doc(hidden)] #[doc(hidden)]
pub use {self::func::Memory, wasmtime_environ}; pub use {
self::func::{Memory, MemoryMut, Options},
wasmtime_environ,
};
pub(crate) use self::store::ComponentStoreData; pub(crate) use self::store::ComponentStoreData;

View File

@@ -1,5 +1,6 @@
use anyhow::Result; use anyhow::Result;
use wasmtime::component::Component; use wasmtime::component::*;
use wasmtime::{Store, StoreContextMut, Trap};
#[test] #[test]
fn can_compile() -> Result<()> { fn can_compile() -> Result<()> {
@@ -78,3 +79,256 @@ fn can_compile() -> Result<()> {
)?; )?;
Ok(()) Ok(())
} }
#[test]
fn simple() -> Result<()> {
let component = r#"
(component
(import "" (func $log (param string)))
(module $libc
(memory (export "memory") 1)
(func (export "canonical_abi_realloc") (param i32 i32 i32 i32) (result i32)
unreachable)
(func (export "canonical_abi_free") (param i32 i32 i32)
unreachable)
)
(instance $libc (instantiate (module $libc)))
(func $log_lower
(canon.lower (into $libc) (func $log))
)
(module $m
(import "libc" "memory" (memory 1))
(import "host" "log" (func $log (param i32 i32)))
(func (export "call")
i32.const 5
i32.const 11
call $log)
(data (i32.const 5) "hello world")
)
(instance $i (instantiate (module $m)
(with "libc" (instance $libc))
(with "host" (instance (export "log" (func $log_lower))))
))
(func (export "call")
(canon.lift (func) (func $i "call"))
)
)
"#;
let engine = super::engine();
let mut linker = Linker::new(&engine);
linker.root().func_wrap(
"",
|mut store: StoreContextMut<'_, Option<String>>, arg: WasmStr| -> Result<_> {
let s = arg.to_str(&store)?.to_string();
assert!(store.data().is_none());
*store.data_mut() = Some(s);
Ok(())
},
)?;
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, None);
let instance = linker.instantiate(&mut store, &component)?;
assert!(store.data().is_none());
instance
.get_typed_func::<(), (), _>(&mut store, "call")?
.call(&mut store, ())?;
assert_eq!(store.data().as_ref().unwrap(), "hello world");
Ok(())
}
#[test]
fn attempt_to_leave_during_malloc() -> Result<()> {
let component = r#"
(component
(import "thunk" (func $thunk))
(import "ret-string" (func $ret_string (result string)))
(module $host_shim
(table (export "table") 2 funcref)
(func $shim_thunk (export "thunk")
i32.const 0
call_indirect)
(func $shim_ret_string (export "ret-string") (param i32)
local.get 0
i32.const 1
call_indirect (param i32))
)
(instance $host_shim (instantiate (module $host_shim)))
(module $m
(import "host" "thunk" (func $thunk))
(import "host" "ret-string" (func $ret_string (param i32)))
(memory (export "memory") 1)
(func $realloc (export "canonical_abi_realloc") (param i32 i32 i32 i32) (result i32)
call $thunk
unreachable)
(func (export "canonical_abi_free") (param i32 i32 i32)
unreachable)
(func $run (export "run")
i32.const 8
call $ret_string)
(func (export "take-string") (param i32 i32)
unreachable)
)
(instance $m (instantiate (module $m) (with "host" (instance $host_shim))))
(module $host_shim_filler_inner
(import "shim" "table" (table 2 funcref))
(import "host" "thunk" (func $thunk))
(import "host" "ret-string" (func $ret_string (param i32)))
(elem (i32.const 0) $thunk $ret_string)
)
(func $thunk_lower
(canon.lower (into $m) (func $thunk))
)
(func $ret_string_lower
(canon.lower (into $m) (func $ret_string))
)
(instance (instantiate (module $host_shim_filler_inner)
(with "shim" (instance $host_shim))
(with "host" (instance
(export "thunk" (func $thunk_lower))
(export "ret-string" (func $ret_string_lower))
))
))
(func (export "run")
(canon.lift (func) (func $m "run"))
)
(func (export "take-string")
(canon.lift (func (param string)) (into $m) (func $m "take-string"))
)
)
"#;
let engine = super::engine();
let mut linker = Linker::new(&engine);
linker
.root()
.func_wrap("thunk", || -> Result<()> { panic!("should not get here") })?;
linker
.root()
.func_wrap("ret-string", || -> Result<String> {
Ok("hello".to_string())
})?;
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, ());
let instance = linker.instantiate(&mut store, &component)?;
// Assert that during a host import if we return values to wasm that a trap
// happens if we try to leave the instance.
let trap = instance
.get_typed_func::<(), (), _>(&mut store, "run")?
.call(&mut store, ())
.unwrap_err()
.downcast::<Trap>()?;
assert!(
trap.to_string().contains("cannot leave component instance"),
"bad trap: {}",
trap,
);
let trace = trap.trace().unwrap();
assert_eq!(trace.len(), 4);
// This was our entry point...
assert_eq!(trace[3].module_name(), Some("m"));
assert_eq!(trace[3].func_name(), Some("run"));
// ... which called an imported function which ends up being originally
// defined by the shim instance. The shim instance then does an indirect
// call through a table which goes to the `canon.lower`'d host function
assert_eq!(trace[2].module_name(), Some("host_shim"));
assert_eq!(trace[2].func_name(), Some("shim_ret_string"));
// ... and the lowered host function will call realloc to allocate space for
// the result
assert_eq!(trace[1].module_name(), Some("m"));
assert_eq!(trace[1].func_name(), Some("realloc"));
// ... but realloc calls the shim instance and tries to exit the
// component, triggering a dynamic trap
assert_eq!(trace[0].module_name(), Some("host_shim"));
assert_eq!(trace[0].func_name(), Some("shim_thunk"));
// In addition to the above trap also ensure that when we enter a wasm
// component if we try to leave while lowering then that's also a dynamic
// trap.
let trap = instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string")?
.call(&mut store, ("x",))
.unwrap_err()
.downcast::<Trap>()?;
assert!(
trap.to_string().contains("cannot leave component instance"),
"bad trap: {}",
trap,
);
Ok(())
}
#[test]
fn attempt_to_reenter_during_host() -> Result<()> {
let component = r#"
(component
(import "thunk" (func $thunk))
(func $thunk_lower (canon.lower (func $thunk)))
(module $m
(import "host" "thunk" (func $thunk))
(func $run (export "run")
call $thunk)
)
(instance $m (instantiate (module $m)
(with "host" (instance (export "thunk" (func $thunk_lower))))
))
(func (export "run")
(canon.lift (func) (func $m "run"))
)
)
"#;
struct State {
func: Option<TypedFunc<(), ()>>,
}
let engine = super::engine();
let mut linker = Linker::new(&engine);
linker.root().func_wrap(
"thunk",
|mut store: StoreContextMut<'_, State>| -> Result<()> {
let func = store.data_mut().func.take().unwrap();
let trap = func.call(&mut store, ()).unwrap_err();
assert!(
trap.to_string()
.contains("cannot reenter component instance"),
"bad trap: {}",
trap,
);
Ok(())
},
)?;
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, State { func: None });
let instance = linker.instantiate(&mut store, &component)?;
let func = instance.get_typed_func::<(), (), _>(&mut store, "run")?;
store.data_mut().func = Some(func);
func.call(&mut store, ())?;
Ok(())
}