Rewrite for recursive safety

This commit rewrites the runtime crate to provide safety in the face
of recursive calls to the guest. The basic principle is that
`GuestMemory` is now a trait which dynamically returns the
pointer/length pair. This also has an implicit contract (hence the
`unsafe` trait) that the pointer/length pair point to a valid list of
bytes in host memory "until something is reentrant".

After this changes the various suite of `Guest*` types were rewritten.
`GuestRef` and `GuestRefMut` were both removed since they cannot safely
exist. The `GuestPtrMut` type was removed for simplicity, and the final
`GuestPtr` type subsumes `GuestString` and `GuestArray`. This means
that there's only one guest pointer type, `GuestPtr<'a, T>`, where `'a`
is the borrow into host memory, basically borrowing the `GuestMemory`
trait object itself.

Some core utilities are exposed on `GuestPtr`, but they're all 100%
safe. Unsafety is now entirely contained within a few small locations:

* Implementations of the `GuestType` for primitive types (e.g. `i8`,
  `u8`, etc) use `unsafe` to read/write memory. The `unsafe` trait of
  `GuestMemory` though should prove that they're safe.

* `GuestPtr<'_, str>` has a method which validates utf-8 contents, and
  this requires `unsafe` internally to read all the bytes. This is
  guaranteed to be safe however given the contract of `GuestMemory`.

And that's it! Everything else is a bunch of safe combinators all built
up on the various utilities provided by `GuestPtr`. The general idioms
are roughly the same as before, with various tweaks here and there. A
summary of expected idioms are:

* For small values you'd `.read()` or `.write()` very quickly. You'd
  pass around the type itself.

* For strings, you'd pass `GuestPtr<'_, str>` down to the point where
  it's actually consumed. At that moment you'd either decide to copy it
  out (a safe operation) or you'd get a raw view to the string (an
  unsafe operation) and assert that you won't call back into wasm while
  you're holding that pointer.

* Arrays are similar to strings, passing around `GuestPtr<'_, [T]>`.
  Arrays also have a `iter()` method which yields an iterator of
  `GuestPtr<'_, T>` for convenience.

Overall there's still a lot of missing documentation on the runtime
crate specifically around the safety of the `GuestMemory` trait as well
as how the utilities/methods are expected to be used. Additionally
there's utilities which aren't currently implemented which would be easy
to implement. For example there's no method to copy out a string or a
slice, although that would be pretty easy to add.

In any case I'm curious to get feedback on this approach and see what
y'all think!
This commit is contained in:
Alex Crichton
2020-03-04 10:21:34 -08:00
parent 3764204250
commit ca9f33b6d9
28 changed files with 751 additions and 2013 deletions

View File

@@ -1,7 +1,7 @@
use proc_macro2::TokenStream;
use quote::quote;
use crate::lifetimes::{anon_lifetime, LifetimeExt};
use crate::lifetimes::anon_lifetime;
use crate::names::Names;
pub fn define_func(names: &Names, func: &witx::InterfaceFunc) -> TokenStream {
@@ -30,7 +30,7 @@ pub fn define_func(names: &Names, func: &witx::InterfaceFunc) -> TokenStream {
});
let abi_args = quote!(
ctx: &mut #ctx_type, memory: &mut wiggle_runtime::GuestMemory,
ctx: &#ctx_type, memory: &dyn wiggle_runtime::GuestMemory,
#(#params),*
);
let abi_ret = if let Some(ret) = &coretype.ret {
@@ -158,13 +158,8 @@ fn marshal_arg(
let arg_name = names.func_ptr_binding(&param.name);
let name = names.func_param(&param.name);
quote! {
let #name = match memory.ptr::<#pointee_type>(#arg_name as u32) {
Ok(p) => match p.read() {
Ok(r) => r,
Err(e) => {
#error_handling
}
},
let #name = match wiggle_runtime::GuestPtr::<#pointee_type>::new(memory, #arg_name as u32).read() {
Ok(r) => r,
Err(e) => {
#error_handling
}
@@ -209,102 +204,25 @@ fn marshal_arg(
let len_name = names.func_len_binding(&param.name);
let name = names.func_param(&param.name);
quote! {
let num_elems = match memory.ptr::<u32>(#len_name as u32) {
Ok(p) => match p.as_ref() {
Ok(r) => r,
Err(e) => {
#error_handling
}
}
Err(e) => {
#error_handling
}
};
let #name: wiggle_runtime::GuestString<#lifetime> = match memory.ptr::<u8>(#ptr_name as u32) {
Ok(p) => match p.array(*num_elems) {
Ok(s) => s.into(),
Err(e) => {
#error_handling
}
}
Err(e) => {
#error_handling
}
};
let #name = wiggle_runtime::GuestPtr::<#lifetime, str>::new(memory, (#ptr_name as u32, #len_name as u32));
}
}
},
witx::Type::Pointer(pointee) => {
witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(pointee, anon_lifetime());
let name = names.func_param(&param.name);
quote! {
let #name = match memory.ptr_mut::<#pointee_type>(#name as u32) {
Ok(p) => p,
Err(e) => {
#error_handling
}
};
let #name = wiggle_runtime::GuestPtr::<#pointee_type>::new(memory, #name as u32);
}
}
witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(pointee, anon_lifetime());
let name = names.func_param(&param.name);
quote! {
let #name = match memory.ptr::<#pointee_type>(#name as u32) {
Ok(p) => p,
Err(e) => {
#error_handling
}
};
}
}
witx::Type::Struct(s) if !s.needs_lifetime() => {
let pointee_type = names.type_ref(tref, anon_lifetime());
let arg_name = names.func_ptr_binding(&param.name);
let name = names.func_param(&param.name);
quote! {
let #name = match memory.ptr::<#pointee_type>(#arg_name as u32) {
Ok(p) => match p.as_ref() {
Ok(r) => r,
Err(e) => {
#error_handling
}
},
Err(e) => {
#error_handling
}
};
}
}
witx::Type::Struct(s) if s.needs_lifetime() => read_conversion,
witx::Type::Struct(_) => read_conversion,
witx::Type::Array(arr) => {
let pointee_type = names.type_ref(arr, anon_lifetime());
let ptr_name = names.func_ptr_binding(&param.name);
let len_name = names.func_len_binding(&param.name);
let name = names.func_param(&param.name);
quote! {
let num_elems = match memory.ptr::<u32>(#len_name as u32) {
Ok(p) => match p.as_ref() {
Ok(r) => r,
Err(e) => {
#error_handling
}
}
Err(e) => {
#error_handling
}
};
let #name = match memory.ptr::<#pointee_type>(#ptr_name as u32) {
Ok(p) => match p.array(*num_elems) {
Ok(s) => s,
Err(e) => {
#error_handling
}
}
Err(e) => {
#error_handling
}
};
let #name = wiggle_runtime::GuestPtr::<[#pointee_type]>::new(memory, (#ptr_name as u32, #len_name as u32));
}
}
witx::Type::Union(_u) => read_conversion,
@@ -313,7 +231,6 @@ fn marshal_arg(
let handle_type = names.type_ref(tref, anon_lifetime());
quote!( let #name = #handle_type::from(#name); )
}
_ => unimplemented!("argument type marshalling"),
}
}
@@ -333,17 +250,14 @@ where
let ptr_name = names.func_ptr_binding(&result.name);
let ptr_err_handling = error_handling(&format!("{}:result_ptr_mut", result.name.as_str()));
let pre = quote! {
let mut #ptr_name = match memory.ptr_mut::<#pointee_type>(#ptr_name as u32) {
Ok(p) => p,
Err(e) => {
#ptr_err_handling
}
};
let #ptr_name = wiggle_runtime::GuestPtr::<#pointee_type>::new(memory, #ptr_name as u32);
};
// trait binding returns func_param name.
let val_name = names.func_param(&result.name);
let post = quote! {
#ptr_name.write(&#val_name);
if let Err(e) = #ptr_name.write(#val_name) {
#ptr_err_handling
}
};
(pre, post)
};

View File

@@ -11,7 +11,12 @@ pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream {
// Check if we're returning an entity anotated with a lifetime,
// in which case, we'll need to annotate the function itself, and
// hence will need an explicit lifetime (rather than anonymous)
let (lifetime, is_anonymous) = if f.results.iter().any(|ret| ret.tref.needs_lifetime()) {
let (lifetime, is_anonymous) = if f
.params
.iter()
.chain(&f.results)
.any(|ret| ret.tref.needs_lifetime())
{
(quote!('a), false)
} else {
(anon_lifetime(), true)

View File

@@ -24,7 +24,7 @@ impl Names {
}
pub fn builtin_type(&self, b: BuiltinType, lifetime: TokenStream) -> TokenStream {
match b {
BuiltinType::String => quote!(wiggle_runtime::GuestString<#lifetime>),
BuiltinType::String => quote!(wiggle_runtime::GuestPtr<#lifetime, str>),
BuiltinType::U8 => quote!(u8),
BuiltinType::U16 => quote!(u16),
BuiltinType::U32 => quote!(u32),
@@ -60,11 +60,7 @@ impl Names {
}
TypeRef::Value(ty) => match &**ty {
witx::Type::Builtin(builtin) => self.builtin_type(*builtin, lifetime.clone()),
witx::Type::Pointer(pointee) => {
let pointee_type = self.type_ref(&pointee, lifetime.clone());
quote!(wiggle_runtime::GuestPtrMut<#lifetime, #pointee_type>)
}
witx::Type::ConstPointer(pointee) => {
witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => {
let pointee_type = self.type_ref(&pointee, lifetime.clone());
quote!(wiggle_runtime::GuestPtr<#lifetime, #pointee_type>)
}

View File

@@ -77,37 +77,25 @@ pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident {
fn size() -> u32 {
::std::mem::size_of::<#repr>() as u32
fn guest_size() -> u32 {
#repr::guest_size()
}
fn align() -> u32 {
::std::mem::align_of::<#repr>() as u32
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<(), wiggle_runtime::GuestError> {
use ::std::convert::TryFrom;
let raw: #repr = unsafe { (location.as_raw() as *const #repr).read() };
let _ = #ident::try_from(raw)?;
Ok(())
fn guest_align() -> usize {
#repr::guest_align()
}
fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
// Perform validation as part of as_ref:
let r = location.as_ref()?;
Ok(*r)
use std::convert::TryFrom;
let val = #repr::read(&location.cast())?;
#ident::try_from(val)
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<#ident>) {
let val: #repr = #repr::from(*self);
unsafe { (location.as_raw() as *mut #repr).write(val) };
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self)
-> Result<(), wiggle_runtime::GuestError>
{
#repr::write(&location.cast(), #repr::from(val))
}
}
impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {}
}
}

View File

@@ -126,35 +126,24 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident {
fn size() -> u32 {
::std::mem::size_of::<#repr>() as u32
fn guest_size() -> u32 {
#repr::guest_size()
}
fn align() -> u32 {
::std::mem::align_of::<#repr>() as u32
fn guest_align() -> usize {
#repr::guest_align()
}
fn name() -> String {
stringify!(#ident).to_owned()
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
use std::convert::TryFrom;
let bits = #repr::read(&location.cast())?;
#ident::try_from(bits)
}
fn validate(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<(), wiggle_runtime::GuestError> {
use ::std::convert::TryFrom;
let raw: #repr = unsafe { (location.as_raw() as *const #repr).read() };
let _ = #ident::try_from(raw)?;
Ok(())
}
fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
Ok(*location.as_ref()?)
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<#ident>) {
let val: #repr = #repr::from(*self);
unsafe { (location.as_raw() as *mut #repr).write(val) };
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
let val: #repr = #repr::from(val);
#repr::write(&location.cast(), val)
}
}
impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {}
}
}

View File

@@ -11,7 +11,7 @@ pub(super) fn define_handle(
) -> TokenStream {
let ident = names.type_(name);
let size = h.mem_size_align().size as u32;
let align = h.mem_size_align().align as u32;
let align = h.mem_size_align().align as usize;
quote! {
#[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)]
pub struct #ident(u32);
@@ -46,32 +46,21 @@ pub(super) fn define_handle(
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident {
fn size() -> u32 {
fn guest_size() -> u32 {
#size
}
fn align() -> u32 {
fn guest_align() -> usize {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &wiggle_runtime::GuestPtr<#ident>) -> Result<(), wiggle_runtime::GuestError> {
Ok(())
}
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
let r = location.as_ref()?;
Ok(*r)
Ok(#ident(u32::read(&location.cast())?))
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<'a, Self>) {
unsafe { (location.as_raw() as *mut #ident).write(*self) };
fn write(location: &wiggle_runtime::GuestPtr<'_, Self>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
u32::write(&location.cast(), val.0)
}
}
impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {}
}
}

View File

@@ -63,35 +63,21 @@ pub(super) fn define_int(names: &Names, name: &witx::Id, i: &witx::IntDatatype)
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident {
fn size() -> u32 {
::std::mem::size_of::<#repr>() as u32
fn guest_size() -> u32 {
#repr::guest_size()
}
fn align() -> u32 {
::std::mem::align_of::<#repr>() as u32
fn guest_align() -> usize {
#repr::guest_align()
}
fn name() -> String {
stringify!(#ident).to_owned()
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
Ok(#ident(#repr::read(&location.cast())?))
}
fn validate(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<(), wiggle_runtime::GuestError> {
use ::std::convert::TryFrom;
let raw: #repr = unsafe { (location.as_raw() as *const #repr).read() };
let _ = #ident::try_from(raw)?;
Ok(())
}
fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
Ok(*location.as_ref()?)
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<#ident>) {
let val: #repr = #repr::from(*self);
unsafe { (location.as_raw() as *mut #repr).write(val) };
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
#repr::write(&location.cast(), val.0)
}
}
impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {}
}
}

View File

@@ -25,7 +25,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea
witx::Type::Pointer(p) => define_witx_pointer(
names,
&namedtype.name,
quote!(wiggle_runtime::GuestPtrMut),
quote!(wiggle_runtime::GuestPtr),
p,
),
witx::Type::ConstPointer(p) => {
@@ -71,7 +71,7 @@ fn define_witx_pointer(
fn define_witx_array(names: &Names, name: &witx::Id, arr_raw: &witx::TypeRef) -> TokenStream {
let ident = names.type_(name);
let pointee_type = names.type_ref(arr_raw, quote!('a));
quote!(pub type #ident<'a> = wiggle_runtime::GuestArray<'a, #pointee_type>;)
quote!(pub type #ident<'a> = wiggle_runtime::GuestPtr<'a, [#pointee_type]>;)
}
fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream {

View File

@@ -10,87 +10,9 @@ pub(super) fn define_struct(
name: &witx::Id,
s: &witx::StructDatatype,
) -> TokenStream {
if !s.needs_lifetime() {
define_copy_struct(names, name, s)
} else {
define_ptr_struct(names, name, s)
}
}
fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -> TokenStream {
let ident = names.type_(name);
let size = s.mem_size_align().size as u32;
let align = s.mem_size_align().align as u32;
let member_decls = s.members.iter().map(|m| {
let name = names.struct_member(&m.name);
let type_ = names.type_ref(&m.tref, anon_lifetime());
quote!(pub #name: #type_)
});
let member_valids = s.member_layout().into_iter().map(|ml| {
let type_ = names.type_ref(&ml.member.tref, anon_lifetime());
let offset = ml.offset as u32;
let fieldname = names.struct_member(&ml.member.name);
quote! {
#type_::validate(
&ptr.cast(#offset).map_err(|e|
wiggle_runtime::GuestError::InDataField{
typename: stringify!(#ident).to_owned(),
field: stringify!(#fieldname).to_owned(),
err: Box::new(e),
})?
).map_err(|e|
wiggle_runtime::GuestError::InDataField {
typename: stringify!(#ident).to_owned(),
field: stringify!(#fieldname).to_owned(),
err: Box::new(e),
})?;
}
});
quote! {
#[repr(C)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct #ident {
#(#member_decls),*
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident {
fn size() -> u32 {
#size
}
fn align() -> u32 {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &wiggle_runtime::GuestPtr<#ident>) -> Result<(), wiggle_runtime::GuestError> {
#(#member_valids)*
Ok(())
}
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
let r = location.as_ref()?;
Ok(*r)
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<'a, Self>) {
unsafe { (location.as_raw() as *mut #ident).write(*self) };
}
}
impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {}
}
}
fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -> TokenStream {
let ident = names.type_(name);
let size = s.mem_size_align().size as u32;
let align = s.mem_size_align().align as u32;
let align = s.mem_size_align().align as usize;
let member_names = s.members.iter().map(|m| names.struct_member(&m.name));
let member_decls = s.members.iter().map(|m| {
@@ -99,11 +21,7 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -
witx::TypeRef::Name(nt) => names.type_(&nt.name),
witx::TypeRef::Value(ty) => match &**ty {
witx::Type::Builtin(builtin) => names.builtin_type(*builtin, quote!('a)),
witx::Type::Pointer(pointee) => {
let pointee_type = names.type_ref(&pointee, quote!('a));
quote!(wiggle_runtime::GuestPtrMut<'a, #pointee_type>)
}
witx::Type::ConstPointer(pointee) => {
witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(&pointee, quote!('a));
quote!(wiggle_runtime::GuestPtr<'a, #pointee_type>)
}
@@ -112,68 +30,29 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -
};
quote!(pub #name: #type_)
});
let member_valids = s.member_layout().into_iter().map(|ml| {
let type_ = match &ml.member.tref {
witx::TypeRef::Name(nt) => names.type_(&nt.name),
witx::TypeRef::Value(ty) => match &**ty {
witx::Type::Builtin(builtin) => names.builtin_type(*builtin, quote!('a)),
witx::Type::Pointer(pointee) => {
let pointee_type = names.type_ref(&pointee, anon_lifetime());
quote!(wiggle_runtime::GuestPtrMut::<#pointee_type>)
}
witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(&pointee, anon_lifetime());
quote!(wiggle_runtime::GuestPtr::<#pointee_type>)
}
_ => unimplemented!("other anonymous struct members"),
},
};
let offset = ml.offset as u32;
let fieldname = names.struct_member(&ml.member.name);
quote! {
#type_::validate(
&ptr.cast(#offset).map_err(|e|
wiggle_runtime::GuestError::InDataField{
typename: stringify!(#ident).to_owned(),
field: stringify!(#fieldname).to_owned(),
err: Box::new(e),
})?
).map_err(|e|
wiggle_runtime::GuestError::InDataField {
typename: stringify!(#ident).to_owned(),
field: stringify!(#fieldname).to_owned(),
err: Box::new(e),
})?;
}
});
let member_reads = s.member_layout().into_iter().map(|ml| {
let name = names.struct_member(&ml.member.name);
let offset = ml.offset as u32;
let location = quote!(location.cast::<u8>().add(#offset)?.cast());
match &ml.member.tref {
witx::TypeRef::Name(nt) => {
let type_ = names.type_(&nt.name);
quote! {
let #name = <#type_ as wiggle_runtime::GuestType>::read(&location.cast(#offset)?)?;
let #name = <#type_ as wiggle_runtime::GuestType>::read(&#location)?;
}
}
witx::TypeRef::Value(ty) => match &**ty {
witx::Type::Builtin(builtin) => {
let type_ = names.builtin_type(*builtin, anon_lifetime());
quote! {
let #name = <#type_ as wiggle_runtime::GuestType>::read(&location.cast(#offset)?)?;
let #name = <#type_ as wiggle_runtime::GuestType>::read(&#location)?;
}
}
witx::Type::Pointer(pointee) => {
witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(&pointee, anon_lifetime());
quote! {
let #name = <wiggle_runtime::GuestPtrMut::<#pointee_type> as wiggle_runtime::GuestType>::read(&location.cast(#offset)?)?;
}
}
witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(&pointee, anon_lifetime());
quote! {
let #name = <wiggle_runtime::GuestPtr::<#pointee_type> as wiggle_runtime::GuestType>::read(&location.cast(#offset)?)?;
let #name = <wiggle_runtime::GuestPtr::<#pointee_type> as wiggle_runtime::GuestType>::read(&#location)?;
}
}
_ => unimplemented!("other anonymous struct members"),
@@ -185,41 +64,42 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -
let name = names.struct_member(&ml.member.name);
let offset = ml.offset as u32;
quote! {
wiggle_runtime::GuestType::write(&self.#name, &location.cast(#offset).expect("cast to inner member"));
wiggle_runtime::GuestType::write(
&location.cast::<u8>().add(#offset)?.cast(),
val.#name,
)?;
}
});
let (struct_lifetime, extra_derive) = if s.needs_lifetime() {
(quote!(<'a>), quote!())
} else {
(quote!(), quote!(, Copy, PartialEq))
};
quote! {
#[derive(Clone, Debug)]
pub struct #ident<'a> {
#[derive(Clone, Debug #extra_derive)]
pub struct #ident #struct_lifetime {
#(#member_decls),*
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident<'a> {
fn size() -> u32 {
impl<'a> wiggle_runtime::GuestType<'a> for #ident #struct_lifetime {
fn guest_size() -> u32 {
#size
}
fn align() -> u32 {
fn guest_align() -> usize {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &wiggle_runtime::GuestPtr<'a, #ident<'a>>) -> Result<(), wiggle_runtime::GuestError> {
#(#member_valids)*
Ok(())
}
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident<'a>>) -> Result<#ident<'a>, wiggle_runtime::GuestError> {
fn read(location: &wiggle_runtime::GuestPtr<'a, Self>) -> Result<Self, wiggle_runtime::GuestError> {
#(#member_reads)*
Ok(#ident { #(#member_names),* })
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<'a, Self>) {
fn write(location: &wiggle_runtime::GuestPtr<'_, Self>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
#(#member_writes)*
Ok(())
}
}
}

View File

@@ -1,4 +1,4 @@
use crate::lifetimes::{anon_lifetime, LifetimeExt};
use crate::lifetimes::LifetimeExt;
use crate::names::Names;
use proc_macro2::TokenStream;
@@ -8,7 +8,7 @@ use witx::Layout;
pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDatatype) -> TokenStream {
let ident = names.type_(name);
let size = u.mem_size_align().size as u32;
let align = u.mem_size_align().align as u32;
let align = u.mem_size_align().align as usize;
let ulayout = u.union_layout();
let contents_offset = ulayout.contents_offset as u32;
@@ -32,8 +32,8 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty
let varianttype = names.type_ref(tref, lifetime.clone());
quote! {
#tagname::#variantname => {
let variant_ptr = location.cast::<#varianttype>(#contents_offset).expect("union variant ptr validated");
let variant_val = <#varianttype as wiggle_runtime::GuestType>::read(&variant_ptr)?;
let variant_ptr = location.cast::<u8>().add(#contents_offset)?;
let variant_val = <#varianttype as wiggle_runtime::GuestType>::read(&variant_ptr.cast())?;
Ok(#ident::#variantname(variant_val))
}
}
@@ -45,17 +45,15 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty
let write_variant = u.variants.iter().map(|v| {
let variantname = names.enum_variant(&v.name);
let write_tag = quote! {
let tag_ptr = location.cast::<#tagname>(0).expect("union tag ptr TODO error report");
let mut tag_ref = tag_ptr.as_ref_mut().expect("union tag ref TODO error report");
*tag_ref = #tagname::#variantname;
location.cast().write(#tagname::#variantname)?;
};
if let Some(tref) = &v.tref {
let varianttype = names.type_ref(tref, lifetime.clone());
quote! {
#ident::#variantname(contents) => {
#write_tag
let variant_ptr = location.cast::<#varianttype>(#contents_offset).expect("union variant ptr validated");
<#varianttype as wiggle_runtime::GuestType>::write(&contents, &variant_ptr);
let variant_ptr = location.cast::<u8>().add(#contents_offset)?;
<#varianttype as wiggle_runtime::GuestType>::write(&variant_ptr.cast(), contents)?;
}
}
} else {
@@ -66,134 +64,46 @@ pub(super) fn define_union(names: &Names, name: &witx::Id, u: &witx::UnionDataty
}
}
});
let validate = union_validate(names, ident.clone(), u, &ulayout);
if !u.needs_lifetime() {
// Type does not have a lifetime parameter:
quote! {
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum #ident {
#(#variants),*
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident {
fn size() -> u32 {
#size
}
fn align() -> u32 {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<(), wiggle_runtime::GuestError> {
#validate
}
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>)
-> Result<Self, wiggle_runtime::GuestError> {
<#ident as wiggle_runtime::GuestType>::validate(location)?;
let tag = *location.cast::<#tagname>(0).expect("validated tag ptr").as_ref().expect("validated tag ref");
match tag {
#(#read_variant)*
}
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<'a, #ident>) {
match self {
#(#write_variant)*
}
}
}
}
let (enum_lifetime, extra_derive) = if u.needs_lifetime() {
(quote!(<'a>), quote!())
} else {
quote! {
#[derive(Clone, Debug)]
pub enum #ident<#lifetime> {
#(#variants),*
}
impl<#lifetime> wiggle_runtime::GuestType<#lifetime> for #ident<#lifetime> {
fn size() -> u32 {
#size
}
fn align() -> u32 {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &wiggle_runtime::GuestPtr<#lifetime, #ident<#lifetime>>) -> Result<(), wiggle_runtime::GuestError> {
#validate
}
fn read(location: &wiggle_runtime::GuestPtr<#lifetime, #ident<#lifetime>>)
-> Result<Self, wiggle_runtime::GuestError> {
<#ident as wiggle_runtime::GuestType>::validate(location)?;
let tag = *location.cast::<#tagname>(0).expect("validated tag ptr").as_ref().expect("validated tag ref");
match tag {
#(#read_variant)*
}
}
fn write(&self, location: &wiggle_runtime::GuestPtrMut<#lifetime, #ident<#lifetime>>) {
match self {
#(#write_variant)*
}
}
}
}
}
}
fn union_validate(
names: &Names,
typename: TokenStream,
u: &witx::UnionDatatype,
ulayout: &witx::UnionLayout,
) -> TokenStream {
let tagname = names.type_(&u.tag.name);
let contents_offset = ulayout.contents_offset as u32;
let with_err = |f: &str| -> TokenStream {
quote!(|e| wiggle_runtime::GuestError::InDataField {
typename: stringify!(#typename).to_owned(),
field: #f.to_owned(),
err: Box::new(e),
})
(quote!(), quote!(, Copy, PartialEq))
};
let tag_err = with_err("<tag>");
let variant_validation = u.variants.iter().map(|v| {
let err = with_err(v.name.as_str());
let variantname = names.enum_variant(&v.name);
if let Some(tref) = &v.tref {
let lifetime = anon_lifetime();
let varianttype = names.type_ref(tref, lifetime.clone());
quote! {
#tagname::#variantname => {
let variant_ptr = ptr.cast::<#varianttype>(#contents_offset).map_err(#err)?;
<#varianttype as wiggle_runtime::GuestType>::validate(&variant_ptr).map_err(#err)?;
}
}
} else {
quote! { #tagname::#variantname => {} }
}
});
quote! {
let tag = *ptr.cast::<#tagname>(0).map_err(#tag_err)?.as_ref().map_err(#tag_err)?;
match tag {
#(#variant_validation)*
#[derive(Clone, Debug #extra_derive)]
pub enum #ident #enum_lifetime {
#(#variants),*
}
impl<'a> wiggle_runtime::GuestType<'a> for #ident #enum_lifetime {
fn guest_size() -> u32 {
#size
}
fn guest_align() -> usize {
#align
}
fn read(location: &wiggle_runtime::GuestPtr<'a, Self>)
-> Result<Self, wiggle_runtime::GuestError>
{
let tag = location.cast().read()?;
match tag {
#(#read_variant)*
}
}
fn write(location: &wiggle_runtime::GuestPtr<'_, Self>, val: Self)
-> Result<(), wiggle_runtime::GuestError>
{
match val {
#(#write_variant)*
}
Ok(())
}
}
Ok(())
}
}