From 814dd19488d536ee3a3b7efea9c3804fdcd5feec Mon Sep 17 00:00:00 2001 From: Pat Hickey Date: Tue, 28 Jan 2020 18:17:48 -0800 Subject: [PATCH] structs that contain pointers work! --- crates/generate/src/funcs.rs | 18 ++++++ crates/generate/src/types.rs | 119 ++++++++++++++++++++++++++++++++++- crates/memory/src/memory.rs | 8 +++ src/lib.rs | 9 +++ test.witx | 9 +++ 5 files changed, 162 insertions(+), 1 deletion(-) diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index fe13723f2c..15f3a8fae5 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -223,6 +223,24 @@ fn marshal_arg( }; } } + witx::Type::Struct(s) if !struct_is_copy(&s) => { + let pointee_type = names.type_ref(tref); + let arg_name = names.func_ptr_binding(¶m.name); + let name = names.func_param(¶m.name); + quote! { + let #name = match memory.ptr_mut::<#pointee_type>(#arg_name as u32) { + Ok(p) => match p.read_ptr_from_guest() { + Ok(r) => r, + Err(e) => { + #error_handling + } + }, + Err(e) => { + #error_handling + } + }; + } + } _ => unimplemented!("argument type marshalling"), } } diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs index 13338769e8..21dcaaa4bb 100644 --- a/crates/generate/src/types.rs +++ b/crates/generate/src/types.rs @@ -15,7 +15,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea if struct_is_copy(s) { define_copy_struct(names, &namedtype.name, &s) } else { - unimplemented!("non-Copy struct") + define_ptr_struct(names, &namedtype.name, &s) } } witx::Type::Union(_) => unimplemented!("union types"), @@ -193,6 +193,123 @@ fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) } } +fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -> TokenStream { + let ident = names.type_(name); + let size = s.mem_size_align().size as u32; + let align = s.mem_size_align().align as u32; + + let member_names = s.members.iter().map(|m| names.struct_member(&m.name)); + let member_decls = s.members.iter().map(|m| { + let name = names.struct_member(&m.name); + let type_ = match &m.tref { + witx::TypeRef::Name(nt) => names.type_(&nt.name), + witx::TypeRef::Value(ty) => match &**ty { + witx::Type::Builtin(builtin) => names.builtin_type(*builtin), + witx::Type::Pointer(pointee) => { + let pointee_type = names.type_ref(&pointee); + quote!(::memory::GuestPtrMut<'a, #pointee_type>) + } + witx::Type::ConstPointer(pointee) => { + let pointee_type = names.type_ref(&pointee); + quote!(::memory::GuestPtr<'a, #pointee_type>) + } + _ => unimplemented!("other anonymous struct members"), + }, + }; + quote!(pub #name: #type_) + }); + let member_valids = s.member_layout().into_iter().map(|ml| { + let type_ = match &ml.member.tref { + witx::TypeRef::Name(nt) => names.type_(&nt.name), + witx::TypeRef::Value(ty) => match &**ty { + witx::Type::Builtin(builtin) => names.builtin_type(*builtin), + witx::Type::Pointer(pointee) => { + let pointee_type = names.type_ref(&pointee); + quote!(::memory::GuestPtrMut::<#pointee_type>) + } + witx::Type::ConstPointer(pointee) => { + let pointee_type = names.type_ref(&pointee); + quote!(::memory::GuestPtr::<#pointee_type>) + } + _ => unimplemented!("other anonymous struct members"), + }, + }; + let offset = ml.offset as u32; + quote!( #type_::validate(&ptr.cast(#offset)?)?; ) + }); + + let member_reads = s.member_layout().into_iter().map(|ml| { + let name = names.struct_member(&ml.member.name); + let offset = ml.offset as u32; + match &ml.member.tref { + witx::TypeRef::Name(nt) => { + let type_ = names.type_(&nt.name); + quote! { + let #name = #type_::read_from_guest(&location.cast(#offset)?)?; + } + } + witx::TypeRef::Value(ty) => match &**ty { + witx::Type::Builtin(builtin) => { + let type_ = names.builtin_type(*builtin); + quote! { + let #name = #type_::read_from_guest(&location.cast(#offset)?)?; + } + } + witx::Type::Pointer(pointee) => { + let pointee_type = names.type_ref(&pointee); + quote! { + let #name = ::memory::GuestPtrMut::<#pointee_type>::read_from_guest(&location.cast(#offset)?)?; + } + } + witx::Type::ConstPointer(pointee) => { + let pointee_type = names.type_ref(&pointee); + quote! { + let #name = ::memory::GuestPtr::<#pointee_type>::read_from_guest(&location.cast(#offset)?)?; + } + } + _ => unimplemented!("other anonymous struct members"), + }, + } + }); + + let member_writes = s.member_layout().into_iter().map(|ml| { + let name = names.struct_member(&ml.member.name); + let offset = ml.offset as u32; + quote!( self.#name.write_to_guest(&location.cast(#offset).expect("cast to inner member")); ) + }); + + quote! { + #[derive(Clone)] + pub struct #ident<'a> { + #(#member_decls),* + } + + impl<'a> ::memory::GuestType for #ident<'a> { + fn size() -> u32 { + #size + } + fn align() -> u32 { + #align + } + fn name() -> String { + stringify!(#ident).to_owned() + } + fn validate(ptr: &::memory::GuestPtr<#ident>) -> Result<(), ::memory::GuestError> { + #(#member_valids)* + Ok(()) + } + } + impl<'a> ::memory::GuestTypePtr<'a> for #ident<'a> { + fn read_from_guest(location: &::memory::GuestPtr<'a, #ident<'a>>) -> Result<#ident<'a>, ::memory::GuestError> { + #(#member_reads)* + Ok(#ident { #(#member_names),* }) + } + fn write_to_guest(&self, location: &::memory::GuestPtrMut<'a, Self>) { + #(#member_writes)* + } + } + } +} fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { match int_repr { witx::IntRepr::U8 => quote!(u8), diff --git a/crates/memory/src/memory.rs b/crates/memory/src/memory.rs index d33f4d707c..f301647c14 100644 --- a/crates/memory/src/memory.rs +++ b/crates/memory/src/memory.rs @@ -170,6 +170,14 @@ impl<'a, T: GuestType> GuestPtrMut<'a, T> { pub fn as_raw(&self) -> *const u8 { self.as_immut().as_raw() } + pub fn elem(&self, elements: i32) -> Result, GuestError> { + self.mem + .ptr_mut(self.region.start + (elements * self.region.len as i32) as u32) + } + + pub fn cast(&self, offset: u32) -> Result, GuestError> { + self.mem.ptr_mut(self.region.start + offset) + } } impl<'a, T: GuestTypeCopy> GuestPtrMut<'a, T> { diff --git a/src/lib.rs b/src/lib.rs index b1898b77c4..c28934faf9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,6 +67,15 @@ pub mod test { println!("sum of pair: {:?}", an_pair); Ok(an_pair.first as i64 + an_pair.second as i64) } + fn sum_of_pair_of_ptrs( + &mut self, + an_pair: &types::PairIntPtrs, + ) -> Result { + let first = *an_pair.first.as_ref().unwrap(); + let second = *an_pair.second.as_ref().unwrap(); + println!("sum of pair of ptrs: {} + {}", first, second); + Ok(first as i64 + second as i64) + } } // Errno is used as a first return value in the functions above, therefore // it must implement GuestErrorType with type Context = WasiCtx. diff --git a/test.witx b/test.witx index 87e11396d7..3fa2ae5ac7 100644 --- a/test.witx +++ b/test.witx @@ -19,6 +19,11 @@ (field $first s32) (field $second s32))) +(typename $pair_int_ptrs + (struct + (field $first (@witx const_pointer s32)) + (field $second (@witx const_pointer s32)))) + (module $foo (@interface func (export "bar") (param $an_int u32) @@ -38,4 +43,8 @@ (param $an_pair $pair_ints) (result $error $errno) (result $doubled s64)) + (@interface func (export "sum_of_pair_of_ptrs") + (param $an_pair $pair_int_ptrs) + (result $error $errno) + (result $doubled s64)) )