From e31ff9dc67b6ea2cc2611307291e86fbaff548de Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Tue, 12 Jul 2022 17:47:58 -0600 Subject: [PATCH] implement wasmtime::component::flags! per #4308 (#4414) * implement wasmtime::component::flags! per #4308 This is the last macro needed to complete #4308. It supports generating a Rust type that represents a `flags` component type, analogous to how the [bitflags crate](https://crates.io/crates/bitflags) operates. Signed-off-by: Joel Dice * wrap `format_flags` output in parens This ensures we generate non-empty output even when no flags are set. Empty output for a `Debug` implementation would be confusing. Signed-off-by: Joel Dice * unconditionally derive `Lift` and `Lower` in wasmtime::component::flags! Per feedback on #4414, we now derive impls for those traits unconditionally, which simplifies the syntax of the macro. Also, I happened to notice an alignment bug in `LowerExpander::expand_variant`, so I fixed that and cleaned up some related code. Finally, I used @jameysharp's trick to calculate bit masks without looping. Signed-off-by: Joel Dice * fix shift overflow regression in previous commit Jamey pointed out my mistake: I didn't consider the case when the flag count was evenly divisible by the representation size. This fixes the problem and adds test cases to cover it. Signed-off-by: Joel Dice --- crates/component-macro/src/lib.rs | 680 ++++++++++++++++---- crates/misc/component-macro-test/src/lib.rs | 48 +- crates/wasmtime/src/component/func/typed.rs | 50 ++ crates/wasmtime/src/component/mod.rs | 6 +- tests/all/component_model/macros.rs | 397 +++++++++++- 5 files changed, 1042 insertions(+), 139 deletions(-) diff --git a/crates/component-macro/src/lib.rs b/crates/component-macro/src/lib.rs index 878c53f681..e3f290922d 100644 --- a/crates/component-macro/src/lib.rs +++ b/crates/component-macro/src/lib.rs @@ -2,7 +2,9 @@ use proc_macro2::{Literal, TokenStream, TokenTree}; use quote::{format_ident, quote}; use std::collections::HashSet; use std::fmt; -use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Error, Result}; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{braced, parse_macro_input, parse_quote, Data, DeriveInput, Error, Result, Token}; #[derive(Debug, Copy, Clone)] enum VariantStyle { @@ -181,6 +183,16 @@ impl From for u32 { } } +impl From for usize { + fn from(size: DiscriminantSize) -> usize { + match size { + DiscriminantSize::Size1 => 1, + DiscriminantSize::Size2 => 2, + DiscriminantSize::Size4 => 4, + } + } +} + fn discriminant_size(case_count: usize) -> Option { if case_count <= 0xFF { Some(DiscriminantSize::Size1) @@ -200,11 +212,17 @@ struct VariantCase<'a> { } trait Expander { - fn expand_record(&self, input: &DeriveInput, fields: &syn::FieldsNamed) -> Result; + fn expand_record( + &self, + name: &syn::Ident, + generics: &syn::Generics, + fields: &[&syn::Field], + ) -> Result; fn expand_variant( &self, - input: &DeriveInput, + name: &syn::Ident, + generics: &syn::Generics, discriminant_size: DiscriminantSize, cases: &[VariantCase], style: VariantStyle, @@ -231,7 +249,11 @@ fn expand_record(expander: &dyn Expander, input: &DeriveInput) -> Result expander.expand_record(input, fields), + syn::Fields::Named(fields) => expander.expand_record( + &input.ident, + &input.generics, + &fields.named.iter().collect::>(), + ), syn::Fields::Unnamed(_) | syn::Fields::Unit => Err(Error::new( name.span(), @@ -312,7 +334,109 @@ fn expand_variant( ) .collect::>>()?; - expander.expand_variant(input, discriminant_size, &cases, style) + expander.expand_variant( + &input.ident, + &input.generics, + discriminant_size, + &cases, + style, + ) +} + +fn expand_record_for_component_type( + name: &syn::Ident, + generics: &syn::Generics, + fields: &[&syn::Field], + typecheck: TokenStream, + typecheck_argument: TokenStream, +) -> Result { + let internal = quote!(wasmtime::component::__internal); + + let mut lower_generic_params = TokenStream::new(); + let mut lower_generic_args = TokenStream::new(); + let mut lower_field_declarations = TokenStream::new(); + let mut sizes = TokenStream::new(); + let mut unique_types = HashSet::new(); + + for (index, syn::Field { ident, ty, .. }) in fields.iter().enumerate() { + let generic = format_ident!("T{}", index); + + lower_generic_params.extend(quote!(#generic: Copy,)); + lower_generic_args.extend(quote!(<#ty as wasmtime::component::ComponentType>::Lower,)); + + lower_field_declarations.extend(quote!(#ident: #generic,)); + + sizes.extend(quote!( + size = #internal::align_to(size, <#ty as wasmtime::component::ComponentType>::ALIGN32); + size += <#ty as wasmtime::component::ComponentType>::SIZE32; + )); + + unique_types.insert(ty); + } + + let alignments = unique_types + .into_iter() + .map(|ty| { + let align = quote!(<#ty as wasmtime::component::ComponentType>::ALIGN32); + quote!(if #align > align { + align = #align; + }) + }) + .collect::(); + + let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::ComponentType)); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let lower = format_ident!("Lower{}", name); + + // You may wonder why we make the types of all the fields of the #lower struct generic. This is to work + // around the lack of [perfect derive support in + // rustc](https://smallcultfollowing.com/babysteps//blog/2022/04/12/implied-bounds-and-perfect-derive/#what-is-perfect-derive) + // as of this writing. + // + // If the struct we're deriving a `ComponentType` impl for has any generic parameters, then #lower needs + // generic parameters too. And if we just copy the parameters and bounds from the impl to #lower, then the + // `#[derive(Clone, Copy)]` will fail unless the original generics were declared with those bounds, which + // we don't want to require. + // + // Alternatively, we could just pass the `Lower` associated type of each generic type as arguments to + // #lower, but that would require distinguishing between generic and concrete types when generating + // #lower_field_declarations, which would require some form of symbol resolution. That doesn't seem worth + // the trouble. + + let expanded = quote! { + #[doc(hidden)] + #[derive(Clone, Copy)] + #[repr(C)] + pub struct #lower <#lower_generic_params> { + #lower_field_declarations + } + + unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause { + type Lower = #lower <#lower_generic_args>; + + const SIZE32: usize = { + let mut size = 0; + #sizes + size + }; + + const ALIGN32: u32 = { + let mut align = 1; + #alignments + align + }; + + #[inline] + fn typecheck( + ty: &#internal::InterfaceType, + types: &#internal::ComponentTypes, + ) -> #internal::anyhow::Result<()> { + #internal::#typecheck(ty, types, &[#typecheck_argument]) + } + } + }; + + Ok(quote!(const _: () = { #expanded };)) } #[proc_macro_derive(Lift, attributes(component))] @@ -325,13 +449,18 @@ pub fn lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream { struct LiftExpander; impl Expander for LiftExpander { - fn expand_record(&self, input: &DeriveInput, fields: &syn::FieldsNamed) -> Result { + fn expand_record( + &self, + name: &syn::Ident, + generics: &syn::Generics, + fields: &[&syn::Field], + ) -> Result { let internal = quote!(wasmtime::component::__internal); let mut lifts = TokenStream::new(); let mut loads = TokenStream::new(); - for syn::Field { ident, ty, .. } in &fields.named { + for syn::Field { ident, ty, .. } in fields { lifts.extend(quote!(#ident: <#ty as wasmtime::component::Lift>::lift( store, options, &src.#ident )?,)); @@ -344,8 +473,7 @@ impl Expander for LiftExpander { )?,)); } - let name = &input.ident; - let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift)); + let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lift)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let expanded = quote! { @@ -381,7 +509,8 @@ impl Expander for LiftExpander { fn expand_variant( &self, - input: &DeriveInput, + name: &syn::Ident, + generics: &syn::Generics, discriminant_size: DiscriminantSize, cases: &[VariantCase], _style: VariantStyle, @@ -415,8 +544,7 @@ impl Expander for LiftExpander { } } - let name = &input.ident; - let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift)); + let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lift)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let from_bytes = match discriminant_size { @@ -425,7 +553,7 @@ impl Expander for LiftExpander { DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)), }; - let payload_offset = u32::from(discriminant_size) as usize; + let payload_offset = usize::from(discriminant_size); let expanded = quote! { unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause { @@ -469,13 +597,18 @@ pub fn lower(input: proc_macro::TokenStream) -> proc_macro::TokenStream { struct LowerExpander; impl Expander for LowerExpander { - fn expand_record(&self, input: &DeriveInput, fields: &syn::FieldsNamed) -> Result { + fn expand_record( + &self, + name: &syn::Ident, + generics: &syn::Generics, + fields: &[&syn::Field], + ) -> Result { let internal = quote!(wasmtime::component::__internal); let mut lowers = TokenStream::new(); let mut stores = TokenStream::new(); - for syn::Field { ident, ty, .. } in &fields.named { + for syn::Field { ident, ty, .. } in fields { lowers.extend(quote!(wasmtime::component::Lower::lower( &self.#ident, store, options, #internal::map_maybe_uninit!(dst.#ident) )?;)); @@ -485,8 +618,7 @@ impl Expander for LowerExpander { )?;)); } - let name = &input.ident; - let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lower)); + let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lower)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let expanded = quote! { @@ -520,7 +652,8 @@ impl Expander for LowerExpander { fn expand_variant( &self, - input: &DeriveInput, + name: &syn::Ident, + generics: &syn::Generics, discriminant_size: DiscriminantSize, cases: &[VariantCase], _style: VariantStyle, @@ -535,6 +668,8 @@ impl Expander for LowerExpander { let index_quoted = discriminant_size.quote(index); + let discriminant_size = usize::from(discriminant_size); + let pattern; let lower; let store; @@ -544,7 +679,10 @@ impl Expander for LowerExpander { lower = quote!(value.lower(store, options, #internal::map_maybe_uninit!(dst.payload.#ident))); store = quote!(value.store( memory, - offset + #internal::align_to(1, ::ALIGN32) + offset + #internal::align_to( + #discriminant_size, + ::ALIGN32 + ) )); } else { pattern = quote!(Self::#ident); @@ -557,16 +695,13 @@ impl Expander for LowerExpander { #lower })); - let discriminant_size = u32::from(discriminant_size) as usize; - stores.extend(quote!(#pattern => { *memory.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes(); #store })); } - let name = &input.ident; - let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lower)); + let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lower)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let expanded = quote! { @@ -621,116 +756,38 @@ pub fn component_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream struct ComponentTypeExpander; impl Expander for ComponentTypeExpander { - fn expand_record(&self, input: &DeriveInput, fields: &syn::FieldsNamed) -> Result { - let internal = quote!(wasmtime::component::__internal); + fn expand_record( + &self, + name: &syn::Ident, + generics: &syn::Generics, + fields: &[&syn::Field], + ) -> Result { + expand_record_for_component_type( + name, + generics, + fields, + quote!(typecheck_record), + fields + .iter() + .map( + |syn::Field { + attrs, ident, ty, .. + }| { + let name = find_rename(attrs)?.unwrap_or_else(|| { + Literal::string(&ident.as_ref().unwrap().to_string()) + }); - let mut field_names_and_checks = TokenStream::new(); - let mut lower_generic_params = TokenStream::new(); - let mut lower_generic_args = TokenStream::new(); - let mut lower_field_declarations = TokenStream::new(); - let mut sizes = TokenStream::new(); - let mut unique_types = HashSet::new(); - - for ( - index, - syn::Field { - attrs, ident, ty, .. - }, - ) in fields.named.iter().enumerate() - { - let name = find_rename(attrs)? - .unwrap_or_else(|| Literal::string(&ident.as_ref().unwrap().to_string())); - - let generic = format_ident!("T{}", index); - - lower_generic_params.extend(quote!(#generic: Copy,)); - lower_generic_args.extend(quote!(<#ty as wasmtime::component::ComponentType>::Lower,)); - - lower_field_declarations.extend(quote!(#ident: #generic,)); - - field_names_and_checks - .extend(quote!((#name, <#ty as wasmtime::component::ComponentType>::typecheck),)); - - sizes.extend(quote!( - size = #internal::align_to(size, <#ty as wasmtime::component::ComponentType>::ALIGN32); - size += <#ty as wasmtime::component::ComponentType>::SIZE32; - )); - - unique_types.insert(ty); - } - - let alignments = unique_types - .into_iter() - .map(|ty| { - let align = quote!(<#ty as wasmtime::component::ComponentType>::ALIGN32); - quote!(if #align > align { - align = #align; - }) - }) - .collect::(); - - let name = &input.ident; - let generics = add_trait_bounds( - &input.generics, - parse_quote!(wasmtime::component::ComponentType), - ); - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - let lower = format_ident!("Lower{}", name); - - // You may wonder why we make the types of all the fields of the #lower struct generic. This is to work - // around the lack of [perfect derive support in - // rustc](https://smallcultfollowing.com/babysteps//blog/2022/04/12/implied-bounds-and-perfect-derive/#what-is-perfect-derive) - // as of this writing. - // - // If the struct we're deriving a `ComponentType` impl for has any generic parameters, then #lower needs - // generic parameters too. And if we just copy the parameters and bounds from the impl to #lower, then the - // `#[derive(Clone, Copy)]` will fail unless the original generics were declared with those bounds, which - // we don't want to require. - // - // Alternatively, we could just pass the `Lower` associated type of each generic type as arguments to - // #lower, but that would require distinguishing between generic and concrete types when generating - // #lower_field_declarations, which would require some form of symbol resolution. That doesn't seem worth - // the trouble. - - let expanded = quote! { - #[doc(hidden)] - #[derive(Clone, Copy)] - #[repr(C)] - pub struct #lower <#lower_generic_params> { - #lower_field_declarations - } - - unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause { - type Lower = #lower <#lower_generic_args>; - - const SIZE32: usize = { - let mut size = 0; - #sizes - size - }; - - const ALIGN32: u32 = { - let mut align = 1; - #alignments - align - }; - - #[inline] - fn typecheck( - ty: &#internal::InterfaceType, - types: &#internal::ComponentTypes, - ) -> #internal::anyhow::Result<()> { - #internal::typecheck_record(ty, types, &[#field_names_and_checks]) - } - } - }; - - Ok(quote!(const _: () = { #expanded };)) + Ok(quote!((#name, <#ty as wasmtime::component::ComponentType>::typecheck),)) + }, + ) + .collect::>()?, + ) } fn expand_variant( &self, - input: &DeriveInput, + name: &syn::Ident, + generics: &syn::Generics, discriminant_size: DiscriminantSize, cases: &[VariantCase], style: VariantStyle, @@ -822,11 +879,7 @@ impl Expander for ComponentTypeExpander { VariantStyle::Enum => quote!(typecheck_enum), }; - let name = &input.ident; - let generics = add_trait_bounds( - &input.generics, - parse_quote!(wasmtime::component::ComponentType), - ); + let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::ComponentType)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let lower = format_ident!("Lower{}", name); let lower_payload = format_ident!("LowerPayload{}", name); @@ -885,3 +938,364 @@ impl Expander for ComponentTypeExpander { Ok(quote!(const _: () = { #expanded };)) } } + +#[derive(Debug)] +struct Flag { + rename: Option, + name: String, +} + +impl Parse for Flag { + fn parse(input: ParseStream) -> Result { + let attributes = syn::Attribute::parse_outer(input)?; + + let rename = find_rename(&attributes)? + .map(|literal| { + let s = literal.to_string(); + + s.strip_prefix('"') + .and_then(|s| s.strip_suffix('"')) + .map(|s| s.to_owned()) + .ok_or_else(|| Error::new(literal.span(), "expected string literal")) + }) + .transpose()?; + + input.parse::()?; + let name = input.parse::()?.to_string(); + + Ok(Self { rename, name }) + } +} + +#[derive(Debug)] +struct Flags { + name: String, + flags: Vec, +} + +impl Parse for Flags { + fn parse(input: ParseStream) -> Result { + let name = input.parse::()?.to_string(); + + let content; + braced!(content in input); + + let flags = content + .parse_terminated::<_, Token![;]>(Flag::parse)? + .into_iter() + .collect(); + + Ok(Self { name, flags }) + } +} + +enum FlagsSize { + /// Flags can fit in a u8 + Size1, + /// Flags can fit in a u16 + Size2, + /// Flags can fit in a specified number of u32 fields + Size4Plus(usize), +} + +fn ceiling_divide(n: usize, d: usize) -> usize { + (n + d - 1) / d +} + +#[proc_macro] +pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + expand_flags(&parse_macro_input!(input as Flags)) + .unwrap_or_else(Error::into_compile_error) + .into() +} + +fn expand_flags(flags: &Flags) -> Result { + let size = if flags.flags.len() <= 8 { + FlagsSize::Size1 + } else if flags.flags.len() <= 16 { + FlagsSize::Size2 + } else { + FlagsSize::Size4Plus(ceiling_divide(flags.flags.len(), 32)) + }; + + let ty; + let eq; + + let count = flags.flags.len(); + + match size { + FlagsSize::Size1 => { + ty = quote!(u8); + + eq = if count == 8 { + quote!(self.__inner0.eq(&rhs.__inner0)) + } else { + let mask = !(0xFF_u8 << count); + + quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask))) + }; + } + FlagsSize::Size2 => { + ty = quote!(u16); + + eq = if count == 16 { + quote!(self.__inner0.eq(&rhs.__inner0)) + } else { + let mask = !(0xFFFF_u16 << count); + + quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask))) + }; + } + FlagsSize::Size4Plus(n) => { + ty = quote!(u32); + + let comparisons = (0..(n - 1)) + .map(|index| { + let field = format_ident!("__inner{}", index); + + quote!(self.#field.eq(&rhs.#field) &&) + }) + .collect::(); + + let field = format_ident!("__inner{}", n - 1); + + eq = if count % 32 == 0 { + quote!(#comparisons self.#field.eq(&rhs.#field)) + } else { + let mask = !(0xFFFF_FFFF_u32 << (count % 32)); + + quote!(#comparisons (self.#field & #mask).eq(&(rhs.#field & #mask))) + } + } + } + + let count; + let mut as_array; + let mut bitor; + let mut bitor_assign; + let mut bitand; + let mut bitand_assign; + let mut bitxor; + let mut bitxor_assign; + let mut not; + + match size { + FlagsSize::Size1 | FlagsSize::Size2 => { + count = 1; + as_array = quote!([self.__inner0 as u32]); + bitor = quote!(Self { + __inner0: self.__inner0.bitor(rhs.__inner0) + }); + bitor_assign = quote!(self.__inner0.bitor_assign(rhs.__inner0)); + bitand = quote!(Self { + __inner0: self.__inner0.bitand(rhs.__inner0) + }); + bitand_assign = quote!(self.__inner0.bitand_assign(rhs.__inner0)); + bitxor = quote!(Self { + __inner0: self.__inner0.bitxor(rhs.__inner0) + }); + bitxor_assign = quote!(self.__inner0.bitxor_assign(rhs.__inner0)); + not = quote!(Self { + __inner0: self.__inner0.not() + }); + } + FlagsSize::Size4Plus(n) => { + count = n; + as_array = TokenStream::new(); + bitor = TokenStream::new(); + bitor_assign = TokenStream::new(); + bitand = TokenStream::new(); + bitand_assign = TokenStream::new(); + bitxor = TokenStream::new(); + bitxor_assign = TokenStream::new(); + not = TokenStream::new(); + + for index in 0..n { + let field = format_ident!("__inner{}", index); + + as_array.extend(quote!(self.#field,)); + bitor.extend(quote!(#field: self.#field.bitor(rhs.#field),)); + bitor_assign.extend(quote!(self.#field.bitor_assign(rhs.#field);)); + bitand.extend(quote!(#field: self.#field.bitand(rhs.#field),)); + bitand_assign.extend(quote!(self.#field.bitand_assign(rhs.#field);)); + bitxor.extend(quote!(#field: self.#field.bitxor(rhs.#field),)); + bitxor_assign.extend(quote!(self.#field.bitxor_assign(rhs.#field);)); + not.extend(quote!(#field: self.#field.not(),)); + } + + as_array = quote!([#as_array]); + bitor = quote!(Self { #bitor }); + bitand = quote!(Self { #bitand }); + bitxor = quote!(Self { #bitxor }); + not = quote!(Self { #not }); + } + }; + + let name = format_ident!("{}", flags.name); + + let mut constants = TokenStream::new(); + let mut rust_names = TokenStream::new(); + let mut component_names = TokenStream::new(); + + for (index, Flag { name, rename }) in flags.flags.iter().enumerate() { + rust_names.extend(quote!(#name,)); + + let component_name = rename.as_ref().unwrap_or(name); + component_names.extend(quote!(#component_name,)); + + let fields = match size { + FlagsSize::Size1 => { + let init = 1_u8 << index; + quote!(__inner0: #init) + } + FlagsSize::Size2 => { + let init = 1_u16 << index; + quote!(__inner0: #init) + } + FlagsSize::Size4Plus(n) => (0..n) + .map(|i| { + let field = format_ident!("__inner{}", i); + + let init = if index / 32 == i { + 1_u32 << (index % 32) + } else { + 0 + }; + + quote!(#field: #init,) + }) + .collect::(), + }; + + let name = format_ident!("{}", name); + + constants.extend(quote!(const #name: Self = Self { #fields };)); + } + + let generics = syn::Generics { + lt_token: None, + params: Punctuated::new(), + gt_token: None, + where_clause: None, + }; + + let fields = { + let ty = syn::parse2::(ty.clone())?; + + (0..count) + .map(|index| syn::Field { + attrs: Vec::new(), + vis: syn::Visibility::Inherited, + ident: Some(format_ident!("__inner{}", index)), + colon_token: None, + ty: ty.clone(), + }) + .collect::>() + }; + + let fields = fields.iter().collect::>(); + + let component_type_impl = expand_record_for_component_type( + &name, + &generics, + &fields, + quote!(typecheck_flags), + component_names, + )?; + + let lower_impl = LowerExpander.expand_record(&name, &generics, &fields)?; + + let lift_impl = LiftExpander.expand_record(&name, &generics, &fields)?; + + let internal = quote!(wasmtime::component::__internal); + + let fields = fields + .iter() + .map(|syn::Field { ident, .. }| quote!(#[doc(hidden)] #ident: #ty,)) + .collect::(); + + let expanded = quote! { + #[derive(Copy, Clone, Default)] + struct #name { #fields } + + impl #name { + #constants + + fn as_array(&self) -> [u32; #count] { + #as_array + } + } + + impl std::cmp::PartialEq for #name { + fn eq(&self, rhs: &#name) -> bool { + #eq + } + } + + impl std::cmp::Eq for #name { } + + impl std::fmt::Debug for #name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + #internal::format_flags(&self.as_array(), &[#rust_names], f) + } + } + + impl std::ops::BitOr for #name { + type Output = #name; + + fn bitor(self, rhs: #name) -> #name { + #bitor + } + } + + impl std::ops::BitOrAssign for #name { + fn bitor_assign(&mut self, rhs: #name) { + #bitor_assign + } + } + + impl std::ops::BitAnd for #name { + type Output = #name; + + fn bitand(self, rhs: #name) -> #name { + #bitand + } + } + + impl std::ops::BitAndAssign for #name { + fn bitand_assign(&mut self, rhs: #name) { + #bitand_assign + } + } + + impl std::ops::BitXor for #name { + type Output = #name; + + fn bitxor(self, rhs: #name) -> #name { + #bitxor + } + } + + impl std::ops::BitXorAssign for #name { + fn bitxor_assign(&mut self, rhs: #name) { + #bitxor_assign + } + } + + impl std::ops::Not for #name { + type Output = #name; + + fn not(self) -> #name { + #not + } + } + + #component_type_impl + + #lower_impl + + #lift_impl + }; + + Ok(expanded) +} diff --git a/crates/misc/component-macro-test/src/lib.rs b/crates/misc/component-macro-test/src/lib.rs index 59a62de385..26fc265465 100644 --- a/crates/misc/component-macro-test/src/lib.rs +++ b/crates/misc/component-macro-test/src/lib.rs @@ -1,6 +1,7 @@ use proc_macro2::{Span, TokenStream}; -use quote::quote; -use syn::parse_macro_input; +use quote::{format_ident, quote}; +use syn::parse::{Parse, ParseStream}; +use syn::{parse_macro_input, Error, Result, Token}; #[proc_macro_attribute] pub fn add_variants( @@ -32,3 +33,46 @@ fn expand_variants(count: &syn::LitInt, mut ty: syn::ItemEnum) -> syn::Result Result { + let name = input.parse::()?.to_string(); + input.parse::()?; + let flag_count = input.parse::()?.base10_parse()?; + + Ok(Self { name, flag_count }) + } +} + +#[proc_macro] +pub fn flags_test(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + expand_flags_test(&parse_macro_input!(input as FlagsTest)) + .unwrap_or_else(Error::into_compile_error) + .into() +} + +fn expand_flags_test(test: &FlagsTest) -> Result { + let name = format_ident!("{}", test.name); + let flags = (0..test.flag_count) + .map(|index| { + let name = format_ident!("F{}", index); + quote!(const #name;) + }) + .collect::(); + + let expanded = quote! { + wasmtime::component::flags! { + #name { + #flags + } + } + }; + + Ok(expanded) +} diff --git a/crates/wasmtime/src/component/func/typed.rs b/crates/wasmtime/src/component/func/typed.rs index 3dac62039d..4e79d69f19 100644 --- a/crates/wasmtime/src/component/func/typed.rs +++ b/crates/wasmtime/src/component/func/typed.rs @@ -5,6 +5,7 @@ use crate::store::StoreOpaque; use crate::{AsContext, AsContextMut, StoreContext, StoreContextMut, ValRaw}; use anyhow::{bail, Context, Result}; use std::borrow::Cow; +use std::fmt; use std::marker; use std::mem::{self, MaybeUninit}; use std::str; @@ -1581,6 +1582,55 @@ pub fn typecheck_union( } } +/// Verify that the given wasm type is a flags type with the expected flags in the right order and with the right +/// names. +pub fn typecheck_flags( + ty: &InterfaceType, + types: &ComponentTypes, + expected: &[&str], +) -> Result<()> { + match ty { + InterfaceType::Flags(index) => { + let names = &types[*index].names; + + if names.len() != expected.len() { + bail!( + "expected flags type with {} names, found {} names", + expected.len(), + names.len() + ); + } + + for (name, expected) in names.iter().zip(expected) { + if name != expected { + bail!("expected flag named {}, found {}", expected, name); + } + } + + Ok(()) + } + other => bail!("expected `flags` found `{}`", desc(other)), + } +} + +/// Format the specified bitflags using the specified names for debugging +pub fn format_flags(bits: &[u32], names: &[&str], f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("(")?; + let mut wrote = false; + for (index, name) in names.iter().enumerate() { + if ((bits[index / 32] >> (index % 32)) & 1) != 0 { + if wrote { + f.write_str("|")?; + } else { + wrote = true; + } + + f.write_str(name)?; + } + } + f.write_str(")") +} + unsafe impl ComponentType for Option where T: ComponentType, diff --git a/crates/wasmtime/src/component/mod.rs b/crates/wasmtime/src/component/mod.rs index a96d3f1d90..9f85b65ee0 100644 --- a/crates/wasmtime/src/component/mod.rs +++ b/crates/wasmtime/src/component/mod.rs @@ -16,7 +16,7 @@ pub use self::func::{ }; pub use self::instance::{ExportInstance, Exports, Instance, InstancePre}; pub use self::linker::{Linker, LinkerInstance}; -pub use wasmtime_component_macro::{ComponentType, Lift, Lower}; +pub use wasmtime_component_macro::{flags, ComponentType, Lift, Lower}; // These items are expected to be used by an eventual // `#[derive(ComponentType)]`, they are not part of Wasmtime's API stability @@ -24,8 +24,8 @@ pub use wasmtime_component_macro::{ComponentType, Lift, Lower}; #[doc(hidden)] pub mod __internal { pub use super::func::{ - align_to, next_field, typecheck_enum, typecheck_record, typecheck_union, typecheck_variant, - MaybeUninitExt, Memory, MemoryMut, Options, + align_to, format_flags, next_field, typecheck_enum, typecheck_flags, typecheck_record, + typecheck_union, typecheck_variant, MaybeUninitExt, Memory, MemoryMut, Options, }; pub use crate::map_maybe_uninit; pub use crate::store::StoreOpaque; diff --git a/tests/all/component_model/macros.rs b/tests/all/component_model/macros.rs index ca519d2f42..73b18f2a6a 100644 --- a/tests/all/component_model/macros.rs +++ b/tests/all/component_model/macros.rs @@ -1,6 +1,6 @@ use super::TypedFuncExt; use anyhow::Result; -use component_macro_test::add_variants; +use component_macro_test::{add_variants, flags_test}; use std::fmt::Write; use wasmtime::component::{Component, ComponentType, Lift, Linker, Lower}; use wasmtime::Store; @@ -476,6 +476,8 @@ fn enum_derive() -> Result<()> { .get_typed_func::<(Foo,), Foo, _>(&mut store, "echo") .is_err()); + // Happy path redux, with large enums (i.e. more than 2^8 cases) + #[add_variants(257)] #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)] #[component(enum)] @@ -514,3 +516,396 @@ fn enum_derive() -> Result<()> { Ok(()) } + +#[test] +fn flags() -> Result<()> { + wasmtime::component::flags! { + Foo { + #[component(name = "foo-bar-baz")] + const A; + const B; + const C; + } + } + + assert_eq!(Foo::default(), (Foo::A | Foo::B) & Foo::C); + assert_eq!(Foo::B, (Foo::A | Foo::B) & Foo::B); + assert_eq!(Foo::A, (Foo::A | Foo::B) & Foo::A); + assert_eq!(Foo::A | Foo::B, Foo::A ^ Foo::B); + assert_eq!(Foo::default(), Foo::A ^ Foo::A); + assert_eq!(Foo::B | Foo::C, !Foo::A); + + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + // Happy path: component type matches flag count and names + + let component = Component::new( + &engine, + make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B" "C"))"#, 4), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")?; + + for n in 0..8 { + let mut input = Foo::default(); + if (n & 1) != 0 { + input |= Foo::A; + } + if (n & 2) != 0 { + input |= Foo::B; + } + if (n & 4) != 0 { + input |= Foo::C; + } + + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // Sad path: flag count mismatch (too few) + + let component = Component::new( + &engine, + make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B"))"#, 4), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + + assert!(instance + .get_typed_func::<(Foo,), Foo, _>(&mut store, "echo") + .is_err()); + + // Sad path: flag count mismatch (too many) + + let component = Component::new( + &engine, + make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B" "C" "D"))"#, 4), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + + assert!(instance + .get_typed_func::<(Foo,), Foo, _>(&mut store, "echo") + .is_err()); + + // Sad path: flag name mismatch + + let component = Component::new( + &engine, + make_echo_component(r#"(type $Foo (flags "A" "B" "C"))"#, 4), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + + assert!(instance + .get_typed_func::<(Foo,), Foo, _>(&mut store, "echo") + .is_err()); + + // Happy path redux, with large flag count (exactly 8) + + flags_test!(Foo8Exact, 8); + + assert_eq!( + Foo8Exact::default(), + (Foo8Exact::F0 | Foo8Exact::F6) & Foo8Exact::F7 + ); + assert_eq!( + Foo8Exact::F6, + (Foo8Exact::F0 | Foo8Exact::F6) & Foo8Exact::F6 + ); + assert_eq!( + Foo8Exact::F0, + (Foo8Exact::F0 | Foo8Exact::F6) & Foo8Exact::F0 + ); + assert_eq!(Foo8Exact::F0 | Foo8Exact::F6, Foo8Exact::F0 ^ Foo8Exact::F6); + assert_eq!(Foo8Exact::default(), Foo8Exact::F0 ^ Foo8Exact::F0); + assert_eq!( + Foo8Exact::F1 + | Foo8Exact::F2 + | Foo8Exact::F3 + | Foo8Exact::F4 + | Foo8Exact::F5 + | Foo8Exact::F6 + | Foo8Exact::F7, + !Foo8Exact::F0 + ); + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (flags {}))"#, + (0..8) + .map(|index| format!(r#""F{}""#, index)) + .collect::>() + .join(" ") + ), + 4, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo8Exact,), Foo8Exact, _>(&mut store, "echo")?; + + for &input in &[ + Foo8Exact::F0, + Foo8Exact::F1, + Foo8Exact::F5, + Foo8Exact::F6, + Foo8Exact::F7, + ] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // Happy path redux, with large flag count (more than 8) + + flags_test!(Foo16, 9); + + assert_eq!(Foo16::default(), (Foo16::F0 | Foo16::F7) & Foo16::F8); + assert_eq!(Foo16::F7, (Foo16::F0 | Foo16::F7) & Foo16::F7); + assert_eq!(Foo16::F0, (Foo16::F0 | Foo16::F7) & Foo16::F0); + assert_eq!(Foo16::F0 | Foo16::F7, Foo16::F0 ^ Foo16::F7); + assert_eq!(Foo16::default(), Foo16::F0 ^ Foo16::F0); + assert_eq!( + Foo16::F1 + | Foo16::F2 + | Foo16::F3 + | Foo16::F4 + | Foo16::F5 + | Foo16::F6 + | Foo16::F7 + | Foo16::F8, + !Foo16::F0 + ); + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (flags {}))"#, + (0..9) + .map(|index| format!(r#""F{}""#, index)) + .collect::>() + .join(" ") + ), + 4, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo16,), Foo16, _>(&mut store, "echo")?; + + for &input in &[Foo16::F0, Foo16::F1, Foo16::F6, Foo16::F7, Foo16::F8] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // Happy path redux, with large flag count (exactly 16) + + flags_test!(Foo16Exact, 16); + + assert_eq!( + Foo16Exact::default(), + (Foo16Exact::F0 | Foo16Exact::F14) & Foo16Exact::F5 + ); + assert_eq!( + Foo16Exact::F14, + (Foo16Exact::F0 | Foo16Exact::F14) & Foo16Exact::F14 + ); + assert_eq!( + Foo16Exact::F0, + (Foo16Exact::F0 | Foo16Exact::F14) & Foo16Exact::F0 + ); + assert_eq!( + Foo16Exact::F0 | Foo16Exact::F14, + Foo16Exact::F0 ^ Foo16Exact::F14 + ); + assert_eq!(Foo16Exact::default(), Foo16Exact::F0 ^ Foo16Exact::F0); + assert_eq!( + Foo16Exact::F0 | Foo16Exact::F15, + !((!Foo16Exact::F0) & (!Foo16Exact::F15)) + ); + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (flags {}))"#, + (0..16) + .map(|index| format!(r#""F{}""#, index)) + .collect::>() + .join(" ") + ), + 4, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo16Exact,), Foo16Exact, _>(&mut store, "echo")?; + + for &input in &[ + Foo16Exact::F0, + Foo16Exact::F1, + Foo16Exact::F13, + Foo16Exact::F14, + Foo16Exact::F15, + ] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // Happy path redux, with large flag count (more than 16) + + flags_test!(Foo32, 17); + + assert_eq!(Foo32::default(), (Foo32::F0 | Foo32::F15) & Foo32::F16); + assert_eq!(Foo32::F15, (Foo32::F0 | Foo32::F15) & Foo32::F15); + assert_eq!(Foo32::F0, (Foo32::F0 | Foo32::F15) & Foo32::F0); + assert_eq!(Foo32::F0 | Foo32::F15, Foo32::F0 ^ Foo32::F15); + assert_eq!(Foo32::default(), Foo32::F0 ^ Foo32::F0); + assert_eq!(Foo32::F0 | Foo32::F16, !((!Foo32::F0) & (!Foo32::F16))); + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (flags {}))"#, + (0..17) + .map(|index| format!(r#""F{}""#, index)) + .collect::>() + .join(" ") + ), + 4, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo32,), Foo32, _>(&mut store, "echo")?; + + for &input in &[Foo32::F0, Foo32::F1, Foo32::F14, Foo32::F15, Foo32::F16] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // Happy path redux, with large flag count (exactly 32) + + flags_test!(Foo32Exact, 32); + + assert_eq!( + Foo32Exact::default(), + (Foo32Exact::F0 | Foo32Exact::F30) & Foo32Exact::F31 + ); + assert_eq!( + Foo32Exact::F30, + (Foo32Exact::F0 | Foo32Exact::F30) & Foo32Exact::F30 + ); + assert_eq!( + Foo32Exact::F0, + (Foo32Exact::F0 | Foo32Exact::F30) & Foo32Exact::F0 + ); + assert_eq!( + Foo32Exact::F0 | Foo32Exact::F30, + Foo32Exact::F0 ^ Foo32Exact::F30 + ); + assert_eq!(Foo32Exact::default(), Foo32Exact::F0 ^ Foo32Exact::F0); + assert_eq!( + Foo32Exact::F0 | Foo32Exact::F15, + !((!Foo32Exact::F0) & (!Foo32Exact::F15)) + ); + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (flags {}))"#, + (0..32) + .map(|index| format!(r#""F{}""#, index)) + .collect::>() + .join(" ") + ), + 4, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo32Exact,), Foo32Exact, _>(&mut store, "echo")?; + + for &input in &[ + Foo32Exact::F0, + Foo32Exact::F1, + Foo32Exact::F29, + Foo32Exact::F30, + Foo32Exact::F31, + ] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // Happy path redux, with large flag count (more than 32) + + flags_test!(Foo64, 33); + + assert_eq!(Foo64::default(), (Foo64::F0 | Foo64::F31) & Foo64::F32); + assert_eq!(Foo64::F31, (Foo64::F0 | Foo64::F31) & Foo64::F31); + assert_eq!(Foo64::F0, (Foo64::F0 | Foo64::F31) & Foo64::F0); + assert_eq!(Foo64::F0 | Foo64::F31, Foo64::F0 ^ Foo64::F31); + assert_eq!(Foo64::default(), Foo64::F0 ^ Foo64::F0); + assert_eq!(Foo64::F0 | Foo64::F32, !((!Foo64::F0) & (!Foo64::F32))); + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (flags {}))"#, + (0..33) + .map(|index| format!(r#""F{}""#, index)) + .collect::>() + .join(" ") + ), + 8, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo64,), Foo64, _>(&mut store, "echo")?; + + for &input in &[Foo64::F0, Foo64::F1, Foo64::F30, Foo64::F31, Foo64::F32] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // Happy path redux, with large flag count (more than 64) + + flags_test!(Foo96, 65); + + assert_eq!(Foo96::default(), (Foo96::F0 | Foo96::F63) & Foo96::F64); + assert_eq!(Foo96::F63, (Foo96::F0 | Foo96::F63) & Foo96::F63); + assert_eq!(Foo96::F0, (Foo96::F0 | Foo96::F63) & Foo96::F0); + assert_eq!(Foo96::F0 | Foo96::F63, Foo96::F0 ^ Foo96::F63); + assert_eq!(Foo96::default(), Foo96::F0 ^ Foo96::F0); + assert_eq!(Foo96::F0 | Foo96::F64, !((!Foo96::F0) & (!Foo96::F64))); + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (flags {}))"#, + (0..65) + .map(|index| format!(r#""F{}""#, index)) + .collect::>() + .join(" ") + ), + 12, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Foo96,), Foo96, _>(&mut store, "echo")?; + + for &input in &[Foo96::F0, Foo96::F1, Foo96::F62, Foo96::F63, Foo96::F64] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + Ok(()) +}