diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index 5f5fcb1472..6eecf28c96 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -1,6 +1,7 @@ use { proc_macro2::Span, std::{ + collections::HashMap, iter::FromIterator, path::{Path, PathBuf}, }, @@ -16,12 +17,14 @@ use { pub struct Config { pub witx: WitxConf, pub ctx: CtxConf, + pub errors: ErrorConf, } #[derive(Debug, Clone)] pub enum ConfigField { Witx(WitxConf), Ctx(CtxConf), + Error(ErrorConf), } impl ConfigField { @@ -30,7 +33,8 @@ impl ConfigField { "witx" => Ok(ConfigField::Witx(WitxConf::Paths(value.parse()?))), "witx_literal" => Ok(ConfigField::Witx(WitxConf::Literal(value.parse()?))), "ctx" => Ok(ConfigField::Ctx(value.parse()?)), - _ => Err(Error::new(err_loc, "expected `witx` or `ctx`")), + "errors" => Ok(ConfigField::Error(value.parse()?)), + _ => Err(Error::new(err_loc, "expected `witx`, `ctx`, or `errors`")), } } } @@ -47,14 +51,27 @@ impl Config { pub fn build(fields: impl Iterator, err_loc: Span) -> Result { let mut witx = None; let mut ctx = None; + let mut errors = None; for f in fields { match f { ConfigField::Witx(c) => { + if witx.is_some() { + return Err(Error::new(err_loc, "duplicate `witx` field")); + } witx = Some(c); } ConfigField::Ctx(c) => { + if ctx.is_some() { + return Err(Error::new(err_loc, "duplicate `ctx` field")); + } ctx = Some(c); } + ConfigField::Error(c) => { + if errors.is_some() { + return Err(Error::new(err_loc, "duplicate `errors` field")); + } + errors = Some(c); + } } } Ok(Config { @@ -64,6 +81,7 @@ impl Config { ctx: ctx .take() .ok_or_else(|| Error::new(err_loc, "`ctx` field required"))?, + errors: errors.take().unwrap_or_default(), }) } @@ -206,3 +224,69 @@ impl Parse for CtxConf { }) } } + +#[derive(Clone, Default, Debug)] +/// Map from abi error type to rich error type +pub struct ErrorConf(HashMap); + +impl ErrorConf { + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } +} + +impl Parse for ErrorConf { + fn parse(input: ParseStream) -> Result { + let content; + let _ = braced!(content in input); + let items: Punctuated = + content.parse_terminated(Parse::parse)?; + let mut m = HashMap::new(); + for i in items { + match m.insert(i.abi_error.clone(), i.clone()) { + None => {} + Some(prev_def) => { + return Err(Error::new( + i.err_loc, + format!( + "duplicate definition of rich error type for {:?}: previously defined at {:?}", + i.abi_error, prev_def.err_loc, + ), + )) + } + } + } + Ok(ErrorConf(m)) + } +} + +#[derive(Clone)] +pub struct ErrorConfField { + pub abi_error: Ident, + pub rich_error: syn::Path, + pub err_loc: Span, +} + +impl std::fmt::Debug for ErrorConfField { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ErrorConfField") + .field("abi_error", &self.abi_error) + .field("rich_error", &"(...)") + .field("err_loc", &self.err_loc) + .finish() + } +} + +impl Parse for ErrorConfField { + fn parse(input: ParseStream) -> Result { + let err_loc = input.span(); + let abi_error = input.parse::()?; + let _arrow: Token![=>] = input.parse()?; + let rich_error = input.parse::()?; + Ok(ErrorConfField { + abi_error, + rich_error, + err_loc, + }) + } +} diff --git a/crates/wiggle/generate/src/error_transform.rs b/crates/wiggle/generate/src/error_transform.rs new file mode 100644 index 0000000000..80cb56f925 --- /dev/null +++ b/crates/wiggle/generate/src/error_transform.rs @@ -0,0 +1,72 @@ +use crate::config::ErrorConf; +use anyhow::{anyhow, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use std::collections::HashMap; +use std::rc::Rc; +use witx::{Document, Id, NamedType, TypeRef}; + +pub struct ErrorTransform { + m: Vec, +} + +impl ErrorTransform { + pub fn new(conf: &ErrorConf, doc: &Document) -> Result { + let mut richtype_identifiers = HashMap::new(); + let m = conf.iter().map(|(ident, field)| + if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) { + if let Some(ident) = field.rich_error.get_ident() { + if let Some(prior_def) = richtype_identifiers.insert(ident.clone(), field.err_loc.clone()) + { + return Err(anyhow!( + "duplicate rich type identifier of {:?} not allowed. prior definition at {:?}", + ident, prior_def + )); + } + Ok(UserErrorType { + abi_type, + rich_type: field.rich_error.clone(), + method_fragment: ident.to_string() + }) + } else { + return Err(anyhow!( + "rich error type must be identifier for now - TODO add ability to provide a corresponding identifier: {:?}", + field.err_loc + )) + } + } + else { Err(anyhow!("No witx typename \"{}\" found", ident.to_string())) } + ).collect::, Error>>()?; + Ok(Self { m }) + } + + pub fn iter(&self) -> impl Iterator { + self.m.iter() + } + + pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&UserErrorType> { + match tref { + TypeRef::Name(nt) => self.m.iter().find(|u| u.abi_type.name == nt.name), + TypeRef::Value { .. } => None, + } + } +} + +pub struct UserErrorType { + abi_type: Rc, + rich_type: syn::Path, + method_fragment: String, +} + +impl UserErrorType { + pub fn abi_type(&self) -> TypeRef { + TypeRef::Name(self.abi_type.clone()) + } + pub fn typename(&self) -> TokenStream { + let t = &self.rich_type; + quote!(#t) + } + pub fn method_fragment(&self) -> &str { + &self.method_fragment + } +} diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index 690ced5512..2c8cf2b350 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -1,6 +1,7 @@ use proc_macro2::TokenStream; use quote::quote; +use crate::error_transform::ErrorTransform; use crate::lifetimes::anon_lifetime; use crate::module_trait::passed_by_reference; use crate::names::Names; @@ -9,6 +10,7 @@ pub fn define_func( names: &Names, func: &witx::InterfaceFunc, trait_name: TokenStream, + errxform: &ErrorTransform, ) -> TokenStream { let funcname = func.name.as_str(); @@ -43,16 +45,23 @@ pub fn define_func( quote!(()) }; - let err_type = coretype.ret.map(|ret| ret.param.tref); - let ret_err = err_type - .clone() - .map(|_res| { + let err_type = coretype.ret.clone().map(|ret| ret.param.tref); + let ret_err = coretype + .ret + .map(|ret| { + let name = ret.param.name.as_str(); + let conversion = if let Some(user_err) = errxform.for_abi_error(&ret.param.tref) { + let method = names.user_error_conversion_method(&user_err); + quote!(#abi_ret::from(UserErrorConversion::#method(ctx, e))) + } else { + quote!(#abi_ret::from(e)) + }; quote! { #[cfg(feature = "trace_log")] { - log::trace!(" | errno={}", e); + log::trace!(" | {}={:?}", #name, e); } - return #abi_ret::from(e); + return #conversion; } }) .unwrap_or_else(|| quote!(())); diff --git a/crates/wiggle/generate/src/lib.rs b/crates/wiggle/generate/src/lib.rs index 699d33e216..bd6d5d9e40 100644 --- a/crates/wiggle/generate/src/lib.rs +++ b/crates/wiggle/generate/src/lib.rs @@ -1,4 +1,5 @@ pub mod config; +mod error_transform; mod funcs; mod lifetimes; mod module_trait; @@ -11,12 +12,13 @@ use quote::quote; use lifetimes::anon_lifetime; pub use config::Config; +pub use error_transform::{ErrorTransform, UserErrorType}; pub use funcs::define_func; pub use module_trait::define_module_trait; pub use names::Names; pub use types::define_datatype; -pub fn generate(doc: &witx::Document, names: &Names) -> TokenStream { +pub fn generate(doc: &witx::Document, names: &Names, errs: &ErrorTransform) -> TokenStream { // TODO at some point config should grow more ability to configure name // overrides. let rt = names.runtime_mod(); @@ -34,13 +36,24 @@ pub fn generate(doc: &witx::Document, names: &Names) -> TokenStream { } }; + let user_error_methods = errs.iter().map(|errtype| { + let abi_typename = names.type_ref(&errtype.abi_type(), anon_lifetime()); + let user_typename = errtype.typename(); + let methodname = names.user_error_conversion_method(&errtype); + quote!(fn #methodname(&self, e: super::#user_typename) -> #abi_typename;) + }); + let user_error_conversion = quote! { + pub trait UserErrorConversion { + #(#user_error_methods)* + } + }; let modules = doc.modules().map(|module| { let modname = names.module(&module.name); let trait_name = names.trait_name(&module.name); let fs = module .funcs() - .map(|f| define_func(&names, &f, quote!(#trait_name))); - let modtrait = define_module_trait(&names, &module); + .map(|f| define_func(&names, &f, quote!(#trait_name), &errs)); + let modtrait = define_module_trait(&names, &module, &errs); let ctx_type = names.ctx_type(); quote!( pub mod #modname { @@ -57,6 +70,7 @@ pub fn generate(doc: &witx::Document, names: &Names) -> TokenStream { pub mod types { #(#types)* #guest_error_conversion + #user_error_conversion } #(#modules)* ) diff --git a/crates/wiggle/generate/src/module_trait.rs b/crates/wiggle/generate/src/module_trait.rs index b09f237028..c73cd052bb 100644 --- a/crates/wiggle/generate/src/module_trait.rs +++ b/crates/wiggle/generate/src/module_trait.rs @@ -1,6 +1,7 @@ use proc_macro2::TokenStream; use quote::quote; +use crate::error_transform::ErrorTransform; use crate::lifetimes::{anon_lifetime, LifetimeExt}; use crate::names::Names; use witx::Module; @@ -20,7 +21,7 @@ pub fn passed_by_reference(ty: &witx::Type) -> bool { } } -pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream { +pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) -> TokenStream { let traitname = names.trait_name(&m.name); let traitmethods = m.funcs().map(|f| { // Check if we're returning an entity anotated with a lifetime, @@ -55,7 +56,14 @@ pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream { let err = f .results .get(0) - .map(|err_result| names.type_ref(&err_result.tref, lifetime.clone())) + .map(|err_result| { + if let Some(custom_err) = errxform.for_abi_error(&err_result.tref) { + let tn = custom_err.typename(); + quote!(super::#tn) + } else { + names.type_ref(&err_result.tref, lifetime.clone()) + } + }) .unwrap_or(quote!(())); if is_anonymous { diff --git a/crates/wiggle/generate/src/names.rs b/crates/wiggle/generate/src/names.rs index 84cbf5df70..f6266ec41f 100644 --- a/crates/wiggle/generate/src/names.rs +++ b/crates/wiggle/generate/src/names.rs @@ -1,10 +1,11 @@ -use crate::lifetimes::LifetimeExt; use escaping::{escape_id, handle_2big_enum_variant, NamingConvention}; use heck::{ShoutySnakeCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use witx::{AtomType, BuiltinType, Id, Type, TypeRef}; +use crate::{lifetimes::LifetimeExt, UserErrorType}; + pub struct Names { ctx_type: Ident, runtime_mod: TokenStream, @@ -161,29 +162,47 @@ impl Names { format_ident!("{}_len", id.as_str().to_snake_case()) } - pub fn guest_error_conversion_method(&self, tref: &TypeRef) -> Ident { + fn builtin_name(b: &BuiltinType) -> &'static str { + match b { + BuiltinType::String => "string", + BuiltinType::U8 => "u8", + BuiltinType::U16 => "u16", + BuiltinType::U32 => "u32", + BuiltinType::U64 => "u64", + BuiltinType::S8 => "i8", + BuiltinType::S16 => "i16", + BuiltinType::S32 => "i32", + BuiltinType::S64 => "i64", + BuiltinType::F32 => "f32", + BuiltinType::F64 => "f64", + BuiltinType::Char8 => "char8", + BuiltinType::USize => "usize", + } + } + + fn snake_typename(tref: &TypeRef) -> String { match tref { - TypeRef::Name(nt) => format_ident!("into_{}", nt.name.as_str().to_snake_case()), + TypeRef::Name(nt) => nt.name.as_str().to_snake_case(), TypeRef::Value(ty) => match &**ty { - Type::Builtin(b) => match b { - BuiltinType::String => unreachable!("error type must be atom"), - BuiltinType::U8 => format_ident!("into_u8"), - BuiltinType::U16 => format_ident!("into_u16"), - BuiltinType::U32 => format_ident!("into_u32"), - BuiltinType::U64 => format_ident!("into_u64"), - BuiltinType::S8 => format_ident!("into_i8"), - BuiltinType::S16 => format_ident!("into_i16"), - BuiltinType::S32 => format_ident!("into_i32"), - BuiltinType::S64 => format_ident!("into_i64"), - BuiltinType::F32 => format_ident!("into_f32"), - BuiltinType::F64 => format_ident!("into_f64"), - BuiltinType::Char8 => format_ident!("into_char8"), - BuiltinType::USize => format_ident!("into_usize"), - }, - _ => panic!("unexpected anonymous error type: {:?}", ty), + Type::Builtin(b) => Self::builtin_name(&b).to_owned(), + _ => panic!("unexpected anonymous type: {:?}", ty), }, } } + + pub fn guest_error_conversion_method(&self, tref: &TypeRef) -> Ident { + let suffix = Self::snake_typename(tref); + format_ident!("into_{}", suffix) + } + + pub fn user_error_conversion_method(&self, user_type: &UserErrorType) -> Ident { + let abi_type = Self::snake_typename(&user_type.abi_type()); + format_ident!( + "{}_from_{}", + abi_type, + user_type.method_fragment().to_snake_case() + ) + } } /// Identifier escaping utilities. diff --git a/crates/wiggle/macro/src/lib.rs b/crates/wiggle/macro/src/lib.rs index 7587a42451..ed4a3dc6f0 100644 --- a/crates/wiggle/macro/src/lib.rs +++ b/crates/wiggle/macro/src/lib.rs @@ -98,7 +98,10 @@ pub fn from_witx(args: TokenStream) -> TokenStream { let doc = config.load_document(); let names = wiggle_generate::Names::new(&config.ctx.name, quote!(wiggle)); - let code = wiggle_generate::generate(&doc, &names); + let error_transform = wiggle_generate::ErrorTransform::new(&config.errors, &doc) + .expect("validating error transform"); + + let code = wiggle_generate::generate(&doc, &names, &error_transform); let metadata = if cfg!(feature = "wiggle_metadata") { wiggle_generate::generate_metadata(&doc, &names) } else { diff --git a/crates/wiggle/test-helpers/src/lib.rs b/crates/wiggle/test-helpers/src/lib.rs index 40d468a435..5d03d27088 100644 --- a/crates/wiggle/test-helpers/src/lib.rs +++ b/crates/wiggle/test-helpers/src/lib.rs @@ -306,6 +306,7 @@ use wiggle::GuestError; // on the test as well. pub struct WasiCtx<'a> { pub guest_errors: RefCell>, + pub log: RefCell>, lifetime: marker::PhantomData<&'a ()>, } @@ -313,6 +314,7 @@ impl<'a> WasiCtx<'a> { pub fn new() -> Self { Self { guest_errors: RefCell::new(vec![]), + log: RefCell::new(vec![]), lifetime: marker::PhantomData, } } @@ -333,6 +335,7 @@ macro_rules! impl_errno { impl<'a> $convert for WasiCtx<'a> { fn into_errno(&self, e: wiggle::GuestError) -> $errno { eprintln!("GuestError: {:?}", e); + self.guest_errors.borrow_mut().push(e); <$errno>::InvalidArg } } diff --git a/crates/wiggle/tests/errors.rs b/crates/wiggle/tests/errors.rs new file mode 100644 index 0000000000..522982c845 --- /dev/null +++ b/crates/wiggle/tests/errors.rs @@ -0,0 +1,179 @@ +/// Execute the wiggle guest conversion code to exercise it +mod convert_just_errno { + use wiggle_test::{impl_errno, HostMemory, WasiCtx}; + + /// The `errors` argument to the wiggle gives us a hook to map a rich error + /// type like this one (typical of wiggle use cases in wasi-common and beyond) + /// down to the flat error enums that witx can specify. + #[derive(Debug, thiserror::Error)] + pub enum RichError { + #[error("Invalid argument: {0}")] + InvalidArg(String), + #[error("Won't cross picket line: {0}")] + PicketLine(String), + } + + // Define an errno with variants corresponding to RichError. Use it in a + // trivial function. + wiggle::from_witx!({ + witx_literal: " +(typename $errno (enum u8 $ok $invalid_arg $picket_line)) +(module $one_error_conversion + (@interface func (export \"foo\") + (param $strike u32) + (result $err $errno))) + ", + ctx: WasiCtx, + errors: { errno => RichError }, + }); + + // The impl of GuestErrorConversion works just like in every other test where + // we have a single error type with witx `$errno` with the success called `$ok` + impl_errno!(types::Errno, types::GuestErrorConversion); + + /// When the `errors` mapping in witx is non-empty, we need to impl the + /// types::UserErrorConversion trait that wiggle generates from that mapping. + impl<'a> types::UserErrorConversion for WasiCtx<'a> { + fn errno_from_rich_error(&self, e: RichError) -> types::Errno { + // WasiCtx can collect a Vec log so we can test this. We're + // logging the Display impl that `thiserror::Error` provides us. + self.log.borrow_mut().push(e.to_string()); + // Then do the trivial mapping down to the flat enum. + match e { + RichError::InvalidArg { .. } => types::Errno::InvalidArg, + RichError::PicketLine { .. } => types::Errno::PicketLine, + } + } + } + + impl<'a> one_error_conversion::OneErrorConversion for WasiCtx<'a> { + fn foo(&self, strike: u32) -> Result<(), RichError> { + // We use the argument to this function to exercise all of the + // possible error cases we could hit here + match strike { + 0 => Ok(()), + 1 => Err(RichError::PicketLine(format!("I'm not a scab"))), + _ => Err(RichError::InvalidArg(format!("out-of-bounds: {}", strike))), + } + } + } + + #[test] + fn one_error_conversion_test() { + let ctx = WasiCtx::new(); + let host_memory = HostMemory::new(); + + // Exercise each of the branches in `foo`. + // Start with the success case: + let r0 = one_error_conversion::foo(&ctx, &host_memory, 0); + assert_eq!( + r0, + i32::from(types::Errno::Ok), + "Expected return value for strike=0" + ); + assert!(ctx.log.borrow().is_empty(), "No error log for strike=0"); + + // First error case: + let r1 = one_error_conversion::foo(&ctx, &host_memory, 1); + assert_eq!( + r1, + i32::from(types::Errno::PicketLine), + "Expected return value for strike=1" + ); + assert_eq!( + ctx.log.borrow_mut().pop().expect("one log entry"), + "Won't cross picket line: I'm not a scab", + "Expected log entry for strike=1", + ); + + // Second error case: + let r2 = one_error_conversion::foo(&ctx, &host_memory, 2); + assert_eq!( + r2, + i32::from(types::Errno::InvalidArg), + "Expected return value for strike=2" + ); + assert_eq!( + ctx.log.borrow_mut().pop().expect("one log entry"), + "Invalid argument: out-of-bounds: 2", + "Expected log entry for strike=2", + ); + } +} + +/// Type-check the wiggle guest conversion code against a more complex case where +/// we use two distinct error types. +mod convert_multiple_error_types { + pub use super::convert_just_errno::RichError; + use wiggle_test::WasiCtx; + + /// Test that we can map multiple types of errors. + #[derive(Debug, thiserror::Error)] + #[allow(dead_code)] + pub enum AnotherRichError { + #[error("I've had this many cups of coffee and can't even think straight: {0}")] + TooMuchCoffee(usize), + } + + // Just like the other error, except that we have a second errno type: + // trivial function. + wiggle::from_witx!({ + witx_literal: " +(typename $errno (enum u8 $ok $invalid_arg $picket_line)) +(typename $errno2 (enum u8 $ok $too_much_coffee)) +(module $two_error_conversions + (@interface func (export \"foo\") + (param $strike u32) + (result $err $errno)) + (@interface func (export \"bar\") + (param $drink u32) + (result $err $errno2))) + ", + ctx: WasiCtx, + errors: { errno => RichError, errno2 => AnotherRichError }, + }); + + // Can't use the impl_errno! macro as usual here because the conversion + // trait ends up having two methods. + // We aren't going to execute this code, so the bodies are elided. + impl<'a> types::GuestErrorConversion for WasiCtx<'a> { + fn into_errno(&self, _e: wiggle::GuestError) -> types::Errno { + unimplemented!() + } + fn into_errno2(&self, _e: wiggle::GuestError) -> types::Errno2 { + unimplemented!() + } + } + impl wiggle::GuestErrorType for types::Errno { + fn success() -> types::Errno { + ::Ok + } + } + impl wiggle::GuestErrorType for types::Errno2 { + fn success() -> types::Errno2 { + ::Ok + } + } + + // The UserErrorConversion trait will also have two methods for this test. They correspond to + // each member of the `errors` mapping. + // Bodies elided. + impl<'a> types::UserErrorConversion for WasiCtx<'a> { + fn errno_from_rich_error(&self, _e: RichError) -> types::Errno { + unimplemented!() + } + fn errno2_from_another_rich_error(&self, _e: AnotherRichError) -> types::Errno2 { + unimplemented!() + } + } + + // And here's the witx module trait impl, bodies elided + impl<'a> two_error_conversions::TwoErrorConversions for WasiCtx<'a> { + fn foo(&self, _: u32) -> Result<(), RichError> { + unimplemented!() + } + fn bar(&self, _: u32) -> Result<(), AnotherRichError> { + unimplemented!() + } + } +}