wasmtime_wiggle: support for async, and add an integration test

This commit is contained in:
Pat Hickey
2021-03-04 17:27:44 -08:00
parent c4d8e2323a
commit ff59797ad0
6 changed files with 331 additions and 34 deletions

View File

@@ -17,10 +17,22 @@ witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx", optional =
wiggle = { path = "..", version = "0.23.0" } wiggle = { path = "..", version = "0.23.0" }
wiggle-borrow = { path = "../borrow", version = "0.23.0" } wiggle-borrow = { path = "../borrow", version = "0.23.0" }
[dev-dependencies]
anyhow = "1"
proptest = "0.10"
[[test]]
name = "atoms_async"
path = "tests/atoms_async.rs"
required-features = ["async", "wasmtime/wat"]
[badges] [badges]
maintenance = { status = "actively-developed" } maintenance = { status = "actively-developed" }
[features] [features]
# Async support for wasmtime
async = [ 'wasmtime/async', 'wasmtime-wiggle-macro/async' ]
# The wiggle proc-macro emits some code (inside `pub mod metadata`) guarded # The wiggle proc-macro emits some code (inside `pub mod metadata`) guarded
# by the `wiggle_metadata` feature flag. We use this feature flag so that # by the `wiggle_metadata` feature flag. We use this feature flag so that
# users of wiggle are not forced to take a direct dependency on the `witx` # users of wiggle are not forced to take a direct dependency on the `witx`
@@ -33,4 +45,4 @@ wiggle_metadata = ['witx', "wiggle/wiggle_metadata"]
# the logs out of wiggle-generated libraries. # the logs out of wiggle-generated libraries.
tracing_log = [ "wiggle/tracing_log" ] tracing_log = [ "wiggle/tracing_log" ]
default = ["wiggle_metadata" ] default = ["wiggle_metadata", "async"]

View File

@@ -24,3 +24,6 @@ proc-macro2 = "1.0"
[badges] [badges]
maintenance = { status = "actively-developed" } maintenance = { status = "actively-developed" }
[features]
async = []
default = []

View File

@@ -1,3 +1,4 @@
pub use wiggle_generate::config::AsyncConf;
use { use {
proc_macro2::Span, proc_macro2::Span,
std::collections::HashMap, std::collections::HashMap,
@@ -7,15 +8,16 @@ use {
punctuated::Punctuated, punctuated::Punctuated,
Error, Ident, Path, Result, Token, Error, Ident, Path, Result, Token,
}, },
wiggle_generate::config::{CtxConf, WitxConf}, wiggle_generate::config::WitxConf,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Config { pub struct Config {
pub target: TargetConf, pub target: TargetConf,
pub witx: WitxConf, pub witx: WitxConf,
pub ctx: CtxConf, pub ctx: CtxConf,
pub modules: ModulesConf, pub modules: ModulesConf,
#[cfg(feature = "async")]
pub async_: AsyncConf,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -24,6 +26,8 @@ pub enum ConfigField {
Witx(WitxConf), Witx(WitxConf),
Ctx(CtxConf), Ctx(CtxConf),
Modules(ModulesConf), Modules(ModulesConf),
#[cfg(feature = "async")]
Async(AsyncConf),
} }
mod kw { mod kw {
@@ -35,6 +39,7 @@ mod kw {
syn::custom_keyword!(name); syn::custom_keyword!(name);
syn::custom_keyword!(docs); syn::custom_keyword!(docs);
syn::custom_keyword!(function_override); syn::custom_keyword!(function_override);
syn::custom_keyword!(async_);
} }
impl Parse for ConfigField { impl Parse for ConfigField {
@@ -60,6 +65,20 @@ impl Parse for ConfigField {
input.parse::<kw::modules>()?; input.parse::<kw::modules>()?;
input.parse::<Token![:]>()?; input.parse::<Token![:]>()?;
Ok(ConfigField::Modules(input.parse()?)) Ok(ConfigField::Modules(input.parse()?))
} else if lookahead.peek(kw::async_) {
input.parse::<kw::async_>()?;
input.parse::<Token![:]>()?;
#[cfg(feature = "async")]
{
Ok(ConfigField::Async(input.parse()?))
}
#[cfg(not(feature = "async"))]
{
Err(syn::Error::new(
input.span(),
"async_ not supported, enable cargo feature \"async\"",
))
}
} else { } else {
Err(lookahead.error()) Err(lookahead.error())
} }
@@ -72,6 +91,8 @@ impl Config {
let mut witx = None; let mut witx = None;
let mut ctx = None; let mut ctx = None;
let mut modules = None; let mut modules = None;
#[cfg(feature = "async")]
let mut async_ = None;
for f in fields { for f in fields {
match f { match f {
ConfigField::Target(c) => { ConfigField::Target(c) => {
@@ -98,6 +119,13 @@ impl Config {
} }
modules = Some(c); modules = Some(c);
} }
#[cfg(feature = "async")]
ConfigField::Async(c) => {
if async_.is_some() {
return Err(Error::new(err_loc, "duplicate `async_` field"));
}
async_ = Some(c);
}
} }
} }
Ok(Config { Ok(Config {
@@ -105,6 +133,8 @@ impl Config {
witx: witx.ok_or_else(|| Error::new(err_loc, "`witx` field required"))?, witx: witx.ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
ctx: ctx.ok_or_else(|| Error::new(err_loc, "`ctx` field required"))?, ctx: ctx.ok_or_else(|| Error::new(err_loc, "`ctx` field required"))?,
modules: modules.ok_or_else(|| Error::new(err_loc, "`modules` field required"))?, modules: modules.ok_or_else(|| Error::new(err_loc, "`modules` field required"))?,
#[cfg(feature = "async")]
async_: async_.unwrap_or_default(),
}) })
} }
@@ -128,6 +158,19 @@ impl Parse for Config {
} }
} }
#[derive(Debug, Clone)]
pub struct CtxConf {
pub name: syn::Type,
}
impl Parse for CtxConf {
fn parse(input: ParseStream) -> Result<Self> {
Ok(CtxConf {
name: input.parse()?,
})
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TargetConf { pub struct TargetConf {
pub path: Path, pub path: Path,

View File

@@ -6,7 +6,7 @@ use wiggle_generate::Names;
mod config; mod config;
use config::{ModuleConf, TargetConf}; use config::{AsyncConf, ModuleConf, TargetConf};
/// Define the structs required to integrate a Wiggle implementation with Wasmtime. /// Define the structs required to integrate a Wiggle implementation with Wasmtime.
/// ///
@@ -46,13 +46,25 @@ use config::{ModuleConf, TargetConf};
pub fn wasmtime_integration(args: TokenStream) -> TokenStream { pub fn wasmtime_integration(args: TokenStream) -> TokenStream {
let config = parse_macro_input!(args as config::Config); let config = parse_macro_input!(args as config::Config);
let doc = config.load_document(); let doc = config.load_document();
let names = Names::new(&config.ctx.name, quote!(wasmtime_wiggle)); let names = Names::new(quote!(wasmtime_wiggle));
#[cfg(feature = "async")]
let async_config = config.async_.clone();
#[cfg(not(feature = "async"))]
let async_config = AsyncConf::default();
let modules = config.modules.iter().map(|(name, module_conf)| { let modules = config.modules.iter().map(|(name, module_conf)| {
let module = doc let module = doc
.module(&witx::Id::new(name)) .module(&witx::Id::new(name))
.unwrap_or_else(|| panic!("witx document did not contain module named '{}'", name)); .unwrap_or_else(|| panic!("witx document did not contain module named '{}'", name));
generate_module(&module, &module_conf, &names, &config.target) generate_module(
&module,
&module_conf,
&names,
&config.target,
&config.ctx.name,
&async_config,
)
}); });
quote!( #(#modules)* ).into() quote!( #(#modules)* ).into()
} }
@@ -62,6 +74,8 @@ fn generate_module(
module_conf: &ModuleConf, module_conf: &ModuleConf,
names: &Names, names: &Names,
target_conf: &TargetConf, target_conf: &TargetConf,
ctx_type: &syn::Type,
async_conf: &AsyncConf,
) -> TokenStream2 { ) -> TokenStream2 {
let fields = module.funcs().map(|f| { let fields = module.funcs().map(|f| {
let name_ident = names.func(&f.name); let name_ident = names.func(&f.name);
@@ -88,9 +102,14 @@ fn generate_module(
let module_id = names.module(&module.name); let module_id = names.module(&module.name);
let target_module = quote! { #target_path::#module_id }; let target_module = quote! { #target_path::#module_id };
let ctor_externs = module let ctor_externs = module.funcs().map(|f| {
.funcs() generate_func(
.map(|f| generate_func(&f, names, &target_module)); &f,
names,
&target_module,
async_conf.is_async(module.name.as_str(), f.name.as_str()),
)
});
let type_name = module_conf.name.clone(); let type_name = module_conf.name.clone();
let type_docs = module_conf let type_docs = module_conf
@@ -107,8 +126,6 @@ contained in the `cx` parameter.",
module_conf.name.to_string() module_conf.name.to_string()
); );
let ctx_type = names.ctx_type();
quote! { quote! {
#type_docs #type_docs
pub struct #type_name { pub struct #type_name {
@@ -150,6 +167,7 @@ fn generate_func(
func: &witx::InterfaceFunc, func: &witx::InterfaceFunc,
names: &Names, names: &Names,
target_module: &TokenStream2, target_module: &TokenStream2,
is_async: bool,
) -> TokenStream2 { ) -> TokenStream2 {
let name_ident = names.func(&func.name); let name_ident = names.func(&func.name);
@@ -172,31 +190,52 @@ fn generate_func(
let runtime = names.runtime_mod(); let runtime = names.runtime_mod();
quote! { let await_ = if is_async { quote!(.await) } else { quote!() };
let my_cx = cx.clone();
let #name_ident = wasmtime::Func::wrap( let closure_body = quote! {
store, unsafe {
move |caller: wasmtime::Caller<'_> #(,#arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> { let mem = match caller.get_export("memory") {
unsafe { Some(wasmtime::Extern::Memory(m)) => m,
let mem = match caller.get_export("memory") { _ => {
Some(wasmtime::Extern::Memory(m)) => m, return Err(wasmtime::Trap::new("missing required memory export"));
_ => {
return Err(wasmtime::Trap::new("missing required memory export"));
}
};
let mem = #runtime::WasmtimeGuestMemory::new(mem);
let result = #target_module::#name_ident(
&mut my_cx.borrow_mut(),
&mem,
#(#arg_names),*
);
match result {
Ok(r) => Ok(r.into()),
Err(wasmtime_wiggle::Trap::String(err)) => Err(wasmtime::Trap::new(err)),
Err(wasmtime_wiggle::Trap::I32Exit(err)) => Err(wasmtime::Trap::i32_exit(err)),
}
} }
};
let mem = #runtime::WasmtimeGuestMemory::new(mem);
let result = #target_module::#name_ident(
&mut *my_cx.borrow_mut(),
&mem,
#(#arg_names),*
) #await_;
match result {
Ok(r) => Ok(r.into()),
Err(wasmtime_wiggle::Trap::String(err)) => Err(wasmtime::Trap::new(err)),
Err(wasmtime_wiggle::Trap::I32Exit(err)) => Err(wasmtime::Trap::i32_exit(err)),
}
}
};
if is_async {
let wrapper = quote::format_ident!("wrap{}_async", params.len());
quote! {
let #name_ident = wasmtime::Func::#wrapper(
store,
cx.clone(),
move |caller: wasmtime::Caller<'_>, my_cx: &Rc<RefCell<_>> #(,#arg_decls)*|
-> Box<dyn std::future::Future<Output = Result<#ret_ty, wasmtime::Trap>>>
{
Box::new(async move { #closure_body })
} }
); );
}
} else {
quote! {
let my_cx = cx.clone();
let #name_ident = wasmtime::Func::wrap(
store,
move |caller: wasmtime::Caller<'_> #(,#arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
#closure_body
}
);
}
} }
} }

View File

@@ -0,0 +1,25 @@
(typename $errno
(enum (@witx tag u32)
;;; Success
$ok
;;; Invalid argument
$invalid_arg
;;; I really don't want to
$dont_want_to
;;; I am physically unable to
$physically_unable
;;; Well, that's a picket line alright!
$picket_line))
(typename $alias_to_float f32)
(module $atoms
(@interface func (export "int_float_args")
(param $an_int u32)
(param $an_float f32)
(result $error (expected (error $errno))))
(@interface func (export "double_int_return_float")
(param $an_int u32)
(result $error (expected $alias_to_float (error $errno))))
)

View File

@@ -0,0 +1,175 @@
use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
wasmtime_wiggle::from_witx!({
witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"],
async_: {
atoms::{double_int_return_float}
}
});
wasmtime_wiggle::wasmtime_integration!({
target: crate,
witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"],
ctx: Ctx,
modules: { atoms => { name: Atoms } },
async_: {
atoms::double_int_return_float
}
});
pub struct Ctx;
impl wiggle::GuestErrorType for types::Errno {
fn success() -> Self {
types::Errno::Ok
}
}
#[wasmtime_wiggle::async_trait(?Send)]
impl atoms::Atoms for Ctx {
fn int_float_args(&self, an_int: u32, an_float: f32) -> Result<(), types::Errno> {
println!("INT FLOAT ARGS: {} {}", an_int, an_float);
Ok(())
}
async fn double_int_return_float(
&self,
an_int: u32,
) -> Result<types::AliasToFloat, types::Errno> {
Ok((an_int as f32) * 2.0)
}
}
#[test]
fn test_sync_host_func() {
let store = async_store();
let ctx = Rc::new(RefCell::new(Ctx));
let atoms = Atoms::new(&store, ctx.clone());
let shim_mod = shim_module(&store);
let mut linker = wasmtime::Linker::new(&store);
atoms.add_to_linker(&mut linker).unwrap();
let shim_inst = run(linker.instantiate_async(&shim_mod)).unwrap();
let results = run(shim_inst
.get_func("int_float_args_shim")
.unwrap()
.call_async(&[0i32.into(), 123.45f32.into()]))
.unwrap();
assert_eq!(results.len(), 1, "one return value");
assert_eq!(
results[0].unwrap_i32(),
types::Errno::Ok as i32,
"int_float_args errno"
);
}
#[test]
fn test_async_host_func() {
let store = async_store();
let ctx = Rc::new(RefCell::new(Ctx));
let atoms = Atoms::new(&store, ctx.clone());
let shim_mod = shim_module(&store);
let mut linker = wasmtime::Linker::new(&store);
atoms.add_to_linker(&mut linker).unwrap();
let shim_inst = run(linker.instantiate_async(&shim_mod)).unwrap();
let input: i32 = 123;
let result_location: i32 = 0;
let results = run(shim_inst
.get_func("double_int_return_float_shim")
.unwrap()
.call_async(&[input.into(), result_location.into()]))
.unwrap();
assert_eq!(results.len(), 1, "one return value");
assert_eq!(
results[0].unwrap_i32(),
types::Errno::Ok as i32,
"double_int_return_float errno"
);
// The actual result is in memory:
let mem = shim_inst.get_memory("memory").unwrap();
let mut result_bytes: [u8; 4] = [0, 0, 0, 0];
mem.read(result_location as usize, &mut result_bytes)
.unwrap();
let result = f32::from_le_bytes(result_bytes);
assert_eq!((input * 2) as f32, result);
}
fn run<F: Future>(future: F) -> F::Output {
let mut f = Pin::from(Box::new(future));
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
loop {
match f.as_mut().poll(&mut cx) {
Poll::Ready(val) => break val,
Poll::Pending => {}
}
}
}
fn dummy_waker() -> Waker {
return unsafe { Waker::from_raw(clone(5 as *const _)) };
unsafe fn clone(ptr: *const ()) -> RawWaker {
assert_eq!(ptr as usize, 5);
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
RawWaker::new(ptr, &VTABLE)
}
unsafe fn wake(ptr: *const ()) {
assert_eq!(ptr as usize, 5);
}
unsafe fn wake_by_ref(ptr: *const ()) {
assert_eq!(ptr as usize, 5);
}
unsafe fn drop(ptr: *const ()) {
assert_eq!(ptr as usize, 5);
}
}
fn async_store() -> wasmtime::Store {
let engine = wasmtime::Engine::default();
wasmtime::Store::new_async(&engine)
}
// Wiggle expects the caller to have an exported memory. Wasmtime can only
// provide this if the caller is a WebAssembly module, so we need to write
// a shim module:
fn shim_module(store: &wasmtime::Store) -> wasmtime::Module {
wasmtime::Module::new(
store.engine(),
r#"
(module
(memory 1)
(export "memory" (memory 0))
(import "atoms" "int_float_args" (func $int_float_args (param i32 f32) (result i32)))
(import "atoms" "double_int_return_float" (func $double_int_return_float (param i32 i32) (result i32)))
(func $int_float_args_shim (param i32 f32) (result i32)
local.get 0
local.get 1
call $int_float_args
)
(func $double_int_return_float_shim (param i32 i32) (result i32)
local.get 0
local.get 1
call $double_int_return_float
)
(export "int_float_args_shim" (func $int_float_args_shim))
(export "double_int_return_float_shim" (func $double_int_return_float_shim))
)
"#,
)
.unwrap()
}