From df9c725fa0e7a4535cc884b3f1973fcb16d82e18 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 11 Feb 2021 09:28:36 -0800 Subject: [PATCH] Update to the next version of the witx crate This commit updates to the 0.9 version of the witx crate implemented in WebAssembly/wasi#395. This new version drastically changes code generation and how we interface with the crate. The intention is to abstract the code generation aspects and allow code generators to implement much more low-level instructions to enable more flexible APIs in the future. Additionally a bunch of `*.witx` files were updated in the WASI repository. It's worth pointing out, however, that `wasi-common` does not change as a result of this change. The shape of the APIs that we need to implement are effectively the same and the only difference is that the shim functions generated by wiggle are a bit different. --- .gitmodules | 3 +- Cargo.lock | 13 +- crates/wasi-common/WASI | 2 +- crates/wiggle/Cargo.toml | 2 +- crates/wiggle/generate/Cargo.toml | 2 +- crates/wiggle/generate/src/error_transform.rs | 6 +- crates/wiggle/generate/src/funcs.rs | 587 +++++++++--------- crates/wiggle/generate/src/lib.rs | 5 +- crates/wiggle/generate/src/lifetimes.rs | 35 +- crates/wiggle/generate/src/module_trait.rs | 64 +- crates/wiggle/generate/src/names.rs | 82 ++- crates/wiggle/generate/src/types/enum.rs | 120 ---- crates/wiggle/generate/src/types/flags.rs | 46 +- crates/wiggle/generate/src/types/handle.rs | 2 - crates/wiggle/generate/src/types/int.rs | 100 --- crates/wiggle/generate/src/types/mod.rs | 56 +- .../src/types/{struct.rs => record.rs} | 10 +- .../src/types/{union.rs => variant.rs} | 77 ++- crates/wiggle/macro/Cargo.toml | 2 +- crates/wiggle/src/lib.rs | 6 + crates/wiggle/wasmtime/Cargo.toml | 2 +- crates/wiggle/wasmtime/macro/Cargo.toml | 2 +- crates/wiggle/wasmtime/macro/src/config.rs | 53 -- crates/wiggle/wasmtime/macro/src/lib.rs | 39 +- 24 files changed, 510 insertions(+), 806 deletions(-) delete mode 100644 crates/wiggle/generate/src/types/enum.rs delete mode 100644 crates/wiggle/generate/src/types/int.rs rename crates/wiggle/generate/src/types/{struct.rs => record.rs} (94%) rename crates/wiggle/generate/src/types/{union.rs => variant.rs} (50%) diff --git a/.gitmodules b/.gitmodules index ee264b99c4..305858152d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,7 +6,8 @@ url = https://github.com/WebAssembly/wasm-c-api [submodule "WASI"] path = crates/wasi-common/WASI - url = https://github.com/WebAssembly/WASI + url = https://github.com/alexcrichton/WASI + branch = abis [submodule "crates/wasi-nn/spec"] path = crates/wasi-nn/spec url = https://github.com/WebAssembly/wasi-nn diff --git a/Cargo.lock b/Cargo.lock index 5948052629..544b8bd861 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3499,18 +3499,18 @@ dependencies = [ [[package]] name = "wast" -version = "22.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe1220ed7f824992b426a76125a3403d048eaf0f627918e97ade0d9b9d510d20" +checksum = "c24a3ee360d01d60ed0a0f960ab76a6acce64348cdb0bf8699c2a866fad57c7c" dependencies = [ "leb128", ] [[package]] name = "wast" -version = "32.0.0" +version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c24a3ee360d01d60ed0a0f960ab76a6acce64348cdb0bf8699c2a866fad57c7c" +checksum = "1d04fe175c7f78214971293e7d8875673804e736092206a3a4544dbc12811c1b" dependencies = [ "leb128", ] @@ -3634,15 +3634,16 @@ dependencies = [ [[package]] name = "witx" -version = "0.8.8" +version = "0.9.0" dependencies = [ "anyhow", "diff", "log", "pretty_env_logger", + "rayon", "structopt", "thiserror", - "wast 22.0.0", + "wast 33.0.0", ] [[package]] diff --git a/crates/wasi-common/WASI b/crates/wasi-common/WASI index 8deb71ddd0..7c4fd252d0 160000 --- a/crates/wasi-common/WASI +++ b/crates/wasi-common/WASI @@ -1 +1 @@ -Subproject commit 8deb71ddd0955101cb69333b08284e7b01775928 +Subproject commit 7c4fd252d0841488de4a6e724e600f1561797387 diff --git a/crates/wiggle/Cargo.toml b/crates/wiggle/Cargo.toml index d7bba56082..88c8315bdc 100644 --- a/crates/wiggle/Cargo.toml +++ b/crates/wiggle/Cargo.toml @@ -12,7 +12,7 @@ include = ["src/**/*", "LICENSE"] [dependencies] thiserror = "1" -witx = { path = "../wasi-common/WASI/tools/witx", version = "0.8.7", optional = true } +witx = { path = "../wasi-common/WASI/tools/witx", version = "0.9", optional = true } wiggle-macro = { path = "macro", version = "0.23.0" } tracing = "0.1.15" bitflags = "1.2" diff --git a/crates/wiggle/generate/Cargo.toml b/crates/wiggle/generate/Cargo.toml index 54cb70d20d..1a7459a519 100644 --- a/crates/wiggle/generate/Cargo.toml +++ b/crates/wiggle/generate/Cargo.toml @@ -14,7 +14,7 @@ include = ["src/**/*", "LICENSE"] [lib] [dependencies] -witx = { version = "0.8.7", path = "../../wasi-common/WASI/tools/witx" } +witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx" } quote = "1.0" proc-macro2 = "1.0" heck = "0.3" diff --git a/crates/wiggle/generate/src/error_transform.rs b/crates/wiggle/generate/src/error_transform.rs index 3598e221de..56fa10eaf0 100644 --- a/crates/wiggle/generate/src/error_transform.rs +++ b/crates/wiggle/generate/src/error_transform.rs @@ -49,10 +49,14 @@ impl ErrorTransform { pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&UserErrorType> { match tref { - TypeRef::Name(nt) => self.m.iter().find(|u| u.abi_type.name == nt.name), + TypeRef::Name(nt) => self.for_name(nt), TypeRef::Value { .. } => None, } } + + pub fn for_name(&self, nt: &NamedType) -> Option<&UserErrorType> { + self.m.iter().find(|u| u.abi_type.name == nt.name) + } } pub struct UserErrorType { diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index 50f5eec640..8e164f878b 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -1,11 +1,12 @@ -use proc_macro2::TokenStream; -use quote::quote; - use crate::error_transform::ErrorTransform; use crate::lifetimes::anon_lifetime; use crate::module_trait::passed_by_reference; use crate::names::Names; use crate::types::WiggleType; +use proc_macro2::{Ident, Span, TokenStream}; +use quote::quote; +use std::mem; +use witx::Instruction; pub fn define_func( names: &Names, @@ -13,163 +14,54 @@ pub fn define_func( func: &witx::InterfaceFunc, errxform: &ErrorTransform, ) -> TokenStream { - let funcname = func.name.as_str(); - - let ident = names.func(&func.name); let rt = names.runtime_mod(); + let ident = names.func(&func.name); let ctx_type = names.ctx_type(); - let coretype = func.core_type(); - let params = coretype.args.iter().map(|arg| { - let name = names.func_core_arg(arg); - let atom = names.atom_type(arg.repr()); - quote!(#name : #atom) + let (wasm_params, wasm_results) = func.wasm_signature(); + let param_names = (0..wasm_params.len()) + .map(|i| Ident::new(&format!("arg{}", i), Span::call_site())) + .collect::>(); + let abi_params = wasm_params.iter().zip(¶m_names).map(|(arg, name)| { + let wasm = names.wasm_type(*arg); + quote!(#name : #wasm) }); - let abi_args = quote!( - ctx: &#ctx_type, - memory: &dyn #rt::GuestMemory, - #(#params),* + let abi_ret = match wasm_results.len() { + 0 => quote!(()), + 1 => { + let ty = names.wasm_type(wasm_results[0]); + quote!(#ty) + } + _ => unimplemented!(), + }; + + let mut body = TokenStream::new(); + func.call_interface( + &module.name, + &mut Rust { + src: &mut body, + params: ¶m_names, + block_storage: Vec::new(), + blocks: Vec::new(), + rt: &rt, + names, + module, + funcname: func.name.as_str(), + errxform, + }, ); - let abi_ret = if let Some(ret) = &coretype.ret { - match ret.signifies { - witx::CoreParamSignifies::Value(atom) => names.atom_type(atom), - _ => unreachable!("ret should always be passed by value"), - } - } else { - quote!(()) - }; - let err_type = coretype.ret.clone().map(|ret| ret.param.tref); - let ret_err = coretype - .ret - .map(|ret| { - let name = names.func_param(&ret.param.name); - let conversion = if let Some(user_err) = errxform.for_abi_error(&ret.param.tref) { - let method = names.user_error_conversion_method(&user_err); - quote!(UserErrorConversion::#method(ctx, e)) - } else { - quote!(Ok(e)) - }; - quote! { - let e = #conversion; - #rt::tracing::event!( - #rt::tracing::Level::TRACE, - #name = #rt::tracing::field::debug(&e), - ); - match e { - Ok(e) => { return Ok(#abi_ret::from(e)); }, - Err(e) => { return Err(e); }, - } - } - }) - .unwrap_or_else(|| quote!(())); - - let error_handling = |location: &str| -> TokenStream { - if let Some(tref) = &err_type { - let abi_ret = match tref.type_().passed_by() { - witx::TypePassedBy::Value(atom) => names.atom_type(atom), - _ => unreachable!("err should always be passed by value"), - }; - let err_typename = names.type_ref(&tref, anon_lifetime()); - let err_method = names.guest_error_conversion_method(&tref); - quote! { - let e = #rt::GuestError::InFunc { funcname: #funcname, location: #location, err: Box::new(e.into()) }; - let err: #err_typename = GuestErrorConversion::#err_method(ctx, e); - return Ok(#abi_ret::from(err)); - } - } else { - quote! { - panic!("error: {:?}", e) - } - } - }; - - let marshal_args = func - .params - .iter() - .map(|p| marshal_arg(names, p, error_handling(p.name.as_str()))); - let trait_args = func.params.iter().map(|param| { - let name = names.func_param(¶m.name); - if passed_by_reference(&*param.tref.type_()) { - quote!(&#name) - } else { - quote!(#name) - } - }); - - let log_marshalled_args = if func.params.len() > 0 { - let rt = names.runtime_mod(); - let args = func.params.iter().map(|param| { - let name = names.func_param(¶m.name); - if param.impls_display() { - quote!( #name = #rt::tracing::field::display(&#name) ) - } else { - quote!( #name = #rt::tracing::field::debug(&#name) ) - } - }); - quote! { - #rt::tracing::event!(#rt::tracing::Level::TRACE, #(#args),*); - } - } else { - quote!() - }; - - let (trait_rets, trait_bindings) = if func.results.len() < 2 { - (quote!({}), quote!(_)) - } else { - let trait_rets: Vec<_> = func - .results - .iter() - .skip(1) - .map(|result| names.func_param(&result.name)) - .collect(); - let bindings = quote!((#(#trait_rets),*)); - let trace_rets = func.results.iter().skip(1).map(|result| { - let name = names.func_param(&result.name); - if result.tref.impls_display() { - quote!(#name = #rt::tracing::field::display(&#name)) - } else { - quote!(#name = #rt::tracing::field::debug(&#name)) - } - }); - let rets = quote! { - #rt::tracing::event!(#rt::tracing::Level::TRACE, #(#trace_rets),*); - (#(#trait_rets),*) - }; - (rets, bindings) - }; - - // Return value pointers need to be validated before the api call, then - // assigned to afterwards. marshal_result returns these two statements as a pair. - let marshal_rets = func - .results - .iter() - .skip(1) - .map(|result| marshal_result(names, result, &error_handling)); - let marshal_rets_pre = marshal_rets.clone().map(|(pre, _post)| pre); - let marshal_rets_post = marshal_rets.map(|(_pre, post)| post); - - let success = if let Some(ref err_type) = err_type { - let err_typename = names.type_ref(&err_type, anon_lifetime()); - quote! { - let success:#err_typename = #rt::GuestErrorType::success(); - #rt::tracing::event!( - #rt::tracing::Level::TRACE, - success=#rt::tracing::field::display(&success) - ); - Ok(#abi_ret::from(success)) - } - } else { - quote!(Ok(())) - }; - - let trait_name = names.trait_name(&module.name); let mod_name = &module.name.as_str(); let func_name = &func.name.as_str(); + quote! { + pub fn #ident( + ctx: &#ctx_type, + memory: &dyn #rt::GuestMemory, + #(#abi_params),* + ) -> Result<#abi_ret, #rt::Trap> { + use std::convert::TryFrom as _; - if func.noreturn { - quote!(pub fn #ident(#abi_args) -> Result<#abi_ret, #rt::Trap> { let _span = #rt::tracing::span!( #rt::tracing::Level::TRACE, "wiggle abi", @@ -178,176 +70,259 @@ pub fn define_func( ); let _enter = _span.enter(); - #(#marshal_args)* - #log_marshalled_args - let trap = #trait_name::#ident(ctx, #(#trait_args),*); - Err(trap) - }) - } else { - quote!(pub fn #ident(#abi_args) -> Result<#abi_ret, #rt::Trap> { - let _span = #rt::tracing::span!( - #rt::tracing::Level::TRACE, - "wiggle abi", - module = #mod_name, - function = #func_name - ); - let _enter = _span.enter(); - - #(#marshal_args)* - #(#marshal_rets_pre)* - #log_marshalled_args - let #trait_bindings = match #trait_name::#ident(ctx, #(#trait_args),*) { - Ok(#trait_bindings) => { #trait_rets }, - Err(e) => { #ret_err }, - }; - #(#marshal_rets_post)* - #success - }) + #body + } } } -fn marshal_arg( - names: &Names, - param: &witx::InterfaceFuncParam, - error_handling: TokenStream, -) -> TokenStream { - let rt = names.runtime_mod(); - let tref = ¶m.tref; - let interface_typename = names.type_ref(&tref, anon_lifetime()); +struct Rust<'a> { + src: &'a mut TokenStream, + params: &'a [Ident], + block_storage: Vec, + blocks: Vec, + rt: &'a TokenStream, + names: &'a Names, + module: &'a witx::Module, + funcname: &'a str, + errxform: &'a ErrorTransform, +} - let try_into_conversion = { - let name = names.func_param(¶m.name); - quote! { - let #name: #interface_typename = { - use ::std::convert::TryInto; - match #name.try_into() { - Ok(a) => a, - Err(e) => { - #error_handling - } +impl witx::Bindgen for Rust<'_> { + type Operand = TokenStream; + + fn push_block(&mut self) { + let prev = mem::replace(self.src, TokenStream::new()); + self.block_storage.push(prev); + } + + fn finish_block(&mut self, operand: Option) { + let to_restore = self.block_storage.pop().unwrap(); + let src = mem::replace(self.src, to_restore); + match operand { + None => self.blocks.push(src), + Some(s) => { + if src.is_empty() { + self.blocks.push(s); + } else { + self.blocks.push(quote!({ #src; #s })); } - }; - } - }; - - let read_conversion = { - let pointee_type = names.type_ref(tref, anon_lifetime()); - let arg_name = names.func_ptr_binding(¶m.name); - let name = names.func_param(¶m.name); - quote! { - let #name = match #rt::GuestPtr::<#pointee_type>::new(memory, #arg_name as u32).read() { - Ok(r) => r, - Err(e) => { - #error_handling - } - }; - } - }; - - match &*tref.type_() { - witx::Type::Enum(_e) => try_into_conversion, - witx::Type::Flags(_f) => try_into_conversion, - witx::Type::Int(_i) => try_into_conversion, - witx::Type::Builtin(b) => match b { - witx::BuiltinType::U8 | witx::BuiltinType::U16 | witx::BuiltinType::Char8 => { - try_into_conversion } - witx::BuiltinType::S8 | witx::BuiltinType::S16 => { - let name = names.func_param(¶m.name); - quote! { - let #name: #interface_typename = match (#name as i32).try_into() { - Ok(a) => a, - Err(e) => { - #error_handling - } + } + } + + // This is only used for `call_wasm` at this time. + fn allocate_space(&mut self, _: usize, _: &witx::NamedType) { + unimplemented!() + } + + fn emit( + &mut self, + inst: &Instruction<'_>, + operands: &mut Vec, + results: &mut Vec, + ) { + let rt = self.rt; + let wrap_err = |location: &str| { + let funcname = self.funcname; + quote! { + |e| { + #rt::GuestError::InFunc { + funcname: #funcname, + location: #location, + err: Box::new(#rt::GuestError::from(e)), } } } - witx::BuiltinType::U32 - | witx::BuiltinType::S32 - | witx::BuiltinType::U64 - | witx::BuiltinType::S64 - | witx::BuiltinType::USize - | witx::BuiltinType::F32 - | witx::BuiltinType::F64 => { - let name = names.func_param(¶m.name); - quote! { - let #name = #name as #interface_typename; + }; + + let mut try_from = |ty: TokenStream| { + let val = operands.pop().unwrap(); + let wrap_err = wrap_err(&format!("convert {}", ty)); + results.push(quote!(#ty::try_from(#val).map_err(#wrap_err)?)); + }; + + match inst { + Instruction::GetArg { nth } => { + let param = &self.params[*nth]; + results.push(quote!(#param)); + } + + Instruction::PointerFromI32 { ty } | Instruction::ConstPointerFromI32 { ty } => { + let val = operands.pop().unwrap(); + let pointee_type = self.names.type_ref(ty, anon_lifetime()); + results.push(quote! { + #rt::GuestPtr::<#pointee_type>::new(memory, #val as u32) + }); + } + + Instruction::ListFromPointerLength { ty } => { + let ptr = &operands[0]; + let len = &operands[1]; + let ty = match &**ty.type_() { + witx::Type::Builtin(witx::BuiltinType::Char) => quote!(str), + _ => { + let ty = self.names.type_ref(ty, anon_lifetime()); + quote!([#ty]) + } + }; + results.push(quote! { + #rt::GuestPtr::<#ty>::new(memory, (#ptr as u32, #len as u32)); + }) + } + + Instruction::CallInterface { func, .. } => { + // Use the `tracing` crate to log all arguments that are going + // out, and afterwards we call the function with those bindings. + let mut args = Vec::new(); + for (i, param) in func.params.iter().enumerate() { + let name = self.names.func_param(¶m.name); + let val = &operands[i]; + self.src.extend(quote!(let #name = #val;)); + if passed_by_reference(param.tref.type_()) { + args.push(quote!(&#name)); + } else { + args.push(quote!(#name)); + } + } + if func.params.len() > 0 { + let args = func + .params + .iter() + .map(|param| { + let name = self.names.func_param(¶m.name); + if param.impls_display() { + quote!( #name = #rt::tracing::field::display(&#name) ) + } else { + quote!( #name = #rt::tracing::field::debug(&#name) ) + } + }) + .collect::>(); + self.src.extend(quote! { + #rt::tracing::event!(#rt::tracing::Level::TRACE, #(#args),*); + }); + } + + let trait_name = self.names.trait_name(&self.module.name); + let ident = self.names.func(&func.name); + self.src.extend(quote! { + let ret = #trait_name::#ident(ctx, #(#args),*); + #rt::tracing::event!( + #rt::tracing::Level::TRACE, + result = #rt::tracing::field::debug(&ret), + ); + }); + + if func.results.len() > 0 { + results.push(quote!(ret)); + } else if func.noreturn { + self.src.extend(quote!(return Err(ret))); } } - witx::BuiltinType::String => { - let lifetime = anon_lifetime(); - let ptr_name = names.func_ptr_binding(¶m.name); - let len_name = names.func_len_binding(¶m.name); - let name = names.func_param(¶m.name); - quote! { - let #name = #rt::GuestPtr::<#lifetime, str>::new(memory, (#ptr_name as u32, #len_name as u32)); - } + + // Lowering an enum is typically simple but if we have an error + // transformation registered for this enum then what we're actually + // doing is lowering from a user-defined error type to the error + // enum, and *then* we lower to an i32. + Instruction::EnumLower { ty } => { + let val = operands.pop().unwrap(); + let val = match self.errxform.for_name(ty) { + Some(custom) => { + let method = self.names.user_error_conversion_method(&custom); + quote!(UserErrorConversion::#method(ctx, #val)?) + } + None => val, + }; + results.push(quote!(#val as i32)); } - }, - witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { - let pointee_type = names.type_ref(pointee, anon_lifetime()); - let name = names.func_param(¶m.name); - quote! { - let #name = #rt::GuestPtr::<#pointee_type>::new(memory, #name as u32); + + Instruction::ResultLower { err: err_ty, .. } => { + let err = self.blocks.pop().unwrap(); + let ok = self.blocks.pop().unwrap(); + let val = operands.pop().unwrap(); + let err_typename = self.names.type_ref(err_ty.unwrap(), anon_lifetime()); + results.push(quote! { + match #val { + Ok(e) => { #ok; <#err_typename as #rt::GuestErrorType>::success() as i32 } + Err(e) => { #err } + } + }); } - } - witx::Type::Struct(_) => read_conversion, - witx::Type::Array(arr) => { - let pointee_type = names.type_ref(arr, anon_lifetime()); - let ptr_name = names.func_ptr_binding(¶m.name); - let len_name = names.func_len_binding(¶m.name); - let name = names.func_param(¶m.name); - quote! { - let #name = #rt::GuestPtr::<[#pointee_type]>::new(memory, (#ptr_name as u32, #len_name as u32)); + + Instruction::VariantPayload => results.push(quote!(e)), + + Instruction::Return { amt: 0 } => {} + Instruction::Return { amt: 1 } => { + let val = operands.pop().unwrap(); + self.src.extend(quote!(return Ok(#val))); } - } - witx::Type::Union(_u) => read_conversion, - witx::Type::Handle(_h) => { - let name = names.func_param(¶m.name); - let handle_type = names.type_ref(tref, anon_lifetime()); - quote!( let #name = #handle_type::from(#name); ) + Instruction::Return { .. } => unimplemented!(), + + Instruction::TupleLower { amt } => { + let names = (0..*amt) + .map(|i| Ident::new(&format!("t{}", i), Span::call_site())) + .collect::>(); + let val = operands.pop().unwrap(); + self.src.extend(quote!( let (#(#names,)*) = #val;)); + results.extend(names.iter().map(|i| quote!(#i))); + } + + Instruction::Store { ty } => { + let ptr = operands.pop().unwrap(); + let val = operands.pop().unwrap(); + let wrap_err = wrap_err(&format!("write {}", ty.name.as_str())); + let pointee_type = self.names.type_(&ty.name); + self.src.extend(quote! { + #rt::GuestPtr::<#pointee_type>::new(memory, #ptr as u32) + .write(#val) + .map_err(#wrap_err)?; + }); + } + + Instruction::HandleFromI32 { ty } => { + let val = operands.pop().unwrap(); + let ty = self.names.type_(&ty.name); + results.push(quote!(#ty::from(#val))); + } + + // Smaller-than-32 numerical conversions are done with `TryFrom` to + // ensure we're not losing bits. + Instruction::U8FromI32 => try_from(quote!(u8)), + Instruction::S8FromI32 => try_from(quote!(i8)), + Instruction::Char8FromI32 => try_from(quote!(u8)), + Instruction::U16FromI32 => try_from(quote!(u16)), + Instruction::S16FromI32 => try_from(quote!(i16)), + + // Conversions with matching bit-widths but different signededness + // use `as` since we're basically just reinterpreting the bits. + Instruction::U32FromI32 => { + let val = operands.pop().unwrap(); + results.push(quote!(#val as u32)); + } + Instruction::U64FromI64 => { + let val = operands.pop().unwrap(); + results.push(quote!(#val as u64)); + } + + // Conversions to enums/bitflags use `TryFrom` to ensure that the + // values are valid coming in. + Instruction::EnumLift { ty } + | Instruction::BitflagsFromI64 { ty } + | Instruction::BitflagsFromI32 { ty } => { + let ty = self.names.type_(&ty.name); + try_from(quote!(#ty)) + } + + // No conversions necessary for these, the native wasm type matches + // our own representation. + Instruction::If32FromF32 + | Instruction::If64FromF64 + | Instruction::S32FromI32 + | Instruction::S64FromI64 => results.push(operands.pop().unwrap()), + + // There's a number of other instructions we could implement but + // they're not exercised by WASI at this time. As necessary we can + // add code to implement them. + other => panic!("no implementation for {:?}", other), } } } - -fn marshal_result( - names: &Names, - result: &witx::InterfaceFuncParam, - error_handling: F, -) -> (TokenStream, TokenStream) -where - F: Fn(&str) -> TokenStream, -{ - let rt = names.runtime_mod(); - let tref = &result.tref; - - let write_val_to_ptr = { - let pointee_type = names.type_ref(tref, anon_lifetime()); - // core type is given func_ptr_binding name. - let ptr_name = names.func_ptr_binding(&result.name); - let ptr_err_handling = error_handling(&format!("{}:result_ptr_mut", result.name.as_str())); - let pre = quote! { - let #ptr_name = #rt::GuestPtr::<#pointee_type>::new(memory, #ptr_name as u32); - }; - // trait binding returns func_param name. - let val_name = names.func_param(&result.name); - let post = quote! { - if let Err(e) = #ptr_name.write(#val_name) { - #ptr_err_handling - } - }; - (pre, post) - }; - - match &*tref.type_() { - witx::Type::Builtin(b) => match b { - witx::BuiltinType::String => unimplemented!("string result types"), - _ => write_val_to_ptr, - }, - witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } | witx::Type::Array { .. } => { - unimplemented!("pointer/array result types") - } - _ => write_val_to_ptr, - } -} diff --git a/crates/wiggle/generate/src/lib.rs b/crates/wiggle/generate/src/lib.rs index b9a07f93ce..3eb359bc67 100644 --- a/crates/wiggle/generate/src/lib.rs +++ b/crates/wiggle/generate/src/lib.rs @@ -6,11 +6,10 @@ mod module_trait; mod names; mod types; +use lifetimes::anon_lifetime; use proc_macro2::TokenStream; use quote::quote; -use lifetimes::anon_lifetime; - pub use config::Config; pub use error_transform::{ErrorTransform, UserErrorType}; pub use funcs::define_func; @@ -67,6 +66,8 @@ pub fn generate(doc: &witx::Document, names: &Names, errs: &ErrorTransform) -> T quote!( pub mod types { + use std::convert::TryFrom; + #(#types)* #guest_error_conversion #user_error_conversion diff --git a/crates/wiggle/generate/src/lifetimes.rs b/crates/wiggle/generate/src/lifetimes.rs index 75b102209c..87b1aad9a4 100644 --- a/crates/wiggle/generate/src/lifetimes.rs +++ b/crates/wiggle/generate/src/lifetimes.rs @@ -19,46 +19,37 @@ impl LifetimeExt for witx::Type { fn is_transparent(&self) -> bool { match self { witx::Type::Builtin(b) => b.is_transparent(), - witx::Type::Struct(s) => s.is_transparent(), - witx::Type::Enum { .. } - | witx::Type::Flags { .. } - | witx::Type::Int { .. } - | witx::Type::Handle { .. } => true, - witx::Type::Union { .. } + witx::Type::Record(s) => s.is_transparent(), + witx::Type::Handle { .. } => true, + witx::Type::Variant { .. } | witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } - | witx::Type::Array { .. } => false, + | witx::Type::List { .. } => false, } } fn needs_lifetime(&self) -> bool { match self { witx::Type::Builtin(b) => b.needs_lifetime(), - witx::Type::Struct(s) => s.needs_lifetime(), - witx::Type::Union(u) => u.needs_lifetime(), - witx::Type::Enum { .. } - | witx::Type::Flags { .. } - | witx::Type::Int { .. } - | witx::Type::Handle { .. } => false, + witx::Type::Record(s) => s.needs_lifetime(), + witx::Type::Variant(u) => u.needs_lifetime(), + witx::Type::Handle { .. } => false, witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } - | witx::Type::Array { .. } => true, + | witx::Type::List { .. } => true, } } } impl LifetimeExt for witx::BuiltinType { fn is_transparent(&self) -> bool { - !self.needs_lifetime() + true } fn needs_lifetime(&self) -> bool { - match self { - witx::BuiltinType::String => true, - _ => false, - } + false } } -impl LifetimeExt for witx::StructDatatype { +impl LifetimeExt for witx::RecordDatatype { fn is_transparent(&self) -> bool { self.members.iter().all(|m| m.tref.is_transparent()) } @@ -67,12 +58,12 @@ impl LifetimeExt for witx::StructDatatype { } } -impl LifetimeExt for witx::UnionDatatype { +impl LifetimeExt for witx::Variant { fn is_transparent(&self) -> bool { false } fn needs_lifetime(&self) -> bool { - self.variants + self.cases .iter() .any(|m| m.tref.as_ref().map(|t| t.needs_lifetime()).unwrap_or(false)) } diff --git a/crates/wiggle/generate/src/module_trait.rs b/crates/wiggle/generate/src/module_trait.rs index a7a7289d09..92e503ddda 100644 --- a/crates/wiggle/generate/src/module_trait.rs +++ b/crates/wiggle/generate/src/module_trait.rs @@ -7,17 +7,9 @@ use crate::names::Names; use witx::Module; pub fn passed_by_reference(ty: &witx::Type) -> bool { - let passed_by = match ty.passed_by() { - witx::TypePassedBy::Value { .. } => false, - witx::TypePassedBy::Pointer { .. } | witx::TypePassedBy::PointerLengthPair { .. } => true, - }; match ty { - witx::Type::Builtin(b) => match &*b { - witx::BuiltinType::String => true, - _ => passed_by, - }, - witx::Type::Pointer(_) | witx::Type::ConstPointer(_) | witx::Type::Array(_) => true, - _ => passed_by, + witx::Type::Pointer(_) | witx::Type::ConstPointer(_) | witx::Type::List(_) => true, + _ => false, } } @@ -49,28 +41,36 @@ pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) quote!(#arg_name: #arg_type) }); - let result = if !f.noreturn { - let rets = f - .results - .iter() - .skip(1) - .map(|ret| names.type_ref(&ret.tref, lifetime.clone())); - let err = f - .results - .get(0) - .map(|err_result| { - if let Some(custom_err) = errxform.for_abi_error(&err_result.tref) { - let tn = custom_err.typename(); - quote!(super::#tn) - } else { - names.type_ref(&err_result.tref, lifetime.clone()) - } - }) - .unwrap_or(quote!(())); - quote!( Result<(#(#rets),*), #err> ) - } else { - let rt = names.runtime_mod(); - quote!(#rt::Trap) + let rt = names.runtime_mod(); + let result = match f.results.len() { + 0 if f.noreturn => quote!(#rt::Trap), + 0 => quote!(()), + 1 => { + let (ok, err) = match &**f.results[0].tref.type_() { + witx::Type::Variant(v) => match v.as_expected() { + Some(p) => p, + None => unimplemented!("anonymous variant ref {:?}", v), + }, + _ => unimplemented!(), + }; + + let ok = match ok { + Some(ty) => names.type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + let err = match err { + Some(ty) => match errxform.for_abi_error(ty) { + Some(custom) => { + let tn = custom.typename(); + quote!(super::#tn) + } + None => names.type_ref(ty, lifetime.clone()), + }, + None => quote!(()), + }; + quote!(Result<#ok, #err>) + } + _ => unimplemented!(), }; if is_anonymous { diff --git a/crates/wiggle/generate/src/names.rs b/crates/wiggle/generate/src/names.rs index f6266ec41f..b9eb9e4752 100644 --- a/crates/wiggle/generate/src/names.rs +++ b/crates/wiggle/generate/src/names.rs @@ -2,7 +2,7 @@ use escaping::{escape_id, handle_2big_enum_variant, NamingConvention}; use heck::{ShoutySnakeCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; -use witx::{AtomType, BuiltinType, Id, Type, TypeRef}; +use witx::{BuiltinType, Id, Type, TypeRef, WasmType}; use crate::{lifetimes::LifetimeExt, UserErrorType}; @@ -32,15 +32,11 @@ impl Names { quote!(#ident) } - pub fn builtin_type(&self, b: BuiltinType, lifetime: TokenStream) -> TokenStream { + pub fn builtin_type(&self, b: BuiltinType) -> TokenStream { match b { - BuiltinType::String => { - let rt = self.runtime_mod(); - quote!(#rt::GuestPtr<#lifetime, str>) - } - BuiltinType::U8 => quote!(u8), + BuiltinType::U8 { .. } => quote!(u8), BuiltinType::U16 => quote!(u16), - BuiltinType::U32 => quote!(u32), + BuiltinType::U32 { .. } => quote!(u32), BuiltinType::U64 => quote!(u64), BuiltinType::S8 => quote!(i8), BuiltinType::S16 => quote!(i16), @@ -48,16 +44,16 @@ impl Names { BuiltinType::S64 => quote!(i64), BuiltinType::F32 => quote!(f32), BuiltinType::F64 => quote!(f64), - BuiltinType::Char8 => quote!(u8), - BuiltinType::USize => quote!(u32), + BuiltinType::Char => quote!(char), } } - pub fn atom_type(&self, atom: AtomType) -> TokenStream { - match atom { - AtomType::I32 => quote!(i32), - AtomType::I64 => quote!(i64), - AtomType::F32 => quote!(f32), - AtomType::F64 => quote!(f64), + + pub fn wasm_type(&self, ty: WasmType) -> TokenStream { + match ty { + WasmType::I32 => quote!(i32), + WasmType::I64 => quote!(i64), + WasmType::F32 => quote!(f32), + WasmType::F64 => quote!(f64), } } @@ -72,16 +68,44 @@ impl Names { } } TypeRef::Value(ty) => match &**ty { - Type::Builtin(builtin) => self.builtin_type(*builtin, lifetime.clone()), + Type::Builtin(builtin) => self.builtin_type(*builtin), Type::Pointer(pointee) | Type::ConstPointer(pointee) => { let rt = self.runtime_mod(); let pointee_type = self.type_ref(&pointee, lifetime.clone()); quote!(#rt::GuestPtr<#lifetime, #pointee_type>) } - Type::Array(pointee) => { - let rt = self.runtime_mod(); - let pointee_type = self.type_ref(&pointee, lifetime.clone()); - quote!(#rt::GuestPtr<#lifetime, [#pointee_type]>) + Type::List(pointee) => match &**pointee.type_() { + Type::Builtin(BuiltinType::Char) => { + let rt = self.runtime_mod(); + quote!(#rt::GuestPtr<#lifetime, str>) + } + _ => { + let rt = self.runtime_mod(); + let pointee_type = self.type_ref(&pointee, lifetime.clone()); + quote!(#rt::GuestPtr<#lifetime, [#pointee_type]>) + } + }, + Type::Variant(v) => match v.as_expected() { + Some((ok, err)) => { + let ok = match ok { + Some(ty) => self.type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + let err = match err { + Some(ty) => self.type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + quote!(Result<#ok, #err>) + } + None => unimplemented!("anonymous variant ref {:?}", tref), + }, + Type::Record(r) if r.is_tuple() => { + let types = r + .members + .iter() + .map(|m| self.type_ref(&m.tref, lifetime.clone())) + .collect::>(); + quote!((#(#types,)*)) } _ => unimplemented!("anonymous type ref {:?}", tref), }, @@ -144,14 +168,6 @@ impl Names { escape_id(id, NamingConvention::SnakeCase) } - pub fn func_core_arg(&self, arg: &witx::CoreParamType) -> Ident { - match arg.signifies { - witx::CoreParamSignifies::Value { .. } => self.func_param(&arg.param.name), - witx::CoreParamSignifies::PointerTo => self.func_ptr_binding(&arg.param.name), - witx::CoreParamSignifies::LengthOf => self.func_len_binding(&arg.param.name), - } - } - /// For when you need a {name}_ptr binding for passing a value by reference: pub fn func_ptr_binding(&self, id: &Id) -> Ident { format_ident!("{}_ptr", id.as_str().to_snake_case()) @@ -164,10 +180,9 @@ impl Names { fn builtin_name(b: &BuiltinType) -> &'static str { match b { - BuiltinType::String => "string", - BuiltinType::U8 => "u8", + BuiltinType::U8 { .. } => "u8", BuiltinType::U16 => "u16", - BuiltinType::U32 => "u32", + BuiltinType::U32 { .. } => "u32", BuiltinType::U64 => "u64", BuiltinType::S8 => "i8", BuiltinType::S16 => "i16", @@ -175,8 +190,7 @@ impl Names { BuiltinType::S64 => "i64", BuiltinType::F32 => "f32", BuiltinType::F64 => "f64", - BuiltinType::Char8 => "char8", - BuiltinType::USize => "usize", + BuiltinType::Char => "char", } } diff --git a/crates/wiggle/generate/src/types/enum.rs b/crates/wiggle/generate/src/types/enum.rs deleted file mode 100644 index cef56065ef..0000000000 --- a/crates/wiggle/generate/src/types/enum.rs +++ /dev/null @@ -1,120 +0,0 @@ -use super::{atom_token, int_repr_tokens}; -use crate::names::Names; - -use proc_macro2::TokenStream; -use quote::quote; - -pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenStream { - let ident = names.type_(&name); - let rt = names.runtime_mod(); - - let repr = int_repr_tokens(e.repr); - let abi_repr = atom_token(match e.repr { - witx::IntRepr::U8 | witx::IntRepr::U16 | witx::IntRepr::U32 => witx::AtomType::I32, - witx::IntRepr::U64 => witx::AtomType::I64, - }); - - let mut variant_names = vec![]; - let mut tryfrom_repr_cases = vec![]; - let mut to_repr_cases = vec![]; - let mut to_display = vec![]; - - for (n, variant) in e.variants.iter().enumerate() { - let variant_name = names.enum_variant(&variant.name); - let docs = variant.docs.trim(); - let ident_str = ident.to_string(); - let variant_str = variant_name.to_string(); - tryfrom_repr_cases.push(quote!(#n => Ok(#ident::#variant_name))); - to_repr_cases.push(quote!(#ident::#variant_name => #n as #repr)); - to_display.push(quote!(#ident::#variant_name => format!("{} ({}::{}({}))", #docs, #ident_str, #variant_str, #repr::from(*self)))); - variant_names.push(variant_name); - } - - quote! { - #[repr(#repr)] - #[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)] - pub enum #ident { - #(#variant_names),* - } - - impl ::std::fmt::Display for #ident { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - let to_str = match self { - #(#to_display,)* - }; - write!(f, "{}", to_str) - } - } - - impl ::std::convert::TryFrom<#repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #repr) -> Result<#ident, #rt::GuestError> { - match value as usize { - #(#tryfrom_repr_cases),*, - _ => Err( #rt::GuestError::InvalidEnumValue(stringify!(#ident))), - } - } - } - - impl ::std::convert::TryFrom<#abi_repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #abi_repr) -> Result<#ident, #rt::GuestError> { - #ident::try_from(value as #repr) - } - } - - impl From<#ident> for #repr { - fn from(e: #ident) -> #repr { - match e { - #(#to_repr_cases),* - } - } - } - - impl From<#ident> for #abi_repr { - fn from(e: #ident) -> #abi_repr { - #repr::from(e) as #abi_repr - } - } - - impl<'a> #rt::GuestType<'a> for #ident { - fn guest_size() -> u32 { - #repr::guest_size() - } - - fn guest_align() -> usize { - #repr::guest_align() - } - - fn read(location: & #rt::GuestPtr<#ident>) -> Result<#ident, #rt::GuestError> { - use std::convert::TryFrom; - let reprval = #repr::read(&location.cast())?; - let value = #ident::try_from(reprval)?; - Ok(value) - } - - fn write(location: & #rt::GuestPtr<'_, #ident>, val: Self) - -> Result<(), #rt::GuestError> - { - #repr::write(&location.cast(), #repr::from(val)) - } - } - - unsafe impl <'a> #rt::GuestTypeTransparent<'a> for #ident { - #[inline] - fn validate(location: *mut #ident) -> Result<(), #rt::GuestError> { - use std::convert::TryFrom; - // Validate value in memory using #ident::try_from(reprval) - let reprval = unsafe { (location as *mut #repr).read() }; - let _val = #ident::try_from(reprval)?; - Ok(()) - } - } - } -} - -impl super::WiggleType for witx::EnumDatatype { - fn impls_display(&self) -> bool { - true - } -} diff --git a/crates/wiggle/generate/src/types/flags.rs b/crates/wiggle/generate/src/types/flags.rs index cda13a2bfa..6082ed5f73 100644 --- a/crates/wiggle/generate/src/types/flags.rs +++ b/crates/wiggle/generate/src/types/flags.rs @@ -1,27 +1,24 @@ -use super::{atom_token, int_repr_tokens}; use crate::names::Names; use proc_macro2::{Literal, TokenStream}; use quote::quote; -use std::convert::TryFrom; -pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDatatype) -> TokenStream { +pub(super) fn define_flags( + names: &Names, + name: &witx::Id, + repr: witx::IntRepr, + record: &witx::RecordDatatype, +) -> TokenStream { let rt = names.runtime_mod(); let ident = names.type_(&name); - let repr = int_repr_tokens(f.repr); - let abi_repr = atom_token(match f.repr { - witx::IntRepr::U8 | witx::IntRepr::U16 | witx::IntRepr::U32 => witx::AtomType::I32, - witx::IntRepr::U64 => witx::AtomType::I64, - }); + let abi_repr = names.wasm_type(repr.into()); + let repr = super::int_repr_tokens(repr); let mut names_ = vec![]; let mut values_ = vec![]; - for (i, f) in f.flags.iter().enumerate() { - let name = names.flag_member(&f.name); - let value = 1u128 - .checked_shl(u32::try_from(i).expect("flag value overflow")) - .expect("flag value overflow"); - let value_token = Literal::u128_unsuffixed(value); + for (i, member) in record.members.iter().enumerate() { + let name = names.flag_member(&member.name); + let value_token = Literal::usize_unsuffixed(1 << i); names_.push(name); values_.push(value_token); } @@ -45,7 +42,7 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty } } - impl ::std::convert::TryFrom<#repr> for #ident { + impl TryFrom<#repr> for #ident { type Error = #rt::GuestError; fn try_from(value: #repr) -> Result { if #repr::from(!#ident::all()) & value != 0 { @@ -56,10 +53,10 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty } } - impl ::std::convert::TryFrom<#abi_repr> for #ident { + impl TryFrom<#abi_repr> for #ident { type Error = #rt::GuestError; - fn try_from(value: #abi_repr) -> Result<#ident, #rt::GuestError> { - #ident::try_from(value as #repr) + fn try_from(value: #abi_repr) -> Result { + #ident::try_from(#repr::try_from(value)?) } } @@ -69,12 +66,6 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty } } - impl From<#ident> for #abi_repr { - fn from(e: #ident) -> #abi_repr { - #repr::from(e) as #abi_repr - } - } - impl<'a> #rt::GuestType<'a> for #ident { fn guest_size() -> u32 { #repr::guest_size() @@ -106,12 +97,5 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty Ok(()) } } - - } -} - -impl super::WiggleType for witx::FlagsDatatype { - fn impls_display(&self) -> bool { - true } } diff --git a/crates/wiggle/generate/src/types/handle.rs b/crates/wiggle/generate/src/types/handle.rs index 310c9e7d6e..0f23419b0d 100644 --- a/crates/wiggle/generate/src/types/handle.rs +++ b/crates/wiggle/generate/src/types/handle.rs @@ -78,8 +78,6 @@ pub(super) fn define_handle( Ok(()) } } - - } } diff --git a/crates/wiggle/generate/src/types/int.rs b/crates/wiggle/generate/src/types/int.rs deleted file mode 100644 index d916b65b70..0000000000 --- a/crates/wiggle/generate/src/types/int.rs +++ /dev/null @@ -1,100 +0,0 @@ -use super::{atom_token, int_repr_tokens}; -use crate::names::Names; - -use proc_macro2::TokenStream; -use quote::quote; - -pub(super) fn define_int(names: &Names, name: &witx::Id, i: &witx::IntDatatype) -> TokenStream { - let rt = names.runtime_mod(); - let ident = names.type_(&name); - let repr = int_repr_tokens(i.repr); - let abi_repr = atom_token(match i.repr { - witx::IntRepr::U8 | witx::IntRepr::U16 | witx::IntRepr::U32 => witx::AtomType::I32, - witx::IntRepr::U64 => witx::AtomType::I64, - }); - let consts = i - .consts - .iter() - .map(|r#const| { - let const_ident = names.int_member(&r#const.name); - let value = r#const.value; - quote!(pub const #const_ident: #ident = #ident(#value)) - }) - .collect::>(); - - quote! { - #[repr(transparent)] - #[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)] - pub struct #ident(#repr); - - impl #ident { - #(#consts;)* - } - - impl ::std::fmt::Display for #ident { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "{:?}", self) - } - } - - impl ::std::convert::TryFrom<#repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #repr) -> Result { - Ok(#ident(value)) - } - } - - impl ::std::convert::TryFrom<#abi_repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #abi_repr) -> Result<#ident, #rt::GuestError> { - #ident::try_from(value as #repr) - } - } - - impl From<#ident> for #repr { - fn from(e: #ident) -> #repr { - e.0 - } - } - - impl From<#ident> for #abi_repr { - fn from(e: #ident) -> #abi_repr { - #repr::from(e) as #abi_repr - } - } - - impl<'a> #rt::GuestType<'a> for #ident { - fn guest_size() -> u32 { - #repr::guest_size() - } - - fn guest_align() -> usize { - #repr::guest_align() - } - - fn read(location: &#rt::GuestPtr<'a, #ident>) -> Result<#ident, #rt::GuestError> { - Ok(#ident(#repr::read(&location.cast())?)) - - } - - fn write(location: &#rt::GuestPtr<'_, #ident>, val: Self) -> Result<(), #rt::GuestError> { - #repr::write(&location.cast(), val.0) - } - } - - unsafe impl<'a> #rt::GuestTypeTransparent<'a> for #ident { - #[inline] - fn validate(_location: *mut #ident) -> Result<(), #rt::GuestError> { - // All bit patterns accepted - Ok(()) - } - } - - } -} - -impl super::WiggleType for witx::IntDatatype { - fn impls_display(&self) -> bool { - true - } -} diff --git a/crates/wiggle/generate/src/types/mod.rs b/crates/wiggle/generate/src/types/mod.rs index 6ed29af94b..255c2ce7b2 100644 --- a/crates/wiggle/generate/src/types/mod.rs +++ b/crates/wiggle/generate/src/types/mod.rs @@ -1,9 +1,8 @@ -mod r#enum; +// mod r#enum; mod flags; mod handle; -mod int; -mod r#struct; -mod union; +mod record; +mod variant; use crate::lifetimes::LifetimeExt; use crate::names::Names; @@ -15,11 +14,11 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea match &namedtype.tref { witx::TypeRef::Name(alias_to) => define_alias(names, &namedtype.name, &alias_to), witx::TypeRef::Value(v) => match &**v { - witx::Type::Enum(e) => r#enum::define_enum(names, &namedtype.name, &e), - witx::Type::Int(i) => int::define_int(names, &namedtype.name, &i), - witx::Type::Flags(f) => flags::define_flags(names, &namedtype.name, &f), - witx::Type::Struct(s) => r#struct::define_struct(names, &namedtype.name, &s), - witx::Type::Union(u) => union::define_union(names, &namedtype.name, &u), + witx::Type::Record(r) => match r.bitflags_repr() { + Some(repr) => flags::define_flags(names, &namedtype.name, repr, &r), + None => record::define_struct(names, &namedtype.name, &r), + }, + witx::Type::Variant(v) => variant::define_variant(names, &namedtype.name, &v), witx::Type::Handle(h) => handle::define_handle(names, &namedtype.name, &h), witx::Type::Builtin(b) => define_builtin(names, &namedtype.name, *b), witx::Type::Pointer(p) => { @@ -30,7 +29,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea let rt = names.runtime_mod(); define_witx_pointer(names, &namedtype.name, quote!(#rt::GuestPtr), p) } - witx::Type::Array(arr) => define_witx_array(names, &namedtype.name, &arr), + witx::Type::List(arr) => define_witx_list(names, &namedtype.name, &arr), }, } } @@ -47,12 +46,8 @@ fn define_alias(names: &Names, name: &witx::Id, to: &witx::NamedType) -> TokenSt fn define_builtin(names: &Names, name: &witx::Id, builtin: witx::BuiltinType) -> TokenStream { let ident = names.type_(name); - let built = names.builtin_type(builtin, quote!('a)); - if builtin.needs_lifetime() { - quote!(pub type #ident<'a> = #built;) - } else { - quote!(pub type #ident = #built;) - } + let built = names.builtin_type(builtin); + quote!(pub type #ident = #built;) } fn define_witx_pointer( @@ -67,14 +62,14 @@ fn define_witx_pointer( quote!(pub type #ident<'a> = #pointer_type<'a, #pointee_type>;) } -fn define_witx_array(names: &Names, name: &witx::Id, arr_raw: &witx::TypeRef) -> TokenStream { +fn define_witx_list(names: &Names, name: &witx::Id, arr_raw: &witx::TypeRef) -> TokenStream { let ident = names.type_(name); let rt = names.runtime_mod(); let pointee_type = names.type_ref(arr_raw, quote!('a)); quote!(pub type #ident<'a> = #rt::GuestPtr<'a, [#pointee_type]>;) } -fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { +pub fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { match int_repr { witx::IntRepr::U8 => quote!(u8), witx::IntRepr::U16 => quote!(u16), @@ -83,15 +78,6 @@ fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { } } -fn atom_token(atom: witx::AtomType) -> TokenStream { - match atom { - witx::AtomType::I32 => quote!(i32), - witx::AtomType::I64 => quote!(i64), - witx::AtomType::F32 => quote!(f32), - witx::AtomType::F64 => quote!(f64), - } -} - pub trait WiggleType { fn impls_display(&self) -> bool; } @@ -114,16 +100,13 @@ impl WiggleType for witx::NamedType { impl WiggleType for witx::Type { fn impls_display(&self) -> bool { match self { - witx::Type::Enum(x) => x.impls_display(), - witx::Type::Int(x) => x.impls_display(), - witx::Type::Flags(x) => x.impls_display(), - witx::Type::Struct(x) => x.impls_display(), - witx::Type::Union(x) => x.impls_display(), + witx::Type::Record(x) => x.impls_display(), + witx::Type::Variant(x) => x.impls_display(), witx::Type::Handle(x) => x.impls_display(), witx::Type::Builtin(x) => x.impls_display(), witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } - | witx::Type::Array { .. } => false, + | witx::Type::List { .. } => false, } } } @@ -136,11 +119,6 @@ impl WiggleType for witx::BuiltinType { impl WiggleType for witx::InterfaceFuncParam { fn impls_display(&self) -> bool { - match &*self.tref.type_() { - witx::Type::Struct { .. } - | witx::Type::Union { .. } - | witx::Type::Builtin(witx::BuiltinType::String { .. }) => false, - _ => self.tref.impls_display(), - } + self.tref.impls_display() } } diff --git a/crates/wiggle/generate/src/types/struct.rs b/crates/wiggle/generate/src/types/record.rs similarity index 94% rename from crates/wiggle/generate/src/types/struct.rs rename to crates/wiggle/generate/src/types/record.rs index 2a79a6e9ac..eaabcc6426 100644 --- a/crates/wiggle/generate/src/types/struct.rs +++ b/crates/wiggle/generate/src/types/record.rs @@ -8,7 +8,7 @@ use witx::Layout; pub(super) fn define_struct( names: &Names, name: &witx::Id, - s: &witx::StructDatatype, + s: &witx::RecordDatatype, ) -> TokenStream { let rt = names.runtime_mod(); let ident = names.type_(name); @@ -28,7 +28,7 @@ pub(super) fn define_struct( } } witx::TypeRef::Value(ty) => match &**ty { - witx::Type::Builtin(builtin) => names.builtin_type(*builtin, quote!('a)), + witx::Type::Builtin(builtin) => names.builtin_type(*builtin), witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { let pointee_type = names.type_ref(&pointee, quote!('a)); quote!(#rt::GuestPtr<'a, #pointee_type>) @@ -52,9 +52,9 @@ pub(super) fn define_struct( } witx::TypeRef::Value(ty) => match &**ty { witx::Type::Builtin(builtin) => { - let type_ = names.builtin_type(*builtin, anon_lifetime()); + let type_ = names.builtin_type(*builtin); quote! { - let #name = <#type_ as #rt::GuestType>::read(&#location)?; + let #name = <#type_ as #rt::GuestType>::read(&#location)?; } } witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { @@ -141,7 +141,7 @@ pub(super) fn define_struct( } } -impl super::WiggleType for witx::StructDatatype { +impl super::WiggleType for witx::RecordDatatype { fn impls_display(&self) -> bool { false } diff --git a/crates/wiggle/generate/src/types/union.rs b/crates/wiggle/generate/src/types/variant.rs similarity index 50% rename from crates/wiggle/generate/src/types/union.rs rename to crates/wiggle/generate/src/types/variant.rs index ecc3253f3c..6568228ded 100644 --- a/crates/wiggle/generate/src/types/union.rs +++ b/crates/wiggle/generate/src/types/variant.rs @@ -1,23 +1,23 @@ use crate::lifetimes::LifetimeExt; use crate::names::Names; -use proc_macro2::TokenStream; +use proc_macro2::{Literal, TokenStream}; use quote::quote; use witx::Layout; -pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDatatype) -> TokenStream { +pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) -> TokenStream { let rt = names.runtime_mod(); let ident = names.type_(name); - let size = u.mem_size_align().size as u32; - let align = u.mem_size_align().align as usize; - let ulayout = u.union_layout(); - let contents_offset = ulayout.contents_offset as u32; + let size = v.mem_size_align().size as u32; + let align = v.mem_size_align().align as usize; + let contents_offset = v.payload_offset() as u32; let lifetime = quote!('a); + let tag_ty = super::int_repr_tokens(v.tag_repr); - let variants = u.variants.iter().map(|v| { - let var_name = names.enum_variant(&v.name); - if let Some(tref) = &v.tref { + let variants = v.cases.iter().map(|c| { + let var_name = names.enum_variant(&c.name); + if let Some(tref) = &c.tref { let var_type = names.type_ref(&tref, lifetime.clone()); quote!(#var_name(#var_type)) } else { @@ -25,30 +25,29 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty } }); - let tagname = names.type_(&u.tag.name); - - let read_variant = u.variants.iter().map(|v| { - let variantname = names.enum_variant(&v.name); - if let Some(tref) = &v.tref { + let read_variant = v.cases.iter().enumerate().map(|(i, c)| { + let i = Literal::usize_unsuffixed(i); + let variantname = names.enum_variant(&c.name); + if let Some(tref) = &c.tref { let varianttype = names.type_ref(tref, lifetime.clone()); quote! { - #tagname::#variantname => { + #i => { let variant_ptr = location.cast::().add(#contents_offset)?; let variant_val = <#varianttype as #rt::GuestType>::read(&variant_ptr.cast())?; Ok(#ident::#variantname(variant_val)) } } } else { - quote! { #tagname::#variantname => Ok(#ident::#variantname), } + quote! { #i => Ok(#ident::#variantname), } } }); - let write_variant = u.variants.iter().map(|v| { - let variantname = names.enum_variant(&v.name); + let write_variant = v.cases.iter().enumerate().map(|(i, c)| { + let variantname = names.enum_variant(&c.name); let write_tag = quote! { - location.cast().write(#tagname::#variantname)?; + location.cast().write(#i as #tag_ty)?; }; - if let Some(tref) = &v.tref { + if let Some(tref) = &c.tref { let varianttype = names.type_ref(tref, lifetime.clone()); quote! { #ident::#variantname(contents) => { @@ -66,7 +65,36 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty } }); - let (enum_lifetime, extra_derive) = if u.needs_lifetime() { + let enum_try_from = if v.cases.iter().all(|c| c.tref.is_none()) { + let tryfrom_repr_cases = v.cases.iter().enumerate().map(|(i, c)| { + let variant_name = names.enum_variant(&c.name); + let n = Literal::usize_unsuffixed(i); + quote!(#n => Ok(#ident::#variant_name)) + }); + let abi_ty = names.wasm_type(v.tag_repr.into()); + quote! { + impl TryFrom<#tag_ty> for #ident { + type Error = #rt::GuestError; + fn try_from(value: #tag_ty) -> Result<#ident, #rt::GuestError> { + match value { + #(#tryfrom_repr_cases),*, + _ => Err( #rt::GuestError::InvalidEnumValue(stringify!(#ident))), + } + } + } + + impl TryFrom<#abi_ty> for #ident { + type Error = #rt::GuestError; + fn try_from(value: #abi_ty) -> Result<#ident, #rt::GuestError> { + #ident::try_from(#tag_ty::try_from(value)?) + } + } + } + } else { + quote!() + }; + + let (enum_lifetime, extra_derive) = if v.needs_lifetime() { (quote!(<'a>), quote!()) } else { (quote!(), quote!(, PartialEq)) @@ -78,6 +106,8 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty #(#variants),* } + #enum_try_from + impl<'a> #rt::GuestType<'a> for #ident #enum_lifetime { fn guest_size() -> u32 { #size @@ -90,9 +120,10 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty fn read(location: &#rt::GuestPtr<'a, Self>) -> Result { - let tag = location.cast().read()?; + let tag = location.cast::<#tag_ty>().read()?; match tag { #(#read_variant)* + _ => Err(#rt::GuestError::InvalidEnumValue(stringify!(#ident))), } } @@ -109,7 +140,7 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty } } -impl super::WiggleType for witx::UnionDatatype { +impl super::WiggleType for witx::Variant { fn impls_display(&self) -> bool { false } diff --git a/crates/wiggle/macro/Cargo.toml b/crates/wiggle/macro/Cargo.toml index c18c3dee74..baee8ce164 100644 --- a/crates/wiggle/macro/Cargo.toml +++ b/crates/wiggle/macro/Cargo.toml @@ -22,7 +22,7 @@ doctest = false [dependencies] wiggle-generate = { path = "../generate", version = "0.23.0" } -witx = { path = "../../wasi-common/WASI/tools/witx", version = "0.8.7" } +witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx" } quote = "1.0" syn = { version = "1.0", features = ["full"] } diff --git a/crates/wiggle/src/lib.rs b/crates/wiggle/src/lib.rs index 4ecbf27da5..24a4ca7c40 100644 --- a/crates/wiggle/src/lib.rs +++ b/crates/wiggle/src/lib.rs @@ -942,3 +942,9 @@ pub enum Trap { /// Any other Trap is just an unstructured String, for reporting and debugging. String(String), } + +impl From for Trap { + fn from(err: GuestError) -> Trap { + Trap::String(err.to_string()) + } +} diff --git a/crates/wiggle/wasmtime/Cargo.toml b/crates/wiggle/wasmtime/Cargo.toml index 97ca34ebc7..31f45eeaf3 100644 --- a/crates/wiggle/wasmtime/Cargo.toml +++ b/crates/wiggle/wasmtime/Cargo.toml @@ -13,7 +13,7 @@ include = ["src/**/*", "LICENSE"] [dependencies] wasmtime = { path = "../../wasmtime", version = "0.23.0", default-features = false } wasmtime-wiggle-macro = { path = "./macro", version = "0.23.0" } -witx = { path = "../../wasi-common/WASI/tools/witx", version = "0.8.7", optional = true } +witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx", optional = true } wiggle = { path = "..", version = "0.23.0" } wiggle-borrow = { path = "../borrow", version = "0.23.0" } diff --git a/crates/wiggle/wasmtime/macro/Cargo.toml b/crates/wiggle/wasmtime/macro/Cargo.toml index 2ac0b1465d..fd4b365f0d 100644 --- a/crates/wiggle/wasmtime/macro/Cargo.toml +++ b/crates/wiggle/wasmtime/macro/Cargo.toml @@ -15,7 +15,7 @@ proc-macro = true test = false [dependencies] -witx = { path = "../../../wasi-common/WASI/tools/witx", version = "0.8.7" } +witx = { version = "0.9", path = "../../../wasi-common/WASI/tools/witx" } wiggle-generate = { path = "../../generate", version = "0.23.0" } quote = "1.0" syn = { version = "1.0", features = ["full", "extra-traits"] } diff --git a/crates/wiggle/wasmtime/macro/src/config.rs b/crates/wiggle/wasmtime/macro/src/config.rs index 6359ed0ec4..5e95bad957 100644 --- a/crates/wiggle/wasmtime/macro/src/config.rs +++ b/crates/wiggle/wasmtime/macro/src/config.rs @@ -144,7 +144,6 @@ impl Parse for TargetConf { enum ModuleConfField { Name(Ident), Docs(String), - FunctionOverride(FunctionOverrideConf), } impl Parse for ModuleConfField { @@ -159,10 +158,6 @@ impl Parse for ModuleConfField { input.parse::()?; let docs: syn::LitStr = input.parse()?; Ok(ModuleConfField::Docs(docs.value())) - } else if lookahead.peek(kw::function_override) { - input.parse::()?; - input.parse::()?; - Ok(ModuleConfField::FunctionOverride(input.parse()?)) } else { Err(lookahead.error()) } @@ -173,14 +168,12 @@ impl Parse for ModuleConfField { pub struct ModuleConf { pub name: Ident, pub docs: Option, - pub function_override: FunctionOverrideConf, } impl ModuleConf { fn build(fields: impl Iterator, err_loc: Span) -> Result { let mut name = None; let mut docs = None; - let mut function_override = None; for f in fields { match f { ModuleConfField::Name(c) => { @@ -195,18 +188,11 @@ impl ModuleConf { } docs = Some(c); } - ModuleConfField::FunctionOverride(c) => { - if function_override.is_some() { - return Err(Error::new(err_loc, "duplicate `function_override` field")); - } - function_override = Some(c); - } } } Ok(ModuleConf { name: name.ok_or_else(|| Error::new(err_loc, "`name` field required"))?, docs, - function_override: function_override.unwrap_or_default(), }) } } @@ -248,42 +234,3 @@ impl Parse for ModulesConf { }) } } - -#[derive(Debug, Clone, Default)] -pub struct FunctionOverrideConf { - pub funcs: Vec, -} -impl FunctionOverrideConf { - pub fn find(&self, name: &str) -> Option<&Ident> { - self.funcs - .iter() - .find(|f| f.name == name) - .map(|f| &f.replacement) - } -} - -impl Parse for FunctionOverrideConf { - fn parse(input: ParseStream) -> Result { - let contents; - let _lbrace = braced!(contents in input); - let fields: Punctuated = - contents.parse_terminated(FunctionOverrideField::parse)?; - Ok(FunctionOverrideConf { - funcs: fields.into_iter().collect(), - }) - } -} - -#[derive(Debug, Clone)] -pub struct FunctionOverrideField { - pub name: String, - pub replacement: Ident, -} -impl Parse for FunctionOverrideField { - fn parse(input: ParseStream) -> Result { - let name = input.parse::()?.to_string(); - input.parse::]>()?; - let replacement = input.parse::()?; - Ok(FunctionOverrideField { name, replacement }) - } -} diff --git a/crates/wiggle/wasmtime/macro/src/lib.rs b/crates/wiggle/wasmtime/macro/src/lib.rs index 7f13308f16..a2d05f729d 100644 --- a/crates/wiggle/wasmtime/macro/src/lib.rs +++ b/crates/wiggle/wasmtime/macro/src/lib.rs @@ -1,5 +1,5 @@ use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; use quote::quote; use syn::parse_macro_input; use wiggle_generate::Names; @@ -88,14 +88,9 @@ fn generate_module( let module_id = names.module(&module.name); let target_module = quote! { #target_path::#module_id }; - let ctor_externs = module.funcs().map(|f| { - if let Some(func_override) = module_conf.function_override.find(&f.name.as_str()) { - let name_ident = names.func(&f.name); - quote! { let #name_ident = wasmtime::Func::wrap(store, #func_override); } - } else { - generate_func(&f, names, &target_module) - } - }); + let ctor_externs = module + .funcs() + .map(|f| generate_func(&f, names, &target_module)); let type_name = module_conf.name.clone(); let type_docs = module_conf @@ -158,23 +153,21 @@ fn generate_func( ) -> TokenStream2 { let name_ident = names.func(&func.name); - let coretype = func.core_type(); + let (params, results) = func.wasm_signature(); - let arg_decls = coretype.args.iter().map(|arg| { - let name = names.func_core_arg(arg); - let atom = names.atom_type(arg.repr()); - quote! { #name: #atom } + let arg_names = (0..params.len()) + .map(|i| Ident::new(&format!("arg{}", i), Span::call_site())) + .collect::>(); + let arg_decls = params.iter().enumerate().map(|(i, ty)| { + let name = &arg_names[i]; + let wasm = names.wasm_type(*ty); + quote! { #name: #wasm } }); - let arg_names = coretype.args.iter().map(|arg| names.func_core_arg(arg)); - let ret_ty = if let Some(ret) = &coretype.ret { - let ret_ty = match ret.signifies { - witx::CoreParamSignifies::Value(atom) => names.atom_type(atom), - _ => unreachable!("coretype ret should always be passed by value"), - }; - quote! { #ret_ty } - } else { - quote! {()} + let ret_ty = match results.len() { + 0 => quote!(()), + 1 => names.wasm_type(results[0]), + _ => unimplemented!(), }; let runtime = names.runtime_mod();