Wasmtime component bindgen: opt-in trappable error types (#5397)
* wip * start trying to write a runtime test * cut out all the more complex test cases until i get this one working * add macro parsing for the trappable error type config * runtime result tests works for an empty and a string error type * debugging: macro is broken because interfaces dont have names??? * thats how you name interfaces * record error and variant error work * show a concrete trap type, remove debug * delete clap annotations from wit-bindgen crate these are not used - clap isnt even an optional dep here - but were a holdover from the old home
This commit is contained in:
@@ -2,8 +2,7 @@ use proc_macro2::{Span, TokenStream};
|
||||
use std::path::{Path, PathBuf};
|
||||
use syn::parse::{Error, Parse, ParseStream, Result};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token;
|
||||
use syn::Token;
|
||||
use syn::{braced, token, Ident, Token};
|
||||
use wasmtime_wit_bindgen::Opts;
|
||||
use wit_parser::{Document, World};
|
||||
|
||||
@@ -68,6 +67,7 @@ impl Parse for Config {
|
||||
}
|
||||
Opt::Tracing(val) => ret.opts.tracing = val,
|
||||
Opt::Async(val) => ret.opts.async_ = val,
|
||||
Opt::TrappableErrorType(val) => ret.opts.trappable_error_type = val,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -99,6 +99,7 @@ mod kw {
|
||||
syn::custom_keyword!(path);
|
||||
syn::custom_keyword!(inline);
|
||||
syn::custom_keyword!(tracing);
|
||||
syn::custom_keyword!(trappable_error_type);
|
||||
}
|
||||
|
||||
enum Opt {
|
||||
@@ -106,6 +107,7 @@ enum Opt {
|
||||
Inline(Span, World),
|
||||
Tracing(bool),
|
||||
Async(bool),
|
||||
TrappableErrorType(Vec<(String, String, String)>),
|
||||
}
|
||||
|
||||
impl Parse for Opt {
|
||||
@@ -132,8 +134,25 @@ impl Parse for Opt {
|
||||
input.parse::<Token![async]>()?;
|
||||
input.parse::<Token![:]>()?;
|
||||
Ok(Opt::Async(input.parse::<syn::LitBool>()?.value))
|
||||
} else if l.peek(kw::trappable_error_type) {
|
||||
input.parse::<kw::trappable_error_type>()?;
|
||||
input.parse::<Token![:]>()?;
|
||||
let contents;
|
||||
let _lbrace = braced!(contents in input);
|
||||
let fields: Punctuated<(String, String, String), Token![,]> =
|
||||
contents.parse_terminated(trappable_error_field_parse)?;
|
||||
Ok(Opt::TrappableErrorType(fields.into_iter().collect()))
|
||||
} else {
|
||||
Err(l.error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn trappable_error_field_parse(input: ParseStream<'_>) -> Result<(String, String, String)> {
|
||||
let interface = input.parse::<Ident>()?.to_string();
|
||||
input.parse::<Token![::]>()?;
|
||||
let type_ = input.parse::<Ident>()?.to_string();
|
||||
input.parse::<Token![:]>()?;
|
||||
let rust_type = input.parse::<Ident>()?.to_string();
|
||||
Ok((interface, type_, rust_type))
|
||||
}
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
/// Type alias for the standard library [`Result`](std::result::Result) type
|
||||
/// to specifie [`Error`] as the error payload.
|
||||
pub type Result<A, E> = std::result::Result<A, Error<E>>;
|
||||
|
||||
/// Error type used by the [`bindgen!`](crate::component::bindgen) macro.
|
||||
///
|
||||
/// This error type represents either the typed error `T` specified here or a
|
||||
/// trap, represented with [`anyhow::Error`].
|
||||
pub struct Error<T> {
|
||||
err: anyhow::Error,
|
||||
ty: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Error<T> {
|
||||
/// Creates a new typed version of this error from the `T` specified.
|
||||
///
|
||||
/// This error, if it makes its way to the guest, will be returned to the
|
||||
/// guest and the guest will be able to act upon it.
|
||||
///
|
||||
/// Alternatively errors can be created with [`Error::trap`] which will
|
||||
/// cause the guest to trap and be unable to act upon it.
|
||||
pub fn new(err: T) -> Error<T>
|
||||
where
|
||||
T: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
Error {
|
||||
err: err.into(),
|
||||
ty: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a custom "trap" which will abort guest execution and have the
|
||||
/// specified `err` as the payload context returned from the original
|
||||
/// invocation.
|
||||
///
|
||||
/// Note that if `err` here actually has type `T` then the error will not be
|
||||
/// considered a trap and will instead be dynamically detected as a normal
|
||||
/// error to communicate to the original module.
|
||||
pub fn trap(err: impl std::error::Error + Send + Sync + 'static) -> Error<T> {
|
||||
Error {
|
||||
err: anyhow::Error::from(err),
|
||||
ty: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to dynamically downcast this error internally to the `T`
|
||||
/// representation.
|
||||
///
|
||||
/// If this error is internally represented as a `T` then `Ok(val)` will be
|
||||
/// returned. If this error is instead represented as a trap then
|
||||
/// `Err(trap)` will be returned instead.
|
||||
pub fn downcast(self) -> anyhow::Result<T>
|
||||
where
|
||||
T: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
self.err.downcast::<T>()
|
||||
}
|
||||
|
||||
/// Attempts to dynamically downcast this error to peek at the inner
|
||||
/// contents of `T` if present.
|
||||
pub fn downcast_ref(&self) -> Option<&T>
|
||||
where
|
||||
T: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
self.err.downcast_ref::<T>()
|
||||
}
|
||||
|
||||
/// Attempts to dynamically downcast this error to peek at the inner
|
||||
/// contents of `T` if present.
|
||||
pub fn downcast_mut(&mut self) -> Option<&mut T>
|
||||
where
|
||||
T: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
self.err.downcast_mut::<T>()
|
||||
}
|
||||
|
||||
/// Converts this error into an `anyhow::Error` which loses the `T` type
|
||||
/// information tagged to this error.
|
||||
pub fn into_inner(self) -> anyhow::Error {
|
||||
self.err
|
||||
}
|
||||
|
||||
/// Same as [`anyhow::Error::context`], attaches a contextual message to
|
||||
/// this error.
|
||||
pub fn context<C>(self, context: C) -> Error<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static,
|
||||
{
|
||||
self.err.context(context).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for Error<T> {
|
||||
type Target = dyn std::error::Error + Send + Sync + 'static;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.err.deref()
|
||||
}
|
||||
}
|
||||
impl<T> std::ops::DerefMut for Error<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.err.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::fmt::Display for Error<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.err.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::fmt::Debug for Error<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.err.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::error::Error for Error<T> {}
|
||||
|
||||
impl<T> From<anyhow::Error> for Error<T> {
|
||||
fn from(err: anyhow::Error) -> Error<T> {
|
||||
Error {
|
||||
err,
|
||||
ty: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@
|
||||
//! probably buggy implementation of the component model.
|
||||
|
||||
mod component;
|
||||
mod error;
|
||||
mod func;
|
||||
mod instance;
|
||||
mod linker;
|
||||
@@ -14,7 +13,6 @@ mod store;
|
||||
pub mod types;
|
||||
mod values;
|
||||
pub use self::component::Component;
|
||||
pub use self::error::{Error, Result};
|
||||
pub use self::func::{
|
||||
ComponentNamedList, ComponentType, Func, Lift, Lower, TypedFunc, WasmList, WasmStr,
|
||||
};
|
||||
|
||||
@@ -9,5 +9,6 @@ documentation = "https://docs.rs/wasmtime-wit-bindgen/"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
heck = { workspace = true }
|
||||
wit-parser = { workspace = true }
|
||||
|
||||
@@ -40,19 +40,19 @@ struct Exports {
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[cfg_attr(feature = "clap", derive(clap::Args))]
|
||||
pub struct Opts {
|
||||
/// Whether or not `rustfmt` is executed to format generated code.
|
||||
#[cfg_attr(feature = "clap", arg(long))]
|
||||
pub rustfmt: bool,
|
||||
|
||||
/// Whether or not to emit `tracing` macro calls on function entry/exit.
|
||||
#[cfg_attr(feature = "clap", arg(long))]
|
||||
pub tracing: bool,
|
||||
|
||||
/// Whether or not to use async rust functions and traits.
|
||||
#[cfg_attr(feature = "clap", arg(long = "async"))]
|
||||
pub async_: bool,
|
||||
|
||||
/// For a given wit interface and type name, generate a "trappable error type"
|
||||
/// of the following Rust type name
|
||||
pub trappable_error_type: Vec<(String, String, String)>,
|
||||
}
|
||||
|
||||
impl Opts {
|
||||
@@ -80,7 +80,7 @@ impl Wasmtime {
|
||||
fn import(&mut self, name: &str, iface: &Interface) {
|
||||
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::Owned);
|
||||
gen.types();
|
||||
gen.generate_from_error_impls();
|
||||
gen.generate_trappable_error_types();
|
||||
gen.generate_add_to_linker(name);
|
||||
|
||||
let snake = name.to_snake_case();
|
||||
@@ -105,7 +105,7 @@ impl Wasmtime {
|
||||
fn export(&mut self, name: &str, iface: &Interface) {
|
||||
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a"));
|
||||
gen.types();
|
||||
gen.generate_from_error_impls();
|
||||
gen.generate_trappable_error_types();
|
||||
|
||||
let camel = name.to_upper_camel_case();
|
||||
uwriteln!(gen.src, "pub struct {camel} {{");
|
||||
@@ -183,7 +183,7 @@ impl Wasmtime {
|
||||
fn export_default(&mut self, _name: &str, iface: &Interface) {
|
||||
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a"));
|
||||
gen.types();
|
||||
gen.generate_from_error_impls();
|
||||
gen.generate_trappable_error_types();
|
||||
let fields = gen.extract_typed_functions();
|
||||
for (name, getter) in fields {
|
||||
let prev = gen
|
||||
@@ -302,6 +302,22 @@ impl Wasmtime {
|
||||
}
|
||||
}
|
||||
|
||||
impl Wasmtime {
|
||||
fn trappable_error_types<'a>(
|
||||
&'a self,
|
||||
iface: &'a Interface,
|
||||
) -> impl Iterator<Item = (&String, &TypeId, &String)> + 'a {
|
||||
self.opts
|
||||
.trappable_error_type
|
||||
.iter()
|
||||
.filter(|(interface_name, _, _)| iface.name == *interface_name)
|
||||
.filter_map(|(_, wit_typename, rust_typename)| {
|
||||
let wit_type = iface.type_lookup.get(wit_typename)?;
|
||||
Some((wit_typename, wit_type, rust_typename))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct InterfaceGenerator<'a> {
|
||||
src: Source,
|
||||
gen: &'a mut Wasmtime,
|
||||
@@ -774,16 +790,20 @@ impl<'a> InterfaceGenerator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn special_case_host_error(&self, results: &Results) -> Option<&Result_> {
|
||||
// We only support the Error case when
|
||||
// a function has just one result, which is itself a `result<a, e>`, and the
|
||||
// `e` is *not* a primitive (i.e. defined in std) type.
|
||||
fn special_case_trappable_error(&self, results: &Results) -> Option<(Result_, String)> {
|
||||
// We fillin a special trappable error type in the case when a function has just one
|
||||
// result, which is itself a `result<a, e>`, and the `e` is *not* a primitive
|
||||
// (i.e. defined in std) type, and matches the typename given by the user.
|
||||
let mut i = results.iter_types();
|
||||
if i.len() == 1 {
|
||||
match i.next().unwrap() {
|
||||
Type::Id(id) => match &self.iface.types[*id].kind {
|
||||
TypeDefKind::Result(r) => match r.err {
|
||||
Some(Type::Id(_)) => Some(&r),
|
||||
Some(Type::Id(error_typeid)) => self
|
||||
.gen
|
||||
.trappable_error_types(&self.iface)
|
||||
.find(|(_, wit_error_typeid, _)| error_typeid == **wit_error_typeid)
|
||||
.map(|(_, _, rust_errortype)| (r.clone(), rust_errortype.clone())),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
@@ -823,22 +843,18 @@ impl<'a> InterfaceGenerator<'a> {
|
||||
self.push_str(")");
|
||||
self.push_str(" -> ");
|
||||
|
||||
if let Some(r) = self.special_case_host_error(&func.results).cloned() {
|
||||
if let Some((r, error_typename)) = self.special_case_trappable_error(&func.results) {
|
||||
// Functions which have a single result `result<ok,err>` get special
|
||||
// cased to use the host_wasmtime_rust::Error<err>, making it possible
|
||||
// for them to trap or use `?` to propogate their errors
|
||||
self.push_str("wasmtime::component::Result<");
|
||||
self.push_str("Result<");
|
||||
if let Some(ok) = r.ok {
|
||||
self.print_ty(&ok, TypeMode::Owned);
|
||||
} else {
|
||||
self.push_str("()");
|
||||
}
|
||||
self.push_str(",");
|
||||
if let Some(err) = r.err {
|
||||
self.print_ty(&err, TypeMode::Owned);
|
||||
} else {
|
||||
self.push_str("()");
|
||||
}
|
||||
self.push_str(&error_typename);
|
||||
self.push_str(">");
|
||||
} else {
|
||||
// All other functions get their return values wrapped in an anyhow::Result.
|
||||
@@ -936,7 +952,7 @@ impl<'a> InterfaceGenerator<'a> {
|
||||
uwrite!(self.src, ");\n");
|
||||
}
|
||||
|
||||
if self.special_case_host_error(&func.results).is_some() {
|
||||
if self.special_case_trappable_error(&func.results).is_some() {
|
||||
uwrite!(
|
||||
self.src,
|
||||
"match r {{
|
||||
@@ -1081,33 +1097,56 @@ impl<'a> InterfaceGenerator<'a> {
|
||||
self.src.push_str("}\n");
|
||||
}
|
||||
|
||||
fn generate_from_error_impls(&mut self) {
|
||||
for (id, ty) in self.iface.types.iter() {
|
||||
if ty.name.is_none() {
|
||||
continue;
|
||||
}
|
||||
let info = self.info(id);
|
||||
if info.error {
|
||||
for (name, mode) in self.modes_of(id) {
|
||||
let name = name.to_upper_camel_case();
|
||||
if self.lifetime_for(&info, mode).is_some() {
|
||||
continue;
|
||||
}
|
||||
self.push_str("impl From<");
|
||||
self.push_str(&name);
|
||||
self.push_str("> for wasmtime::component::Error<");
|
||||
self.push_str(&name);
|
||||
self.push_str("> {\n");
|
||||
self.push_str("fn from(e: ");
|
||||
self.push_str(&name);
|
||||
self.push_str(") -> wasmtime::component::Error::< ");
|
||||
self.push_str(&name);
|
||||
self.push_str("> {\n");
|
||||
self.push_str("wasmtime::component::Error::new(e)\n");
|
||||
self.push_str("}\n");
|
||||
self.push_str("}\n");
|
||||
}
|
||||
fn generate_trappable_error_types(&mut self) {
|
||||
for (wit_typename, wit_type, trappable_type) in self.gen.trappable_error_types(&self.iface)
|
||||
{
|
||||
let info = self.info(*wit_type);
|
||||
if self.lifetime_for(&info, TypeMode::Owned).is_some() {
|
||||
panic!(
|
||||
"type {:?} in interface {:?} is not 'static",
|
||||
wit_typename, self.iface.name
|
||||
)
|
||||
}
|
||||
let abi_type = self.param_name(*wit_type);
|
||||
|
||||
uwriteln!(
|
||||
self.src,
|
||||
"
|
||||
#[derive(Debug)]
|
||||
pub struct {trappable_type} {{
|
||||
inner: anyhow::Error,
|
||||
}}
|
||||
impl std::fmt::Display for {trappable_type} {{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
|
||||
write!(f, \"{{}}\", self.inner)
|
||||
}}
|
||||
}}
|
||||
impl std::error::Error for {trappable_type} {{
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {{
|
||||
self.inner.source()
|
||||
}}
|
||||
}}
|
||||
impl {trappable_type} {{
|
||||
pub fn trap(inner: anyhow::Error) -> Self {{
|
||||
Self {{ inner }}
|
||||
}}
|
||||
pub fn downcast(self) -> Result<{abi_type}, anyhow::Error> {{
|
||||
self.inner.downcast()
|
||||
}}
|
||||
pub fn downcast_ref(&self) -> Option<&{abi_type}> {{
|
||||
self.inner.downcast_ref()
|
||||
}}
|
||||
pub fn context(self, s: impl Into<String>) -> Self {{
|
||||
Self {{ inner: self.inner.context(s.into()) }}
|
||||
}}
|
||||
}}
|
||||
impl From<{abi_type}> for {trappable_type} {{
|
||||
fn from(abi: {abi_type}) -> {trappable_type} {{
|
||||
{trappable_type} {{ inner: anyhow::Error::from(abi) }}
|
||||
}}
|
||||
}}
|
||||
"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user