diff --git a/.gitmodules b/.gitmodules index ee264b99c4..305858152d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,7 +6,8 @@ url = https://github.com/WebAssembly/wasm-c-api [submodule "WASI"] path = crates/wasi-common/WASI - url = https://github.com/WebAssembly/WASI + url = https://github.com/alexcrichton/WASI + branch = abis [submodule "crates/wasi-nn/spec"] path = crates/wasi-nn/spec url = https://github.com/WebAssembly/wasi-nn diff --git a/Cargo.lock b/Cargo.lock index 5948052629..544b8bd861 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3499,18 +3499,18 @@ dependencies = [ [[package]] name = "wast" -version = "22.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe1220ed7f824992b426a76125a3403d048eaf0f627918e97ade0d9b9d510d20" +checksum = "c24a3ee360d01d60ed0a0f960ab76a6acce64348cdb0bf8699c2a866fad57c7c" dependencies = [ "leb128", ] [[package]] name = "wast" -version = "32.0.0" +version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c24a3ee360d01d60ed0a0f960ab76a6acce64348cdb0bf8699c2a866fad57c7c" +checksum = "1d04fe175c7f78214971293e7d8875673804e736092206a3a4544dbc12811c1b" dependencies = [ "leb128", ] @@ -3634,15 +3634,16 @@ dependencies = [ [[package]] name = "witx" -version = "0.8.8" +version = "0.9.0" dependencies = [ "anyhow", "diff", "log", "pretty_env_logger", + "rayon", "structopt", "thiserror", - "wast 22.0.0", + "wast 33.0.0", ] [[package]] diff --git a/crates/wasi-common/WASI b/crates/wasi-common/WASI index 8deb71ddd0..7c4fd252d0 160000 --- a/crates/wasi-common/WASI +++ b/crates/wasi-common/WASI @@ -1 +1 @@ -Subproject commit 8deb71ddd0955101cb69333b08284e7b01775928 +Subproject commit 7c4fd252d0841488de4a6e724e600f1561797387 diff --git a/crates/wiggle/Cargo.toml b/crates/wiggle/Cargo.toml index d7bba56082..88c8315bdc 100644 --- a/crates/wiggle/Cargo.toml +++ b/crates/wiggle/Cargo.toml @@ -12,7 +12,7 @@ include = ["src/**/*", "LICENSE"] [dependencies] thiserror = "1" -witx = { path = "../wasi-common/WASI/tools/witx", version = "0.8.7", optional = true } +witx = { path = "../wasi-common/WASI/tools/witx", version = "0.9", optional = true } wiggle-macro = { path = "macro", version = "0.23.0" } tracing = "0.1.15" bitflags = "1.2" diff --git a/crates/wiggle/generate/Cargo.toml b/crates/wiggle/generate/Cargo.toml index 54cb70d20d..1a7459a519 100644 --- a/crates/wiggle/generate/Cargo.toml +++ b/crates/wiggle/generate/Cargo.toml @@ -14,7 +14,7 @@ include = ["src/**/*", "LICENSE"] [lib] [dependencies] -witx = { version = "0.8.7", path = "../../wasi-common/WASI/tools/witx" } +witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx" } quote = "1.0" proc-macro2 = "1.0" heck = "0.3" diff --git a/crates/wiggle/generate/src/error_transform.rs b/crates/wiggle/generate/src/error_transform.rs index 3598e221de..56fa10eaf0 100644 --- a/crates/wiggle/generate/src/error_transform.rs +++ b/crates/wiggle/generate/src/error_transform.rs @@ -49,10 +49,14 @@ impl ErrorTransform { pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&UserErrorType> { match tref { - TypeRef::Name(nt) => self.m.iter().find(|u| u.abi_type.name == nt.name), + TypeRef::Name(nt) => self.for_name(nt), TypeRef::Value { .. } => None, } } + + pub fn for_name(&self, nt: &NamedType) -> Option<&UserErrorType> { + self.m.iter().find(|u| u.abi_type.name == nt.name) + } } pub struct UserErrorType { diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index 50f5eec640..8e164f878b 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -1,11 +1,12 @@ -use proc_macro2::TokenStream; -use quote::quote; - use crate::error_transform::ErrorTransform; use crate::lifetimes::anon_lifetime; use crate::module_trait::passed_by_reference; use crate::names::Names; use crate::types::WiggleType; +use proc_macro2::{Ident, Span, TokenStream}; +use quote::quote; +use std::mem; +use witx::Instruction; pub fn define_func( names: &Names, @@ -13,163 +14,54 @@ pub fn define_func( func: &witx::InterfaceFunc, errxform: &ErrorTransform, ) -> TokenStream { - let funcname = func.name.as_str(); - - let ident = names.func(&func.name); let rt = names.runtime_mod(); + let ident = names.func(&func.name); let ctx_type = names.ctx_type(); - let coretype = func.core_type(); - let params = coretype.args.iter().map(|arg| { - let name = names.func_core_arg(arg); - let atom = names.atom_type(arg.repr()); - quote!(#name : #atom) + let (wasm_params, wasm_results) = func.wasm_signature(); + let param_names = (0..wasm_params.len()) + .map(|i| Ident::new(&format!("arg{}", i), Span::call_site())) + .collect::>(); + let abi_params = wasm_params.iter().zip(¶m_names).map(|(arg, name)| { + let wasm = names.wasm_type(*arg); + quote!(#name : #wasm) }); - let abi_args = quote!( - ctx: &#ctx_type, - memory: &dyn #rt::GuestMemory, - #(#params),* + let abi_ret = match wasm_results.len() { + 0 => quote!(()), + 1 => { + let ty = names.wasm_type(wasm_results[0]); + quote!(#ty) + } + _ => unimplemented!(), + }; + + let mut body = TokenStream::new(); + func.call_interface( + &module.name, + &mut Rust { + src: &mut body, + params: ¶m_names, + block_storage: Vec::new(), + blocks: Vec::new(), + rt: &rt, + names, + module, + funcname: func.name.as_str(), + errxform, + }, ); - let abi_ret = if let Some(ret) = &coretype.ret { - match ret.signifies { - witx::CoreParamSignifies::Value(atom) => names.atom_type(atom), - _ => unreachable!("ret should always be passed by value"), - } - } else { - quote!(()) - }; - let err_type = coretype.ret.clone().map(|ret| ret.param.tref); - let ret_err = coretype - .ret - .map(|ret| { - let name = names.func_param(&ret.param.name); - let conversion = if let Some(user_err) = errxform.for_abi_error(&ret.param.tref) { - let method = names.user_error_conversion_method(&user_err); - quote!(UserErrorConversion::#method(ctx, e)) - } else { - quote!(Ok(e)) - }; - quote! { - let e = #conversion; - #rt::tracing::event!( - #rt::tracing::Level::TRACE, - #name = #rt::tracing::field::debug(&e), - ); - match e { - Ok(e) => { return Ok(#abi_ret::from(e)); }, - Err(e) => { return Err(e); }, - } - } - }) - .unwrap_or_else(|| quote!(())); - - let error_handling = |location: &str| -> TokenStream { - if let Some(tref) = &err_type { - let abi_ret = match tref.type_().passed_by() { - witx::TypePassedBy::Value(atom) => names.atom_type(atom), - _ => unreachable!("err should always be passed by value"), - }; - let err_typename = names.type_ref(&tref, anon_lifetime()); - let err_method = names.guest_error_conversion_method(&tref); - quote! { - let e = #rt::GuestError::InFunc { funcname: #funcname, location: #location, err: Box::new(e.into()) }; - let err: #err_typename = GuestErrorConversion::#err_method(ctx, e); - return Ok(#abi_ret::from(err)); - } - } else { - quote! { - panic!("error: {:?}", e) - } - } - }; - - let marshal_args = func - .params - .iter() - .map(|p| marshal_arg(names, p, error_handling(p.name.as_str()))); - let trait_args = func.params.iter().map(|param| { - let name = names.func_param(¶m.name); - if passed_by_reference(&*param.tref.type_()) { - quote!(&#name) - } else { - quote!(#name) - } - }); - - let log_marshalled_args = if func.params.len() > 0 { - let rt = names.runtime_mod(); - let args = func.params.iter().map(|param| { - let name = names.func_param(¶m.name); - if param.impls_display() { - quote!( #name = #rt::tracing::field::display(&#name) ) - } else { - quote!( #name = #rt::tracing::field::debug(&#name) ) - } - }); - quote! { - #rt::tracing::event!(#rt::tracing::Level::TRACE, #(#args),*); - } - } else { - quote!() - }; - - let (trait_rets, trait_bindings) = if func.results.len() < 2 { - (quote!({}), quote!(_)) - } else { - let trait_rets: Vec<_> = func - .results - .iter() - .skip(1) - .map(|result| names.func_param(&result.name)) - .collect(); - let bindings = quote!((#(#trait_rets),*)); - let trace_rets = func.results.iter().skip(1).map(|result| { - let name = names.func_param(&result.name); - if result.tref.impls_display() { - quote!(#name = #rt::tracing::field::display(&#name)) - } else { - quote!(#name = #rt::tracing::field::debug(&#name)) - } - }); - let rets = quote! { - #rt::tracing::event!(#rt::tracing::Level::TRACE, #(#trace_rets),*); - (#(#trait_rets),*) - }; - (rets, bindings) - }; - - // Return value pointers need to be validated before the api call, then - // assigned to afterwards. marshal_result returns these two statements as a pair. - let marshal_rets = func - .results - .iter() - .skip(1) - .map(|result| marshal_result(names, result, &error_handling)); - let marshal_rets_pre = marshal_rets.clone().map(|(pre, _post)| pre); - let marshal_rets_post = marshal_rets.map(|(_pre, post)| post); - - let success = if let Some(ref err_type) = err_type { - let err_typename = names.type_ref(&err_type, anon_lifetime()); - quote! { - let success:#err_typename = #rt::GuestErrorType::success(); - #rt::tracing::event!( - #rt::tracing::Level::TRACE, - success=#rt::tracing::field::display(&success) - ); - Ok(#abi_ret::from(success)) - } - } else { - quote!(Ok(())) - }; - - let trait_name = names.trait_name(&module.name); let mod_name = &module.name.as_str(); let func_name = &func.name.as_str(); + quote! { + pub fn #ident( + ctx: &#ctx_type, + memory: &dyn #rt::GuestMemory, + #(#abi_params),* + ) -> Result<#abi_ret, #rt::Trap> { + use std::convert::TryFrom as _; - if func.noreturn { - quote!(pub fn #ident(#abi_args) -> Result<#abi_ret, #rt::Trap> { let _span = #rt::tracing::span!( #rt::tracing::Level::TRACE, "wiggle abi", @@ -178,176 +70,259 @@ pub fn define_func( ); let _enter = _span.enter(); - #(#marshal_args)* - #log_marshalled_args - let trap = #trait_name::#ident(ctx, #(#trait_args),*); - Err(trap) - }) - } else { - quote!(pub fn #ident(#abi_args) -> Result<#abi_ret, #rt::Trap> { - let _span = #rt::tracing::span!( - #rt::tracing::Level::TRACE, - "wiggle abi", - module = #mod_name, - function = #func_name - ); - let _enter = _span.enter(); - - #(#marshal_args)* - #(#marshal_rets_pre)* - #log_marshalled_args - let #trait_bindings = match #trait_name::#ident(ctx, #(#trait_args),*) { - Ok(#trait_bindings) => { #trait_rets }, - Err(e) => { #ret_err }, - }; - #(#marshal_rets_post)* - #success - }) + #body + } } } -fn marshal_arg( - names: &Names, - param: &witx::InterfaceFuncParam, - error_handling: TokenStream, -) -> TokenStream { - let rt = names.runtime_mod(); - let tref = ¶m.tref; - let interface_typename = names.type_ref(&tref, anon_lifetime()); +struct Rust<'a> { + src: &'a mut TokenStream, + params: &'a [Ident], + block_storage: Vec, + blocks: Vec, + rt: &'a TokenStream, + names: &'a Names, + module: &'a witx::Module, + funcname: &'a str, + errxform: &'a ErrorTransform, +} - let try_into_conversion = { - let name = names.func_param(¶m.name); - quote! { - let #name: #interface_typename = { - use ::std::convert::TryInto; - match #name.try_into() { - Ok(a) => a, - Err(e) => { - #error_handling - } +impl witx::Bindgen for Rust<'_> { + type Operand = TokenStream; + + fn push_block(&mut self) { + let prev = mem::replace(self.src, TokenStream::new()); + self.block_storage.push(prev); + } + + fn finish_block(&mut self, operand: Option) { + let to_restore = self.block_storage.pop().unwrap(); + let src = mem::replace(self.src, to_restore); + match operand { + None => self.blocks.push(src), + Some(s) => { + if src.is_empty() { + self.blocks.push(s); + } else { + self.blocks.push(quote!({ #src; #s })); } - }; - } - }; - - let read_conversion = { - let pointee_type = names.type_ref(tref, anon_lifetime()); - let arg_name = names.func_ptr_binding(¶m.name); - let name = names.func_param(¶m.name); - quote! { - let #name = match #rt::GuestPtr::<#pointee_type>::new(memory, #arg_name as u32).read() { - Ok(r) => r, - Err(e) => { - #error_handling - } - }; - } - }; - - match &*tref.type_() { - witx::Type::Enum(_e) => try_into_conversion, - witx::Type::Flags(_f) => try_into_conversion, - witx::Type::Int(_i) => try_into_conversion, - witx::Type::Builtin(b) => match b { - witx::BuiltinType::U8 | witx::BuiltinType::U16 | witx::BuiltinType::Char8 => { - try_into_conversion } - witx::BuiltinType::S8 | witx::BuiltinType::S16 => { - let name = names.func_param(¶m.name); - quote! { - let #name: #interface_typename = match (#name as i32).try_into() { - Ok(a) => a, - Err(e) => { - #error_handling - } + } + } + + // This is only used for `call_wasm` at this time. + fn allocate_space(&mut self, _: usize, _: &witx::NamedType) { + unimplemented!() + } + + fn emit( + &mut self, + inst: &Instruction<'_>, + operands: &mut Vec, + results: &mut Vec, + ) { + let rt = self.rt; + let wrap_err = |location: &str| { + let funcname = self.funcname; + quote! { + |e| { + #rt::GuestError::InFunc { + funcname: #funcname, + location: #location, + err: Box::new(#rt::GuestError::from(e)), } } } - witx::BuiltinType::U32 - | witx::BuiltinType::S32 - | witx::BuiltinType::U64 - | witx::BuiltinType::S64 - | witx::BuiltinType::USize - | witx::BuiltinType::F32 - | witx::BuiltinType::F64 => { - let name = names.func_param(¶m.name); - quote! { - let #name = #name as #interface_typename; + }; + + let mut try_from = |ty: TokenStream| { + let val = operands.pop().unwrap(); + let wrap_err = wrap_err(&format!("convert {}", ty)); + results.push(quote!(#ty::try_from(#val).map_err(#wrap_err)?)); + }; + + match inst { + Instruction::GetArg { nth } => { + let param = &self.params[*nth]; + results.push(quote!(#param)); + } + + Instruction::PointerFromI32 { ty } | Instruction::ConstPointerFromI32 { ty } => { + let val = operands.pop().unwrap(); + let pointee_type = self.names.type_ref(ty, anon_lifetime()); + results.push(quote! { + #rt::GuestPtr::<#pointee_type>::new(memory, #val as u32) + }); + } + + Instruction::ListFromPointerLength { ty } => { + let ptr = &operands[0]; + let len = &operands[1]; + let ty = match &**ty.type_() { + witx::Type::Builtin(witx::BuiltinType::Char) => quote!(str), + _ => { + let ty = self.names.type_ref(ty, anon_lifetime()); + quote!([#ty]) + } + }; + results.push(quote! { + #rt::GuestPtr::<#ty>::new(memory, (#ptr as u32, #len as u32)); + }) + } + + Instruction::CallInterface { func, .. } => { + // Use the `tracing` crate to log all arguments that are going + // out, and afterwards we call the function with those bindings. + let mut args = Vec::new(); + for (i, param) in func.params.iter().enumerate() { + let name = self.names.func_param(¶m.name); + let val = &operands[i]; + self.src.extend(quote!(let #name = #val;)); + if passed_by_reference(param.tref.type_()) { + args.push(quote!(&#name)); + } else { + args.push(quote!(#name)); + } + } + if func.params.len() > 0 { + let args = func + .params + .iter() + .map(|param| { + let name = self.names.func_param(¶m.name); + if param.impls_display() { + quote!( #name = #rt::tracing::field::display(&#name) ) + } else { + quote!( #name = #rt::tracing::field::debug(&#name) ) + } + }) + .collect::>(); + self.src.extend(quote! { + #rt::tracing::event!(#rt::tracing::Level::TRACE, #(#args),*); + }); + } + + let trait_name = self.names.trait_name(&self.module.name); + let ident = self.names.func(&func.name); + self.src.extend(quote! { + let ret = #trait_name::#ident(ctx, #(#args),*); + #rt::tracing::event!( + #rt::tracing::Level::TRACE, + result = #rt::tracing::field::debug(&ret), + ); + }); + + if func.results.len() > 0 { + results.push(quote!(ret)); + } else if func.noreturn { + self.src.extend(quote!(return Err(ret))); } } - witx::BuiltinType::String => { - let lifetime = anon_lifetime(); - let ptr_name = names.func_ptr_binding(¶m.name); - let len_name = names.func_len_binding(¶m.name); - let name = names.func_param(¶m.name); - quote! { - let #name = #rt::GuestPtr::<#lifetime, str>::new(memory, (#ptr_name as u32, #len_name as u32)); - } + + // Lowering an enum is typically simple but if we have an error + // transformation registered for this enum then what we're actually + // doing is lowering from a user-defined error type to the error + // enum, and *then* we lower to an i32. + Instruction::EnumLower { ty } => { + let val = operands.pop().unwrap(); + let val = match self.errxform.for_name(ty) { + Some(custom) => { + let method = self.names.user_error_conversion_method(&custom); + quote!(UserErrorConversion::#method(ctx, #val)?) + } + None => val, + }; + results.push(quote!(#val as i32)); } - }, - witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { - let pointee_type = names.type_ref(pointee, anon_lifetime()); - let name = names.func_param(¶m.name); - quote! { - let #name = #rt::GuestPtr::<#pointee_type>::new(memory, #name as u32); + + Instruction::ResultLower { err: err_ty, .. } => { + let err = self.blocks.pop().unwrap(); + let ok = self.blocks.pop().unwrap(); + let val = operands.pop().unwrap(); + let err_typename = self.names.type_ref(err_ty.unwrap(), anon_lifetime()); + results.push(quote! { + match #val { + Ok(e) => { #ok; <#err_typename as #rt::GuestErrorType>::success() as i32 } + Err(e) => { #err } + } + }); } - } - witx::Type::Struct(_) => read_conversion, - witx::Type::Array(arr) => { - let pointee_type = names.type_ref(arr, anon_lifetime()); - let ptr_name = names.func_ptr_binding(¶m.name); - let len_name = names.func_len_binding(¶m.name); - let name = names.func_param(¶m.name); - quote! { - let #name = #rt::GuestPtr::<[#pointee_type]>::new(memory, (#ptr_name as u32, #len_name as u32)); + + Instruction::VariantPayload => results.push(quote!(e)), + + Instruction::Return { amt: 0 } => {} + Instruction::Return { amt: 1 } => { + let val = operands.pop().unwrap(); + self.src.extend(quote!(return Ok(#val))); } - } - witx::Type::Union(_u) => read_conversion, - witx::Type::Handle(_h) => { - let name = names.func_param(¶m.name); - let handle_type = names.type_ref(tref, anon_lifetime()); - quote!( let #name = #handle_type::from(#name); ) + Instruction::Return { .. } => unimplemented!(), + + Instruction::TupleLower { amt } => { + let names = (0..*amt) + .map(|i| Ident::new(&format!("t{}", i), Span::call_site())) + .collect::>(); + let val = operands.pop().unwrap(); + self.src.extend(quote!( let (#(#names,)*) = #val;)); + results.extend(names.iter().map(|i| quote!(#i))); + } + + Instruction::Store { ty } => { + let ptr = operands.pop().unwrap(); + let val = operands.pop().unwrap(); + let wrap_err = wrap_err(&format!("write {}", ty.name.as_str())); + let pointee_type = self.names.type_(&ty.name); + self.src.extend(quote! { + #rt::GuestPtr::<#pointee_type>::new(memory, #ptr as u32) + .write(#val) + .map_err(#wrap_err)?; + }); + } + + Instruction::HandleFromI32 { ty } => { + let val = operands.pop().unwrap(); + let ty = self.names.type_(&ty.name); + results.push(quote!(#ty::from(#val))); + } + + // Smaller-than-32 numerical conversions are done with `TryFrom` to + // ensure we're not losing bits. + Instruction::U8FromI32 => try_from(quote!(u8)), + Instruction::S8FromI32 => try_from(quote!(i8)), + Instruction::Char8FromI32 => try_from(quote!(u8)), + Instruction::U16FromI32 => try_from(quote!(u16)), + Instruction::S16FromI32 => try_from(quote!(i16)), + + // Conversions with matching bit-widths but different signededness + // use `as` since we're basically just reinterpreting the bits. + Instruction::U32FromI32 => { + let val = operands.pop().unwrap(); + results.push(quote!(#val as u32)); + } + Instruction::U64FromI64 => { + let val = operands.pop().unwrap(); + results.push(quote!(#val as u64)); + } + + // Conversions to enums/bitflags use `TryFrom` to ensure that the + // values are valid coming in. + Instruction::EnumLift { ty } + | Instruction::BitflagsFromI64 { ty } + | Instruction::BitflagsFromI32 { ty } => { + let ty = self.names.type_(&ty.name); + try_from(quote!(#ty)) + } + + // No conversions necessary for these, the native wasm type matches + // our own representation. + Instruction::If32FromF32 + | Instruction::If64FromF64 + | Instruction::S32FromI32 + | Instruction::S64FromI64 => results.push(operands.pop().unwrap()), + + // There's a number of other instructions we could implement but + // they're not exercised by WASI at this time. As necessary we can + // add code to implement them. + other => panic!("no implementation for {:?}", other), } } } - -fn marshal_result( - names: &Names, - result: &witx::InterfaceFuncParam, - error_handling: F, -) -> (TokenStream, TokenStream) -where - F: Fn(&str) -> TokenStream, -{ - let rt = names.runtime_mod(); - let tref = &result.tref; - - let write_val_to_ptr = { - let pointee_type = names.type_ref(tref, anon_lifetime()); - // core type is given func_ptr_binding name. - let ptr_name = names.func_ptr_binding(&result.name); - let ptr_err_handling = error_handling(&format!("{}:result_ptr_mut", result.name.as_str())); - let pre = quote! { - let #ptr_name = #rt::GuestPtr::<#pointee_type>::new(memory, #ptr_name as u32); - }; - // trait binding returns func_param name. - let val_name = names.func_param(&result.name); - let post = quote! { - if let Err(e) = #ptr_name.write(#val_name) { - #ptr_err_handling - } - }; - (pre, post) - }; - - match &*tref.type_() { - witx::Type::Builtin(b) => match b { - witx::BuiltinType::String => unimplemented!("string result types"), - _ => write_val_to_ptr, - }, - witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } | witx::Type::Array { .. } => { - unimplemented!("pointer/array result types") - } - _ => write_val_to_ptr, - } -} diff --git a/crates/wiggle/generate/src/lib.rs b/crates/wiggle/generate/src/lib.rs index b9a07f93ce..3eb359bc67 100644 --- a/crates/wiggle/generate/src/lib.rs +++ b/crates/wiggle/generate/src/lib.rs @@ -6,11 +6,10 @@ mod module_trait; mod names; mod types; +use lifetimes::anon_lifetime; use proc_macro2::TokenStream; use quote::quote; -use lifetimes::anon_lifetime; - pub use config::Config; pub use error_transform::{ErrorTransform, UserErrorType}; pub use funcs::define_func; @@ -67,6 +66,8 @@ pub fn generate(doc: &witx::Document, names: &Names, errs: &ErrorTransform) -> T quote!( pub mod types { + use std::convert::TryFrom; + #(#types)* #guest_error_conversion #user_error_conversion diff --git a/crates/wiggle/generate/src/lifetimes.rs b/crates/wiggle/generate/src/lifetimes.rs index 75b102209c..87b1aad9a4 100644 --- a/crates/wiggle/generate/src/lifetimes.rs +++ b/crates/wiggle/generate/src/lifetimes.rs @@ -19,46 +19,37 @@ impl LifetimeExt for witx::Type { fn is_transparent(&self) -> bool { match self { witx::Type::Builtin(b) => b.is_transparent(), - witx::Type::Struct(s) => s.is_transparent(), - witx::Type::Enum { .. } - | witx::Type::Flags { .. } - | witx::Type::Int { .. } - | witx::Type::Handle { .. } => true, - witx::Type::Union { .. } + witx::Type::Record(s) => s.is_transparent(), + witx::Type::Handle { .. } => true, + witx::Type::Variant { .. } | witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } - | witx::Type::Array { .. } => false, + | witx::Type::List { .. } => false, } } fn needs_lifetime(&self) -> bool { match self { witx::Type::Builtin(b) => b.needs_lifetime(), - witx::Type::Struct(s) => s.needs_lifetime(), - witx::Type::Union(u) => u.needs_lifetime(), - witx::Type::Enum { .. } - | witx::Type::Flags { .. } - | witx::Type::Int { .. } - | witx::Type::Handle { .. } => false, + witx::Type::Record(s) => s.needs_lifetime(), + witx::Type::Variant(u) => u.needs_lifetime(), + witx::Type::Handle { .. } => false, witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } - | witx::Type::Array { .. } => true, + | witx::Type::List { .. } => true, } } } impl LifetimeExt for witx::BuiltinType { fn is_transparent(&self) -> bool { - !self.needs_lifetime() + true } fn needs_lifetime(&self) -> bool { - match self { - witx::BuiltinType::String => true, - _ => false, - } + false } } -impl LifetimeExt for witx::StructDatatype { +impl LifetimeExt for witx::RecordDatatype { fn is_transparent(&self) -> bool { self.members.iter().all(|m| m.tref.is_transparent()) } @@ -67,12 +58,12 @@ impl LifetimeExt for witx::StructDatatype { } } -impl LifetimeExt for witx::UnionDatatype { +impl LifetimeExt for witx::Variant { fn is_transparent(&self) -> bool { false } fn needs_lifetime(&self) -> bool { - self.variants + self.cases .iter() .any(|m| m.tref.as_ref().map(|t| t.needs_lifetime()).unwrap_or(false)) } diff --git a/crates/wiggle/generate/src/module_trait.rs b/crates/wiggle/generate/src/module_trait.rs index a7a7289d09..92e503ddda 100644 --- a/crates/wiggle/generate/src/module_trait.rs +++ b/crates/wiggle/generate/src/module_trait.rs @@ -7,17 +7,9 @@ use crate::names::Names; use witx::Module; pub fn passed_by_reference(ty: &witx::Type) -> bool { - let passed_by = match ty.passed_by() { - witx::TypePassedBy::Value { .. } => false, - witx::TypePassedBy::Pointer { .. } | witx::TypePassedBy::PointerLengthPair { .. } => true, - }; match ty { - witx::Type::Builtin(b) => match &*b { - witx::BuiltinType::String => true, - _ => passed_by, - }, - witx::Type::Pointer(_) | witx::Type::ConstPointer(_) | witx::Type::Array(_) => true, - _ => passed_by, + witx::Type::Pointer(_) | witx::Type::ConstPointer(_) | witx::Type::List(_) => true, + _ => false, } } @@ -49,28 +41,36 @@ pub fn define_module_trait(names: &Names, m: &Module, errxform: &ErrorTransform) quote!(#arg_name: #arg_type) }); - let result = if !f.noreturn { - let rets = f - .results - .iter() - .skip(1) - .map(|ret| names.type_ref(&ret.tref, lifetime.clone())); - let err = f - .results - .get(0) - .map(|err_result| { - if let Some(custom_err) = errxform.for_abi_error(&err_result.tref) { - let tn = custom_err.typename(); - quote!(super::#tn) - } else { - names.type_ref(&err_result.tref, lifetime.clone()) - } - }) - .unwrap_or(quote!(())); - quote!( Result<(#(#rets),*), #err> ) - } else { - let rt = names.runtime_mod(); - quote!(#rt::Trap) + let rt = names.runtime_mod(); + let result = match f.results.len() { + 0 if f.noreturn => quote!(#rt::Trap), + 0 => quote!(()), + 1 => { + let (ok, err) = match &**f.results[0].tref.type_() { + witx::Type::Variant(v) => match v.as_expected() { + Some(p) => p, + None => unimplemented!("anonymous variant ref {:?}", v), + }, + _ => unimplemented!(), + }; + + let ok = match ok { + Some(ty) => names.type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + let err = match err { + Some(ty) => match errxform.for_abi_error(ty) { + Some(custom) => { + let tn = custom.typename(); + quote!(super::#tn) + } + None => names.type_ref(ty, lifetime.clone()), + }, + None => quote!(()), + }; + quote!(Result<#ok, #err>) + } + _ => unimplemented!(), }; if is_anonymous { diff --git a/crates/wiggle/generate/src/names.rs b/crates/wiggle/generate/src/names.rs index f6266ec41f..b9eb9e4752 100644 --- a/crates/wiggle/generate/src/names.rs +++ b/crates/wiggle/generate/src/names.rs @@ -2,7 +2,7 @@ use escaping::{escape_id, handle_2big_enum_variant, NamingConvention}; use heck::{ShoutySnakeCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; -use witx::{AtomType, BuiltinType, Id, Type, TypeRef}; +use witx::{BuiltinType, Id, Type, TypeRef, WasmType}; use crate::{lifetimes::LifetimeExt, UserErrorType}; @@ -32,15 +32,11 @@ impl Names { quote!(#ident) } - pub fn builtin_type(&self, b: BuiltinType, lifetime: TokenStream) -> TokenStream { + pub fn builtin_type(&self, b: BuiltinType) -> TokenStream { match b { - BuiltinType::String => { - let rt = self.runtime_mod(); - quote!(#rt::GuestPtr<#lifetime, str>) - } - BuiltinType::U8 => quote!(u8), + BuiltinType::U8 { .. } => quote!(u8), BuiltinType::U16 => quote!(u16), - BuiltinType::U32 => quote!(u32), + BuiltinType::U32 { .. } => quote!(u32), BuiltinType::U64 => quote!(u64), BuiltinType::S8 => quote!(i8), BuiltinType::S16 => quote!(i16), @@ -48,16 +44,16 @@ impl Names { BuiltinType::S64 => quote!(i64), BuiltinType::F32 => quote!(f32), BuiltinType::F64 => quote!(f64), - BuiltinType::Char8 => quote!(u8), - BuiltinType::USize => quote!(u32), + BuiltinType::Char => quote!(char), } } - pub fn atom_type(&self, atom: AtomType) -> TokenStream { - match atom { - AtomType::I32 => quote!(i32), - AtomType::I64 => quote!(i64), - AtomType::F32 => quote!(f32), - AtomType::F64 => quote!(f64), + + pub fn wasm_type(&self, ty: WasmType) -> TokenStream { + match ty { + WasmType::I32 => quote!(i32), + WasmType::I64 => quote!(i64), + WasmType::F32 => quote!(f32), + WasmType::F64 => quote!(f64), } } @@ -72,16 +68,44 @@ impl Names { } } TypeRef::Value(ty) => match &**ty { - Type::Builtin(builtin) => self.builtin_type(*builtin, lifetime.clone()), + Type::Builtin(builtin) => self.builtin_type(*builtin), Type::Pointer(pointee) | Type::ConstPointer(pointee) => { let rt = self.runtime_mod(); let pointee_type = self.type_ref(&pointee, lifetime.clone()); quote!(#rt::GuestPtr<#lifetime, #pointee_type>) } - Type::Array(pointee) => { - let rt = self.runtime_mod(); - let pointee_type = self.type_ref(&pointee, lifetime.clone()); - quote!(#rt::GuestPtr<#lifetime, [#pointee_type]>) + Type::List(pointee) => match &**pointee.type_() { + Type::Builtin(BuiltinType::Char) => { + let rt = self.runtime_mod(); + quote!(#rt::GuestPtr<#lifetime, str>) + } + _ => { + let rt = self.runtime_mod(); + let pointee_type = self.type_ref(&pointee, lifetime.clone()); + quote!(#rt::GuestPtr<#lifetime, [#pointee_type]>) + } + }, + Type::Variant(v) => match v.as_expected() { + Some((ok, err)) => { + let ok = match ok { + Some(ty) => self.type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + let err = match err { + Some(ty) => self.type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + quote!(Result<#ok, #err>) + } + None => unimplemented!("anonymous variant ref {:?}", tref), + }, + Type::Record(r) if r.is_tuple() => { + let types = r + .members + .iter() + .map(|m| self.type_ref(&m.tref, lifetime.clone())) + .collect::>(); + quote!((#(#types,)*)) } _ => unimplemented!("anonymous type ref {:?}", tref), }, @@ -144,14 +168,6 @@ impl Names { escape_id(id, NamingConvention::SnakeCase) } - pub fn func_core_arg(&self, arg: &witx::CoreParamType) -> Ident { - match arg.signifies { - witx::CoreParamSignifies::Value { .. } => self.func_param(&arg.param.name), - witx::CoreParamSignifies::PointerTo => self.func_ptr_binding(&arg.param.name), - witx::CoreParamSignifies::LengthOf => self.func_len_binding(&arg.param.name), - } - } - /// For when you need a {name}_ptr binding for passing a value by reference: pub fn func_ptr_binding(&self, id: &Id) -> Ident { format_ident!("{}_ptr", id.as_str().to_snake_case()) @@ -164,10 +180,9 @@ impl Names { fn builtin_name(b: &BuiltinType) -> &'static str { match b { - BuiltinType::String => "string", - BuiltinType::U8 => "u8", + BuiltinType::U8 { .. } => "u8", BuiltinType::U16 => "u16", - BuiltinType::U32 => "u32", + BuiltinType::U32 { .. } => "u32", BuiltinType::U64 => "u64", BuiltinType::S8 => "i8", BuiltinType::S16 => "i16", @@ -175,8 +190,7 @@ impl Names { BuiltinType::S64 => "i64", BuiltinType::F32 => "f32", BuiltinType::F64 => "f64", - BuiltinType::Char8 => "char8", - BuiltinType::USize => "usize", + BuiltinType::Char => "char", } } diff --git a/crates/wiggle/generate/src/types/enum.rs b/crates/wiggle/generate/src/types/enum.rs deleted file mode 100644 index cef56065ef..0000000000 --- a/crates/wiggle/generate/src/types/enum.rs +++ /dev/null @@ -1,120 +0,0 @@ -use super::{atom_token, int_repr_tokens}; -use crate::names::Names; - -use proc_macro2::TokenStream; -use quote::quote; - -pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenStream { - let ident = names.type_(&name); - let rt = names.runtime_mod(); - - let repr = int_repr_tokens(e.repr); - let abi_repr = atom_token(match e.repr { - witx::IntRepr::U8 | witx::IntRepr::U16 | witx::IntRepr::U32 => witx::AtomType::I32, - witx::IntRepr::U64 => witx::AtomType::I64, - }); - - let mut variant_names = vec![]; - let mut tryfrom_repr_cases = vec![]; - let mut to_repr_cases = vec![]; - let mut to_display = vec![]; - - for (n, variant) in e.variants.iter().enumerate() { - let variant_name = names.enum_variant(&variant.name); - let docs = variant.docs.trim(); - let ident_str = ident.to_string(); - let variant_str = variant_name.to_string(); - tryfrom_repr_cases.push(quote!(#n => Ok(#ident::#variant_name))); - to_repr_cases.push(quote!(#ident::#variant_name => #n as #repr)); - to_display.push(quote!(#ident::#variant_name => format!("{} ({}::{}({}))", #docs, #ident_str, #variant_str, #repr::from(*self)))); - variant_names.push(variant_name); - } - - quote! { - #[repr(#repr)] - #[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)] - pub enum #ident { - #(#variant_names),* - } - - impl ::std::fmt::Display for #ident { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - let to_str = match self { - #(#to_display,)* - }; - write!(f, "{}", to_str) - } - } - - impl ::std::convert::TryFrom<#repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #repr) -> Result<#ident, #rt::GuestError> { - match value as usize { - #(#tryfrom_repr_cases),*, - _ => Err( #rt::GuestError::InvalidEnumValue(stringify!(#ident))), - } - } - } - - impl ::std::convert::TryFrom<#abi_repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #abi_repr) -> Result<#ident, #rt::GuestError> { - #ident::try_from(value as #repr) - } - } - - impl From<#ident> for #repr { - fn from(e: #ident) -> #repr { - match e { - #(#to_repr_cases),* - } - } - } - - impl From<#ident> for #abi_repr { - fn from(e: #ident) -> #abi_repr { - #repr::from(e) as #abi_repr - } - } - - impl<'a> #rt::GuestType<'a> for #ident { - fn guest_size() -> u32 { - #repr::guest_size() - } - - fn guest_align() -> usize { - #repr::guest_align() - } - - fn read(location: & #rt::GuestPtr<#ident>) -> Result<#ident, #rt::GuestError> { - use std::convert::TryFrom; - let reprval = #repr::read(&location.cast())?; - let value = #ident::try_from(reprval)?; - Ok(value) - } - - fn write(location: & #rt::GuestPtr<'_, #ident>, val: Self) - -> Result<(), #rt::GuestError> - { - #repr::write(&location.cast(), #repr::from(val)) - } - } - - unsafe impl <'a> #rt::GuestTypeTransparent<'a> for #ident { - #[inline] - fn validate(location: *mut #ident) -> Result<(), #rt::GuestError> { - use std::convert::TryFrom; - // Validate value in memory using #ident::try_from(reprval) - let reprval = unsafe { (location as *mut #repr).read() }; - let _val = #ident::try_from(reprval)?; - Ok(()) - } - } - } -} - -impl super::WiggleType for witx::EnumDatatype { - fn impls_display(&self) -> bool { - true - } -} diff --git a/crates/wiggle/generate/src/types/flags.rs b/crates/wiggle/generate/src/types/flags.rs index cda13a2bfa..6082ed5f73 100644 --- a/crates/wiggle/generate/src/types/flags.rs +++ b/crates/wiggle/generate/src/types/flags.rs @@ -1,27 +1,24 @@ -use super::{atom_token, int_repr_tokens}; use crate::names::Names; use proc_macro2::{Literal, TokenStream}; use quote::quote; -use std::convert::TryFrom; -pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDatatype) -> TokenStream { +pub(super) fn define_flags( + names: &Names, + name: &witx::Id, + repr: witx::IntRepr, + record: &witx::RecordDatatype, +) -> TokenStream { let rt = names.runtime_mod(); let ident = names.type_(&name); - let repr = int_repr_tokens(f.repr); - let abi_repr = atom_token(match f.repr { - witx::IntRepr::U8 | witx::IntRepr::U16 | witx::IntRepr::U32 => witx::AtomType::I32, - witx::IntRepr::U64 => witx::AtomType::I64, - }); + let abi_repr = names.wasm_type(repr.into()); + let repr = super::int_repr_tokens(repr); let mut names_ = vec![]; let mut values_ = vec![]; - for (i, f) in f.flags.iter().enumerate() { - let name = names.flag_member(&f.name); - let value = 1u128 - .checked_shl(u32::try_from(i).expect("flag value overflow")) - .expect("flag value overflow"); - let value_token = Literal::u128_unsuffixed(value); + for (i, member) in record.members.iter().enumerate() { + let name = names.flag_member(&member.name); + let value_token = Literal::usize_unsuffixed(1 << i); names_.push(name); values_.push(value_token); } @@ -45,7 +42,7 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty } } - impl ::std::convert::TryFrom<#repr> for #ident { + impl TryFrom<#repr> for #ident { type Error = #rt::GuestError; fn try_from(value: #repr) -> Result { if #repr::from(!#ident::all()) & value != 0 { @@ -56,10 +53,10 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty } } - impl ::std::convert::TryFrom<#abi_repr> for #ident { + impl TryFrom<#abi_repr> for #ident { type Error = #rt::GuestError; - fn try_from(value: #abi_repr) -> Result<#ident, #rt::GuestError> { - #ident::try_from(value as #repr) + fn try_from(value: #abi_repr) -> Result { + #ident::try_from(#repr::try_from(value)?) } } @@ -69,12 +66,6 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty } } - impl From<#ident> for #abi_repr { - fn from(e: #ident) -> #abi_repr { - #repr::from(e) as #abi_repr - } - } - impl<'a> #rt::GuestType<'a> for #ident { fn guest_size() -> u32 { #repr::guest_size() @@ -106,12 +97,5 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty Ok(()) } } - - } -} - -impl super::WiggleType for witx::FlagsDatatype { - fn impls_display(&self) -> bool { - true } } diff --git a/crates/wiggle/generate/src/types/handle.rs b/crates/wiggle/generate/src/types/handle.rs index 310c9e7d6e..0f23419b0d 100644 --- a/crates/wiggle/generate/src/types/handle.rs +++ b/crates/wiggle/generate/src/types/handle.rs @@ -78,8 +78,6 @@ pub(super) fn define_handle( Ok(()) } } - - } } diff --git a/crates/wiggle/generate/src/types/int.rs b/crates/wiggle/generate/src/types/int.rs deleted file mode 100644 index d916b65b70..0000000000 --- a/crates/wiggle/generate/src/types/int.rs +++ /dev/null @@ -1,100 +0,0 @@ -use super::{atom_token, int_repr_tokens}; -use crate::names::Names; - -use proc_macro2::TokenStream; -use quote::quote; - -pub(super) fn define_int(names: &Names, name: &witx::Id, i: &witx::IntDatatype) -> TokenStream { - let rt = names.runtime_mod(); - let ident = names.type_(&name); - let repr = int_repr_tokens(i.repr); - let abi_repr = atom_token(match i.repr { - witx::IntRepr::U8 | witx::IntRepr::U16 | witx::IntRepr::U32 => witx::AtomType::I32, - witx::IntRepr::U64 => witx::AtomType::I64, - }); - let consts = i - .consts - .iter() - .map(|r#const| { - let const_ident = names.int_member(&r#const.name); - let value = r#const.value; - quote!(pub const #const_ident: #ident = #ident(#value)) - }) - .collect::>(); - - quote! { - #[repr(transparent)] - #[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)] - pub struct #ident(#repr); - - impl #ident { - #(#consts;)* - } - - impl ::std::fmt::Display for #ident { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "{:?}", self) - } - } - - impl ::std::convert::TryFrom<#repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #repr) -> Result { - Ok(#ident(value)) - } - } - - impl ::std::convert::TryFrom<#abi_repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #abi_repr) -> Result<#ident, #rt::GuestError> { - #ident::try_from(value as #repr) - } - } - - impl From<#ident> for #repr { - fn from(e: #ident) -> #repr { - e.0 - } - } - - impl From<#ident> for #abi_repr { - fn from(e: #ident) -> #abi_repr { - #repr::from(e) as #abi_repr - } - } - - impl<'a> #rt::GuestType<'a> for #ident { - fn guest_size() -> u32 { - #repr::guest_size() - } - - fn guest_align() -> usize { - #repr::guest_align() - } - - fn read(location: &#rt::GuestPtr<'a, #ident>) -> Result<#ident, #rt::GuestError> { - Ok(#ident(#repr::read(&location.cast())?)) - - } - - fn write(location: &#rt::GuestPtr<'_, #ident>, val: Self) -> Result<(), #rt::GuestError> { - #repr::write(&location.cast(), val.0) - } - } - - unsafe impl<'a> #rt::GuestTypeTransparent<'a> for #ident { - #[inline] - fn validate(_location: *mut #ident) -> Result<(), #rt::GuestError> { - // All bit patterns accepted - Ok(()) - } - } - - } -} - -impl super::WiggleType for witx::IntDatatype { - fn impls_display(&self) -> bool { - true - } -} diff --git a/crates/wiggle/generate/src/types/mod.rs b/crates/wiggle/generate/src/types/mod.rs index 6ed29af94b..255c2ce7b2 100644 --- a/crates/wiggle/generate/src/types/mod.rs +++ b/crates/wiggle/generate/src/types/mod.rs @@ -1,9 +1,8 @@ -mod r#enum; +// mod r#enum; mod flags; mod handle; -mod int; -mod r#struct; -mod union; +mod record; +mod variant; use crate::lifetimes::LifetimeExt; use crate::names::Names; @@ -15,11 +14,11 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea match &namedtype.tref { witx::TypeRef::Name(alias_to) => define_alias(names, &namedtype.name, &alias_to), witx::TypeRef::Value(v) => match &**v { - witx::Type::Enum(e) => r#enum::define_enum(names, &namedtype.name, &e), - witx::Type::Int(i) => int::define_int(names, &namedtype.name, &i), - witx::Type::Flags(f) => flags::define_flags(names, &namedtype.name, &f), - witx::Type::Struct(s) => r#struct::define_struct(names, &namedtype.name, &s), - witx::Type::Union(u) => union::define_union(names, &namedtype.name, &u), + witx::Type::Record(r) => match r.bitflags_repr() { + Some(repr) => flags::define_flags(names, &namedtype.name, repr, &r), + None => record::define_struct(names, &namedtype.name, &r), + }, + witx::Type::Variant(v) => variant::define_variant(names, &namedtype.name, &v), witx::Type::Handle(h) => handle::define_handle(names, &namedtype.name, &h), witx::Type::Builtin(b) => define_builtin(names, &namedtype.name, *b), witx::Type::Pointer(p) => { @@ -30,7 +29,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea let rt = names.runtime_mod(); define_witx_pointer(names, &namedtype.name, quote!(#rt::GuestPtr), p) } - witx::Type::Array(arr) => define_witx_array(names, &namedtype.name, &arr), + witx::Type::List(arr) => define_witx_list(names, &namedtype.name, &arr), }, } } @@ -47,12 +46,8 @@ fn define_alias(names: &Names, name: &witx::Id, to: &witx::NamedType) -> TokenSt fn define_builtin(names: &Names, name: &witx::Id, builtin: witx::BuiltinType) -> TokenStream { let ident = names.type_(name); - let built = names.builtin_type(builtin, quote!('a)); - if builtin.needs_lifetime() { - quote!(pub type #ident<'a> = #built;) - } else { - quote!(pub type #ident = #built;) - } + let built = names.builtin_type(builtin); + quote!(pub type #ident = #built;) } fn define_witx_pointer( @@ -67,14 +62,14 @@ fn define_witx_pointer( quote!(pub type #ident<'a> = #pointer_type<'a, #pointee_type>;) } -fn define_witx_array(names: &Names, name: &witx::Id, arr_raw: &witx::TypeRef) -> TokenStream { +fn define_witx_list(names: &Names, name: &witx::Id, arr_raw: &witx::TypeRef) -> TokenStream { let ident = names.type_(name); let rt = names.runtime_mod(); let pointee_type = names.type_ref(arr_raw, quote!('a)); quote!(pub type #ident<'a> = #rt::GuestPtr<'a, [#pointee_type]>;) } -fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { +pub fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { match int_repr { witx::IntRepr::U8 => quote!(u8), witx::IntRepr::U16 => quote!(u16), @@ -83,15 +78,6 @@ fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { } } -fn atom_token(atom: witx::AtomType) -> TokenStream { - match atom { - witx::AtomType::I32 => quote!(i32), - witx::AtomType::I64 => quote!(i64), - witx::AtomType::F32 => quote!(f32), - witx::AtomType::F64 => quote!(f64), - } -} - pub trait WiggleType { fn impls_display(&self) -> bool; } @@ -114,16 +100,13 @@ impl WiggleType for witx::NamedType { impl WiggleType for witx::Type { fn impls_display(&self) -> bool { match self { - witx::Type::Enum(x) => x.impls_display(), - witx::Type::Int(x) => x.impls_display(), - witx::Type::Flags(x) => x.impls_display(), - witx::Type::Struct(x) => x.impls_display(), - witx::Type::Union(x) => x.impls_display(), + witx::Type::Record(x) => x.impls_display(), + witx::Type::Variant(x) => x.impls_display(), witx::Type::Handle(x) => x.impls_display(), witx::Type::Builtin(x) => x.impls_display(), witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } - | witx::Type::Array { .. } => false, + | witx::Type::List { .. } => false, } } } @@ -136,11 +119,6 @@ impl WiggleType for witx::BuiltinType { impl WiggleType for witx::InterfaceFuncParam { fn impls_display(&self) -> bool { - match &*self.tref.type_() { - witx::Type::Struct { .. } - | witx::Type::Union { .. } - | witx::Type::Builtin(witx::BuiltinType::String { .. }) => false, - _ => self.tref.impls_display(), - } + self.tref.impls_display() } } diff --git a/crates/wiggle/generate/src/types/struct.rs b/crates/wiggle/generate/src/types/record.rs similarity index 94% rename from crates/wiggle/generate/src/types/struct.rs rename to crates/wiggle/generate/src/types/record.rs index 2a79a6e9ac..eaabcc6426 100644 --- a/crates/wiggle/generate/src/types/struct.rs +++ b/crates/wiggle/generate/src/types/record.rs @@ -8,7 +8,7 @@ use witx::Layout; pub(super) fn define_struct( names: &Names, name: &witx::Id, - s: &witx::StructDatatype, + s: &witx::RecordDatatype, ) -> TokenStream { let rt = names.runtime_mod(); let ident = names.type_(name); @@ -28,7 +28,7 @@ pub(super) fn define_struct( } } witx::TypeRef::Value(ty) => match &**ty { - witx::Type::Builtin(builtin) => names.builtin_type(*builtin, quote!('a)), + witx::Type::Builtin(builtin) => names.builtin_type(*builtin), witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { let pointee_type = names.type_ref(&pointee, quote!('a)); quote!(#rt::GuestPtr<'a, #pointee_type>) @@ -52,9 +52,9 @@ pub(super) fn define_struct( } witx::TypeRef::Value(ty) => match &**ty { witx::Type::Builtin(builtin) => { - let type_ = names.builtin_type(*builtin, anon_lifetime()); + let type_ = names.builtin_type(*builtin); quote! { - let #name = <#type_ as #rt::GuestType>::read(&#location)?; + let #name = <#type_ as #rt::GuestType>::read(&#location)?; } } witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { @@ -141,7 +141,7 @@ pub(super) fn define_struct( } } -impl super::WiggleType for witx::StructDatatype { +impl super::WiggleType for witx::RecordDatatype { fn impls_display(&self) -> bool { false } diff --git a/crates/wiggle/generate/src/types/union.rs b/crates/wiggle/generate/src/types/variant.rs similarity index 50% rename from crates/wiggle/generate/src/types/union.rs rename to crates/wiggle/generate/src/types/variant.rs index ecc3253f3c..6568228ded 100644 --- a/crates/wiggle/generate/src/types/union.rs +++ b/crates/wiggle/generate/src/types/variant.rs @@ -1,23 +1,23 @@ use crate::lifetimes::LifetimeExt; use crate::names::Names; -use proc_macro2::TokenStream; +use proc_macro2::{Literal, TokenStream}; use quote::quote; use witx::Layout; -pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDatatype) -> TokenStream { +pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) -> TokenStream { let rt = names.runtime_mod(); let ident = names.type_(name); - let size = u.mem_size_align().size as u32; - let align = u.mem_size_align().align as usize; - let ulayout = u.union_layout(); - let contents_offset = ulayout.contents_offset as u32; + let size = v.mem_size_align().size as u32; + let align = v.mem_size_align().align as usize; + let contents_offset = v.payload_offset() as u32; let lifetime = quote!('a); + let tag_ty = super::int_repr_tokens(v.tag_repr); - let variants = u.variants.iter().map(|v| { - let var_name = names.enum_variant(&v.name); - if let Some(tref) = &v.tref { + let variants = v.cases.iter().map(|c| { + let var_name = names.enum_variant(&c.name); + if let Some(tref) = &c.tref { let var_type = names.type_ref(&tref, lifetime.clone()); quote!(#var_name(#var_type)) } else { @@ -25,30 +25,29 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty } }); - let tagname = names.type_(&u.tag.name); - - let read_variant = u.variants.iter().map(|v| { - let variantname = names.enum_variant(&v.name); - if let Some(tref) = &v.tref { + let read_variant = v.cases.iter().enumerate().map(|(i, c)| { + let i = Literal::usize_unsuffixed(i); + let variantname = names.enum_variant(&c.name); + if let Some(tref) = &c.tref { let varianttype = names.type_ref(tref, lifetime.clone()); quote! { - #tagname::#variantname => { + #i => { let variant_ptr = location.cast::().add(#contents_offset)?; let variant_val = <#varianttype as #rt::GuestType>::read(&variant_ptr.cast())?; Ok(#ident::#variantname(variant_val)) } } } else { - quote! { #tagname::#variantname => Ok(#ident::#variantname), } + quote! { #i => Ok(#ident::#variantname), } } }); - let write_variant = u.variants.iter().map(|v| { - let variantname = names.enum_variant(&v.name); + let write_variant = v.cases.iter().enumerate().map(|(i, c)| { + let variantname = names.enum_variant(&c.name); let write_tag = quote! { - location.cast().write(#tagname::#variantname)?; + location.cast().write(#i as #tag_ty)?; }; - if let Some(tref) = &v.tref { + if let Some(tref) = &c.tref { let varianttype = names.type_ref(tref, lifetime.clone()); quote! { #ident::#variantname(contents) => { @@ -66,7 +65,36 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty } }); - let (enum_lifetime, extra_derive) = if u.needs_lifetime() { + let enum_try_from = if v.cases.iter().all(|c| c.tref.is_none()) { + let tryfrom_repr_cases = v.cases.iter().enumerate().map(|(i, c)| { + let variant_name = names.enum_variant(&c.name); + let n = Literal::usize_unsuffixed(i); + quote!(#n => Ok(#ident::#variant_name)) + }); + let abi_ty = names.wasm_type(v.tag_repr.into()); + quote! { + impl TryFrom<#tag_ty> for #ident { + type Error = #rt::GuestError; + fn try_from(value: #tag_ty) -> Result<#ident, #rt::GuestError> { + match value { + #(#tryfrom_repr_cases),*, + _ => Err( #rt::GuestError::InvalidEnumValue(stringify!(#ident))), + } + } + } + + impl TryFrom<#abi_ty> for #ident { + type Error = #rt::GuestError; + fn try_from(value: #abi_ty) -> Result<#ident, #rt::GuestError> { + #ident::try_from(#tag_ty::try_from(value)?) + } + } + } + } else { + quote!() + }; + + let (enum_lifetime, extra_derive) = if v.needs_lifetime() { (quote!(<'a>), quote!()) } else { (quote!(), quote!(, PartialEq)) @@ -78,6 +106,8 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty #(#variants),* } + #enum_try_from + impl<'a> #rt::GuestType<'a> for #ident #enum_lifetime { fn guest_size() -> u32 { #size @@ -90,9 +120,10 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty fn read(location: &#rt::GuestPtr<'a, Self>) -> Result { - let tag = location.cast().read()?; + let tag = location.cast::<#tag_ty>().read()?; match tag { #(#read_variant)* + _ => Err(#rt::GuestError::InvalidEnumValue(stringify!(#ident))), } } @@ -109,7 +140,7 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty } } -impl super::WiggleType for witx::UnionDatatype { +impl super::WiggleType for witx::Variant { fn impls_display(&self) -> bool { false } diff --git a/crates/wiggle/macro/Cargo.toml b/crates/wiggle/macro/Cargo.toml index c18c3dee74..baee8ce164 100644 --- a/crates/wiggle/macro/Cargo.toml +++ b/crates/wiggle/macro/Cargo.toml @@ -22,7 +22,7 @@ doctest = false [dependencies] wiggle-generate = { path = "../generate", version = "0.23.0" } -witx = { path = "../../wasi-common/WASI/tools/witx", version = "0.8.7" } +witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx" } quote = "1.0" syn = { version = "1.0", features = ["full"] } diff --git a/crates/wiggle/src/lib.rs b/crates/wiggle/src/lib.rs index 4ecbf27da5..24a4ca7c40 100644 --- a/crates/wiggle/src/lib.rs +++ b/crates/wiggle/src/lib.rs @@ -942,3 +942,9 @@ pub enum Trap { /// Any other Trap is just an unstructured String, for reporting and debugging. String(String), } + +impl From for Trap { + fn from(err: GuestError) -> Trap { + Trap::String(err.to_string()) + } +} diff --git a/crates/wiggle/wasmtime/Cargo.toml b/crates/wiggle/wasmtime/Cargo.toml index 97ca34ebc7..31f45eeaf3 100644 --- a/crates/wiggle/wasmtime/Cargo.toml +++ b/crates/wiggle/wasmtime/Cargo.toml @@ -13,7 +13,7 @@ include = ["src/**/*", "LICENSE"] [dependencies] wasmtime = { path = "../../wasmtime", version = "0.23.0", default-features = false } wasmtime-wiggle-macro = { path = "./macro", version = "0.23.0" } -witx = { path = "../../wasi-common/WASI/tools/witx", version = "0.8.7", optional = true } +witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx", optional = true } wiggle = { path = "..", version = "0.23.0" } wiggle-borrow = { path = "../borrow", version = "0.23.0" } diff --git a/crates/wiggle/wasmtime/macro/Cargo.toml b/crates/wiggle/wasmtime/macro/Cargo.toml index 2ac0b1465d..fd4b365f0d 100644 --- a/crates/wiggle/wasmtime/macro/Cargo.toml +++ b/crates/wiggle/wasmtime/macro/Cargo.toml @@ -15,7 +15,7 @@ proc-macro = true test = false [dependencies] -witx = { path = "../../../wasi-common/WASI/tools/witx", version = "0.8.7" } +witx = { version = "0.9", path = "../../../wasi-common/WASI/tools/witx" } wiggle-generate = { path = "../../generate", version = "0.23.0" } quote = "1.0" syn = { version = "1.0", features = ["full", "extra-traits"] } diff --git a/crates/wiggle/wasmtime/macro/src/config.rs b/crates/wiggle/wasmtime/macro/src/config.rs index 6359ed0ec4..5e95bad957 100644 --- a/crates/wiggle/wasmtime/macro/src/config.rs +++ b/crates/wiggle/wasmtime/macro/src/config.rs @@ -144,7 +144,6 @@ impl Parse for TargetConf { enum ModuleConfField { Name(Ident), Docs(String), - FunctionOverride(FunctionOverrideConf), } impl Parse for ModuleConfField { @@ -159,10 +158,6 @@ impl Parse for ModuleConfField { input.parse::()?; let docs: syn::LitStr = input.parse()?; Ok(ModuleConfField::Docs(docs.value())) - } else if lookahead.peek(kw::function_override) { - input.parse::()?; - input.parse::()?; - Ok(ModuleConfField::FunctionOverride(input.parse()?)) } else { Err(lookahead.error()) } @@ -173,14 +168,12 @@ impl Parse for ModuleConfField { pub struct ModuleConf { pub name: Ident, pub docs: Option, - pub function_override: FunctionOverrideConf, } impl ModuleConf { fn build(fields: impl Iterator, err_loc: Span) -> Result { let mut name = None; let mut docs = None; - let mut function_override = None; for f in fields { match f { ModuleConfField::Name(c) => { @@ -195,18 +188,11 @@ impl ModuleConf { } docs = Some(c); } - ModuleConfField::FunctionOverride(c) => { - if function_override.is_some() { - return Err(Error::new(err_loc, "duplicate `function_override` field")); - } - function_override = Some(c); - } } } Ok(ModuleConf { name: name.ok_or_else(|| Error::new(err_loc, "`name` field required"))?, docs, - function_override: function_override.unwrap_or_default(), }) } } @@ -248,42 +234,3 @@ impl Parse for ModulesConf { }) } } - -#[derive(Debug, Clone, Default)] -pub struct FunctionOverrideConf { - pub funcs: Vec, -} -impl FunctionOverrideConf { - pub fn find(&self, name: &str) -> Option<&Ident> { - self.funcs - .iter() - .find(|f| f.name == name) - .map(|f| &f.replacement) - } -} - -impl Parse for FunctionOverrideConf { - fn parse(input: ParseStream) -> Result { - let contents; - let _lbrace = braced!(contents in input); - let fields: Punctuated = - contents.parse_terminated(FunctionOverrideField::parse)?; - Ok(FunctionOverrideConf { - funcs: fields.into_iter().collect(), - }) - } -} - -#[derive(Debug, Clone)] -pub struct FunctionOverrideField { - pub name: String, - pub replacement: Ident, -} -impl Parse for FunctionOverrideField { - fn parse(input: ParseStream) -> Result { - let name = input.parse::()?.to_string(); - input.parse::]>()?; - let replacement = input.parse::()?; - Ok(FunctionOverrideField { name, replacement }) - } -} diff --git a/crates/wiggle/wasmtime/macro/src/lib.rs b/crates/wiggle/wasmtime/macro/src/lib.rs index 7f13308f16..a2d05f729d 100644 --- a/crates/wiggle/wasmtime/macro/src/lib.rs +++ b/crates/wiggle/wasmtime/macro/src/lib.rs @@ -1,5 +1,5 @@ use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; use quote::quote; use syn::parse_macro_input; use wiggle_generate::Names; @@ -88,14 +88,9 @@ fn generate_module( let module_id = names.module(&module.name); let target_module = quote! { #target_path::#module_id }; - let ctor_externs = module.funcs().map(|f| { - if let Some(func_override) = module_conf.function_override.find(&f.name.as_str()) { - let name_ident = names.func(&f.name); - quote! { let #name_ident = wasmtime::Func::wrap(store, #func_override); } - } else { - generate_func(&f, names, &target_module) - } - }); + let ctor_externs = module + .funcs() + .map(|f| generate_func(&f, names, &target_module)); let type_name = module_conf.name.clone(); let type_docs = module_conf @@ -158,23 +153,21 @@ fn generate_func( ) -> TokenStream2 { let name_ident = names.func(&func.name); - let coretype = func.core_type(); + let (params, results) = func.wasm_signature(); - let arg_decls = coretype.args.iter().map(|arg| { - let name = names.func_core_arg(arg); - let atom = names.atom_type(arg.repr()); - quote! { #name: #atom } + let arg_names = (0..params.len()) + .map(|i| Ident::new(&format!("arg{}", i), Span::call_site())) + .collect::>(); + let arg_decls = params.iter().enumerate().map(|(i, ty)| { + let name = &arg_names[i]; + let wasm = names.wasm_type(*ty); + quote! { #name: #wasm } }); - let arg_names = coretype.args.iter().map(|arg| names.func_core_arg(arg)); - let ret_ty = if let Some(ret) = &coretype.ret { - let ret_ty = match ret.signifies { - witx::CoreParamSignifies::Value(atom) => names.atom_type(atom), - _ => unreachable!("coretype ret should always be passed by value"), - }; - quote! { #ret_ty } - } else { - quote! {()} + let ret_ty = match results.len() { + 0 => quote!(()), + 1 => names.wasm_type(results[0]), + _ => unimplemented!(), }; let runtime = names.runtime_mod();