squash all tagged union work into one commit

This commit is contained in:
Pat Hickey
2020-02-20 14:36:53 -08:00
committed by Jakub Konka
parent 25a411d7fd
commit f6a732b6cf
5 changed files with 297 additions and 12 deletions

View File

@@ -305,6 +305,10 @@ fn marshal_arg(
};
}
}
witx::Type::Union(_u) => {
let name = names.func_param(&param.name);
quote!(let #name = unimplemented!("union argument marshaling");)
}
_ => unimplemented!("argument type marshalling"),
}
}

View File

@@ -19,7 +19,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea
define_ptr_struct(names, &namedtype.name, &s)
}
}
witx::Type::Union(_) => unimplemented!("union types"),
witx::Type::Union(u) => define_union(names, &namedtype.name, &u),
witx::Type::Handle(_h) => unimplemented!("handle types"),
witx::Type::Builtin(b) => define_builtin(names, &namedtype.name, *b),
witx::Type::Pointer(p) => define_witx_pointer(
@@ -384,10 +384,9 @@ fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenS
impl<'a> wiggle_runtime::GuestTypeCopy<'a> for #ident {}
impl<'a> wiggle_runtime::GuestTypeClone<'a> for #ident {
fn read_from_guest(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
use ::std::convert::TryFrom;
let raw: #repr = unsafe { (location.as_raw() as *const #repr).read() };
let val = #ident::try_from(raw)?;
Ok(val)
// Perform validation as part of as_ref:
let r = location.as_ref()?;
Ok(*r)
}
fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<#ident>) {
let val: #repr = #repr::from(*self);
@@ -408,6 +407,8 @@ fn define_builtin(names: &Names, name: &witx::Id, builtin: witx::BuiltinType) ->
}
}
// XXX DRY - should move these funcs to be a trait that Type, BuiltinType, StructDatatype,
// UnionDatatype all implement
pub fn type_needs_lifetime(tref: &witx::TypeRef) -> bool {
match &*tref.type_() {
witx::Type::Builtin(b) => match b {
@@ -419,7 +420,7 @@ pub fn type_needs_lifetime(tref: &witx::TypeRef) -> bool {
| witx::Type::Int { .. }
| witx::Type::Handle { .. } => false,
witx::Type::Struct(s) => !struct_is_copy(&s),
witx::Type::Union { .. } => true,
witx::Type::Union(u) => !union_is_copy(&u),
witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } => true,
witx::Type::Array { .. } => true,
}
@@ -432,16 +433,39 @@ pub fn struct_is_copy(s: &witx::StructDatatype) -> bool {
witx::BuiltinType::String => false,
_ => true,
},
witx::Type::ConstPointer { .. }
| witx::Type::Pointer { .. }
| witx::Type::Array { .. }
| witx::Type::Union { .. } => false,
witx::Type::ConstPointer { .. } | witx::Type::Pointer { .. } | witx::Type::Array { .. } => {
false
}
witx::Type::Union(u) => union_is_copy(u),
witx::Type::Enum { .. }
| witx::Type::Int { .. }
| witx::Type::Flags { .. }
| witx::Type::Handle { .. } => true,
})
}
pub fn union_is_copy(u: &witx::UnionDatatype) -> bool {
u.variants.iter().all(|m| {
if let Some(tref) = &m.tref {
match &*tref.type_() {
witx::Type::Struct(s) => struct_is_copy(&s),
witx::Type::Builtin(b) => match &*b {
witx::BuiltinType::String => false,
_ => true,
},
witx::Type::ConstPointer { .. }
| witx::Type::Pointer { .. }
| witx::Type::Array { .. } => false,
witx::Type::Union(u) => union_is_copy(u),
witx::Type::Enum { .. }
| witx::Type::Int { .. }
| witx::Type::Flags { .. }
| witx::Type::Handle { .. } => true,
}
} else {
true
}
})
}
fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -> TokenStream {
let ident = names.type_(name);
@@ -498,7 +522,8 @@ fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype)
}
impl<'a> wiggle_runtime::GuestTypeClone<'a> for #ident {
fn read_from_guest(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
Ok(*location.as_ref()?)
let r = location.as_ref()?;
Ok(*r)
}
fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<'a, Self>) {
unsafe { (location.as_raw() as *mut #ident).write(*self) };
@@ -643,6 +668,191 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -
}
}
fn union_validate(
names: &Names,
typename: TokenStream,
u: &witx::UnionDatatype,
ulayout: &witx::UnionLayout,
) -> TokenStream {
let tagname = names.type_(&u.tag.name);
let contents_offset = ulayout.contents_offset as u32;
let with_err = |f: &str| -> TokenStream {
quote!(|e| wiggle_runtime::GuestError::InDataField {
typename: stringify!(#typename).to_owned(),
field: #f.to_owned(),
err: Box::new(e),
})
};
let tag_err = with_err("<tag>");
let variant_validation = u.variants.iter().map(|v| {
let err = with_err(v.name.as_str());
let variantname = names.enum_variant(&v.name);
if let Some(tref) = &v.tref {
let varianttype = names.type_ref(tref, anon_lifetime());
quote! {
#tagname::#variantname => {
let variant_ptr = ptr.cast::<#varianttype>(#contents_offset).map_err(#err)?;
wiggle_runtime::GuestType::validate(&variant_ptr).map_err(#err)?;
}
}
} else {
quote! { #tagname::#variantname => {} }
}
});
quote! {
let tag = *ptr.cast::<#tagname>(0).map_err(#tag_err)?.as_ref().map_err(#tag_err)?;
match tag {
#(#variant_validation)*
}
Ok(())
}
}
fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDatatype) -> TokenStream {
let ident = names.type_(name);
let size = u.mem_size_align().size as u32;
let align = u.mem_size_align().align as u32;
let ulayout = u.union_layout();
let contents_offset = ulayout.contents_offset as u32;
let lifetime = quote!('a);
let variants = u.variants.iter().map(|v| {
let var_name = names.enum_variant(&v.name);
if let Some(tref) = &v.tref {
let var_type = names.type_ref(&tref, lifetime.clone());
quote!(#var_name(#var_type))
} else {
quote!(#var_name)
}
});
let tagname = names.type_(&u.tag.name);
let read_variant = u.variants.iter().map(|v| {
let variantname = names.enum_variant(&v.name);
if let Some(tref) = &v.tref {
let varianttype = names.type_ref(tref, lifetime.clone());
quote! {
#tagname::#variantname => {
let variant_ptr = location.cast::<#varianttype>(#contents_offset).expect("union variant ptr validated");
let variant_val = wiggle_runtime::GuestTypeClone::read_from_guest(&variant_ptr)?;
Ok(#ident::#variantname(variant_val))
}
}
} else {
quote! { #tagname::#variantname => Ok(#ident::#variantname), }
}
});
let write_variant = u.variants.iter().map(|v| {
let variantname = names.enum_variant(&v.name);
let write_tag = quote! {
let tag_ptr = location.cast::<#tagname>(0).expect("union tag ptr TODO error report");
let mut tag_ref = tag_ptr.as_ref_mut().expect("union tag ref TODO error report");
*tag_ref = #tagname::#variantname;
};
if let Some(tref) = &v.tref {
let varianttype = names.type_ref(tref, lifetime.clone());
quote! {
#ident::#variantname(contents) => {
#write_tag
let variant_ptr = location.cast::<#varianttype>(#contents_offset).expect("union variant ptr validated");
contents.write_to_guest(&variant_ptr);
}
}
} else {
quote! {
#ident::#variantname => {
#write_tag
}
}
}
});
let validate = union_validate(names, ident.clone(), u, &ulayout);
if union_is_copy(u) {
// Type does not have a lifetime parameter:
quote! {
#[derive(Clone, Debug, PartialEq)]
pub enum #ident {
#(#variants),*
}
impl wiggle_runtime::GuestType for #ident {
fn size() -> u32 {
#size
}
fn align() -> u32 {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &wiggle_runtime::GuestPtr<#ident>) -> Result<(), wiggle_runtime::GuestError> {
#validate
}
}
impl<#lifetime> wiggle_runtime::GuestTypeClone<#lifetime> for #ident {
fn read_from_guest(location: &wiggle_runtime::GuestPtr<'a, #ident>)
-> Result<Self, wiggle_runtime::GuestError> {
wiggle_runtime::GuestType::validate(location)?;
let tag = *location.cast::<#tagname>(0).expect("validated tag ptr").as_ref().expect("validated tag ref");
match tag {
#(#read_variant)*
}
}
fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<'a, #ident>) {
match self {
#(#write_variant)*
}
}
}
}
} else {
quote! {
#[derive(Clone)]
pub enum #ident<#lifetime> {
#(#variants),*
}
impl<#lifetime> wiggle_runtime::GuestType for #ident<#lifetime> {
fn size() -> u32 {
#size
}
fn align() -> u32 {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &wiggle_runtime::GuestPtr<#ident>) -> Result<(), wiggle_runtime::GuestError> {
#validate
}
}
impl<#lifetime> wiggle_runtime::GuestTypeClone<#lifetime> for #ident<#lifetime> {
fn read_from_guest(location: &wiggle_runtime::GuestPtr<'a, #ident>)
-> Result<Self, wiggle_runtime::GuestError> {
wiggle_runtime::GuestType::validate(location)?;
let tag = *location.cast::<#tagname>(0).expect("validated tag ptr").as_ref().expect("validated tag ref");
match tag {
#(#read_variant)*
}
}
fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<'a, #ident>) {
match self {
#(#write_variant)*
}
}
}
}
}
}
fn define_witx_pointer(
names: &Names,
name: &witx::Id,

43
tests/union.rs Normal file
View File

@@ -0,0 +1,43 @@
use proptest::prelude::*;
use wiggle_runtime::{
GuestError, GuestErrorType, GuestMemory, GuestPtr, GuestPtrMut, GuestRef, GuestRefMut,
};
use wiggle_test::{impl_errno, HostMemory, MemArea, WasiCtx};
wiggle_generate::from_witx!({
witx: ["tests/union.witx"],
ctx: WasiCtx,
});
impl_errno!(types::Errno);
impl union_example::UnionExample for WasiCtx {
fn get_tag(&mut self, u: &types::Reason) -> Result<types::Excuse, types::Errno> {
println!("GET TAG: {:?}", u);
Ok(types::Excuse::DogAte)
}
fn reason_mult(
&mut self,
u: &types::ReasonMut<'_>,
multiply_by: u32,
) -> Result<(), types::Errno> {
match u {
types::ReasonMut::DogAte(fptr) => {
let mut f = fptr.as_ref_mut().expect("valid pointer");
let val = *f;
println!("REASON MULT DogAte({})", val);
*f = val * multiply_by as f32;
}
types::ReasonMut::Traffic(iptr) => {
let mut i = iptr.as_ref_mut().expect("valid pointer");
let val = *i;
println!("REASON MULT Traffic({})", val);
*i = val * multiply_by as i32;
}
types::ReasonMut::Sleeping => {
println!("REASON MULT Sleeping");
}
}
Ok(())
}
}

28
tests/union.witx Normal file
View File

@@ -0,0 +1,28 @@
(use "errno.witx")
(use "excuse.witx")
(typename $reason
(union $excuse
(field $dog_ate f32)
(field $traffic s32)
(empty $sleeping)))
(typename $reason_mut
(union $excuse
(field $dog_ate (@witx pointer f32))
(field $traffic (@witx pointer s32))
(empty $sleeping)))
(module $union_example
(@interface func (export "get_tag")
(param $r $reason)
(result $error $errno)
(result $t $excuse)
)
(@interface func (export "reason_mult")
(param $r $reason_mut)
(param $multiply_by u32)
(result $error $errno)
)
)