Wasmtime component bindgen: opt-in trappable error types (#5397)

* wip

* start trying to write a runtime test

* cut out all the more complex test cases until i get this one working

* add macro parsing for the trappable error type config

* runtime result tests works for an empty and a string error type

* debugging: macro is broken because interfaces dont have names???

* thats how you name interfaces

* record error and variant error work

* show a concrete trap type, remove debug

* delete clap annotations from wit-bindgen crate

these are not used - clap isnt even an optional dep here - but were a holdover from the old home
This commit is contained in:
Pat Hickey
2022-12-14 10:44:05 -08:00
committed by GitHub
parent f0af622208
commit 2e0bc7dab6
8 changed files with 714 additions and 176 deletions

View File

@@ -2,8 +2,7 @@ use proc_macro2::{Span, TokenStream};
use std::path::{Path, PathBuf};
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::token;
use syn::Token;
use syn::{braced, token, Ident, Token};
use wasmtime_wit_bindgen::Opts;
use wit_parser::{Document, World};
@@ -68,6 +67,7 @@ impl Parse for Config {
}
Opt::Tracing(val) => ret.opts.tracing = val,
Opt::Async(val) => ret.opts.async_ = val,
Opt::TrappableErrorType(val) => ret.opts.trappable_error_type = val,
}
}
} else {
@@ -99,6 +99,7 @@ mod kw {
syn::custom_keyword!(path);
syn::custom_keyword!(inline);
syn::custom_keyword!(tracing);
syn::custom_keyword!(trappable_error_type);
}
enum Opt {
@@ -106,6 +107,7 @@ enum Opt {
Inline(Span, World),
Tracing(bool),
Async(bool),
TrappableErrorType(Vec<(String, String, String)>),
}
impl Parse for Opt {
@@ -132,8 +134,25 @@ impl Parse for Opt {
input.parse::<Token![async]>()?;
input.parse::<Token![:]>()?;
Ok(Opt::Async(input.parse::<syn::LitBool>()?.value))
} else if l.peek(kw::trappable_error_type) {
input.parse::<kw::trappable_error_type>()?;
input.parse::<Token![:]>()?;
let contents;
let _lbrace = braced!(contents in input);
let fields: Punctuated<(String, String, String), Token![,]> =
contents.parse_terminated(trappable_error_field_parse)?;
Ok(Opt::TrappableErrorType(fields.into_iter().collect()))
} else {
Err(l.error())
}
}
}
fn trappable_error_field_parse(input: ParseStream<'_>) -> Result<(String, String, String)> {
let interface = input.parse::<Ident>()?.to_string();
input.parse::<Token![::]>()?;
let type_ = input.parse::<Ident>()?.to_string();
input.parse::<Token![:]>()?;
let rust_type = input.parse::<Ident>()?.to_string();
Ok((interface, type_, rust_type))
}

View File

@@ -1,126 +0,0 @@
/// Type alias for the standard library [`Result`](std::result::Result) type
/// to specifie [`Error`] as the error payload.
pub type Result<A, E> = std::result::Result<A, Error<E>>;
/// Error type used by the [`bindgen!`](crate::component::bindgen) macro.
///
/// This error type represents either the typed error `T` specified here or a
/// trap, represented with [`anyhow::Error`].
pub struct Error<T> {
err: anyhow::Error,
ty: std::marker::PhantomData<T>,
}
impl<T> Error<T> {
/// Creates a new typed version of this error from the `T` specified.
///
/// This error, if it makes its way to the guest, will be returned to the
/// guest and the guest will be able to act upon it.
///
/// Alternatively errors can be created with [`Error::trap`] which will
/// cause the guest to trap and be unable to act upon it.
pub fn new(err: T) -> Error<T>
where
T: std::error::Error + Send + Sync + 'static,
{
Error {
err: err.into(),
ty: std::marker::PhantomData,
}
}
/// Creates a custom "trap" which will abort guest execution and have the
/// specified `err` as the payload context returned from the original
/// invocation.
///
/// Note that if `err` here actually has type `T` then the error will not be
/// considered a trap and will instead be dynamically detected as a normal
/// error to communicate to the original module.
pub fn trap(err: impl std::error::Error + Send + Sync + 'static) -> Error<T> {
Error {
err: anyhow::Error::from(err),
ty: std::marker::PhantomData,
}
}
/// Attempts to dynamically downcast this error internally to the `T`
/// representation.
///
/// If this error is internally represented as a `T` then `Ok(val)` will be
/// returned. If this error is instead represented as a trap then
/// `Err(trap)` will be returned instead.
pub fn downcast(self) -> anyhow::Result<T>
where
T: std::error::Error + Send + Sync + 'static,
{
self.err.downcast::<T>()
}
/// Attempts to dynamically downcast this error to peek at the inner
/// contents of `T` if present.
pub fn downcast_ref(&self) -> Option<&T>
where
T: std::error::Error + Send + Sync + 'static,
{
self.err.downcast_ref::<T>()
}
/// Attempts to dynamically downcast this error to peek at the inner
/// contents of `T` if present.
pub fn downcast_mut(&mut self) -> Option<&mut T>
where
T: std::error::Error + Send + Sync + 'static,
{
self.err.downcast_mut::<T>()
}
/// Converts this error into an `anyhow::Error` which loses the `T` type
/// information tagged to this error.
pub fn into_inner(self) -> anyhow::Error {
self.err
}
/// Same as [`anyhow::Error::context`], attaches a contextual message to
/// this error.
pub fn context<C>(self, context: C) -> Error<T>
where
C: std::fmt::Display + Send + Sync + 'static,
{
self.err.context(context).into()
}
}
impl<T> std::ops::Deref for Error<T> {
type Target = dyn std::error::Error + Send + Sync + 'static;
fn deref(&self) -> &Self::Target {
self.err.deref()
}
}
impl<T> std::ops::DerefMut for Error<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.err.deref_mut()
}
}
impl<T> std::fmt::Display for Error<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.err.fmt(f)
}
}
impl<T> std::fmt::Debug for Error<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.err.fmt(f)
}
}
impl<T> std::error::Error for Error<T> {}
impl<T> From<anyhow::Error> for Error<T> {
fn from(err: anyhow::Error) -> Error<T> {
Error {
err,
ty: std::marker::PhantomData,
}
}
}

View File

@@ -4,7 +4,6 @@
//! probably buggy implementation of the component model.
mod component;
mod error;
mod func;
mod instance;
mod linker;
@@ -14,7 +13,6 @@ mod store;
pub mod types;
mod values;
pub use self::component::Component;
pub use self::error::{Error, Result};
pub use self::func::{
ComponentNamedList, ComponentType, Func, Lift, Lower, TypedFunc, WasmList, WasmStr,
};

View File

@@ -9,5 +9,6 @@ documentation = "https://docs.rs/wasmtime-wit-bindgen/"
edition.workspace = true
[dependencies]
anyhow = { workspace = true }
heck = { workspace = true }
wit-parser = { workspace = true }

View File

@@ -40,19 +40,19 @@ struct Exports {
}
#[derive(Default, Debug, Clone)]
#[cfg_attr(feature = "clap", derive(clap::Args))]
pub struct Opts {
/// Whether or not `rustfmt` is executed to format generated code.
#[cfg_attr(feature = "clap", arg(long))]
pub rustfmt: bool,
/// Whether or not to emit `tracing` macro calls on function entry/exit.
#[cfg_attr(feature = "clap", arg(long))]
pub tracing: bool,
/// Whether or not to use async rust functions and traits.
#[cfg_attr(feature = "clap", arg(long = "async"))]
pub async_: bool,
/// For a given wit interface and type name, generate a "trappable error type"
/// of the following Rust type name
pub trappable_error_type: Vec<(String, String, String)>,
}
impl Opts {
@@ -80,7 +80,7 @@ impl Wasmtime {
fn import(&mut self, name: &str, iface: &Interface) {
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::Owned);
gen.types();
gen.generate_from_error_impls();
gen.generate_trappable_error_types();
gen.generate_add_to_linker(name);
let snake = name.to_snake_case();
@@ -105,7 +105,7 @@ impl Wasmtime {
fn export(&mut self, name: &str, iface: &Interface) {
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a"));
gen.types();
gen.generate_from_error_impls();
gen.generate_trappable_error_types();
let camel = name.to_upper_camel_case();
uwriteln!(gen.src, "pub struct {camel} {{");
@@ -183,7 +183,7 @@ impl Wasmtime {
fn export_default(&mut self, _name: &str, iface: &Interface) {
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a"));
gen.types();
gen.generate_from_error_impls();
gen.generate_trappable_error_types();
let fields = gen.extract_typed_functions();
for (name, getter) in fields {
let prev = gen
@@ -302,6 +302,22 @@ impl Wasmtime {
}
}
impl Wasmtime {
fn trappable_error_types<'a>(
&'a self,
iface: &'a Interface,
) -> impl Iterator<Item = (&String, &TypeId, &String)> + 'a {
self.opts
.trappable_error_type
.iter()
.filter(|(interface_name, _, _)| iface.name == *interface_name)
.filter_map(|(_, wit_typename, rust_typename)| {
let wit_type = iface.type_lookup.get(wit_typename)?;
Some((wit_typename, wit_type, rust_typename))
})
}
}
struct InterfaceGenerator<'a> {
src: Source,
gen: &'a mut Wasmtime,
@@ -774,16 +790,20 @@ impl<'a> InterfaceGenerator<'a> {
}
}
fn special_case_host_error(&self, results: &Results) -> Option<&Result_> {
// We only support the Error case when
// a function has just one result, which is itself a `result<a, e>`, and the
// `e` is *not* a primitive (i.e. defined in std) type.
fn special_case_trappable_error(&self, results: &Results) -> Option<(Result_, String)> {
// We fillin a special trappable error type in the case when a function has just one
// result, which is itself a `result<a, e>`, and the `e` is *not* a primitive
// (i.e. defined in std) type, and matches the typename given by the user.
let mut i = results.iter_types();
if i.len() == 1 {
match i.next().unwrap() {
Type::Id(id) => match &self.iface.types[*id].kind {
TypeDefKind::Result(r) => match r.err {
Some(Type::Id(_)) => Some(&r),
Some(Type::Id(error_typeid)) => self
.gen
.trappable_error_types(&self.iface)
.find(|(_, wit_error_typeid, _)| error_typeid == **wit_error_typeid)
.map(|(_, _, rust_errortype)| (r.clone(), rust_errortype.clone())),
_ => None,
},
_ => None,
@@ -823,22 +843,18 @@ impl<'a> InterfaceGenerator<'a> {
self.push_str(")");
self.push_str(" -> ");
if let Some(r) = self.special_case_host_error(&func.results).cloned() {
if let Some((r, error_typename)) = self.special_case_trappable_error(&func.results) {
// Functions which have a single result `result<ok,err>` get special
// cased to use the host_wasmtime_rust::Error<err>, making it possible
// for them to trap or use `?` to propogate their errors
self.push_str("wasmtime::component::Result<");
self.push_str("Result<");
if let Some(ok) = r.ok {
self.print_ty(&ok, TypeMode::Owned);
} else {
self.push_str("()");
}
self.push_str(",");
if let Some(err) = r.err {
self.print_ty(&err, TypeMode::Owned);
} else {
self.push_str("()");
}
self.push_str(&error_typename);
self.push_str(">");
} else {
// All other functions get their return values wrapped in an anyhow::Result.
@@ -936,7 +952,7 @@ impl<'a> InterfaceGenerator<'a> {
uwrite!(self.src, ");\n");
}
if self.special_case_host_error(&func.results).is_some() {
if self.special_case_trappable_error(&func.results).is_some() {
uwrite!(
self.src,
"match r {{
@@ -1081,33 +1097,56 @@ impl<'a> InterfaceGenerator<'a> {
self.src.push_str("}\n");
}
fn generate_from_error_impls(&mut self) {
for (id, ty) in self.iface.types.iter() {
if ty.name.is_none() {
continue;
}
let info = self.info(id);
if info.error {
for (name, mode) in self.modes_of(id) {
let name = name.to_upper_camel_case();
if self.lifetime_for(&info, mode).is_some() {
continue;
}
self.push_str("impl From<");
self.push_str(&name);
self.push_str("> for wasmtime::component::Error<");
self.push_str(&name);
self.push_str("> {\n");
self.push_str("fn from(e: ");
self.push_str(&name);
self.push_str(") -> wasmtime::component::Error::< ");
self.push_str(&name);
self.push_str("> {\n");
self.push_str("wasmtime::component::Error::new(e)\n");
self.push_str("}\n");
self.push_str("}\n");
}
fn generate_trappable_error_types(&mut self) {
for (wit_typename, wit_type, trappable_type) in self.gen.trappable_error_types(&self.iface)
{
let info = self.info(*wit_type);
if self.lifetime_for(&info, TypeMode::Owned).is_some() {
panic!(
"type {:?} in interface {:?} is not 'static",
wit_typename, self.iface.name
)
}
let abi_type = self.param_name(*wit_type);
uwriteln!(
self.src,
"
#[derive(Debug)]
pub struct {trappable_type} {{
inner: anyhow::Error,
}}
impl std::fmt::Display for {trappable_type} {{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
write!(f, \"{{}}\", self.inner)
}}
}}
impl std::error::Error for {trappable_type} {{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {{
self.inner.source()
}}
}}
impl {trappable_type} {{
pub fn trap(inner: anyhow::Error) -> Self {{
Self {{ inner }}
}}
pub fn downcast(self) -> Result<{abi_type}, anyhow::Error> {{
self.inner.downcast()
}}
pub fn downcast_ref(&self) -> Option<&{abi_type}> {{
self.inner.downcast_ref()
}}
pub fn context(self, s: impl Into<String>) -> Self {{
Self {{ inner: self.inner.context(s.into()) }}
}}
}}
impl From<{abi_type}> for {trappable_type} {{
fn from(abi: {abi_type}) -> {trappable_type} {{
{trappable_type} {{ inner: anyhow::Error::from(abi) }}
}}
}}
"
);
}
}