diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index 2ab5e9119b..bed0e5c834 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -47,7 +47,9 @@ impl Parse for ConfigField { } else if lookahead.peek(Token![async]) { input.parse::()?; input.parse::()?; - Ok(ConfigField::Async(input.parse()?)) + Ok(ConfigField::Async(AsyncConf { + functions: input.parse()?, + })) } else { Err(lookahead.error()) } @@ -282,40 +284,62 @@ impl Parse for ErrorConfField { #[derive(Clone, Default, Debug)] /// Modules and funcs that have async signatures pub struct AsyncConf { - functions: HashMap>, + functions: AsyncFunctions, +} + +#[derive(Clone, Debug)] +pub enum AsyncFunctions { + Some(HashMap>), + All, +} +impl Default for AsyncFunctions { + fn default() -> Self { + AsyncFunctions::Some(HashMap::default()) + } } impl AsyncConf { pub fn is_async(&self, module: &str, function: &str) -> bool { - self.functions - .get(module) - .and_then(|fs| fs.iter().find(|f| *f == function)) - .is_some() + match &self.functions { + AsyncFunctions::Some(fs) => fs + .get(module) + .and_then(|fs| fs.iter().find(|f| *f == function)) + .is_some(), + AsyncFunctions::All => true, + } } } -impl Parse for AsyncConf { +impl Parse for AsyncFunctions { fn parse(input: ParseStream) -> Result { let content; - let _ = braced!(content in input); - let items: Punctuated = - content.parse_terminated(Parse::parse)?; - let mut functions: HashMap> = HashMap::new(); - use std::collections::hash_map::Entry; - for i in items { - let function_names = i - .function_names - .iter() - .map(|i| i.to_string()) - .collect::>(); - match functions.entry(i.module_name.to_string()) { - Entry::Occupied(o) => o.into_mut().extend(function_names), - Entry::Vacant(v) => { - v.insert(function_names); + let lookahead = input.lookahead1(); + if lookahead.peek(syn::token::Brace) { + let _ = braced!(content in input); + let items: Punctuated = + content.parse_terminated(Parse::parse)?; + let mut functions: HashMap> = HashMap::new(); + use std::collections::hash_map::Entry; + for i in items { + let function_names = i + .function_names + .iter() + .map(|i| i.to_string()) + .collect::>(); + match functions.entry(i.module_name.to_string()) { + Entry::Occupied(o) => o.into_mut().extend(function_names), + Entry::Vacant(v) => { + v.insert(function_names); + } } } + Ok(AsyncFunctions::Some(functions)) + } else if lookahead.peek(Token![*]) { + let _: Token![*] = input.parse().unwrap(); + Ok(AsyncFunctions::All) + } else { + Err(lookahead.error()) } - Ok(AsyncConf { functions }) } } diff --git a/crates/wiggle/tests/atoms_async.rs b/crates/wiggle/tests/atoms_async.rs index 0057c9203a..7d41d3a273 100644 --- a/crates/wiggle/tests/atoms_async.rs +++ b/crates/wiggle/tests/atoms_async.rs @@ -7,9 +7,7 @@ use wiggle_test::{impl_errno, HostMemory, MemArea, WasiCtx}; wiggle::from_witx!({ witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"], - async: { - atoms::{int_float_args, double_int_return_float} - } + async: *, }); impl_errno!(types::Errno); diff --git a/crates/wiggle/wasmtime/macro/src/config.rs b/crates/wiggle/wasmtime/macro/src/config.rs index 0aa9adcce5..30815817b2 100644 --- a/crates/wiggle/wasmtime/macro/src/config.rs +++ b/crates/wiggle/wasmtime/macro/src/config.rs @@ -1,4 +1,4 @@ -use wiggle_generate::config::AsyncConfField; +use wiggle_generate::config::AsyncFunctions; use { proc_macro2::Span, std::collections::HashMap, @@ -66,13 +66,17 @@ impl Parse for ConfigField { } else if lookahead.peek(Token![async]) { input.parse::()?; input.parse::()?; - Ok(ConfigField::Async(input.parse()?)) + Ok(ConfigField::Async(AsyncConf { + blocking: false, + functions: input.parse()?, + })) } else if lookahead.peek(kw::block_on) { input.parse::()?; input.parse::()?; - let mut async_conf: AsyncConf = input.parse()?; - async_conf.blocking = true; - Ok(ConfigField::Async(async_conf)) + Ok(ConfigField::Async(AsyncConf { + blocking: true, + functions: input.parse()?, + })) } else { Err(lookahead.error()) } @@ -273,7 +277,7 @@ impl Parse for ModulesConf { /// Modules and funcs that have async signatures pub struct AsyncConf { blocking: bool, - functions: HashMap>, + functions: AsyncFunctions, } #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -297,47 +301,24 @@ impl Asyncness { impl AsyncConf { pub fn is_async(&self, module: &str, function: &str) -> Asyncness { - if self - .functions - .get(module) - .and_then(|fs| fs.iter().find(|f| *f == function)) - .is_some() - { - if self.blocking { - Asyncness::Blocking - } else { - Asyncness::Async - } + let a = if self.blocking { + Asyncness::Blocking } else { - Asyncness::Sync - } - } -} - -impl Parse for AsyncConf { - fn parse(input: ParseStream) -> Result { - let content; - let _ = braced!(content in input); - let items: Punctuated = - content.parse_terminated(Parse::parse)?; - let mut functions: HashMap> = HashMap::new(); - use std::collections::hash_map::Entry; - for i in items { - let function_names = i - .function_names - .iter() - .map(|i| i.to_string()) - .collect::>(); - match functions.entry(i.module_name.to_string()) { - Entry::Occupied(o) => o.into_mut().extend(function_names), - Entry::Vacant(v) => { - v.insert(function_names); + Asyncness::Async + }; + match &self.functions { + AsyncFunctions::Some(fs) => { + if fs + .get(module) + .and_then(|fs| fs.iter().find(|f| *f == function)) + .is_some() + { + a + } else { + Asyncness::Sync } } + AsyncFunctions::All => a, } - Ok(AsyncConf { - functions, - blocking: false, - }) } }