support enums with more than 256 variants in derive macro (#4370)

* support enums with more than 256 variants in derive macro

This addresses #4361.  Technically, we now support up to 2^32 variants, which is
the maximum for the canonical ABI.  In practice, though, the derived code for
enums with even just 2^16 variants takes a prohibitively long time to compile.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* simplify `LowerExpander::expand_variant` code

Signed-off-by: Joel Dice <joel.dice@fermyon.com>
This commit is contained in:
Joel Dice
2022-07-05 09:36:43 -06:00
committed by GitHub
parent 7320db98d1
commit 5542c4ef26
6 changed files with 190 additions and 28 deletions

View File

@@ -145,6 +145,54 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn
generics
}
#[derive(Debug, Copy, Clone)]
enum DiscriminantSize {
Size1,
Size2,
Size4,
}
impl DiscriminantSize {
fn quote(self, discriminant: usize) -> TokenStream {
match self {
Self::Size1 => {
let discriminant = u8::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size2 => {
let discriminant = u16::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size4 => {
let discriminant = u32::try_from(discriminant).unwrap();
quote!(#discriminant)
}
}
}
}
impl From<DiscriminantSize> for u32 {
fn from(size: DiscriminantSize) -> u32 {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
fn discriminant_size(case_count: usize) -> Option<DiscriminantSize> {
if case_count <= 0xFF {
Some(DiscriminantSize::Size1)
} else if case_count <= 0xFFFF {
Some(DiscriminantSize::Size2)
} else if case_count <= 0xFFFF_FFFF {
Some(DiscriminantSize::Size4)
} else {
None
}
}
struct VariantCase<'a> {
attrs: &'a [syn::Attribute],
ident: &'a syn::Ident,
@@ -157,6 +205,7 @@ trait Expander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
style: VariantStyle,
) -> Result<TokenStream>;
@@ -217,6 +266,13 @@ fn expand_variant(
));
}
let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| {
Error::new(
input.ident.span(),
"`enum`s with more than 2^32 variants are not supported",
)
})?;
let cases = body
.variants
.iter()
@@ -240,8 +296,13 @@ fn expand_variant(
name.span(),
format!(
"`{}` component types can only be derived for Rust `enum`s \
containing variants with at most one unnamed field each",
style
containing variants with {}",
style,
match style {
VariantStyle::Variant => "at most one unnamed field each",
VariantStyle::Enum => "no fields",
VariantStyle::Union => "exactly one unnamed field each",
}
),
))
}
@@ -251,7 +312,7 @@ fn expand_variant(
)
.collect::<Result<Vec<_>>>()?;
expander.expand_variant(input, &cases, style)
expander.expand_variant(input, discriminant_size, &cases, style)
}
#[proc_macro_derive(Lift, attributes(component))]
@@ -321,6 +382,7 @@ impl Expander for LiftExpander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
_style: VariantStyle,
) -> Result<TokenStream> {
@@ -330,31 +392,26 @@ impl Expander for LiftExpander {
let mut loads = TokenStream::new();
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u8 = u8::try_from(index).map_err(|_| {
Error::new(
input.ident.span(),
"`enum`s with more than 256 variants not yet supported",
)
})?;
let index_u32 = u32::try_from(index).unwrap();
let index_i32 = index_u8 as i32;
let index_quoted = discriminant_size.quote(index);
if let Some(ty) = ty {
lifts.extend(
quote!(#index_i32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
quote!(#index_u32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
store, options, unsafe { &src.payload.#ident }
)?),),
);
loads.extend(
quote!(#index_u8 => Self::#ident(<#ty as wasmtime::component::Lift>::load(
quote!(#index_quoted => Self::#ident(<#ty as wasmtime::component::Lift>::load(
memory, &payload[..<#ty as wasmtime::component::ComponentType>::SIZE32]
)?),),
);
} else {
lifts.extend(quote!(#index_i32 => Self::#ident,));
lifts.extend(quote!(#index_u32 => Self::#ident,));
loads.extend(quote!(#index_u8 => Self::#ident,));
loads.extend(quote!(#index_quoted => Self::#ident,));
}
}
@@ -362,6 +419,14 @@ impl Expander for LiftExpander {
let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift));
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let from_bytes = match discriminant_size {
DiscriminantSize::Size1 => quote!(bytes[0]),
DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),
DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),
};
let payload_offset = u32::from(discriminant_size) as usize;
let expanded = quote! {
unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause {
#[inline]
@@ -370,7 +435,7 @@ impl Expander for LiftExpander {
options: &#internal::Options,
src: &Self::Lower,
) -> #internal::anyhow::Result<Self> {
Ok(match src.tag.get_i32() {
Ok(match src.tag.get_u32() {
#lifts
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
})
@@ -380,8 +445,8 @@ impl Expander for LiftExpander {
fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result<Self> {
let align = <Self as wasmtime::component::ComponentType>::ALIGN32;
debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
let discrim = bytes[0];
let payload = &bytes[#internal::align_to(1, align)..];
let discrim = #from_bytes;
let payload = &bytes[#internal::align_to(#payload_offset, align)..];
Ok(match discrim {
#loads
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
@@ -456,6 +521,7 @@ impl Expander for LowerExpander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
_style: VariantStyle,
) -> Result<TokenStream> {
@@ -465,14 +531,9 @@ impl Expander for LowerExpander {
let mut stores = TokenStream::new();
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u8 = u8::try_from(index).map_err(|_| {
Error::new(
input.ident.span(),
"`enum`s with more than 256 variants not yet supported",
)
})?;
let index_u32 = u32::try_from(index).unwrap();
let index_i32 = index_u8 as i32;
let index_quoted = discriminant_size.quote(index);
let pattern;
let lower;
@@ -492,12 +553,14 @@ impl Expander for LowerExpander {
}
lowers.extend(quote!(#pattern => {
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_i32));
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_u32 as i32));
#lower
}));
let discriminant_size = u32::from(discriminant_size) as usize;
stores.extend(quote!(#pattern => {
memory.get::<1>(offset)[0] = #index_u8;
*memory.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes();
#store
}));
}
@@ -668,6 +731,7 @@ impl Expander for ComponentTypeExpander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
style: VariantStyle,
) -> Result<TokenStream> {
@@ -766,6 +830,7 @@ impl Expander for ComponentTypeExpander {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let lower = format_ident!("Lower{}", name);
let lower_payload = format_ident!("LowerPayload{}", name);
let discriminant_size = u32::from(discriminant_size);
// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union
// generic. This is to work around a [normalization bug in
@@ -806,11 +871,11 @@ impl Expander for ComponentTypeExpander {
const SIZE32: usize = {
let mut size = 0;
#sizes
#internal::align_to(1, Self::ALIGN32) + size
#internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size
};
const ALIGN32: u32 = {
let mut align = 1;
let mut align = #discriminant_size;
#alignments
align
};

View File

@@ -0,0 +1,15 @@
[package]
name = "component-macro-test"
authors = ["The Wasmtime Project Developers"]
license = "Apache-2.0 WITH LLVM-exception"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0", features = ["full"] }

View File

@@ -0,0 +1,34 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::parse_macro_input;
#[proc_macro_attribute]
pub fn add_variants(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
expand_variants(
&parse_macro_input!(attr as syn::LitInt),
parse_macro_input!(item as syn::ItemEnum),
)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn expand_variants(count: &syn::LitInt, mut ty: syn::ItemEnum) -> syn::Result<TokenStream> {
let count = count
.base10_digits()
.parse::<usize>()
.map_err(|_| syn::Error::new(count.span(), "expected unsigned integer"))?;
ty.variants = (0..count)
.map(|index| syn::Variant {
attrs: Vec::new(),
ident: syn::Ident::new(&format!("V{}", index), Span::call_site()),
fields: syn::Fields::Unit,
discriminant: None,
})
.collect();
Ok(quote!(#ty))
}