structs that contain pointers work!

This commit is contained in:
Pat Hickey
2020-01-28 18:17:48 -08:00
parent 35d9373976
commit 814dd19488
5 changed files with 162 additions and 1 deletions

View File

@@ -223,6 +223,24 @@ fn marshal_arg(
};
}
}
witx::Type::Struct(s) if !struct_is_copy(&s) => {
let pointee_type = names.type_ref(tref);
let arg_name = names.func_ptr_binding(&param.name);
let name = names.func_param(&param.name);
quote! {
let #name = match memory.ptr_mut::<#pointee_type>(#arg_name as u32) {
Ok(p) => match p.read_ptr_from_guest() {
Ok(r) => r,
Err(e) => {
#error_handling
}
},
Err(e) => {
#error_handling
}
};
}
}
_ => unimplemented!("argument type marshalling"),
}
}

View File

@@ -15,7 +15,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea
if struct_is_copy(s) {
define_copy_struct(names, &namedtype.name, &s)
} else {
unimplemented!("non-Copy struct")
define_ptr_struct(names, &namedtype.name, &s)
}
}
witx::Type::Union(_) => unimplemented!("union types"),
@@ -193,6 +193,123 @@ fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype)
}
}
fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -> TokenStream {
let ident = names.type_(name);
let size = s.mem_size_align().size as u32;
let align = s.mem_size_align().align as u32;
let member_names = s.members.iter().map(|m| names.struct_member(&m.name));
let member_decls = s.members.iter().map(|m| {
let name = names.struct_member(&m.name);
let type_ = match &m.tref {
witx::TypeRef::Name(nt) => names.type_(&nt.name),
witx::TypeRef::Value(ty) => match &**ty {
witx::Type::Builtin(builtin) => names.builtin_type(*builtin),
witx::Type::Pointer(pointee) => {
let pointee_type = names.type_ref(&pointee);
quote!(::memory::GuestPtrMut<'a, #pointee_type>)
}
witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(&pointee);
quote!(::memory::GuestPtr<'a, #pointee_type>)
}
_ => unimplemented!("other anonymous struct members"),
},
};
quote!(pub #name: #type_)
});
let member_valids = s.member_layout().into_iter().map(|ml| {
let type_ = match &ml.member.tref {
witx::TypeRef::Name(nt) => names.type_(&nt.name),
witx::TypeRef::Value(ty) => match &**ty {
witx::Type::Builtin(builtin) => names.builtin_type(*builtin),
witx::Type::Pointer(pointee) => {
let pointee_type = names.type_ref(&pointee);
quote!(::memory::GuestPtrMut::<#pointee_type>)
}
witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(&pointee);
quote!(::memory::GuestPtr::<#pointee_type>)
}
_ => unimplemented!("other anonymous struct members"),
},
};
let offset = ml.offset as u32;
quote!( #type_::validate(&ptr.cast(#offset)?)?; )
});
let member_reads = s.member_layout().into_iter().map(|ml| {
let name = names.struct_member(&ml.member.name);
let offset = ml.offset as u32;
match &ml.member.tref {
witx::TypeRef::Name(nt) => {
let type_ = names.type_(&nt.name);
quote! {
let #name = #type_::read_from_guest(&location.cast(#offset)?)?;
}
}
witx::TypeRef::Value(ty) => match &**ty {
witx::Type::Builtin(builtin) => {
let type_ = names.builtin_type(*builtin);
quote! {
let #name = #type_::read_from_guest(&location.cast(#offset)?)?;
}
}
witx::Type::Pointer(pointee) => {
let pointee_type = names.type_ref(&pointee);
quote! {
let #name = ::memory::GuestPtrMut::<#pointee_type>::read_from_guest(&location.cast(#offset)?)?;
}
}
witx::Type::ConstPointer(pointee) => {
let pointee_type = names.type_ref(&pointee);
quote! {
let #name = ::memory::GuestPtr::<#pointee_type>::read_from_guest(&location.cast(#offset)?)?;
}
}
_ => unimplemented!("other anonymous struct members"),
},
}
});
let member_writes = s.member_layout().into_iter().map(|ml| {
let name = names.struct_member(&ml.member.name);
let offset = ml.offset as u32;
quote!( self.#name.write_to_guest(&location.cast(#offset).expect("cast to inner member")); )
});
quote! {
#[derive(Clone)]
pub struct #ident<'a> {
#(#member_decls),*
}
impl<'a> ::memory::GuestType for #ident<'a> {
fn size() -> u32 {
#size
}
fn align() -> u32 {
#align
}
fn name() -> String {
stringify!(#ident).to_owned()
}
fn validate(ptr: &::memory::GuestPtr<#ident>) -> Result<(), ::memory::GuestError> {
#(#member_valids)*
Ok(())
}
}
impl<'a> ::memory::GuestTypePtr<'a> for #ident<'a> {
fn read_from_guest(location: &::memory::GuestPtr<'a, #ident<'a>>) -> Result<#ident<'a>, ::memory::GuestError> {
#(#member_reads)*
Ok(#ident { #(#member_names),* })
}
fn write_to_guest(&self, location: &::memory::GuestPtrMut<'a, Self>) {
#(#member_writes)*
}
}
}
}
fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream {
match int_repr {
witx::IntRepr::U8 => quote!(u8),

View File

@@ -170,6 +170,14 @@ impl<'a, T: GuestType> GuestPtrMut<'a, T> {
pub fn as_raw(&self) -> *const u8 {
self.as_immut().as_raw()
}
pub fn elem(&self, elements: i32) -> Result<GuestPtrMut<'a, T>, GuestError> {
self.mem
.ptr_mut(self.region.start + (elements * self.region.len as i32) as u32)
}
pub fn cast<TT: GuestType>(&self, offset: u32) -> Result<GuestPtrMut<'a, TT>, GuestError> {
self.mem.ptr_mut(self.region.start + offset)
}
}
impl<'a, T: GuestTypeCopy> GuestPtrMut<'a, T> {

View File

@@ -67,6 +67,15 @@ pub mod test {
println!("sum of pair: {:?}", an_pair);
Ok(an_pair.first as i64 + an_pair.second as i64)
}
fn sum_of_pair_of_ptrs(
&mut self,
an_pair: &types::PairIntPtrs,
) -> Result<i64, types::Errno> {
let first = *an_pair.first.as_ref().unwrap();
let second = *an_pair.second.as_ref().unwrap();
println!("sum of pair of ptrs: {} + {}", first, second);
Ok(first as i64 + second as i64)
}
}
// Errno is used as a first return value in the functions above, therefore
// it must implement GuestErrorType with type Context = WasiCtx.

View File

@@ -19,6 +19,11 @@
(field $first s32)
(field $second s32)))
(typename $pair_int_ptrs
(struct
(field $first (@witx const_pointer s32))
(field $second (@witx const_pointer s32))))
(module $foo
(@interface func (export "bar")
(param $an_int u32)
@@ -38,4 +43,8 @@
(param $an_pair $pair_ints)
(result $error $errno)
(result $doubled s64))
(@interface func (export "sum_of_pair_of_ptrs")
(param $an_pair $pair_int_ptrs)
(result $error $errno)
(result $doubled s64))
)