diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index 9a3c1d4719..fe13723f2c 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -2,6 +2,7 @@ use proc_macro2::TokenStream; use quote::quote; use crate::names::Names; +use crate::types::struct_is_copy; // FIXME need to template what argument is required to an import function - some context // struct (e.g. WasiCtx) should be provided at the invocation of the `gen` proc macro. @@ -75,10 +76,14 @@ pub fn define_func(names: &Names, func: &witx::InterfaceFunc) -> TokenStream { .params .iter() .map(|p| marshal_arg(names, p, error_handling.clone())); - let trait_args = func - .params - .iter() - .map(|param| names.func_param(¶m.name)); + let trait_args = func.params.iter().map(|param| { + let name = names.func_param(¶m.name); + match param.tref.type_().passed_by() { + witx::TypePassedBy::Value { .. } => quote!(#name), + witx::TypePassedBy::Pointer { .. } => quote!(&#name), + witx::TypePassedBy::PointerLengthPair { .. } => unimplemented!(), + } + }); let (trait_rets, trait_bindings) = if func.results.len() < 2 { (quote!({}), quote!(_)) @@ -200,6 +205,24 @@ fn marshal_arg( }; } } + witx::Type::Struct(s) if struct_is_copy(&s) => { + let pointee_type = names.type_ref(tref); + let arg_name = names.func_ptr_binding(¶m.name); + let name = names.func_param(¶m.name); + quote! { + let #name = match memory.ptr::<#pointee_type>(#arg_name as u32) { + Ok(p) => match p.as_ref() { + Ok(r) => r, + Err(e) => { + #error_handling + } + }, + Err(e) => { + #error_handling + } + }; + } + } _ => unimplemented!("argument type marshalling"), } } diff --git a/crates/generate/src/module_trait.rs b/crates/generate/src/module_trait.rs index a25addd710..19669fa410 100644 --- a/crates/generate/src/module_trait.rs +++ b/crates/generate/src/module_trait.rs @@ -10,7 +10,12 @@ pub fn define_module_trait(names: &Names, m: &Module) -> TokenStream { let funcname = names.func(&f.name); let args = f.params.iter().map(|arg| { let arg_name = names.func_param(&arg.name); - let arg_type = names.type_ref(&arg.tref); + let arg_typename = names.type_ref(&arg.tref); + let arg_type = match arg.tref.type_().passed_by() { + witx::TypePassedBy::Value { .. } => quote!(#arg_typename), + witx::TypePassedBy::Pointer { .. } => quote!(&#arg_typename), + witx::TypePassedBy::PointerLengthPair { .. } => unimplemented!(), + }; quote!(#arg_name: #arg_type) }); let rets = f diff --git a/crates/generate/src/names.rs b/crates/generate/src/names.rs index acbb3b00aa..558a7e28dc 100644 --- a/crates/generate/src/names.rs +++ b/crates/generate/src/names.rs @@ -72,6 +72,10 @@ impl Names { } } + pub fn struct_member(&self, id: &Id) -> Ident { + format_ident!("{}", id.as_str().to_snake_case()) + } + pub fn module(&self, id: &Id) -> Ident { format_ident!("{}", id.as_str().to_snake_case()) } diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs index a49995b260..c5185a51a8 100644 --- a/crates/generate/src/types.rs +++ b/crates/generate/src/types.rs @@ -2,6 +2,7 @@ use crate::names::Names; use proc_macro2::TokenStream; use quote::quote; +use witx::Layout; pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStream { match &namedtype.tref { @@ -10,7 +11,13 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea witx::Type::Enum(e) => define_enum(names, &namedtype.name, &e), witx::Type::Int(_) => unimplemented!("int types"), witx::Type::Flags(_) => unimplemented!("flag types"), - witx::Type::Struct(_) => unimplemented!("struct types"), + witx::Type::Struct(s) => { + if struct_is_copy(s) { + define_copy_struct(names, &namedtype.name, &s) + } else { + unimplemented!("non-Copy struct") + } + } witx::Type::Union(_) => unimplemented!("union types"), witx::Type::Handle(_h) => unimplemented!("handle types"), witx::Type::Builtin(b) => define_builtin(names, &namedtype.name, *b), @@ -126,6 +133,59 @@ fn define_builtin(names: &Names, name: &witx::Id, builtin: witx::BuiltinType) -> quote!(pub type #ident = #built;) } +pub fn struct_is_copy(s: &witx::StructDatatype) -> bool { + s.members.iter().all(|m| match &*m.tref.type_() { + witx::Type::Struct(s) => struct_is_copy(&s), + witx::Type::Builtin(b) => match &*b { + witx::BuiltinType::String => false, + _ => true, + }, + witx::Type::ConstPointer { .. } + | witx::Type::Pointer { .. } + | witx::Type::Array { .. } + | witx::Type::Union { .. } => false, + witx::Type::Enum { .. } + | witx::Type::Int { .. } + | witx::Type::Flags { .. } + | witx::Type::Handle { .. } => true, + }) +} + +fn define_copy_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) -> TokenStream { + let ident = names.type_(name); + let member_decls = s.members.iter().map(|m| { + let name = names.struct_member(&m.name); + let type_ = names.type_ref(&m.tref); + quote!(pub #name: #type_) + }); + let size = s.mem_size_align().size as u32; + let align = s.mem_size_align().align as u32; + + quote! { + #[repr(C)] + #[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)] + pub struct #ident { + #(#member_decls),* + } + + impl ::memory::GuestType for #ident { + fn size() -> u32 { + #size + } + fn align() -> u32 { + #align + } + fn name() -> String { + stringify!(#ident).to_owned() + } + fn validate(_ptr: &::memory::GuestPtr<#ident>) -> Result<(), ::memory::GuestError> { + Ok(()) // FIXME + } + } + impl ::memory::GuestTypeCopy for #ident {} + } +} + fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { match int_repr { witx::IntRepr::U8 => quote!(u8), diff --git a/src/lib.rs b/src/lib.rs index b907f0cb1e..b1898b77c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,6 +62,11 @@ pub mod test { println!("bat: {}", an_int); Ok((an_int as f32) * 2.0) } + + fn sum_of_pair(&mut self, an_pair: &types::PairInts) -> Result { + println!("sum of pair: {:?}", an_pair); + Ok(an_pair.first as i64 + an_pair.second as i64) + } } // Errno is used as a first return value in the functions above, therefore // it must implement GuestErrorType with type Context = WasiCtx. diff --git a/test.witx b/test.witx index 422d99bd72..af5af5441f 100644 --- a/test.witx +++ b/test.witx @@ -14,6 +14,11 @@ $traffic $sleeping)) +(typename $pair_ints + (struct + (field $first s32) + (field $second s32))) + (module $foo (@interface func (export "bar") (param $an_int u32) @@ -29,4 +34,9 @@ (param $an_int u32) (result $error $errno) (result $doubled_it f32)) + + (@interface func (export "sum_of_pair") + (param $an_pair $pair_ints) + (result $error $errno) + (result $doubled s64)) )