diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index f624e15146..e886308f86 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -8,39 +8,34 @@ use crate::names::Names; // pub fn define_func(names: &Names, func: &witx::InterfaceFunc) -> TokenStream { let ident = names.func(&func.name); + let coretype = func.core_type(); - let arg_signature = |param: &witx::InterfaceFuncParam| -> TokenStream { - let name = names.func_param(¶m.name); - match param.tref.type_().passed_by() { - witx::TypePassedBy::Value(atom) => { - let atom = names.atom_type(atom); - quote!(#name: #atom) - } - witx::TypePassedBy::Pointer => { - let atom = names.atom_type(witx::AtomType::I32); - quote!(#name: #atom) - } - witx::TypePassedBy::PointerLengthPair => { - let atom = names.atom_type(witx::AtomType::I32); - let len_name = names.func_len_param(¶m.name); - quote!(#name: #atom, #len_name: #atom) - } + let params = coretype.args.iter().map(|arg| match arg.signifies { + witx::CoreParamSignifies::Value(atom) => { + let atom = names.atom_type(atom); + let name = names.func_param(&arg.param.name); + quote!(#name : #atom) } - }; + witx::CoreParamSignifies::PointerTo => { + let atom = names.atom_type(witx::AtomType::I32); + let name = names.func_ptr_binding(&arg.param.name); + quote!(#name: #atom) + } + witx::CoreParamSignifies::LengthOf => { + let atom = names.atom_type(witx::AtomType::I32); + let name = names.func_len_binding(&arg.param.name); + quote!(#name: #atom) + } + }); - let params = func - .params - .iter() - .chain(func.results.iter().skip(1)) - .map(arg_signature); let abi_args = quote!( ctx: &mut WasiCtx, memory: ::memory::GuestMemory, #(#params),* ); - let abi_ret = if let Some(first_result) = func.results.get(0) { - match first_result.tref.type_().passed_by() { - witx::TypePassedBy::Value(atom) => names.atom_type(atom), - _ => unreachable!("first result should always be passed by value"), + let abi_ret = if let Some(ret) = &coretype.ret { + match ret.signifies { + witx::CoreParamSignifies::Value(atom) => names.atom_type(atom), + _ => unreachable!("ret should always be passed by value"), } } else if func.noreturn { // Ideally we would return `quote!(!)` here, but, we'd have to change @@ -52,21 +47,34 @@ pub fn define_func(names: &Names, func: &witx::InterfaceFunc) -> TokenStream { quote!(()) }; - let err_type = func - .results - .get(0) - .map(|res| names.type_ref(&res.tref)) - .unwrap_or_else(|| abi_ret.clone()); - let err_val = func - .results - .get(0) + let err_type = coretype.ret.map(|ret| ret.param.tref); + let err_val = err_type + .clone() .map(|_res| quote!(#abi_ret::from(e))) .unwrap_or_else(|| quote!(())); + let error_handling: TokenStream = { + if let Some(tref) = &err_type { + let abi_ret = match tref.type_().passed_by() { + witx::TypePassedBy::Value(atom) => names.atom_type(atom), + _ => unreachable!("err should always be passed by value"), + }; + let err_typename = names.type_ref(&tref); + quote! { + let err: #err_typename = ::memory::GuestErrorType::from_error(e, ctx); + return #abi_ret::from(err); + } + } else { + quote! { + panic!("error: {:?}", e) + } + } + }; + let marshal_args = func .params .iter() - .map(|p| marshal_arg(names, p, func.results.get(0).map(|r| &r.tref))); + .map(|p| marshal_arg(names, p, error_handling.clone())); let trait_args = func .params .iter() @@ -84,88 +92,93 @@ pub fn define_func(names: &Names, func: &witx::InterfaceFunc) -> TokenStream { (tuple.clone(), tuple) }; + // Return value pointers need to be validated before the api call, then + // assigned to afterwards. marshal_result returns these two statements as a pair. let marshal_rets = func .results .iter() .skip(1) - .map(|_result| quote! { unimplemented!("convert result..."); }); + .map(|result| marshal_result(names, result, error_handling.clone())); + let marshal_rets_pre = marshal_rets.clone().map(|(pre, _post)| pre); + let marshal_rets_post = marshal_rets.map(|(_pre, post)| post); + + let success = if let Some(err_type) = err_type { + let err_typename = names.type_ref(&err_type); + quote! { + let success:#err_typename = ::memory::GuestErrorType::success(); + #abi_ret::from(success) + } + } else { + quote!() + }; quote!(pub fn #ident(#abi_args) -> #abi_ret { #(#marshal_args)* + #(#marshal_rets_pre)* let #trait_bindings = match ctx.#ident(#(#trait_args),*) { Ok(#trait_bindings) => #trait_rets, Err(e) => { return #err_val; }, }; - #(#marshal_rets)* - let success:#err_type = ::memory::GuestErrorType::success(); - #abi_ret::from(success) + #(#marshal_rets_post)* + #success }) } fn marshal_arg( names: &Names, param: &witx::InterfaceFuncParam, - error_type: Option<&witx::TypeRef>, + error_handling: TokenStream, ) -> TokenStream { let tref = ¶m.tref; let interface_typename = names.type_ref(&tref); - let name = names.func_param(¶m.name); - let error_handling: TokenStream = { - if let Some(tref) = error_type { - let abi_ret = match tref.type_().passed_by() { - witx::TypePassedBy::Value(atom) => names.atom_type(atom), - _ => unreachable!("err should always be passed by value"), + let try_into_conversion = { + let name = names.func_param(¶m.name); + quote! { + use ::std::convert::TryInto; + let #name: #interface_typename = match #name.try_into() { + Ok(a) => a, + Err(e) => { + #error_handling + } }; - let err_typename = names.type_ref(&tref); - quote! { - let err: #err_typename = ::memory::GuestErrorType::from_error(e, ctx); - return #abi_ret::from(err); - } - } else { - quote! { - panic!("error: {:?}", e) - } } }; - let try_into_conversion = quote! { - use ::std::convert::TryInto; - let #name: #interface_typename = match #name.try_into() { - Ok(a) => a, - Err(e) => { - #error_handling - } - }; - }; - match &*tref.type_() { witx::Type::Enum(_e) => try_into_conversion, witx::Type::Builtin(b) => match b { witx::BuiltinType::U8 | witx::BuiltinType::U16 | witx::BuiltinType::Char8 => { try_into_conversion } - witx::BuiltinType::S8 | witx::BuiltinType::S16 => quote! { - let #name: #interface_typename = match (#name as i32).try_into() { - Ok(a) => a, - Err(e) => { - #error_handling + witx::BuiltinType::S8 | witx::BuiltinType::S16 => { + let name = names.func_param(¶m.name); + quote! { + let #name: #interface_typename = match (#name as i32).try_into() { + Ok(a) => a, + Err(e) => { + #error_handling + } } } - }, + } witx::BuiltinType::U32 | witx::BuiltinType::S32 | witx::BuiltinType::U64 | witx::BuiltinType::S64 | witx::BuiltinType::USize | witx::BuiltinType::F32 - | witx::BuiltinType::F64 => quote! { - let #name = #name as #interface_typename; - }, + | witx::BuiltinType::F64 => { + let name = names.func_param(¶m.name); + quote! { + let #name = #name as #interface_typename; + } + } witx::BuiltinType::String => unimplemented!("string types unimplemented"), }, witx::Type::Pointer(pointee) => { let pointee_type = names.type_ref(pointee); + let name = names.func_param(¶m.name); quote! { let #name = match memory.ptr_mut::<#pointee_type>(#name as u32) { Ok(p) => p, @@ -177,6 +190,7 @@ fn marshal_arg( } witx::Type::ConstPointer(pointee) => { let pointee_type = names.type_ref(pointee); + let name = names.func_param(¶m.name); quote! { let #name = match memory.ptr::<#pointee_type>(#name as u32) { Ok(p) => p, @@ -189,3 +203,52 @@ fn marshal_arg( _ => unimplemented!("argument type marshalling"), } } + +fn marshal_result( + names: &Names, + result: &witx::InterfaceFuncParam, + error_handling: TokenStream, +) -> (TokenStream, TokenStream) { + let tref = &result.tref; + + let write_val_to_ptr = { + let pointee_type = names.type_ref(tref); + // core type is given func_ptr_binding name. + let ptr_name = names.func_ptr_binding(&result.name); + let pre = quote! { + let #ptr_name = match memory.ptr_mut::<#pointee_type>(#ptr_name as u32) { + Ok(p) => p, + Err(e) => { + #error_handling + } + }; + }; + // trait binding returns func_param name. + let val_name = names.func_param(&result.name); + let post = quote! { + use ::memory::GuestTypeCopy; + #pointee_type::write_val(#val_name, &#ptr_name); + }; + (pre, post) + }; + + match &*tref.type_() { + witx::Type::Builtin(b) => match b { + witx::BuiltinType::U8 + | witx::BuiltinType::S8 + | witx::BuiltinType::U16 + | witx::BuiltinType::S16 + | witx::BuiltinType::U32 + | witx::BuiltinType::S32 + | witx::BuiltinType::U64 + | witx::BuiltinType::S64 + | witx::BuiltinType::F32 + | witx::BuiltinType::F64 + | witx::BuiltinType::USize + | witx::BuiltinType::Char8 => write_val_to_ptr, + witx::BuiltinType::String => unimplemented!("string types"), + }, + witx::Type::Enum(_e) => write_val_to_ptr, + _ => unimplemented!("marshal result"), + } +} diff --git a/crates/generate/src/names.rs b/crates/generate/src/names.rs index 6465e7f832..acbb3b00aa 100644 --- a/crates/generate/src/names.rs +++ b/crates/generate/src/names.rs @@ -88,8 +88,13 @@ impl Names { format_ident!("{}", id.as_str().to_snake_case()) } - /// For when you need a {name}_len parameter for passing an array: - pub fn func_len_param(&self, id: &Id) -> Ident { + /// For when you need a {name}_ptr binding for passing a value by reference: + pub fn func_ptr_binding(&self, id: &Id) -> Ident { + format_ident!("{}_ptr", id.as_str().to_snake_case()) + } + + /// For when you need a {name}_len binding for passing an array: + pub fn func_len_binding(&self, id: &Id) -> Ident { format_ident!("{}_len", id.as_str().to_snake_case()) } } diff --git a/src/lib.rs b/src/lib.rs index fd287a0f0c..cf36236e83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,8 +62,12 @@ pub mod test { ); Ok(()) } - } + fn bat(&mut self, an_int: u32) -> Result { + println!("bat: {}", an_int); + Ok((an_int as f32) * 2.0) + } + } // Errno is used as a first return value in the functions above, therefore // it must implement GuestErrorType with type Context = WasiCtx. // The context type should let you do logging or debugging or whatever you need diff --git a/test.witx b/test.witx index 16964d8eea..422d99bd72 100644 --- a/test.witx +++ b/test.witx @@ -25,4 +25,8 @@ (param $a_lamer_excuse (@witx const_pointer $excuse)) (param $two_layers_of_excuses (@witx pointer (@witx const_pointer $excuse))) (result $error $errno)) + (@interface func (export "bat") + (param $an_int u32) + (result $error $errno) + (result $doubled_it f32)) )