diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index 57a9e63da7..861de4c258 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -154,6 +154,7 @@ fn marshal_arg( match &*tref.type_() { witx::Type::Enum(_e) => try_into_conversion, witx::Type::Flags(_f) => try_into_conversion, + witx::Type::Int(_i) => try_into_conversion, witx::Type::Builtin(b) => match b { witx::BuiltinType::U8 | witx::BuiltinType::U16 | witx::BuiltinType::Char8 => { try_into_conversion @@ -359,7 +360,7 @@ where | witx::BuiltinType::Char8 => write_val_to_ptr, witx::BuiltinType::String => unimplemented!("string types"), }, - witx::Type::Enum(_) | witx::Type::Flags(_) => write_val_to_ptr, + witx::Type::Enum(_) | witx::Type::Flags(_) | witx::Type::Int(_) => write_val_to_ptr, _ => unimplemented!("marshal result"), } } diff --git a/crates/generate/src/names.rs b/crates/generate/src/names.rs index 3337a18619..c967f6e1da 100644 --- a/crates/generate/src/names.rs +++ b/crates/generate/src/names.rs @@ -86,6 +86,10 @@ impl Names { format_ident!("{}", id.as_str().to_shouty_snake_case()) } + pub fn int_member(&self, id: &Id) -> Ident { + format_ident!("{}", id.as_str().to_shouty_snake_case()) + } + pub fn struct_member(&self, id: &Id) -> Ident { format_ident!("{}", id.as_str().to_snake_case()) } diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs index 61e21b71a5..d2f94353da 100644 --- a/crates/generate/src/types.rs +++ b/crates/generate/src/types.rs @@ -10,7 +10,7 @@ pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStrea witx::TypeRef::Name(alias_to) => define_alias(names, &namedtype.name, &alias_to), witx::TypeRef::Value(v) => match &**v { witx::Type::Enum(e) => define_enum(names, &namedtype.name, &e), - witx::Type::Int(_) => unimplemented!("int types"), + witx::Type::Int(i) => define_int(names, &namedtype.name, &i), witx::Type::Flags(f) => define_flags(names, &namedtype.name, &f), witx::Type::Struct(s) => { if struct_is_copy(s) { @@ -46,6 +46,92 @@ fn define_alias(names: &Names, name: &witx::Id, to: &witx::NamedType) -> TokenSt } } +fn define_int(names: &Names, name: &witx::Id, i: &witx::IntDatatype) -> TokenStream { + let ident = names.type_(&name); + let repr = int_repr_tokens(i.repr); + let abi_repr = atom_token(match i.repr { + witx::IntRepr::U8 | witx::IntRepr::U16 | witx::IntRepr::U32 => witx::AtomType::I32, + witx::IntRepr::U64 => witx::AtomType::I64, + }); + let consts = i + .consts + .iter() + .map(|r#const| { + let const_ident = names.int_member(&r#const.name); + let value = r#const.value; + quote!(pub const #const_ident: #ident = #ident(#value)) + }) + .collect::>(); + + quote! { + #[repr(transparent)] + #[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)] + pub struct #ident(#repr); + + impl #ident { + #(#consts;)* + } + + impl ::std::convert::TryFrom<#repr> for #ident { + type Error = wiggle_runtime::GuestError; + fn try_from(value: #repr) -> Result { + Ok(#ident(value)) + } + } + + impl ::std::convert::TryFrom<#abi_repr> for #ident { + type Error = wiggle_runtime::GuestError; + fn try_from(value: #abi_repr) -> Result<#ident, wiggle_runtime::GuestError> { + #ident::try_from(value as #repr) + } + } + + impl From<#ident> for #repr { + fn from(e: #ident) -> #repr { + e.0 + } + } + + impl From<#ident> for #abi_repr { + fn from(e: #ident) -> #abi_repr { + #repr::from(e) as #abi_repr + } + } + + impl wiggle_runtime::GuestType for #ident { + fn size() -> u32 { + ::std::mem::size_of::<#repr>() as u32 + } + + fn align() -> u32 { + ::std::mem::align_of::<#repr>() as u32 + } + + fn name() -> String { + stringify!(#ident).to_owned() + } + + fn validate<'a>(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<(), wiggle_runtime::GuestError> { + use ::std::convert::TryFrom; + let raw: #repr = unsafe { (location.as_raw() as *const #repr).read() }; + let _ = #ident::try_from(raw)?; + Ok(()) + } + } + + impl wiggle_runtime::GuestTypeCopy for #ident {} + impl<'a> wiggle_runtime::GuestTypeClone<'a> for #ident { + fn read_from_guest(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> { + Ok(*location.as_ref()?) + } + fn write_to_guest(&self, location: &wiggle_runtime::GuestPtrMut<#ident>) { + let val: #repr = #repr::from(*self); + unsafe { (location.as_raw() as *mut #repr).write(val) }; + } + } + } +} + fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDatatype) -> TokenStream { let ident = names.type_(&name); let repr = int_repr_tokens(f.repr); diff --git a/tests/main.rs b/tests/main.rs index a64395ffa3..422173b7af 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -143,6 +143,15 @@ impl foo::Foo for WasiCtx { println!("a_string='{}'", as_str); Ok(as_str.len() as u32) } + + fn cookie_cutter(&mut self, init_cookie: types::Cookie) -> Result { + let res = if init_cookie == types::Cookie::START { + types::Bool::True + } else { + types::Bool::False + }; + Ok(res) + } } // Errno is used as a first return value in the functions above, therefore // it must implement GuestErrorType with type Context = WasiCtx. @@ -958,3 +967,62 @@ proptest! { e.test() } } + +fn cookie_strat() -> impl Strategy { + (0..std::u64::MAX) + .prop_map(|x| types::Cookie::try_from(x).expect("within range of cookie")) + .boxed() +} + +#[derive(Debug)] +struct CookieCutterExercise { + cookie: types::Cookie, + return_ptr_loc: MemArea, +} + +impl CookieCutterExercise { + pub fn strat() -> BoxedStrategy { + (cookie_strat(), HostMemory::mem_area_strat(4)) + .prop_map(|(cookie, return_ptr_loc)| Self { + cookie, + return_ptr_loc, + }) + .boxed() + } + + pub fn test(&self) { + let mut ctx = WasiCtx::new(); + let mut host_memory = HostMemory::new(); + let mut guest_memory = GuestMemory::new(host_memory.as_mut_ptr(), host_memory.len() as u32); + + let res = foo::cookie_cutter( + &mut ctx, + &mut guest_memory, + self.cookie.into(), + self.return_ptr_loc.ptr as i32, + ); + assert_eq!(res, types::Errno::Ok.into(), "cookie cutter errno"); + + let is_cookie_start = *guest_memory + .ptr::(self.return_ptr_loc.ptr) + .expect("ptr to returned Bool") + .as_ref() + .expect("deref to Bool value"); + + assert_eq!( + if is_cookie_start == types::Bool::True { + true + } else { + false + }, + self.cookie == types::Cookie::START, + "returned Bool should test if input was Cookie::START", + ); + } +} +proptest! { + #[test] + fn cookie_cutter(e in CookieCutterExercise::strat()) { + e.test() + } +} diff --git a/tests/test.witx b/tests/test.witx index a6edac306f..4a0930a616 100644 --- a/tests/test.witx +++ b/tests/test.witx @@ -18,6 +18,15 @@ $awd $suv)) +(typename $cookie + (int u64 + (const $start 0))) + +(typename $bool + (enum u8 + $false + $true)) + (typename $pair_ints (struct (field $first s32) @@ -77,4 +86,9 @@ (result $error $errno) (result $total_bytes u32) ) + (@interface func (export "cookie_cutter") + (param $init_cookie $cookie) + (result $error $errno) + (result $is_start $bool) + ) )