diff --git a/crates/wiggle/wasmtime/Cargo.toml b/crates/wiggle/wasmtime/Cargo.toml index 31f45eeaf3..c9174b9edc 100644 --- a/crates/wiggle/wasmtime/Cargo.toml +++ b/crates/wiggle/wasmtime/Cargo.toml @@ -17,10 +17,22 @@ witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx", optional = wiggle = { path = "..", version = "0.23.0" } wiggle-borrow = { path = "../borrow", version = "0.23.0" } +[dev-dependencies] +anyhow = "1" +proptest = "0.10" + +[[test]] +name = "atoms_async" +path = "tests/atoms_async.rs" +required-features = ["async", "wasmtime/wat"] + [badges] maintenance = { status = "actively-developed" } [features] +# Async support for wasmtime +async = [ 'wasmtime/async', 'wasmtime-wiggle-macro/async' ] + # The wiggle proc-macro emits some code (inside `pub mod metadata`) guarded # by the `wiggle_metadata` feature flag. We use this feature flag so that # users of wiggle are not forced to take a direct dependency on the `witx` @@ -33,4 +45,4 @@ wiggle_metadata = ['witx', "wiggle/wiggle_metadata"] # the logs out of wiggle-generated libraries. tracing_log = [ "wiggle/tracing_log" ] -default = ["wiggle_metadata" ] +default = ["wiggle_metadata", "async"] diff --git a/crates/wiggle/wasmtime/macro/Cargo.toml b/crates/wiggle/wasmtime/macro/Cargo.toml index fd4b365f0d..e1f5f6d8de 100644 --- a/crates/wiggle/wasmtime/macro/Cargo.toml +++ b/crates/wiggle/wasmtime/macro/Cargo.toml @@ -24,3 +24,6 @@ proc-macro2 = "1.0" [badges] maintenance = { status = "actively-developed" } +[features] +async = [] +default = [] diff --git a/crates/wiggle/wasmtime/macro/src/config.rs b/crates/wiggle/wasmtime/macro/src/config.rs index 5e95bad957..9c5ee7b776 100644 --- a/crates/wiggle/wasmtime/macro/src/config.rs +++ b/crates/wiggle/wasmtime/macro/src/config.rs @@ -1,3 +1,4 @@ +pub use wiggle_generate::config::AsyncConf; use { proc_macro2::Span, std::collections::HashMap, @@ -7,15 +8,16 @@ use { punctuated::Punctuated, Error, Ident, Path, Result, Token, }, - wiggle_generate::config::{CtxConf, WitxConf}, + wiggle_generate::config::WitxConf, }; - #[derive(Debug, Clone)] pub struct Config { pub target: TargetConf, pub witx: WitxConf, pub ctx: CtxConf, pub modules: ModulesConf, + #[cfg(feature = "async")] + pub async_: AsyncConf, } #[derive(Debug, Clone)] @@ -24,6 +26,8 @@ pub enum ConfigField { Witx(WitxConf), Ctx(CtxConf), Modules(ModulesConf), + #[cfg(feature = "async")] + Async(AsyncConf), } mod kw { @@ -35,6 +39,7 @@ mod kw { syn::custom_keyword!(name); syn::custom_keyword!(docs); syn::custom_keyword!(function_override); + syn::custom_keyword!(async_); } impl Parse for ConfigField { @@ -60,6 +65,20 @@ impl Parse for ConfigField { input.parse::()?; input.parse::()?; Ok(ConfigField::Modules(input.parse()?)) + } else if lookahead.peek(kw::async_) { + input.parse::()?; + input.parse::()?; + #[cfg(feature = "async")] + { + Ok(ConfigField::Async(input.parse()?)) + } + #[cfg(not(feature = "async"))] + { + Err(syn::Error::new( + input.span(), + "async_ not supported, enable cargo feature \"async\"", + )) + } } else { Err(lookahead.error()) } @@ -72,6 +91,8 @@ impl Config { let mut witx = None; let mut ctx = None; let mut modules = None; + #[cfg(feature = "async")] + let mut async_ = None; for f in fields { match f { ConfigField::Target(c) => { @@ -98,6 +119,13 @@ impl Config { } modules = Some(c); } + #[cfg(feature = "async")] + ConfigField::Async(c) => { + if async_.is_some() { + return Err(Error::new(err_loc, "duplicate `async_` field")); + } + async_ = Some(c); + } } } Ok(Config { @@ -105,6 +133,8 @@ impl Config { witx: witx.ok_or_else(|| Error::new(err_loc, "`witx` field required"))?, ctx: ctx.ok_or_else(|| Error::new(err_loc, "`ctx` field required"))?, modules: modules.ok_or_else(|| Error::new(err_loc, "`modules` field required"))?, + #[cfg(feature = "async")] + async_: async_.unwrap_or_default(), }) } @@ -128,6 +158,19 @@ impl Parse for Config { } } +#[derive(Debug, Clone)] +pub struct CtxConf { + pub name: syn::Type, +} + +impl Parse for CtxConf { + fn parse(input: ParseStream) -> Result { + Ok(CtxConf { + name: input.parse()?, + }) + } +} + #[derive(Debug, Clone)] pub struct TargetConf { pub path: Path, diff --git a/crates/wiggle/wasmtime/macro/src/lib.rs b/crates/wiggle/wasmtime/macro/src/lib.rs index a2d05f729d..199036e79c 100644 --- a/crates/wiggle/wasmtime/macro/src/lib.rs +++ b/crates/wiggle/wasmtime/macro/src/lib.rs @@ -6,7 +6,7 @@ use wiggle_generate::Names; mod config; -use config::{ModuleConf, TargetConf}; +use config::{AsyncConf, ModuleConf, TargetConf}; /// Define the structs required to integrate a Wiggle implementation with Wasmtime. /// @@ -46,13 +46,25 @@ use config::{ModuleConf, TargetConf}; pub fn wasmtime_integration(args: TokenStream) -> TokenStream { let config = parse_macro_input!(args as config::Config); let doc = config.load_document(); - let names = Names::new(&config.ctx.name, quote!(wasmtime_wiggle)); + let names = Names::new(quote!(wasmtime_wiggle)); + + #[cfg(feature = "async")] + let async_config = config.async_.clone(); + #[cfg(not(feature = "async"))] + let async_config = AsyncConf::default(); let modules = config.modules.iter().map(|(name, module_conf)| { let module = doc .module(&witx::Id::new(name)) .unwrap_or_else(|| panic!("witx document did not contain module named '{}'", name)); - generate_module(&module, &module_conf, &names, &config.target) + generate_module( + &module, + &module_conf, + &names, + &config.target, + &config.ctx.name, + &async_config, + ) }); quote!( #(#modules)* ).into() } @@ -62,6 +74,8 @@ fn generate_module( module_conf: &ModuleConf, names: &Names, target_conf: &TargetConf, + ctx_type: &syn::Type, + async_conf: &AsyncConf, ) -> TokenStream2 { let fields = module.funcs().map(|f| { let name_ident = names.func(&f.name); @@ -88,9 +102,14 @@ fn generate_module( let module_id = names.module(&module.name); let target_module = quote! { #target_path::#module_id }; - let ctor_externs = module - .funcs() - .map(|f| generate_func(&f, names, &target_module)); + let ctor_externs = module.funcs().map(|f| { + generate_func( + &f, + names, + &target_module, + async_conf.is_async(module.name.as_str(), f.name.as_str()), + ) + }); let type_name = module_conf.name.clone(); let type_docs = module_conf @@ -107,8 +126,6 @@ contained in the `cx` parameter.", module_conf.name.to_string() ); - let ctx_type = names.ctx_type(); - quote! { #type_docs pub struct #type_name { @@ -150,6 +167,7 @@ fn generate_func( func: &witx::InterfaceFunc, names: &Names, target_module: &TokenStream2, + is_async: bool, ) -> TokenStream2 { let name_ident = names.func(&func.name); @@ -172,31 +190,52 @@ fn generate_func( let runtime = names.runtime_mod(); - quote! { - let my_cx = cx.clone(); - let #name_ident = wasmtime::Func::wrap( - store, - move |caller: wasmtime::Caller<'_> #(,#arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> { - unsafe { - let mem = match caller.get_export("memory") { - Some(wasmtime::Extern::Memory(m)) => m, - _ => { - return Err(wasmtime::Trap::new("missing required memory export")); - } - }; - let mem = #runtime::WasmtimeGuestMemory::new(mem); - let result = #target_module::#name_ident( - &mut my_cx.borrow_mut(), - &mem, - #(#arg_names),* - ); - match result { - Ok(r) => Ok(r.into()), - Err(wasmtime_wiggle::Trap::String(err)) => Err(wasmtime::Trap::new(err)), - Err(wasmtime_wiggle::Trap::I32Exit(err)) => Err(wasmtime::Trap::i32_exit(err)), - } + let await_ = if is_async { quote!(.await) } else { quote!() }; + + let closure_body = quote! { + unsafe { + let mem = match caller.get_export("memory") { + Some(wasmtime::Extern::Memory(m)) => m, + _ => { + return Err(wasmtime::Trap::new("missing required memory export")); } + }; + let mem = #runtime::WasmtimeGuestMemory::new(mem); + let result = #target_module::#name_ident( + &mut *my_cx.borrow_mut(), + &mem, + #(#arg_names),* + ) #await_; + match result { + Ok(r) => Ok(r.into()), + Err(wasmtime_wiggle::Trap::String(err)) => Err(wasmtime::Trap::new(err)), + Err(wasmtime_wiggle::Trap::I32Exit(err)) => Err(wasmtime::Trap::i32_exit(err)), + } + } + + }; + if is_async { + let wrapper = quote::format_ident!("wrap{}_async", params.len()); + quote! { + let #name_ident = wasmtime::Func::#wrapper( + store, + cx.clone(), + move |caller: wasmtime::Caller<'_>, my_cx: &Rc> #(,#arg_decls)*| + -> Box>> + { + Box::new(async move { #closure_body }) } ); + } + } else { + quote! { + let my_cx = cx.clone(); + let #name_ident = wasmtime::Func::wrap( + store, + move |caller: wasmtime::Caller<'_> #(,#arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> { + #closure_body + } + ); + } } } diff --git a/crates/wiggle/wasmtime/tests/atoms.witx b/crates/wiggle/wasmtime/tests/atoms.witx new file mode 100644 index 0000000000..af5955dbf0 --- /dev/null +++ b/crates/wiggle/wasmtime/tests/atoms.witx @@ -0,0 +1,25 @@ + +(typename $errno + (enum (@witx tag u32) + ;;; Success + $ok + ;;; Invalid argument + $invalid_arg + ;;; I really don't want to + $dont_want_to + ;;; I am physically unable to + $physically_unable + ;;; Well, that's a picket line alright! + $picket_line)) + +(typename $alias_to_float f32) + +(module $atoms + (@interface func (export "int_float_args") + (param $an_int u32) + (param $an_float f32) + (result $error (expected (error $errno)))) + (@interface func (export "double_int_return_float") + (param $an_int u32) + (result $error (expected $alias_to_float (error $errno)))) +) diff --git a/crates/wiggle/wasmtime/tests/atoms_async.rs b/crates/wiggle/wasmtime/tests/atoms_async.rs new file mode 100644 index 0000000000..f31da11ecb --- /dev/null +++ b/crates/wiggle/wasmtime/tests/atoms_async.rs @@ -0,0 +1,175 @@ +use std::cell::RefCell; +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; + +wasmtime_wiggle::from_witx!({ + witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"], + async_: { + atoms::{double_int_return_float} + } +}); + +wasmtime_wiggle::wasmtime_integration!({ + target: crate, + witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"], + ctx: Ctx, + modules: { atoms => { name: Atoms } }, + async_: { + atoms::double_int_return_float + } +}); + +pub struct Ctx; +impl wiggle::GuestErrorType for types::Errno { + fn success() -> Self { + types::Errno::Ok + } +} + +#[wasmtime_wiggle::async_trait(?Send)] +impl atoms::Atoms for Ctx { + fn int_float_args(&self, an_int: u32, an_float: f32) -> Result<(), types::Errno> { + println!("INT FLOAT ARGS: {} {}", an_int, an_float); + Ok(()) + } + async fn double_int_return_float( + &self, + an_int: u32, + ) -> Result { + Ok((an_int as f32) * 2.0) + } +} + +#[test] +fn test_sync_host_func() { + let store = async_store(); + + let ctx = Rc::new(RefCell::new(Ctx)); + let atoms = Atoms::new(&store, ctx.clone()); + + let shim_mod = shim_module(&store); + let mut linker = wasmtime::Linker::new(&store); + atoms.add_to_linker(&mut linker).unwrap(); + let shim_inst = run(linker.instantiate_async(&shim_mod)).unwrap(); + + let results = run(shim_inst + .get_func("int_float_args_shim") + .unwrap() + .call_async(&[0i32.into(), 123.45f32.into()])) + .unwrap(); + + assert_eq!(results.len(), 1, "one return value"); + assert_eq!( + results[0].unwrap_i32(), + types::Errno::Ok as i32, + "int_float_args errno" + ); +} + +#[test] +fn test_async_host_func() { + let store = async_store(); + + let ctx = Rc::new(RefCell::new(Ctx)); + let atoms = Atoms::new(&store, ctx.clone()); + + let shim_mod = shim_module(&store); + let mut linker = wasmtime::Linker::new(&store); + atoms.add_to_linker(&mut linker).unwrap(); + let shim_inst = run(linker.instantiate_async(&shim_mod)).unwrap(); + + let input: i32 = 123; + let result_location: i32 = 0; + + let results = run(shim_inst + .get_func("double_int_return_float_shim") + .unwrap() + .call_async(&[input.into(), result_location.into()])) + .unwrap(); + + assert_eq!(results.len(), 1, "one return value"); + assert_eq!( + results[0].unwrap_i32(), + types::Errno::Ok as i32, + "double_int_return_float errno" + ); + + // The actual result is in memory: + let mem = shim_inst.get_memory("memory").unwrap(); + let mut result_bytes: [u8; 4] = [0, 0, 0, 0]; + mem.read(result_location as usize, &mut result_bytes) + .unwrap(); + let result = f32::from_le_bytes(result_bytes); + assert_eq!((input * 2) as f32, result); +} + +fn run(future: F) -> F::Output { + let mut f = Pin::from(Box::new(future)); + let waker = dummy_waker(); + let mut cx = Context::from_waker(&waker); + loop { + match f.as_mut().poll(&mut cx) { + Poll::Ready(val) => break val, + Poll::Pending => {} + } + } +} + +fn dummy_waker() -> Waker { + return unsafe { Waker::from_raw(clone(5 as *const _)) }; + + unsafe fn clone(ptr: *const ()) -> RawWaker { + assert_eq!(ptr as usize, 5); + const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop); + RawWaker::new(ptr, &VTABLE) + } + + unsafe fn wake(ptr: *const ()) { + assert_eq!(ptr as usize, 5); + } + + unsafe fn wake_by_ref(ptr: *const ()) { + assert_eq!(ptr as usize, 5); + } + + unsafe fn drop(ptr: *const ()) { + assert_eq!(ptr as usize, 5); + } +} +fn async_store() -> wasmtime::Store { + let engine = wasmtime::Engine::default(); + wasmtime::Store::new_async(&engine) +} + +// Wiggle expects the caller to have an exported memory. Wasmtime can only +// provide this if the caller is a WebAssembly module, so we need to write +// a shim module: +fn shim_module(store: &wasmtime::Store) -> wasmtime::Module { + wasmtime::Module::new( + store.engine(), + r#" + (module + (memory 1) + (export "memory" (memory 0)) + (import "atoms" "int_float_args" (func $int_float_args (param i32 f32) (result i32))) + (import "atoms" "double_int_return_float" (func $double_int_return_float (param i32 i32) (result i32))) + + (func $int_float_args_shim (param i32 f32) (result i32) + local.get 0 + local.get 1 + call $int_float_args + ) + (func $double_int_return_float_shim (param i32 i32) (result i32) + local.get 0 + local.get 1 + call $double_int_return_float + ) + (export "int_float_args_shim" (func $int_float_args_shim)) + (export "double_int_return_float_shim" (func $double_int_return_float_shim)) + ) + "#, + ) + .unwrap() +}