support enums with more than 256 variants in derive macro (#4370)
* support enums with more than 256 variants in derive macro This addresses #4361. Technically, we now support up to 2^32 variants, which is the maximum for the canonical ABI. In practice, though, the derived code for enums with even just 2^16 variants takes a prohibitively long time to compile. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * simplify `LowerExpander::expand_variant` code Signed-off-by: Joel Dice <joel.dice@fermyon.com>
This commit is contained in:
@@ -145,6 +145,54 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn
|
||||
generics
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
enum DiscriminantSize {
|
||||
Size1,
|
||||
Size2,
|
||||
Size4,
|
||||
}
|
||||
|
||||
impl DiscriminantSize {
|
||||
fn quote(self, discriminant: usize) -> TokenStream {
|
||||
match self {
|
||||
Self::Size1 => {
|
||||
let discriminant = u8::try_from(discriminant).unwrap();
|
||||
quote!(#discriminant)
|
||||
}
|
||||
Self::Size2 => {
|
||||
let discriminant = u16::try_from(discriminant).unwrap();
|
||||
quote!(#discriminant)
|
||||
}
|
||||
Self::Size4 => {
|
||||
let discriminant = u32::try_from(discriminant).unwrap();
|
||||
quote!(#discriminant)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DiscriminantSize> for u32 {
|
||||
fn from(size: DiscriminantSize) -> u32 {
|
||||
match size {
|
||||
DiscriminantSize::Size1 => 1,
|
||||
DiscriminantSize::Size2 => 2,
|
||||
DiscriminantSize::Size4 => 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn discriminant_size(case_count: usize) -> Option<DiscriminantSize> {
|
||||
if case_count <= 0xFF {
|
||||
Some(DiscriminantSize::Size1)
|
||||
} else if case_count <= 0xFFFF {
|
||||
Some(DiscriminantSize::Size2)
|
||||
} else if case_count <= 0xFFFF_FFFF {
|
||||
Some(DiscriminantSize::Size4)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
struct VariantCase<'a> {
|
||||
attrs: &'a [syn::Attribute],
|
||||
ident: &'a syn::Ident,
|
||||
@@ -157,6 +205,7 @@ trait Expander {
|
||||
fn expand_variant(
|
||||
&self,
|
||||
input: &DeriveInput,
|
||||
discriminant_size: DiscriminantSize,
|
||||
cases: &[VariantCase],
|
||||
style: VariantStyle,
|
||||
) -> Result<TokenStream>;
|
||||
@@ -217,6 +266,13 @@ fn expand_variant(
|
||||
));
|
||||
}
|
||||
|
||||
let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| {
|
||||
Error::new(
|
||||
input.ident.span(),
|
||||
"`enum`s with more than 2^32 variants are not supported",
|
||||
)
|
||||
})?;
|
||||
|
||||
let cases = body
|
||||
.variants
|
||||
.iter()
|
||||
@@ -240,8 +296,13 @@ fn expand_variant(
|
||||
name.span(),
|
||||
format!(
|
||||
"`{}` component types can only be derived for Rust `enum`s \
|
||||
containing variants with at most one unnamed field each",
|
||||
style
|
||||
containing variants with {}",
|
||||
style,
|
||||
match style {
|
||||
VariantStyle::Variant => "at most one unnamed field each",
|
||||
VariantStyle::Enum => "no fields",
|
||||
VariantStyle::Union => "exactly one unnamed field each",
|
||||
}
|
||||
),
|
||||
))
|
||||
}
|
||||
@@ -251,7 +312,7 @@ fn expand_variant(
|
||||
)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
expander.expand_variant(input, &cases, style)
|
||||
expander.expand_variant(input, discriminant_size, &cases, style)
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Lift, attributes(component))]
|
||||
@@ -321,6 +382,7 @@ impl Expander for LiftExpander {
|
||||
fn expand_variant(
|
||||
&self,
|
||||
input: &DeriveInput,
|
||||
discriminant_size: DiscriminantSize,
|
||||
cases: &[VariantCase],
|
||||
_style: VariantStyle,
|
||||
) -> Result<TokenStream> {
|
||||
@@ -330,31 +392,26 @@ impl Expander for LiftExpander {
|
||||
let mut loads = TokenStream::new();
|
||||
|
||||
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
||||
let index_u8 = u8::try_from(index).map_err(|_| {
|
||||
Error::new(
|
||||
input.ident.span(),
|
||||
"`enum`s with more than 256 variants not yet supported",
|
||||
)
|
||||
})?;
|
||||
let index_u32 = u32::try_from(index).unwrap();
|
||||
|
||||
let index_i32 = index_u8 as i32;
|
||||
let index_quoted = discriminant_size.quote(index);
|
||||
|
||||
if let Some(ty) = ty {
|
||||
lifts.extend(
|
||||
quote!(#index_i32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
|
||||
quote!(#index_u32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
|
||||
store, options, unsafe { &src.payload.#ident }
|
||||
)?),),
|
||||
);
|
||||
|
||||
loads.extend(
|
||||
quote!(#index_u8 => Self::#ident(<#ty as wasmtime::component::Lift>::load(
|
||||
quote!(#index_quoted => Self::#ident(<#ty as wasmtime::component::Lift>::load(
|
||||
memory, &payload[..<#ty as wasmtime::component::ComponentType>::SIZE32]
|
||||
)?),),
|
||||
);
|
||||
} else {
|
||||
lifts.extend(quote!(#index_i32 => Self::#ident,));
|
||||
lifts.extend(quote!(#index_u32 => Self::#ident,));
|
||||
|
||||
loads.extend(quote!(#index_u8 => Self::#ident,));
|
||||
loads.extend(quote!(#index_quoted => Self::#ident,));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,6 +419,14 @@ impl Expander for LiftExpander {
|
||||
let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift));
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
let from_bytes = match discriminant_size {
|
||||
DiscriminantSize::Size1 => quote!(bytes[0]),
|
||||
DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),
|
||||
DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),
|
||||
};
|
||||
|
||||
let payload_offset = u32::from(discriminant_size) as usize;
|
||||
|
||||
let expanded = quote! {
|
||||
unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause {
|
||||
#[inline]
|
||||
@@ -370,7 +435,7 @@ impl Expander for LiftExpander {
|
||||
options: &#internal::Options,
|
||||
src: &Self::Lower,
|
||||
) -> #internal::anyhow::Result<Self> {
|
||||
Ok(match src.tag.get_i32() {
|
||||
Ok(match src.tag.get_u32() {
|
||||
#lifts
|
||||
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
||||
})
|
||||
@@ -380,8 +445,8 @@ impl Expander for LiftExpander {
|
||||
fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result<Self> {
|
||||
let align = <Self as wasmtime::component::ComponentType>::ALIGN32;
|
||||
debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
|
||||
let discrim = bytes[0];
|
||||
let payload = &bytes[#internal::align_to(1, align)..];
|
||||
let discrim = #from_bytes;
|
||||
let payload = &bytes[#internal::align_to(#payload_offset, align)..];
|
||||
Ok(match discrim {
|
||||
#loads
|
||||
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
||||
@@ -456,6 +521,7 @@ impl Expander for LowerExpander {
|
||||
fn expand_variant(
|
||||
&self,
|
||||
input: &DeriveInput,
|
||||
discriminant_size: DiscriminantSize,
|
||||
cases: &[VariantCase],
|
||||
_style: VariantStyle,
|
||||
) -> Result<TokenStream> {
|
||||
@@ -465,14 +531,9 @@ impl Expander for LowerExpander {
|
||||
let mut stores = TokenStream::new();
|
||||
|
||||
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
||||
let index_u8 = u8::try_from(index).map_err(|_| {
|
||||
Error::new(
|
||||
input.ident.span(),
|
||||
"`enum`s with more than 256 variants not yet supported",
|
||||
)
|
||||
})?;
|
||||
let index_u32 = u32::try_from(index).unwrap();
|
||||
|
||||
let index_i32 = index_u8 as i32;
|
||||
let index_quoted = discriminant_size.quote(index);
|
||||
|
||||
let pattern;
|
||||
let lower;
|
||||
@@ -492,12 +553,14 @@ impl Expander for LowerExpander {
|
||||
}
|
||||
|
||||
lowers.extend(quote!(#pattern => {
|
||||
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_i32));
|
||||
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_u32 as i32));
|
||||
#lower
|
||||
}));
|
||||
|
||||
let discriminant_size = u32::from(discriminant_size) as usize;
|
||||
|
||||
stores.extend(quote!(#pattern => {
|
||||
memory.get::<1>(offset)[0] = #index_u8;
|
||||
*memory.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes();
|
||||
#store
|
||||
}));
|
||||
}
|
||||
@@ -668,6 +731,7 @@ impl Expander for ComponentTypeExpander {
|
||||
fn expand_variant(
|
||||
&self,
|
||||
input: &DeriveInput,
|
||||
discriminant_size: DiscriminantSize,
|
||||
cases: &[VariantCase],
|
||||
style: VariantStyle,
|
||||
) -> Result<TokenStream> {
|
||||
@@ -766,6 +830,7 @@ impl Expander for ComponentTypeExpander {
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
let lower = format_ident!("Lower{}", name);
|
||||
let lower_payload = format_ident!("LowerPayload{}", name);
|
||||
let discriminant_size = u32::from(discriminant_size);
|
||||
|
||||
// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union
|
||||
// generic. This is to work around a [normalization bug in
|
||||
@@ -806,11 +871,11 @@ impl Expander for ComponentTypeExpander {
|
||||
const SIZE32: usize = {
|
||||
let mut size = 0;
|
||||
#sizes
|
||||
#internal::align_to(1, Self::ALIGN32) + size
|
||||
#internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size
|
||||
};
|
||||
|
||||
const ALIGN32: u32 = {
|
||||
let mut align = 1;
|
||||
let mut align = #discriminant_size;
|
||||
#alignments
|
||||
align
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user