diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index a430a8547a..2ab5e9119b 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -280,12 +280,14 @@ impl Parse for ErrorConfField { } #[derive(Clone, Default, Debug)] -/// Modules and funcs that should be async -pub struct AsyncConf(HashMap>); +/// Modules and funcs that have async signatures +pub struct AsyncConf { + functions: HashMap>, +} impl AsyncConf { pub fn is_async(&self, module: &str, function: &str) -> bool { - self.0 + self.functions .get(module) .and_then(|fs| fs.iter().find(|f| *f == function)) .is_some() @@ -298,7 +300,7 @@ impl Parse for AsyncConf { let _ = braced!(content in input); let items: Punctuated = content.parse_terminated(Parse::parse)?; - let mut m: HashMap> = HashMap::new(); + let mut functions: HashMap> = HashMap::new(); use std::collections::hash_map::Entry; for i in items { let function_names = i @@ -306,14 +308,14 @@ impl Parse for AsyncConf { .iter() .map(|i| i.to_string()) .collect::>(); - match m.entry(i.module_name.to_string()) { + match functions.entry(i.module_name.to_string()) { Entry::Occupied(o) => o.into_mut().extend(function_names), Entry::Vacant(v) => { v.insert(function_names); } } } - Ok(AsyncConf(m)) + Ok(AsyncConf { functions }) } } diff --git a/crates/wiggle/wasmtime/macro/src/config.rs b/crates/wiggle/wasmtime/macro/src/config.rs index e37d3492ea..0aa9adcce5 100644 --- a/crates/wiggle/wasmtime/macro/src/config.rs +++ b/crates/wiggle/wasmtime/macro/src/config.rs @@ -1,4 +1,4 @@ -pub use wiggle_generate::config::AsyncConf; +use wiggle_generate::config::AsyncConfField; use { proc_macro2::Span, std::collections::HashMap, @@ -37,6 +37,7 @@ mod kw { syn::custom_keyword!(name); syn::custom_keyword!(docs); syn::custom_keyword!(function_override); + syn::custom_keyword!(block_on); } impl Parse for ConfigField { @@ -66,6 +67,12 @@ impl Parse for ConfigField { input.parse::()?; input.parse::()?; Ok(ConfigField::Async(input.parse()?)) + } else if lookahead.peek(kw::block_on) { + input.parse::()?; + input.parse::()?; + let mut async_conf: AsyncConf = input.parse()?; + async_conf.blocking = true; + Ok(ConfigField::Async(async_conf)) } else { Err(lookahead.error()) } @@ -261,3 +268,76 @@ impl Parse for ModulesConf { }) } } + +#[derive(Clone, Default, Debug)] +/// Modules and funcs that have async signatures +pub struct AsyncConf { + blocking: bool, + functions: HashMap>, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Asyncness { + /// Wiggle function is synchronous, wasmtime Func is synchronous + Sync, + /// Wiggle function is asynchronous, but wasmtime Func is synchronous + Blocking, + /// Wiggle function and wasmtime Func are asynchronous. + Async, +} + +impl Asyncness { + pub fn is_sync(&self) -> bool { + match self { + Asyncness::Sync => true, + _ => false, + } + } +} + +impl AsyncConf { + pub fn is_async(&self, module: &str, function: &str) -> Asyncness { + if self + .functions + .get(module) + .and_then(|fs| fs.iter().find(|f| *f == function)) + .is_some() + { + if self.blocking { + Asyncness::Blocking + } else { + Asyncness::Async + } + } else { + Asyncness::Sync + } + } +} + +impl Parse for AsyncConf { + fn parse(input: ParseStream) -> Result { + let content; + let _ = braced!(content in input); + let items: Punctuated = + content.parse_terminated(Parse::parse)?; + let mut functions: HashMap> = HashMap::new(); + use std::collections::hash_map::Entry; + for i in items { + let function_names = i + .function_names + .iter() + .map(|i| i.to_string()) + .collect::>(); + match functions.entry(i.module_name.to_string()) { + Entry::Occupied(o) => o.into_mut().extend(function_names), + Entry::Vacant(v) => { + v.insert(function_names); + } + } + } + Ok(AsyncConf { + functions, + blocking: false, + }) + } +} diff --git a/crates/wiggle/wasmtime/macro/src/lib.rs b/crates/wiggle/wasmtime/macro/src/lib.rs index b7bf5543c9..4764d217c1 100644 --- a/crates/wiggle/wasmtime/macro/src/lib.rs +++ b/crates/wiggle/wasmtime/macro/src/lib.rs @@ -6,7 +6,7 @@ use wiggle_generate::Names; mod config; -use config::{AsyncConf, ModuleConf, TargetConf}; +use config::{AsyncConf, Asyncness, ModuleConf, TargetConf}; /// Define the structs required to integrate a Wiggle implementation with Wasmtime. /// @@ -101,14 +101,19 @@ fn generate_module( let mut ctor_externs = Vec::new(); let mut host_funcs = Vec::new(); - #[cfg(not(feature = "async"))] let mut requires_dummy_executor = false; for f in module.funcs() { - let is_async = async_conf.is_async(module.name.as_str(), f.name.as_str()); - #[cfg(not(feature = "async"))] - if is_async { - requires_dummy_executor = true; + let asyncness = async_conf.is_async(module.name.as_str(), f.name.as_str()); + match asyncness { + Asyncness::Blocking => requires_dummy_executor = true, + Asyncness::Async => { + assert!( + cfg!(feature = "async"), + "generating async wasmtime Funcs requires cargo feature \"async\"" + ); + } + _ => {} } generate_func( &module_id, @@ -116,7 +121,7 @@ fn generate_module( names, &target_module, ctx_type, - is_async, + asyncness, &mut fns, &mut ctor_externs, &mut host_funcs, @@ -160,14 +165,11 @@ contained in the `cx` parameter.", } }); - #[cfg(not(feature = "async"))] let dummy_executor = if requires_dummy_executor { dummy_executor() } else { quote!() }; - #[cfg(feature = "async")] - let dummy_executor = quote!(); quote! { #type_docs @@ -243,7 +245,7 @@ fn generate_func( names: &Names, target_module: &TokenStream2, ctx_type: &syn::Type, - is_async: bool, + asyncness: Asyncness, fns: &mut Vec, ctors: &mut Vec, host_funcs: &mut Vec<(witx::Id, TokenStream2)>, @@ -271,8 +273,16 @@ fn generate_func( _ => unimplemented!(), }; - let async_ = if is_async { quote!(async) } else { quote!() }; - let await_ = if is_async { quote!(.await) } else { quote!() }; + let async_ = if asyncness.is_sync() { + quote!() + } else { + quote!(async) + }; + let await_ = if asyncness.is_sync() { + quote!() + } else { + quote!(.await) + }; let runtime = names.runtime_mod(); let fn_ident = format_ident!("{}_{}", module_ident, name_ident); @@ -296,9 +306,8 @@ fn generate_func( } }); - if is_async { - #[cfg(feature = "async")] - { + match asyncness { + Asyncness::Async => { let wrapper = format_ident!("wrap{}_async", params.len()); ctors.push(quote! { let #name_ident = wasmtime::Func::#wrapper( @@ -309,11 +318,9 @@ fn generate_func( Box::new(async move { Self::#fn_ident(&caller, &mut my_ctx.borrow_mut() #(, #arg_names)*).await }) } ); - }); + }); } - - #[cfg(not(feature = "async"))] - { + Asyncness::Blocking => { // Emit a synchronous function. Self::#fn_ident returns a Future, so we need to // use a dummy executor to let any synchronous code inside there execute correctly. If // the future ends up Pending, this func will Trap. @@ -327,8 +334,8 @@ fn generate_func( ); }); } - } else { - ctors.push(quote! { + Asyncness::Sync => { + ctors.push(quote! { let my_ctx = ctx.clone(); let #name_ident = wasmtime::Func::wrap( store, @@ -337,11 +344,11 @@ fn generate_func( } ); }); + } } - let host_wrapper = if is_async { - #[cfg(feature = "async")] - { + let host_wrapper = match asyncness { + Asyncness::Async => { let wrapper = format_ident!("wrap{}_host_func_async", params.len()); quote! { config.#wrapper( @@ -361,8 +368,7 @@ fn generate_func( } } - #[cfg(not(feature = "async"))] - { + Asyncness::Blocking => { // Emit a synchronous host function. Self::#fn_ident returns a Future, so we need to // use a dummy executor to let any synchronous code inside there execute correctly. If // the future ends up Pending, this func will Trap. @@ -380,24 +386,25 @@ fn generate_func( ); } } - } else { - quote! { - config.wrap_host_func( - module, - field, - move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> { - let ctx = caller - .store() - .get::>>() - .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?; - Self::#fn_ident(&caller, &mut ctx.borrow_mut() #(, #arg_names)*) - }, - ); + Asyncness::Sync => { + quote! { + config.wrap_host_func( + module, + field, + move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> { + let ctx = caller + .store() + .get::>>() + .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?; + Self::#fn_ident(&caller, &mut ctx.borrow_mut() #(, #arg_names)*) + }, + ); + } } }; host_funcs.push((func.name.clone(), host_wrapper)); } -#[cfg(not(feature = "async"))] + fn dummy_executor() -> TokenStream2 { quote! { fn run_in_dummy_executor(future: F) -> F::Output { diff --git a/crates/wiggle/wasmtime/tests/atoms_sync.rs b/crates/wiggle/wasmtime/tests/atoms_sync.rs index 057ebbaa6e..eee48f5338 100644 --- a/crates/wiggle/wasmtime/tests/atoms_sync.rs +++ b/crates/wiggle/wasmtime/tests/atoms_sync.rs @@ -1,11 +1,3 @@ -#![allow(unused)] -// These tests are designed to check the behavior with the crate's async feature (& wasmtimes async -// feature) disabled. Run with: -// `cargo test --no-default-features --features wasmtime/wat --test atoms_sync` -#[cfg(feature = "async")] -#[test] -fn these_tests_require_async_feature_disabled() {} - use std::cell::RefCell; use std::rc::Rc; @@ -21,7 +13,7 @@ wasmtime_wiggle::wasmtime_integration!({ witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"], ctx: Ctx, modules: { atoms => { name: Atoms } }, - async: { + block_on: { atoms::double_int_return_float } }); @@ -94,7 +86,6 @@ fn run_double_int_return_float(linker: &wasmtime::Linker) { assert_eq!((input * 2) as f32, result); } -#[cfg(not(feature = "async"))] #[test] fn test_sync_host_func() { let store = store(); @@ -108,7 +99,6 @@ fn test_sync_host_func() { run_int_float_args(&linker); } -#[cfg(not(feature = "async"))] #[test] fn test_async_host_func() { let store = store(); @@ -122,7 +112,6 @@ fn test_async_host_func() { run_double_int_return_float(&linker); } -#[cfg(not(feature = "async"))] #[test] fn test_sync_config_host_func() { let mut config = wasmtime::Config::new(); @@ -137,7 +126,6 @@ fn test_sync_config_host_func() { run_int_float_args(&linker); } -#[cfg(not(feature = "async"))] #[test] fn test_async_config_host_func() { let mut config = wasmtime::Config::new();