diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index 861de4c258..2cf4507442 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -141,11 +141,13 @@ fn marshal_arg( let try_into_conversion = { let name = names.func_param(¶m.name); quote! { - use ::std::convert::TryInto; - let #name: #interface_typename = match #name.try_into() { - Ok(a) => a, - Err(e) => { - #error_handling + let #name: #interface_typename = { + use ::std::convert::TryInto; + match #name.try_into() { + Ok(a) => a, + Err(e) => { + #error_handling + } } }; } @@ -322,15 +324,9 @@ where // core type is given func_ptr_binding name. let ptr_name = names.func_ptr_binding(&result.name); let ptr_err_handling = error_handling(&format!("{}:result_ptr_mut", result.name.as_str())); - let ref_err_handling = error_handling(&format!("{}:result_ref_mut", result.name.as_str())); let pre = quote! { let mut #ptr_name = match memory.ptr_mut::<#pointee_type>(#ptr_name as u32) { - Ok(p) => match p.as_ref_mut() { - Ok(r) => r, - Err(e) => { - #ref_err_handling - } - }, + Ok(p) => p, Err(e) => { #ptr_err_handling } @@ -339,7 +335,7 @@ where // trait binding returns func_param name. let val_name = names.func_param(&result.name); let post = quote! { - *#ptr_name = #val_name; + #ptr_name.write_ptr_to_guest(&#val_name); }; (pre, post) }; @@ -361,6 +357,7 @@ where witx::BuiltinType::String => unimplemented!("string types"), }, witx::Type::Enum(_) | witx::Type::Flags(_) | witx::Type::Int(_) => write_val_to_ptr, - _ => unimplemented!("marshal result"), + witx::Type::Struct(_) => write_val_to_ptr, + _ => unimplemented!("missing marshalling result for {:?}", &*tref.type_()), } } diff --git a/crates/generate/src/module_trait.rs b/crates/generate/src/module_trait.rs index 100144dca5..4cb4c45045 100644 --- a/crates/generate/src/module_trait.rs +++ b/crates/generate/src/module_trait.rs @@ -2,16 +2,25 @@ use proc_macro2::TokenStream; use quote::quote; use crate::names::Names; -use crate::types::anon_lifetime; +use crate::types::{anon_lifetime, type_needs_lifetime}; use witx::Module; pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream { let traitname = names.trait_name(&m.name); let traitmethods = m.funcs().map(|f| { + // Check if we're returning an entity anotated with a lifetime, + // in which case, we'll need to annotate the function itself, and + // hence will need an explicit lifetime (rather than anonymous) + let (lifetime, is_anonymous) = if f.results.iter().any(|ret| type_needs_lifetime(&ret.tref)) + { + (quote!('a), false) + } else { + (anon_lifetime(), true) + }; let funcname = names.func(&f.name); let args = f.params.iter().map(|arg| { let arg_name = names.func_param(&arg.name); - let arg_typename = names.type_ref(&arg.tref, anon_lifetime()); + let arg_typename = names.type_ref(&arg.tref, lifetime.clone()); let arg_type = match arg.tref.type_().passed_by() { witx::TypePassedBy::Value { .. } => quote!(#arg_typename), witx::TypePassedBy::Pointer { .. } => quote!(&#arg_typename), @@ -23,13 +32,18 @@ pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream { .results .iter() .skip(1) - .map(|ret| names.type_ref(&ret.tref, anon_lifetime())); + .map(|ret| names.type_ref(&ret.tref, lifetime.clone())); let err = f .results .get(0) - .map(|err_result| names.type_ref(&err_result.tref, anon_lifetime())) + .map(|err_result| names.type_ref(&err_result.tref, lifetime.clone())) .unwrap_or(quote!(())); - quote!(fn #funcname(&mut self, #(#args),*) -> Result<(#(#rets),*), #err>;) + + if is_anonymous { + quote!(fn #funcname(&mut self, #(#args),*) -> Result<(#(#rets),*), #err>;) + } else { + quote!(fn #funcname<#lifetime>(&mut self, #(#args),*) -> Result<(#(#rets),*), #err>;) + } }); quote! { pub trait #traitname { diff --git a/tests/structs.rs b/tests/structs.rs index 0d9aa5c5a0..d282f0a91b 100644 --- a/tests/structs.rs +++ b/tests/structs.rs @@ -1,5 +1,5 @@ use proptest::prelude::*; -use wiggle_runtime::GuestError; +use wiggle_runtime::{GuestError, GuestPtr}; use wiggle_test::{impl_errno, HostMemory, MemArea, WasiCtx}; wiggle_generate::from_witx!({ @@ -34,6 +34,21 @@ impl structs::Structs for WasiCtx { let second = an_pair.second as i64; Ok(first as i64 + second) } + + fn return_pair_ints(&mut self) -> Result { + Ok(types::PairInts { + first: 10, + second: 20, + }) + } + + fn return_pair_of_ptrs<'a>( + &mut self, + first: GuestPtr<'a, i32>, + second: GuestPtr<'a, i32>, + ) -> Result, types::Errno> { + Ok(types::PairIntPtrs { first, second }) + } } #[derive(Debug)] @@ -297,3 +312,146 @@ proptest! { e.test() } } + +#[derive(Debug)] +struct ReturnPairInts { + pub return_loc: MemArea, +} + +impl ReturnPairInts { + pub fn strat() -> BoxedStrategy { + HostMemory::mem_area_strat(8) + .prop_map(|return_loc| ReturnPairInts { return_loc }) + .boxed() + } + + pub fn test(&self) { + let mut ctx = WasiCtx::new(); + let mut host_memory = HostMemory::new(); + let mut guest_memory = host_memory.guest_memory(); + + let err = + structs::return_pair_ints(&mut ctx, &mut guest_memory, self.return_loc.ptr as i32); + + assert_eq!(err, types::Errno::Ok.into(), "return struct errno"); + + let return_struct: types::PairInts = *guest_memory + .ptr(self.return_loc.ptr) + .expect("return ptr") + .as_ref() + .expect("return ref"); + + assert_eq!( + return_struct, + types::PairInts { + first: 10, + second: 20 + }, + "return_pair_ints return value" + ); + } +} + +proptest! { + #[test] + fn return_pair_ints(e in ReturnPairInts::strat()) { + e.test(); + } +} + +#[derive(Debug)] +struct ReturnPairPtrsExercise { + input_first: i32, + input_second: i32, + input_first_loc: MemArea, + input_second_loc: MemArea, + return_loc: MemArea, +} + +impl ReturnPairPtrsExercise { + pub fn strat() -> BoxedStrategy { + ( + prop::num::i32::ANY, + prop::num::i32::ANY, + HostMemory::mem_area_strat(4), + HostMemory::mem_area_strat(4), + HostMemory::mem_area_strat(8), + ) + .prop_map( + |(input_first, input_second, input_first_loc, input_second_loc, return_loc)| { + ReturnPairPtrsExercise { + input_first, + input_second, + input_first_loc, + input_second_loc, + return_loc, + } + }, + ) + .prop_filter("non-overlapping pointers", |e| { + MemArea::non_overlapping_set(&[ + &e.input_first_loc, + &e.input_second_loc, + &e.return_loc, + ]) + }) + .boxed() + } + pub fn test(&self) { + let mut ctx = WasiCtx::new(); + let mut host_memory = HostMemory::new(); + let mut guest_memory = host_memory.guest_memory(); + + *guest_memory + .ptr_mut(self.input_first_loc.ptr) + .expect("input_first ptr") + .as_ref_mut() + .expect("input_first ref") = self.input_first; + *guest_memory + .ptr_mut(self.input_second_loc.ptr) + .expect("input_second ptr") + .as_ref_mut() + .expect("input_second ref") = self.input_second; + + let res = structs::return_pair_of_ptrs( + &mut ctx, + &mut guest_memory, + self.input_first_loc.ptr as i32, + self.input_second_loc.ptr as i32, + self.return_loc.ptr as i32, + ); + + assert_eq!(res, types::Errno::Ok.into(), "return pair of ptrs errno"); + + let ptr_pair_int_ptrs: GuestPtr> = + guest_memory.ptr(self.return_loc.ptr).expect("return ptr"); + let ret_first_ptr: GuestPtr = ptr_pair_int_ptrs + .cast::>(0u32) + .expect("extract ptr to first element in struct") + .clone_from_guest() + .expect("read ptr to first element in struct"); + let ret_second_ptr: GuestPtr = ptr_pair_int_ptrs + .cast::>(4u32) + .expect("extract ptr to second element in struct") + .clone_from_guest() + .expect("read ptr to second element in struct"); + assert_eq!( + self.input_first, + *ret_first_ptr + .as_ref() + .expect("deref extracted ptr to first element") + ); + assert_eq!( + self.input_second, + *ret_second_ptr + .as_ref() + .expect("deref extracted ptr to second element") + ); + } +} +proptest! { + #[test] + fn return_pair_of_ptrs(e in ReturnPairPtrsExercise::strat()) { + e.test() + } +} diff --git a/tests/structs.witx b/tests/structs.witx index c6bb8bb6b0..0542bc68fa 100644 --- a/tests/structs.witx +++ b/tests/structs.witx @@ -29,4 +29,12 @@ (param $an_pair $pair_int_and_ptr) (result $error $errno) (result $double s64)) + (@interface func (export "return_pair_ints") + (result $error $errno) + (result $an_pair $pair_ints)) + (@interface func (export "return_pair_of_ptrs") + (param $first (@witx const_pointer s32)) + (param $second (@witx const_pointer s32)) + (result $error $errno) + (result $an_pair $pair_int_ptrs)) )