diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index 0d4ada6402..5b09f9dee6 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::{Path, PathBuf}; use proc_macro2::Span; @@ -12,12 +13,14 @@ use syn::{ pub struct Config { pub witx: WitxConf, pub ctx: CtxConf, + pub errors: ErrorConf, } #[derive(Debug, Clone)] pub enum ConfigField { Witx(WitxConf), Ctx(CtxConf), + Error(ErrorConf), } impl ConfigField { @@ -25,7 +28,8 @@ impl ConfigField { match ident { "witx" => Ok(ConfigField::Witx(value.parse()?)), "ctx" => Ok(ConfigField::Ctx(value.parse()?)), - _ => Err(Error::new(err_loc, "expected `witx` or `ctx`")), + "errors" => Ok(ConfigField::Error(value.parse()?)), + _ => Err(Error::new(err_loc, "expected `witx`, `ctx`, or `errors`")), } } } @@ -42,14 +46,27 @@ impl Config { pub fn build(fields: impl Iterator, err_loc: Span) -> Result { let mut witx = None; let mut ctx = None; + let mut errors = None; for f in fields { match f { ConfigField::Witx(c) => { + if witx.is_some() { + return Err(Error::new(err_loc, "duplicate `witx` field")); + } witx = Some(c); } ConfigField::Ctx(c) => { + if ctx.is_some() { + return Err(Error::new(err_loc, "duplicate `ctx` field")); + } ctx = Some(c); } + ConfigField::Error(c) => { + if errors.is_some() { + return Err(Error::new(err_loc, "duplicate `errors` field")); + } + errors = Some(c); + } } } Ok(Config { @@ -59,6 +76,7 @@ impl Config { ctx: ctx .take() .ok_or_else(|| Error::new(err_loc, "`ctx` field required"))?, + errors: errors.take().unwrap_or_default(), }) } } @@ -113,3 +131,69 @@ impl Parse for CtxConf { }) } } + +#[derive(Clone, Default, Debug)] +/// Map from abi error type to rich error type +pub struct ErrorConf(HashMap); + +impl ErrorConf { + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } +} + +impl Parse for ErrorConf { + fn parse(input: ParseStream) -> Result { + let content; + let _ = braced!(content in input); + let items: Punctuated = + content.parse_terminated(Parse::parse)?; + let mut m = HashMap::new(); + for i in items { + match m.insert(i.abi_error.clone(), i.clone()) { + None => {} + Some(prev_def) => { + return Err(Error::new( + i.err_loc, + format!( + "duplicate definition of rich error type for {:?}: previously defined at {:?}", + i.abi_error, prev_def.err_loc, + ), + )) + } + } + } + Ok(ErrorConf(m)) + } +} + +#[derive(Clone)] +pub struct ErrorConfField { + pub abi_error: Ident, + pub rich_error: syn::Path, + pub err_loc: Span, +} + +impl std::fmt::Debug for ErrorConfField { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ErrorConfField") + .field("abi_error", &self.abi_error) + .field("rich_error", &"(...)") + .field("err_loc", &self.err_loc) + .finish() + } +} + +impl Parse for ErrorConfField { + fn parse(input: ParseStream) -> Result { + let err_loc = input.span(); + let abi_error = input.parse::()?; + let _arrow: Token![=>] = input.parse()?; + let rich_error = input.parse::()?; + Ok(ErrorConfField { + abi_error, + rich_error, + err_loc, + }) + } +} diff --git a/crates/wiggle/generate/src/error_transform.rs b/crates/wiggle/generate/src/error_transform.rs new file mode 100644 index 0000000000..80cb56f925 --- /dev/null +++ b/crates/wiggle/generate/src/error_transform.rs @@ -0,0 +1,72 @@ +use crate::config::ErrorConf; +use anyhow::{anyhow, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use std::collections::HashMap; +use std::rc::Rc; +use witx::{Document, Id, NamedType, TypeRef}; + +pub struct ErrorTransform { + m: Vec, +} + +impl ErrorTransform { + pub fn new(conf: &ErrorConf, doc: &Document) -> Result { + let mut richtype_identifiers = HashMap::new(); + let m = conf.iter().map(|(ident, field)| + if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) { + if let Some(ident) = field.rich_error.get_ident() { + if let Some(prior_def) = richtype_identifiers.insert(ident.clone(), field.err_loc.clone()) + { + return Err(anyhow!( + "duplicate rich type identifier of {:?} not allowed. prior definition at {:?}", + ident, prior_def + )); + } + Ok(UserErrorType { + abi_type, + rich_type: field.rich_error.clone(), + method_fragment: ident.to_string() + }) + } else { + return Err(anyhow!( + "rich error type must be identifier for now - TODO add ability to provide a corresponding identifier: {:?}", + field.err_loc + )) + } + } + else { Err(anyhow!("No witx typename \"{}\" found", ident.to_string())) } + ).collect::, Error>>()?; + Ok(Self { m }) + } + + pub fn iter(&self) -> impl Iterator { + self.m.iter() + } + + pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&UserErrorType> { + match tref { + TypeRef::Name(nt) => self.m.iter().find(|u| u.abi_type.name == nt.name), + TypeRef::Value { .. } => None, + } + } +} + +pub struct UserErrorType { + abi_type: Rc, + rich_type: syn::Path, + method_fragment: String, +} + +impl UserErrorType { + pub fn abi_type(&self) -> TypeRef { + TypeRef::Name(self.abi_type.clone()) + } + pub fn typename(&self) -> TokenStream { + let t = &self.rich_type; + quote!(#t) + } + pub fn method_fragment(&self) -> &str { + &self.method_fragment + } +} diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index 690ced5512..2c8cf2b350 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -1,6 +1,7 @@ use proc_macro2::TokenStream; use quote::quote; +use crate::error_transform::ErrorTransform; use crate::lifetimes::anon_lifetime; use crate::module_trait::passed_by_reference; use crate::names::Names; @@ -9,6 +10,7 @@ pub fn define_func( names: &Names, func: &witx::InterfaceFunc, trait_name: TokenStream, + errxform: &ErrorTransform, ) -> TokenStream { let funcname = func.name.as_str(); @@ -43,16 +45,23 @@ pub fn define_func( quote!(()) }; - let err_type = coretype.ret.map(|ret| ret.param.tref); - let ret_err = err_type - .clone() - .map(|_res| { + let err_type = coretype.ret.clone().map(|ret| ret.param.tref); + let ret_err = coretype + .ret + .map(|ret| { + let name = ret.param.name.as_str(); + let conversion = if let Some(user_err) = errxform.for_abi_error(&ret.param.tref) { + let method = names.user_error_conversion_method(&user_err); + quote!(#abi_ret::from(UserErrorConversion::#method(ctx, e))) + } else { + quote!(#abi_ret::from(e)) + }; quote! { #[cfg(feature = "trace_log")] { - log::trace!(" | errno={}", e); + log::trace!(" | {}={:?}", #name, e); } - return #abi_ret::from(e); + return #conversion; } }) .unwrap_or_else(|| quote!(())); diff --git a/crates/wiggle/generate/src/lib.rs b/crates/wiggle/generate/src/lib.rs index 699d33e216..bd6d5d9e40 100644 --- a/crates/wiggle/generate/src/lib.rs +++ b/crates/wiggle/generate/src/lib.rs @@ -1,4 +1,5 @@ pub mod config; +mod error_transform; mod funcs; mod lifetimes; mod module_trait; @@ -11,12 +12,13 @@ use quote::quote; use lifetimes::anon_lifetime; pub use config::Config; +pub use error_transform::{ErrorTransform, UserErrorType}; pub use funcs::define_func; pub use module_trait::define_module_trait; pub use names::Names; pub use types::define_datatype; -pub fn generate(doc: &witx::Document, names: &Names) -> TokenStream { +pub fn generate(doc: &witx::Document, names: &Names, errs: &ErrorTransform) -> TokenStream { // TODO at some point config should grow more ability to configure name // overrides. let rt = names.runtime_mod(); @@ -34,13 +36,24 @@ pub fn generate(doc: &witx::Document, names: &Names) -> TokenStream { } }; + let user_error_methods = errs.iter().map(|errtype| { + let abi_typename = names.type_ref(&errtype.abi_type(), anon_lifetime()); + let user_typename = errtype.typename(); + let methodname = names.user_error_conversion_method(&errtype); + quote!(fn #methodname(&self, e: super::#user_typename) -> #abi_typename;) + }); + let user_error_conversion = quote! { + pub trait UserErrorConversion { + #(#user_error_methods)* + } + }; let modules = doc.modules().map(|module| { let modname = names.module(&module.name); let trait_name = names.trait_name(&module.name); let fs = module .funcs() - .map(|f| define_func(&names, &f, quote!(#trait_name))); - let modtrait = define_module_trait(&names, &module); + .map(|f| define_func(&names, &f, quote!(#trait_name), &errs)); + let modtrait = define_module_trait(&names, &module, &errs); let ctx_type = names.ctx_type(); quote!( pub mod #modname { @@ -57,6 +70,7 @@ pub fn generate(doc: &witx::Document, names: &Names) -> TokenStream { pub mod types { #(#types)* #guest_error_conversion + #user_error_conversion } #(#modules)* ) diff --git a/crates/wiggle/generate/src/module_trait.rs b/crates/wiggle/generate/src/module_trait.rs index b09f237028..c73cd052bb 100644 --- a/crates/wiggle/generate/src/module_trait.rs +++ b/crates/wiggle/generate/src/module_trait.rs @@ -1,6 +1,7 @@ use proc_macro2::TokenStream; use quote::quote; +use crate::error_transform::ErrorTransform; use crate::lifetimes::{anon_lifetime, LifetimeExt}; use crate::names::Names; use witx::Module; @@ -20,7 +21,7 @@ pub fn passed_by_reference(ty: &witx::Type) -> bool { } } -pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream { +pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) -> TokenStream { let traitname = names.trait_name(&m.name); let traitmethods = m.funcs().map(|f| { // Check if we're returning an entity anotated with a lifetime, @@ -55,7 +56,14 @@ pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream { let err = f .results .get(0) - .map(|err_result| names.type_ref(&err_result.tref, lifetime.clone())) + .map(|err_result| { + if let Some(custom_err) = errxform.for_abi_error(&err_result.tref) { + let tn = custom_err.typename(); + quote!(super::#tn) + } else { + names.type_ref(&err_result.tref, lifetime.clone()) + } + }) .unwrap_or(quote!(())); if is_anonymous { diff --git a/crates/wiggle/generate/src/names.rs b/crates/wiggle/generate/src/names.rs index 644b877d8c..47ab1312a6 100644 --- a/crates/wiggle/generate/src/names.rs +++ b/crates/wiggle/generate/src/names.rs @@ -3,7 +3,7 @@ use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use witx::{AtomType, BuiltinType, Id, Type, TypeRef}; -use crate::lifetimes::LifetimeExt; +use crate::{lifetimes::LifetimeExt, UserErrorType}; pub struct Names { ctx_type: Ident, @@ -150,27 +150,45 @@ impl Names { format_ident!("{}_len", id.as_str().to_snake_case()) } - pub fn guest_error_conversion_method(&self, tref: &TypeRef) -> Ident { + fn builtin_name(b: &BuiltinType) -> &'static str { + match b { + BuiltinType::String => "string", + BuiltinType::U8 => "u8", + BuiltinType::U16 => "u16", + BuiltinType::U32 => "u32", + BuiltinType::U64 => "u64", + BuiltinType::S8 => "i8", + BuiltinType::S16 => "i16", + BuiltinType::S32 => "i32", + BuiltinType::S64 => "i64", + BuiltinType::F32 => "f32", + BuiltinType::F64 => "f64", + BuiltinType::Char8 => "char8", + BuiltinType::USize => "usize", + } + } + + fn snake_typename(tref: &TypeRef) -> String { match tref { - TypeRef::Name(nt) => format_ident!("into_{}", nt.name.as_str().to_snake_case()), + TypeRef::Name(nt) => nt.name.as_str().to_snake_case(), TypeRef::Value(ty) => match &**ty { - Type::Builtin(b) => match b { - BuiltinType::String => unreachable!("error type must be atom"), - BuiltinType::U8 => format_ident!("into_u8"), - BuiltinType::U16 => format_ident!("into_u16"), - BuiltinType::U32 => format_ident!("into_u32"), - BuiltinType::U64 => format_ident!("into_u64"), - BuiltinType::S8 => format_ident!("into_i8"), - BuiltinType::S16 => format_ident!("into_i16"), - BuiltinType::S32 => format_ident!("into_i32"), - BuiltinType::S64 => format_ident!("into_i64"), - BuiltinType::F32 => format_ident!("into_f32"), - BuiltinType::F64 => format_ident!("into_f64"), - BuiltinType::Char8 => format_ident!("into_char8"), - BuiltinType::USize => format_ident!("into_usize"), - }, - _ => panic!("unexpected anonymous error type: {:?}", ty), + Type::Builtin(b) => Self::builtin_name(&b).to_owned(), + _ => panic!("unexpected anonymous type: {:?}", ty), }, } } + + pub fn guest_error_conversion_method(&self, tref: &TypeRef) -> Ident { + let suffix = Self::snake_typename(tref); + format_ident!("into_{}", suffix) + } + + pub fn user_error_conversion_method(&self, user_type: &UserErrorType) -> Ident { + let abi_type = Self::snake_typename(&user_type.abi_type()); + format_ident!( + "{}_from_{}", + abi_type, + user_type.method_fragment().to_snake_case() + ) + } } diff --git a/crates/wiggle/macro/src/lib.rs b/crates/wiggle/macro/src/lib.rs index 8fd6f680a7..214d14dd72 100644 --- a/crates/wiggle/macro/src/lib.rs +++ b/crates/wiggle/macro/src/lib.rs @@ -98,7 +98,10 @@ pub fn from_witx(args: TokenStream) -> TokenStream { let doc = witx::load(&config.witx.paths).expect("loading witx"); let names = wiggle_generate::Names::new(&config.ctx.name, quote!(wiggle)); - let code = wiggle_generate::generate(&doc, &names); + let error_transform = wiggle_generate::ErrorTransform::new(&config.errors, &doc) + .expect("validating error transform"); + + let code = wiggle_generate::generate(&doc, &names, &error_transform); let metadata = if cfg!(feature = "wiggle_metadata") { wiggle_generate::generate_metadata(&doc, &names) } else { diff --git a/crates/wiggle/tests/errors.rs b/crates/wiggle/tests/errors.rs new file mode 100644 index 0000000000..a63199ddb9 --- /dev/null +++ b/crates/wiggle/tests/errors.rs @@ -0,0 +1,35 @@ +use wiggle_test::{impl_errno, WasiCtx}; + +#[derive(Debug, thiserror::Error)] +pub enum RichError { + #[error("Invalid argument: {0}")] + InvalidArg(String), + #[error("Won't cross picket line: {0}")] + PicketLine(String), +} + +wiggle::from_witx!({ + witx: ["tests/arrays.witx"], + ctx: WasiCtx, + errors: { errno => RichError }, +}); + +impl_errno!(types::Errno, types::GuestErrorConversion); + +impl<'a> types::UserErrorConversion for WasiCtx<'a> { + fn errno_from_rich_error(&self, _e: RichError) -> types::Errno { + unimplemented!(); + } +} + +impl<'a> arrays::Arrays for WasiCtx<'a> { + fn reduce_excuses( + &self, + _excuses: &types::ConstExcuseArray, + ) -> Result { + unimplemented!() + } + fn populate_excuses(&self, _excuses: &types::ExcuseArray) -> Result<(), RichError> { + unimplemented!() + } +}