diff --git a/crates/wiggle/Cargo.toml b/crates/wiggle/Cargo.toml index 88c8315bdc..a5fdbf7cbd 100644 --- a/crates/wiggle/Cargo.toml +++ b/crates/wiggle/Cargo.toml @@ -16,6 +16,7 @@ witx = { path = "../wasi-common/WASI/tools/witx", version = "0.9", optional = tr wiggle-macro = { path = "macro", version = "0.23.0" } tracing = "0.1.15" bitflags = "1.2" +async-trait = "0.1.42" [badges] maintenance = { status = "actively-developed" } diff --git a/crates/wiggle/generate/src/error_transform.rs b/crates/wiggle/generate/src/codegen_settings.rs similarity index 80% rename from crates/wiggle/generate/src/error_transform.rs rename to crates/wiggle/generate/src/codegen_settings.rs index 56fa10eaf0..c51c87e9dd 100644 --- a/crates/wiggle/generate/src/error_transform.rs +++ b/crates/wiggle/generate/src/codegen_settings.rs @@ -1,10 +1,28 @@ -use crate::config::ErrorConf; +use crate::config::{AsyncConf, ErrorConf}; use anyhow::{anyhow, Error}; use proc_macro2::TokenStream; use quote::quote; use std::collections::HashMap; use std::rc::Rc; -use witx::{Document, Id, NamedType, TypeRef}; +use witx::{Document, Id, InterfaceFunc, Module, NamedType, TypeRef}; + +pub struct CodegenSettings { + pub errors: ErrorTransform, + async_: AsyncConf, +} +impl CodegenSettings { + pub fn new(error_conf: &ErrorConf, async_: &AsyncConf, doc: &Document) -> Result { + let errors = ErrorTransform::new(error_conf, doc)?; + Ok(Self { + errors, + async_: async_.clone(), + }) + } + pub fn is_async(&self, module: &Module, func: &InterfaceFunc) -> bool { + self.async_ + .is_async(module.name.as_str(), func.name.as_str()) + } +} pub struct ErrorTransform { m: Vec, diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index 520532a3ec..b6467da609 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -12,22 +12,22 @@ use { #[derive(Debug, Clone)] pub struct Config { pub witx: WitxConf, - pub ctx: CtxConf, pub errors: ErrorConf, + pub async_: AsyncConf, } #[derive(Debug, Clone)] pub enum ConfigField { Witx(WitxConf), - Ctx(CtxConf), Error(ErrorConf), + Async(AsyncConf), } mod kw { syn::custom_keyword!(witx); syn::custom_keyword!(witx_literal); - syn::custom_keyword!(ctx); syn::custom_keyword!(errors); + syn::custom_keyword!(async_); } impl Parse for ConfigField { @@ -41,14 +41,14 @@ impl Parse for ConfigField { input.parse::()?; input.parse::()?; Ok(ConfigField::Witx(WitxConf::Literal(input.parse()?))) - } else if lookahead.peek(kw::ctx) { - input.parse::()?; - input.parse::()?; - Ok(ConfigField::Ctx(input.parse()?)) } else if lookahead.peek(kw::errors) { input.parse::()?; input.parse::()?; Ok(ConfigField::Error(input.parse()?)) + } else if lookahead.peek(kw::async_) { + input.parse::()?; + input.parse::()?; + Ok(ConfigField::Async(input.parse()?)) } else { Err(lookahead.error()) } @@ -58,8 +58,8 @@ impl Parse for ConfigField { impl Config { pub fn build(fields: impl Iterator, err_loc: Span) -> Result { let mut witx = None; - let mut ctx = None; let mut errors = None; + let mut async_ = None; for f in fields { match f { ConfigField::Witx(c) => { @@ -68,28 +68,26 @@ impl Config { } witx = Some(c); } - ConfigField::Ctx(c) => { - if ctx.is_some() { - return Err(Error::new(err_loc, "duplicate `ctx` field")); - } - ctx = Some(c); - } ConfigField::Error(c) => { if errors.is_some() { return Err(Error::new(err_loc, "duplicate `errors` field")); } errors = Some(c); } + ConfigField::Async(c) => { + if async_.is_some() { + return Err(Error::new(err_loc, "duplicate `async` field")); + } + async_ = Some(c); + } } } Ok(Config { witx: witx .take() .ok_or_else(|| Error::new(err_loc, "`witx` field required"))?, - ctx: ctx - .take() - .ok_or_else(|| Error::new(err_loc, "`ctx` field required"))?, errors: errors.take().unwrap_or_default(), + async_: async_.take().unwrap_or_default(), }) } @@ -216,19 +214,6 @@ impl Parse for Literal { } } -#[derive(Debug, Clone)] -pub struct CtxConf { - pub name: Ident, -} - -impl Parse for CtxConf { - fn parse(input: ParseStream) -> Result { - Ok(CtxConf { - name: input.parse()?, - }) - } -} - #[derive(Clone, Default, Debug)] /// Map from abi error type to rich error type pub struct ErrorConf(HashMap); @@ -294,3 +279,77 @@ impl Parse for ErrorConfField { }) } } + +#[derive(Clone, Default, Debug)] +/// Modules and funcs that should be async +pub struct AsyncConf(HashMap>); + +impl AsyncConf { + pub fn is_async(&self, module: &str, function: &str) -> bool { + self.0 + .get(module) + .and_then(|fs| fs.iter().find(|f| *f == function)) + .is_some() + } +} + +impl Parse for AsyncConf { + fn parse(input: ParseStream) -> Result { + let content; + let _ = braced!(content in input); + let items: Punctuated = + content.parse_terminated(Parse::parse)?; + let mut m: HashMap> = HashMap::new(); + use std::collections::hash_map::Entry; + for i in items { + let function_names = i + .function_names + .iter() + .map(|i| i.to_string()) + .collect::>(); + match m.entry(i.module_name.to_string()) { + Entry::Occupied(o) => o.into_mut().extend(function_names), + Entry::Vacant(v) => { + v.insert(function_names); + } + } + } + Ok(AsyncConf(m)) + } +} + +#[derive(Clone)] +pub struct AsyncConfField { + pub module_name: Ident, + pub function_names: Vec, + pub err_loc: Span, +} + +impl Parse for AsyncConfField { + fn parse(input: ParseStream) -> Result { + let err_loc = input.span(); + let module_name = input.parse::()?; + let _doublecolon: Token![::] = input.parse()?; + let lookahead = input.lookahead1(); + if lookahead.peek(syn::token::Brace) { + let content; + let _ = braced!(content in input); + let function_names: Punctuated = + content.parse_terminated(Parse::parse)?; + Ok(AsyncConfField { + module_name, + function_names: function_names.iter().cloned().collect(), + err_loc, + }) + } else if lookahead.peek(Ident) { + let name = input.parse()?; + Ok(AsyncConfField { + module_name, + function_names: vec![name], + err_loc, + }) + } else { + Err(lookahead.error()) + } + } +} diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index 187178e8b6..af15709d0b 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -1,4 +1,4 @@ -use crate::error_transform::ErrorTransform; +use crate::codegen_settings::CodegenSettings; use crate::lifetimes::anon_lifetime; use crate::module_trait::passed_by_reference; use crate::names::Names; @@ -12,11 +12,10 @@ pub fn define_func( names: &Names, module: &witx::Module, func: &witx::InterfaceFunc, - errxform: &ErrorTransform, + settings: &CodegenSettings, ) -> TokenStream { let rt = names.runtime_mod(); let ident = names.func(&func.name); - let ctx_type = names.ctx_type(); let (wasm_params, wasm_results) = func.wasm_signature(); let param_names = (0..wasm_params.len()) @@ -37,6 +36,7 @@ pub fn define_func( }; let mut body = TokenStream::new(); + let mut required_impls = vec![names.trait_name(&module.name)]; func.call_interface( &module.name, &mut Rust { @@ -48,16 +48,22 @@ pub fn define_func( names, module, funcname: func.name.as_str(), - errxform, + settings, + required_impls: &mut required_impls, }, ); + let asyncness = if settings.is_async(&module, &func) { + quote!(async) + } else { + quote!() + }; let mod_name = &module.name.as_str(); let func_name = &func.name.as_str(); quote! { #[allow(unreachable_code)] // deals with warnings in noreturn functions - pub fn #ident( - ctx: &#ctx_type, + pub #asyncness fn #ident( + ctx: &(impl #(#required_impls)+*), memory: &dyn #rt::GuestMemory, #(#abi_params),* ) -> Result<#abi_ret, #rt::Trap> { @@ -85,7 +91,16 @@ struct Rust<'a> { names: &'a Names, module: &'a witx::Module, funcname: &'a str, - errxform: &'a ErrorTransform, + settings: &'a CodegenSettings, + required_impls: &'a mut Vec, +} + +impl Rust<'_> { + fn required_impl(&mut self, i: Ident) { + if !self.required_impls.contains(&i) { + self.required_impls.push(i); + } + } } impl witx::Bindgen for Rust<'_> { @@ -205,8 +220,16 @@ impl witx::Bindgen for Rust<'_> { let trait_name = self.names.trait_name(&self.module.name); let ident = self.names.func(&func.name); + if self.settings.is_async(&self.module, &func) { + self.src.extend(quote! { + let ret = #trait_name::#ident(ctx, #(#args),*).await; + }) + } else { + self.src.extend(quote! { + let ret = #trait_name::#ident(ctx, #(#args),*); + }) + }; self.src.extend(quote! { - let ret = #trait_name::#ident(ctx, #(#args),*); #rt::tracing::event!( #rt::tracing::Level::TRACE, result = #rt::tracing::field::debug(&ret), @@ -226,9 +249,10 @@ impl witx::Bindgen for Rust<'_> { // enum, and *then* we lower to an i32. Instruction::EnumLower { ty } => { let val = operands.pop().unwrap(); - let val = match self.errxform.for_name(ty) { + let val = match self.settings.errors.for_name(ty) { Some(custom) => { let method = self.names.user_error_conversion_method(&custom); + self.required_impl(quote::format_ident!("UserErrorConversion")); quote!(UserErrorConversion::#method(ctx, #val)?) } None => val, diff --git a/crates/wiggle/generate/src/lib.rs b/crates/wiggle/generate/src/lib.rs index 2e172633df..90cca766d1 100644 --- a/crates/wiggle/generate/src/lib.rs +++ b/crates/wiggle/generate/src/lib.rs @@ -1,5 +1,5 @@ +mod codegen_settings; pub mod config; -mod error_transform; mod funcs; mod lifetimes; mod module_trait; @@ -11,14 +11,14 @@ use lifetimes::anon_lifetime; use proc_macro2::{Literal, TokenStream}; use quote::quote; +pub use codegen_settings::{CodegenSettings, UserErrorType}; pub use config::Config; -pub use error_transform::{ErrorTransform, UserErrorType}; pub use funcs::define_func; pub use module_trait::define_module_trait; pub use names::Names; pub use types::define_datatype; -pub fn generate(doc: &witx::Document, names: &Names, errs: &ErrorTransform) -> TokenStream { +pub fn generate(doc: &witx::Document, names: &Names, settings: &CodegenSettings) -> TokenStream { // TODO at some point config should grow more ability to configure name // overrides. let rt = names.runtime_mod(); @@ -49,7 +49,7 @@ pub fn generate(doc: &witx::Document, names: &Names, errs: &ErrorTransform) -> T } }; - let user_error_methods = errs.iter().map(|errtype| { + let user_error_methods = settings.errors.iter().map(|errtype| { let abi_typename = names.type_ref(&errtype.abi_type(), anon_lifetime()); let user_typename = errtype.typename(); let methodname = names.user_error_conversion_method(&errtype); @@ -64,12 +64,10 @@ pub fn generate(doc: &witx::Document, names: &Names, errs: &ErrorTransform) -> T let modname = names.module(&module.name); let fs = module .funcs() - .map(|f| define_func(&names, &module, &f, &errs)); - let modtrait = define_module_trait(&names, &module, &errs); - let ctx_type = names.ctx_type(); + .map(|f| define_func(&names, &module, &f, &settings)); + let modtrait = define_module_trait(&names, &module, &settings); quote!( pub mod #modname { - use super::#ctx_type; use super::types::*; #(#fs)* diff --git a/crates/wiggle/generate/src/module_trait.rs b/crates/wiggle/generate/src/module_trait.rs index 30dae00076..2db27490ab 100644 --- a/crates/wiggle/generate/src/module_trait.rs +++ b/crates/wiggle/generate/src/module_trait.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream; use quote::quote; -use crate::error_transform::ErrorTransform; +use crate::codegen_settings::CodegenSettings; use crate::lifetimes::{anon_lifetime, LifetimeExt}; use crate::names::Names; use witx::Module; @@ -15,8 +15,9 @@ pub fn passed_by_reference(ty: &witx::Type) -> bool { } } -pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) -> TokenStream { +pub fn define_module_trait(names: &Names, m: &Module, settings: &CodegenSettings) -> TokenStream { let traitname = names.trait_name(&m.name); + let rt = names.runtime_mod(); let traitmethods = m.funcs().map(|f| { // Check if we're returning an entity anotated with a lifetime, // in which case, we'll need to annotate the function itself, and @@ -43,7 +44,6 @@ pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) quote!(#arg_name: #arg_type) }); - let rt = names.runtime_mod(); let result = match f.results.len() { 0 if f.noreturn => quote!(#rt::Trap), 0 => quote!(()), @@ -61,7 +61,7 @@ pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) None => quote!(()), }; let err = match err { - Some(ty) => match errxform.for_abi_error(ty) { + Some(ty) => match settings.errors.for_abi_error(ty) { Some(custom) => { let tn = custom.typename(); quote!(super::#tn) @@ -75,13 +75,22 @@ pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) _ => unimplemented!(), }; - if is_anonymous { - quote!(fn #funcname(&self, #(#args),*) -> #result; ) + let asyncness = if settings.is_async(&m, &f) { + quote!(async) } else { - quote!(fn #funcname<#lifetime>(&self, #(#args),*) -> #result;) + quote!() + }; + + if is_anonymous { + quote!(#asyncness fn #funcname(&self, #(#args),*) -> #result; ) + } else { + quote!(#asyncness fn #funcname<#lifetime>(&self, #(#args),*) -> #result;) } }); + quote! { + use #rt::async_trait; + #[async_trait(?Send)] pub trait #traitname { #(#traitmethods)* } diff --git a/crates/wiggle/generate/src/names.rs b/crates/wiggle/generate/src/names.rs index b9eb9e4752..ae3b6bc74b 100644 --- a/crates/wiggle/generate/src/names.rs +++ b/crates/wiggle/generate/src/names.rs @@ -7,20 +7,12 @@ use witx::{BuiltinType, Id, Type, TypeRef, WasmType}; use crate::{lifetimes::LifetimeExt, UserErrorType}; pub struct Names { - ctx_type: Ident, runtime_mod: TokenStream, } impl Names { - pub fn new(ctx_type: &Ident, runtime_mod: TokenStream) -> Names { - Names { - ctx_type: ctx_type.clone(), - runtime_mod, - } - } - - pub fn ctx_type(&self) -> Ident { - self.ctx_type.clone() + pub fn new(runtime_mod: TokenStream) -> Names { + Names { runtime_mod } } pub fn runtime_mod(&self) -> TokenStream { diff --git a/crates/wiggle/macro/src/lib.rs b/crates/wiggle/macro/src/lib.rs index a3c10afb86..e7ea3455f0 100644 --- a/crates/wiggle/macro/src/lib.rs +++ b/crates/wiggle/macro/src/lib.rs @@ -15,9 +15,11 @@ use syn::parse_macro_input; /// Rust-idiomatic snake\_case. /// /// * For each `@interface func` defined in a witx module, an abi-level -/// function is generated which takes ABI-level arguments, along with a -/// "context" struct (whose type is given by the `ctx` field in the -/// macro invocation) and a `GuestMemory` implementation. +/// function is generated which takes ABI-level arguments, along with +/// a ref that impls the module trait, and a `GuestMemory` implementation. +/// Users typically won't use these abi-level functions: The `wasmtime-wiggle` +/// and `lucet-wiggle` crates adapt these to work with a particular WebAssembly +/// engine. /// /// * A public "module trait" is defined (called the module name, in /// SnakeCase) which has a `&self` method for each function in the @@ -27,57 +29,94 @@ use syn::parse_macro_input; /// Arguments are provided using Rust struct value syntax. /// /// * `witx` takes a list of string literal paths. Paths are relative to the -/// CARGO_MANIFEST_DIR of the crate where the macro is invoked. -/// * `ctx` takes a type name. This type must implement all of the module -/// traits +/// CARGO_MANIFEST_DIR of the crate where the macro is invoked. Alternatively, +/// `witx_literal` takes a string containing a complete witx document. +/// * Optional: `errors` takes a mapping of witx identifiers to types, e.g +/// `{ errno => YourErrnoType }`. This allows you to use the `UserErrorConversion` +/// trait to map these rich errors into the flat witx type, or to terminate +/// WebAssembly execution by trapping. +/// * Optional: `async_` takes a set of witx modules and functions which are +/// made Rust `async` functions in the module trait. /// /// ## Example /// /// ``` -/// use wiggle::{GuestPtr, GuestErrorType}; -/// -/// /// The test witx file `arrays.witx` lives in the test directory. For a -/// /// full-fledged example with runtime tests, see `tests/arrays.rs` and -/// /// the rest of the files in that directory. +/// use wiggle::GuestPtr; /// wiggle::from_witx!({ -/// witx: ["../tests/arrays.witx"], -/// ctx: YourCtxType, +/// witx_literal: " +/// (typename $errno +/// (enum (@witx tag u32) +/// $ok +/// $invalid_arg +/// $io +/// $overflow)) +/// (typename $alias_to_float f32) +/// (module $example +/// (@interface func (export \"int_float_args\") +/// (param $an_int u32) +/// (param $some_floats (list f32)) +/// (result $r (expected (error $errno)))) +/// (@interface func (export \"double_int_return_float\") +/// (param $an_int u32) +/// (result $r (expected $alias_to_float (error $errno))))) +/// ", +/// errors: { errno => YourRichError }, +/// async_: { example::double_int_return_float }, /// }); /// -/// /// The `ctx` type for this wiggle invocation. +/// /// Witx generates a set of traits, which the user must impl on a +/// /// type they define. We call this the ctx type. It stores any context +/// /// these functions need to execute. /// pub struct YourCtxType {} /// -/// /// `arrays.witx` contains one module called `arrays`. So, we must -/// /// implement this one method trait for our ctx type: -/// impl arrays::Arrays for YourCtxType { +/// /// Witx provides a hook to translate "rich" (arbitrary Rust type) errors +/// /// into the flat error enums used at the WebAssembly interface. You will +/// /// need to impl the `types::UserErrorConversion` trait to provide a translation +/// /// from this rich type. +/// #[derive(Debug)] +/// pub enum YourRichError { +/// InvalidArg(String), +/// Io(std::io::Error), +/// Overflow, +/// Trap(String), +/// } +/// +/// /// The above witx text contains one module called `$example`. So, we must +/// /// implement this one method trait for our ctx type. +/// #[wiggle::async_trait(?Send)] +/// /// We specified in the `async_` field that `example::double_int_return_float` +/// /// is an asynchronous method. Therefore, we use the `async_trait` proc macro +/// /// (re-exported by wiggle from the crate of the same name) to define this +/// /// trait, so that `double_int_return_float` can be an `async fn`. +/// impl example::Example for YourCtxType { /// /// The arrays module has two methods, shown here. /// /// Note that the `GuestPtr` type comes from `wiggle`, /// /// whereas the witx-defined types like `Excuse` and `Errno` come /// /// from the `pub mod types` emitted by the `wiggle::from_witx!` /// /// invocation above. -/// fn reduce_excuses(&self, _a: &GuestPtr<[GuestPtr]>) -/// -> Result { +/// fn int_float_args(&self, _int: u32, _floats: &GuestPtr<[f32]>) +/// -> Result<(), YourRichError> { /// unimplemented!() /// } -/// fn populate_excuses(&self, _a: &GuestPtr<[GuestPtr]>) -/// -> Result<(), types::Errno> { -/// unimplemented!() +/// async fn double_int_return_float(&self, int: u32) +/// -> Result { +/// Ok(int.checked_mul(2).ok_or(YourRichError::Overflow)? as f32) /// } /// } /// -/// /// For all types used in the `Error` position of a `Result` in the module -/// /// traits, you must implement `GuestErrorType` which tells wiggle-generated +/// /// For all types used in the `error` an `expected` in the witx document, +/// /// you must implement `GuestErrorType` which tells wiggle-generated /// /// code what value to return when the method returns Ok(...). -/// impl GuestErrorType for types::Errno { +/// impl wiggle::GuestErrorType for types::Errno { /// fn success() -> Self { /// unimplemented!() /// } /// } /// /// /// The `types::GuestErrorConversion` trait is also generated with a method for -/// /// each type used in the `Error` position. This trait allows wiggle-generated +/// /// each type used in the `error` position. This trait allows wiggle-generated /// /// code to convert a `wiggle::GuestError` into the right error type. The trait -/// /// must be implemented for the user's `ctx` type. +/// /// must be implemented for the user's ctx type. /// /// impl types::GuestErrorConversion for YourCtxType { /// fn into_errno(&self, _e: wiggle::GuestError) -> types::Errno { @@ -85,6 +124,26 @@ use syn::parse_macro_input; /// } /// } /// +/// /// If you specify a `error` mapping to the macro, you must implement the +/// /// `types::UserErrorConversion` for your ctx type as well. This trait gives +/// /// you an opportunity to store or log your rich error type, while returning +/// /// a basic witx enum to the WebAssembly caller. It also gives you the ability +/// /// to terminate WebAssembly execution by trapping. +/// +/// impl types::UserErrorConversion for YourCtxType { +/// fn errno_from_your_rich_error(&self, e: YourRichError) +/// -> Result +/// { +/// println!("Rich error: {:?}", e); +/// match e { +/// YourRichError::InvalidArg{..} => Ok(types::Errno::InvalidArg), +/// YourRichError::Io{..} => Ok(types::Errno::Io), +/// YourRichError::Overflow => Ok(types::Errno::Overflow), +/// YourRichError::Trap(s) => Err(wiggle::Trap::String(s)), +/// } +/// } +/// } +/// /// # fn main() { println!("this fools doc tests into compiling the above outside a function body") /// # } /// ``` @@ -93,10 +152,11 @@ pub fn from_witx(args: TokenStream) -> TokenStream { let config = parse_macro_input!(args as wiggle_generate::Config); let doc = config.load_document(); - let names = wiggle_generate::Names::new(&config.ctx.name, quote!(wiggle)); + let names = wiggle_generate::Names::new(quote!(wiggle)); - let error_transform = wiggle_generate::ErrorTransform::new(&config.errors, &doc) - .expect("validating error transform"); + let error_transform = + wiggle_generate::CodegenSettings::new(&config.errors, &config.async_, &doc) + .expect("validating codegen settings"); let code = wiggle_generate::generate(&doc, &names, &error_transform); let metadata = if cfg!(feature = "wiggle_metadata") { diff --git a/crates/wiggle/src/lib.rs b/crates/wiggle/src/lib.rs index 24a4ca7c40..76d3f0fcbb 100644 --- a/crates/wiggle/src/lib.rs +++ b/crates/wiggle/src/lib.rs @@ -6,8 +6,10 @@ use std::slice; use std::str; use std::sync::Arc; -pub use bitflags; pub use wiggle_macro::from_witx; +// re-exports so users of wiggle don't need to track the dependency: +pub use async_trait::async_trait; +pub use bitflags; #[cfg(feature = "wiggle_metadata")] pub use witx;