From f6a732b6cfaf4c8b1e12538a9f0146e63017454c Mon Sep 17 00:00:00 2001 From: Pat Hickey Date: Thu, 20 Feb 2020 14:36:53 -0800 Subject: [PATCH] squash all tagged union work into one commit --- crates/WASI | 2 +- crates/generate/src/funcs.rs | 4 + crates/generate/src/types.rs | 232 +++++++++++++++++++++++++++++++++-- tests/union.rs | 43 +++++++ tests/union.witx | 28 +++++ 5 files changed, 297 insertions(+), 12 deletions(-) create mode 100644 tests/union.rs create mode 100644 tests/union.witx diff --git a/crates/WASI b/crates/WASI index 77629f3442..6d96ec08bf 160000 --- a/crates/WASI +++ b/crates/WASI @@ -1 +1 @@ -Subproject commit 77629f34429c1bc65af797dac687fd47fc73df4b +Subproject commit 6d96ec08bf976ee5518ae2734bb01f9a8e919e7c diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index 4045447628..8814df19ab 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -305,6 +305,10 @@ fn marshal_arg( }; } } + witx::Type::Union(_u) => { + let name = names.func_param(¶m.name); + quote!(let #name = unimplemented!("union argument marshaling");) + } _ => unimplemented!("argument type marshalling"), } } diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs index ebdb7d3567..e6767f9144 100644 --- a/crates/generate/src/types.rs +++ b/crates/generate/src/types.rs @@ -19,7 +19,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea define_ptr_struct(names, &namedtype.name, &s) } } - witx::Type::Union(_) => unimplemented!("union types"), + witx::Type::Union(u) => define_union(names, &namedtype.name, &u), witx::Type::Handle(_h) => unimplemented!("handle types"), witx::Type::Builtin(b) => define_builtin(names, &namedtype.name, *b), witx::Type::Pointer(p) => define_witx_pointer( @@ -384,10 +384,9 @@ fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenS impl<'a> wiggle_runtime::GuestTypeCopy<'a> for #ident {} impl<'a> wiggle_runtime::GuestTypeClone<'a> for #ident { fn read_from_guest(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> { - use ::std::convert::TryFrom; - let raw: #repr = unsafe { (location.as_raw() as *const #repr).read() }; - let val = #ident::try_from(raw)?; - Ok(val) + // Perform validation as part of as_ref: + let r = location.as_ref()?; + Ok(*r) } fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<#ident>) { let val: #repr = #repr::from(*self); @@ -408,6 +407,8 @@ fn define_builtin(names: &Names, name: &witx::Id, builtin: witx::BuiltinType) -> } } +// XXX DRY - should move these funcs to be a trait that Type, BuiltinType, StructDatatype, +// UnionDatatype all implement pub fn type_needs_lifetime(tref: &witx::TypeRef) -> bool { match &*tref.type_() { witx::Type::Builtin(b) => match b { @@ -419,7 +420,7 @@ pub fn type_needs_lifetime(tref: &witx::TypeRef) -> bool { | witx::Type::Int { .. } | witx::Type::Handle { .. } => false, witx::Type::Struct(s) => !struct_is_copy(&s), - witx::Type::Union { .. } => true, + witx::Type::Union(u) => !union_is_copy(&u), witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } => true, witx::Type::Array { .. } => true, } @@ -432,16 +433,39 @@ pub fn struct_is_copy(s: &witx::StructDatatype) -> bool { witx::BuiltinType::String => false, _ => true, }, - witx::Type::ConstPointer { .. } - | witx::Type::Pointer { .. } - | witx::Type::Array { .. } - | witx::Type::Union { .. } => false, + witx::Type::ConstPointer { .. } | witx::Type::Pointer { .. } | witx::Type::Array { .. } => { + false + } + witx::Type::Union(u) => union_is_copy(u), witx::Type::Enum { .. } | witx::Type::Int { .. } | witx::Type::Flags { .. } | witx::Type::Handle { .. } => true, }) } +pub fn union_is_copy(u: &witx::UnionDatatype) -> bool { + u.variants.iter().all(|m| { + if let Some(tref) = &m.tref { + match &*tref.type_() { + witx::Type::Struct(s) => struct_is_copy(&s), + witx::Type::Builtin(b) => match &*b { + witx::BuiltinType::String => false, + _ => true, + }, + witx::Type::ConstPointer { .. } + | witx::Type::Pointer { .. } + | witx::Type::Array { .. } => false, + witx::Type::Union(u) => union_is_copy(u), + witx::Type::Enum { .. } + | witx::Type::Int { .. } + | witx::Type::Flags { .. } + | witx::Type::Handle { .. } => true, + } + } else { + true + } + }) +} fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -> TokenStream { let ident = names.type_(name); @@ -498,7 +522,8 @@ fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) } impl<'a> wiggle_runtime::GuestTypeClone<'a> for #ident { fn read_from_guest(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> { - Ok(*location.as_ref()?) + let r = location.as_ref()?; + Ok(*r) } fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<'a, Self>) { unsafe { (location.as_raw() as *mut #ident).write(*self) }; @@ -643,6 +668,191 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) - } } +fn union_validate( + names: &Names, + typename: TokenStream, + u: &witx::UnionDatatype, + ulayout: &witx::UnionLayout, +) -> TokenStream { + let tagname = names.type_(&u.tag.name); + let contents_offset = ulayout.contents_offset as u32; + + let with_err = |f: &str| -> TokenStream { + quote!(|e| wiggle_runtime::GuestError::InDataField { + typename: stringify!(#typename).to_owned(), + field: #f.to_owned(), + err: Box::new(e), + }) + }; + + let tag_err = with_err(""); + let variant_validation = u.variants.iter().map(|v| { + let err = with_err(v.name.as_str()); + let variantname = names.enum_variant(&v.name); + if let Some(tref) = &v.tref { + let varianttype = names.type_ref(tref, anon_lifetime()); + quote! { + #tagname::#variantname => { + let variant_ptr = ptr.cast::<#varianttype>(#contents_offset).map_err(#err)?; + wiggle_runtime::GuestType::validate(&variant_ptr).map_err(#err)?; + } + } + } else { + quote! { #tagname::#variantname => {} } + } + }); + + quote! { + let tag = *ptr.cast::<#tagname>(0).map_err(#tag_err)?.as_ref().map_err(#tag_err)?; + match tag { + #(#variant_validation)* + } + Ok(()) + } +} + +fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDatatype) -> TokenStream { + let ident = names.type_(name); + let size = u.mem_size_align().size as u32; + let align = u.mem_size_align().align as u32; + let ulayout = u.union_layout(); + let contents_offset = ulayout.contents_offset as u32; + + let lifetime = quote!('a); + + let variants = u.variants.iter().map(|v| { + let var_name = names.enum_variant(&v.name); + if let Some(tref) = &v.tref { + let var_type = names.type_ref(&tref, lifetime.clone()); + quote!(#var_name(#var_type)) + } else { + quote!(#var_name) + } + }); + + let tagname = names.type_(&u.tag.name); + + let read_variant = u.variants.iter().map(|v| { + let variantname = names.enum_variant(&v.name); + if let Some(tref) = &v.tref { + let varianttype = names.type_ref(tref, lifetime.clone()); + quote! { + #tagname::#variantname => { + let variant_ptr = location.cast::<#varianttype>(#contents_offset).expect("union variant ptr validated"); + let variant_val = wiggle_runtime::GuestTypeClone::read_from_guest(&variant_ptr)?; + Ok(#ident::#variantname(variant_val)) + } + } + } else { + quote! { #tagname::#variantname => Ok(#ident::#variantname), } + } + }); + + let write_variant = u.variants.iter().map(|v| { + let variantname = names.enum_variant(&v.name); + let write_tag = quote! { + let tag_ptr = location.cast::<#tagname>(0).expect("union tag ptr TODO error report"); + let mut tag_ref = tag_ptr.as_ref_mut().expect("union tag ref TODO error report"); + *tag_ref = #tagname::#variantname; + }; + if let Some(tref) = &v.tref { + let varianttype = names.type_ref(tref, lifetime.clone()); + quote! { + #ident::#variantname(contents) => { + #write_tag + let variant_ptr = location.cast::<#varianttype>(#contents_offset).expect("union variant ptr validated"); + contents.write_to_guest(&variant_ptr); + } + } + } else { + quote! { + #ident::#variantname => { + #write_tag + } + } + } + }); + let validate = union_validate(names, ident.clone(), u, &ulayout); + + if union_is_copy(u) { + // Type does not have a lifetime parameter: + quote! { + #[derive(Clone, Debug, PartialEq)] + pub enum #ident { + #(#variants),* + } + impl wiggle_runtime::GuestType for #ident { + fn size() -> u32 { + #size + } + fn align() -> u32 { + #align + } + fn name() -> String { + stringify!(#ident).to_owned() + } + fn validate(ptr: &wiggle_runtime::GuestPtr<#ident>) -> Result<(), wiggle_runtime::GuestError> { + #validate + } + } + impl<#lifetime> wiggle_runtime::GuestTypeClone<#lifetime> for #ident { + fn read_from_guest(location: &wiggle_runtime::GuestPtr<'a, #ident>) + -> Result { + wiggle_runtime::GuestType::validate(location)?; + let tag = *location.cast::<#tagname>(0).expect("validated tag ptr").as_ref().expect("validated tag ref"); + match tag { + #(#read_variant)* + } + + } + fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<'a, #ident>) { + match self { + #(#write_variant)* + } + } + } + } + } else { + quote! { + #[derive(Clone)] + pub enum #ident<#lifetime> { + #(#variants),* + } + + impl<#lifetime> wiggle_runtime::GuestType for #ident<#lifetime> { + fn size() -> u32 { + #size + } + fn align() -> u32 { + #align + } + fn name() -> String { + stringify!(#ident).to_owned() + } + fn validate(ptr: &wiggle_runtime::GuestPtr<#ident>) -> Result<(), wiggle_runtime::GuestError> { + #validate + } + } + impl<#lifetime> wiggle_runtime::GuestTypeClone<#lifetime> for #ident<#lifetime> { + fn read_from_guest(location: &wiggle_runtime::GuestPtr<'a, #ident>) + -> Result { + wiggle_runtime::GuestType::validate(location)?; + let tag = *location.cast::<#tagname>(0).expect("validated tag ptr").as_ref().expect("validated tag ref"); + match tag { + #(#read_variant)* + } + + } + fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<'a, #ident>) { + match self { + #(#write_variant)* + } + } + } + } + } +} + fn define_witx_pointer( names: &Names, name: &witx::Id, diff --git a/tests/union.rs b/tests/union.rs new file mode 100644 index 0000000000..d5da1bff96 --- /dev/null +++ b/tests/union.rs @@ -0,0 +1,43 @@ +use proptest::prelude::*; +use wiggle_runtime::{ + GuestError, GuestErrorType, GuestMemory, GuestPtr, GuestPtrMut, GuestRef, GuestRefMut, +}; +use wiggle_test::{impl_errno, HostMemory, MemArea, WasiCtx}; + +wiggle_generate::from_witx!({ + witx: ["tests/union.witx"], + ctx: WasiCtx, +}); + +impl_errno!(types::Errno); + +impl union_example::UnionExample for WasiCtx { + fn get_tag(&mut self, u: &types::Reason) -> Result { + println!("GET TAG: {:?}", u); + Ok(types::Excuse::DogAte) + } + fn reason_mult( + &mut self, + u: &types::ReasonMut<'_>, + multiply_by: u32, + ) -> Result<(), types::Errno> { + match u { + types::ReasonMut::DogAte(fptr) => { + let mut f = fptr.as_ref_mut().expect("valid pointer"); + let val = *f; + println!("REASON MULT DogAte({})", val); + *f = val * multiply_by as f32; + } + types::ReasonMut::Traffic(iptr) => { + let mut i = iptr.as_ref_mut().expect("valid pointer"); + let val = *i; + println!("REASON MULT Traffic({})", val); + *i = val * multiply_by as i32; + } + types::ReasonMut::Sleeping => { + println!("REASON MULT Sleeping"); + } + } + Ok(()) + } +} diff --git a/tests/union.witx b/tests/union.witx new file mode 100644 index 0000000000..7dda97f6e9 --- /dev/null +++ b/tests/union.witx @@ -0,0 +1,28 @@ +(use "errno.witx") +(use "excuse.witx") + +(typename $reason + (union $excuse + (field $dog_ate f32) + (field $traffic s32) + (empty $sleeping))) + +(typename $reason_mut + (union $excuse + (field $dog_ate (@witx pointer f32)) + (field $traffic (@witx pointer s32)) + (empty $sleeping))) + +(module $union_example + (@interface func (export "get_tag") + (param $r $reason) + (result $error $errno) + (result $t $excuse) + ) + + (@interface func (export "reason_mult") + (param $r $reason_mut) + (param $multiply_by u32) + (result $error $errno) + ) +)