support enums with more than 256 variants in derive macro (#4370)
* support enums with more than 256 variants in derive macro This addresses #4361. Technically, we now support up to 2^32 variants, which is the maximum for the canonical ABI. In practice, though, the derived code for enums with even just 2^16 variants takes a prohibitively long time to compile. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * simplify `LowerExpander::expand_variant` code Signed-off-by: Joel Dice <joel.dice@fermyon.com>
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -450,6 +450,15 @@ dependencies = [
|
|||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "component-macro-test"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "console"
|
name = "console"
|
||||||
version = "0.15.0"
|
version = "0.15.0"
|
||||||
@@ -3413,6 +3422,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap 3.1.15",
|
"clap 3.1.15",
|
||||||
|
"component-macro-test",
|
||||||
"criterion",
|
"criterion",
|
||||||
"env_logger 0.9.0",
|
"env_logger 0.9.0",
|
||||||
"filecheck",
|
"filecheck",
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ async-trait = "0.1"
|
|||||||
wat = "1.0.43"
|
wat = "1.0.43"
|
||||||
once_cell = "1.9.0"
|
once_cell = "1.9.0"
|
||||||
rayon = "1.5.0"
|
rayon = "1.5.0"
|
||||||
|
component-macro-test = { path = "crates/misc/component-macro-test" }
|
||||||
|
|
||||||
[target.'cfg(windows)'.dev-dependencies]
|
[target.'cfg(windows)'.dev-dependencies]
|
||||||
windows-sys = { version = "0.36.0", features = ["Win32_System_Memory"] }
|
windows-sys = { version = "0.36.0", features = ["Win32_System_Memory"] }
|
||||||
|
|||||||
@@ -145,6 +145,54 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn
|
|||||||
generics
|
generics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
enum DiscriminantSize {
|
||||||
|
Size1,
|
||||||
|
Size2,
|
||||||
|
Size4,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiscriminantSize {
|
||||||
|
fn quote(self, discriminant: usize) -> TokenStream {
|
||||||
|
match self {
|
||||||
|
Self::Size1 => {
|
||||||
|
let discriminant = u8::try_from(discriminant).unwrap();
|
||||||
|
quote!(#discriminant)
|
||||||
|
}
|
||||||
|
Self::Size2 => {
|
||||||
|
let discriminant = u16::try_from(discriminant).unwrap();
|
||||||
|
quote!(#discriminant)
|
||||||
|
}
|
||||||
|
Self::Size4 => {
|
||||||
|
let discriminant = u32::try_from(discriminant).unwrap();
|
||||||
|
quote!(#discriminant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DiscriminantSize> for u32 {
|
||||||
|
fn from(size: DiscriminantSize) -> u32 {
|
||||||
|
match size {
|
||||||
|
DiscriminantSize::Size1 => 1,
|
||||||
|
DiscriminantSize::Size2 => 2,
|
||||||
|
DiscriminantSize::Size4 => 4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn discriminant_size(case_count: usize) -> Option<DiscriminantSize> {
|
||||||
|
if case_count <= 0xFF {
|
||||||
|
Some(DiscriminantSize::Size1)
|
||||||
|
} else if case_count <= 0xFFFF {
|
||||||
|
Some(DiscriminantSize::Size2)
|
||||||
|
} else if case_count <= 0xFFFF_FFFF {
|
||||||
|
Some(DiscriminantSize::Size4)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct VariantCase<'a> {
|
struct VariantCase<'a> {
|
||||||
attrs: &'a [syn::Attribute],
|
attrs: &'a [syn::Attribute],
|
||||||
ident: &'a syn::Ident,
|
ident: &'a syn::Ident,
|
||||||
@@ -157,6 +205,7 @@ trait Expander {
|
|||||||
fn expand_variant(
|
fn expand_variant(
|
||||||
&self,
|
&self,
|
||||||
input: &DeriveInput,
|
input: &DeriveInput,
|
||||||
|
discriminant_size: DiscriminantSize,
|
||||||
cases: &[VariantCase],
|
cases: &[VariantCase],
|
||||||
style: VariantStyle,
|
style: VariantStyle,
|
||||||
) -> Result<TokenStream>;
|
) -> Result<TokenStream>;
|
||||||
@@ -217,6 +266,13 @@ fn expand_variant(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| {
|
||||||
|
Error::new(
|
||||||
|
input.ident.span(),
|
||||||
|
"`enum`s with more than 2^32 variants are not supported",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let cases = body
|
let cases = body
|
||||||
.variants
|
.variants
|
||||||
.iter()
|
.iter()
|
||||||
@@ -240,8 +296,13 @@ fn expand_variant(
|
|||||||
name.span(),
|
name.span(),
|
||||||
format!(
|
format!(
|
||||||
"`{}` component types can only be derived for Rust `enum`s \
|
"`{}` component types can only be derived for Rust `enum`s \
|
||||||
containing variants with at most one unnamed field each",
|
containing variants with {}",
|
||||||
style
|
style,
|
||||||
|
match style {
|
||||||
|
VariantStyle::Variant => "at most one unnamed field each",
|
||||||
|
VariantStyle::Enum => "no fields",
|
||||||
|
VariantStyle::Union => "exactly one unnamed field each",
|
||||||
|
}
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -251,7 +312,7 @@ fn expand_variant(
|
|||||||
)
|
)
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
expander.expand_variant(input, &cases, style)
|
expander.expand_variant(input, discriminant_size, &cases, style)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[proc_macro_derive(Lift, attributes(component))]
|
#[proc_macro_derive(Lift, attributes(component))]
|
||||||
@@ -321,6 +382,7 @@ impl Expander for LiftExpander {
|
|||||||
fn expand_variant(
|
fn expand_variant(
|
||||||
&self,
|
&self,
|
||||||
input: &DeriveInput,
|
input: &DeriveInput,
|
||||||
|
discriminant_size: DiscriminantSize,
|
||||||
cases: &[VariantCase],
|
cases: &[VariantCase],
|
||||||
_style: VariantStyle,
|
_style: VariantStyle,
|
||||||
) -> Result<TokenStream> {
|
) -> Result<TokenStream> {
|
||||||
@@ -330,31 +392,26 @@ impl Expander for LiftExpander {
|
|||||||
let mut loads = TokenStream::new();
|
let mut loads = TokenStream::new();
|
||||||
|
|
||||||
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
||||||
let index_u8 = u8::try_from(index).map_err(|_| {
|
let index_u32 = u32::try_from(index).unwrap();
|
||||||
Error::new(
|
|
||||||
input.ident.span(),
|
|
||||||
"`enum`s with more than 256 variants not yet supported",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let index_i32 = index_u8 as i32;
|
let index_quoted = discriminant_size.quote(index);
|
||||||
|
|
||||||
if let Some(ty) = ty {
|
if let Some(ty) = ty {
|
||||||
lifts.extend(
|
lifts.extend(
|
||||||
quote!(#index_i32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
|
quote!(#index_u32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
|
||||||
store, options, unsafe { &src.payload.#ident }
|
store, options, unsafe { &src.payload.#ident }
|
||||||
)?),),
|
)?),),
|
||||||
);
|
);
|
||||||
|
|
||||||
loads.extend(
|
loads.extend(
|
||||||
quote!(#index_u8 => Self::#ident(<#ty as wasmtime::component::Lift>::load(
|
quote!(#index_quoted => Self::#ident(<#ty as wasmtime::component::Lift>::load(
|
||||||
memory, &payload[..<#ty as wasmtime::component::ComponentType>::SIZE32]
|
memory, &payload[..<#ty as wasmtime::component::ComponentType>::SIZE32]
|
||||||
)?),),
|
)?),),
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
lifts.extend(quote!(#index_i32 => Self::#ident,));
|
lifts.extend(quote!(#index_u32 => Self::#ident,));
|
||||||
|
|
||||||
loads.extend(quote!(#index_u8 => Self::#ident,));
|
loads.extend(quote!(#index_quoted => Self::#ident,));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,6 +419,14 @@ impl Expander for LiftExpander {
|
|||||||
let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift));
|
let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift));
|
||||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||||
|
|
||||||
|
let from_bytes = match discriminant_size {
|
||||||
|
DiscriminantSize::Size1 => quote!(bytes[0]),
|
||||||
|
DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),
|
||||||
|
DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload_offset = u32::from(discriminant_size) as usize;
|
||||||
|
|
||||||
let expanded = quote! {
|
let expanded = quote! {
|
||||||
unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause {
|
unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause {
|
||||||
#[inline]
|
#[inline]
|
||||||
@@ -370,7 +435,7 @@ impl Expander for LiftExpander {
|
|||||||
options: &#internal::Options,
|
options: &#internal::Options,
|
||||||
src: &Self::Lower,
|
src: &Self::Lower,
|
||||||
) -> #internal::anyhow::Result<Self> {
|
) -> #internal::anyhow::Result<Self> {
|
||||||
Ok(match src.tag.get_i32() {
|
Ok(match src.tag.get_u32() {
|
||||||
#lifts
|
#lifts
|
||||||
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
||||||
})
|
})
|
||||||
@@ -380,8 +445,8 @@ impl Expander for LiftExpander {
|
|||||||
fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result<Self> {
|
fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result<Self> {
|
||||||
let align = <Self as wasmtime::component::ComponentType>::ALIGN32;
|
let align = <Self as wasmtime::component::ComponentType>::ALIGN32;
|
||||||
debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
|
debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
|
||||||
let discrim = bytes[0];
|
let discrim = #from_bytes;
|
||||||
let payload = &bytes[#internal::align_to(1, align)..];
|
let payload = &bytes[#internal::align_to(#payload_offset, align)..];
|
||||||
Ok(match discrim {
|
Ok(match discrim {
|
||||||
#loads
|
#loads
|
||||||
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
||||||
@@ -456,6 +521,7 @@ impl Expander for LowerExpander {
|
|||||||
fn expand_variant(
|
fn expand_variant(
|
||||||
&self,
|
&self,
|
||||||
input: &DeriveInput,
|
input: &DeriveInput,
|
||||||
|
discriminant_size: DiscriminantSize,
|
||||||
cases: &[VariantCase],
|
cases: &[VariantCase],
|
||||||
_style: VariantStyle,
|
_style: VariantStyle,
|
||||||
) -> Result<TokenStream> {
|
) -> Result<TokenStream> {
|
||||||
@@ -465,14 +531,9 @@ impl Expander for LowerExpander {
|
|||||||
let mut stores = TokenStream::new();
|
let mut stores = TokenStream::new();
|
||||||
|
|
||||||
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
||||||
let index_u8 = u8::try_from(index).map_err(|_| {
|
let index_u32 = u32::try_from(index).unwrap();
|
||||||
Error::new(
|
|
||||||
input.ident.span(),
|
|
||||||
"`enum`s with more than 256 variants not yet supported",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let index_i32 = index_u8 as i32;
|
let index_quoted = discriminant_size.quote(index);
|
||||||
|
|
||||||
let pattern;
|
let pattern;
|
||||||
let lower;
|
let lower;
|
||||||
@@ -492,12 +553,14 @@ impl Expander for LowerExpander {
|
|||||||
}
|
}
|
||||||
|
|
||||||
lowers.extend(quote!(#pattern => {
|
lowers.extend(quote!(#pattern => {
|
||||||
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_i32));
|
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_u32 as i32));
|
||||||
#lower
|
#lower
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
let discriminant_size = u32::from(discriminant_size) as usize;
|
||||||
|
|
||||||
stores.extend(quote!(#pattern => {
|
stores.extend(quote!(#pattern => {
|
||||||
memory.get::<1>(offset)[0] = #index_u8;
|
*memory.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes();
|
||||||
#store
|
#store
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
@@ -668,6 +731,7 @@ impl Expander for ComponentTypeExpander {
|
|||||||
fn expand_variant(
|
fn expand_variant(
|
||||||
&self,
|
&self,
|
||||||
input: &DeriveInput,
|
input: &DeriveInput,
|
||||||
|
discriminant_size: DiscriminantSize,
|
||||||
cases: &[VariantCase],
|
cases: &[VariantCase],
|
||||||
style: VariantStyle,
|
style: VariantStyle,
|
||||||
) -> Result<TokenStream> {
|
) -> Result<TokenStream> {
|
||||||
@@ -766,6 +830,7 @@ impl Expander for ComponentTypeExpander {
|
|||||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||||
let lower = format_ident!("Lower{}", name);
|
let lower = format_ident!("Lower{}", name);
|
||||||
let lower_payload = format_ident!("LowerPayload{}", name);
|
let lower_payload = format_ident!("LowerPayload{}", name);
|
||||||
|
let discriminant_size = u32::from(discriminant_size);
|
||||||
|
|
||||||
// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union
|
// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union
|
||||||
// generic. This is to work around a [normalization bug in
|
// generic. This is to work around a [normalization bug in
|
||||||
@@ -806,11 +871,11 @@ impl Expander for ComponentTypeExpander {
|
|||||||
const SIZE32: usize = {
|
const SIZE32: usize = {
|
||||||
let mut size = 0;
|
let mut size = 0;
|
||||||
#sizes
|
#sizes
|
||||||
#internal::align_to(1, Self::ALIGN32) + size
|
#internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size
|
||||||
};
|
};
|
||||||
|
|
||||||
const ALIGN32: u32 = {
|
const ALIGN32: u32 = {
|
||||||
let mut align = 1;
|
let mut align = #discriminant_size;
|
||||||
#alignments
|
#alignments
|
||||||
align
|
align
|
||||||
};
|
};
|
||||||
|
|||||||
15
crates/misc/component-macro-test/Cargo.toml
Normal file
15
crates/misc/component-macro-test/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[package]
|
||||||
|
name = "component-macro-test"
|
||||||
|
authors = ["The Wasmtime Project Developers"]
|
||||||
|
license = "Apache-2.0 WITH LLVM-exception"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
proc-macro = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
proc-macro2 = "1.0"
|
||||||
|
quote = "1.0"
|
||||||
|
syn = { version = "1.0", features = ["full"] }
|
||||||
34
crates/misc/component-macro-test/src/lib.rs
Normal file
34
crates/misc/component-macro-test/src/lib.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
use proc_macro2::{Span, TokenStream};
|
||||||
|
use quote::quote;
|
||||||
|
use syn::parse_macro_input;
|
||||||
|
|
||||||
|
#[proc_macro_attribute]
|
||||||
|
pub fn add_variants(
|
||||||
|
attr: proc_macro::TokenStream,
|
||||||
|
item: proc_macro::TokenStream,
|
||||||
|
) -> proc_macro::TokenStream {
|
||||||
|
expand_variants(
|
||||||
|
&parse_macro_input!(attr as syn::LitInt),
|
||||||
|
parse_macro_input!(item as syn::ItemEnum),
|
||||||
|
)
|
||||||
|
.unwrap_or_else(syn::Error::into_compile_error)
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expand_variants(count: &syn::LitInt, mut ty: syn::ItemEnum) -> syn::Result<TokenStream> {
|
||||||
|
let count = count
|
||||||
|
.base10_digits()
|
||||||
|
.parse::<usize>()
|
||||||
|
.map_err(|_| syn::Error::new(count.span(), "expected unsigned integer"))?;
|
||||||
|
|
||||||
|
ty.variants = (0..count)
|
||||||
|
.map(|index| syn::Variant {
|
||||||
|
attrs: Vec::new(),
|
||||||
|
ident: syn::Ident::new(&format!("V{}", index), Span::call_site()),
|
||||||
|
fields: syn::Fields::Unit,
|
||||||
|
discriminant: None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(quote!(#ty))
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
use super::TypedFuncExt;
|
use super::TypedFuncExt;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use component_macro_test::add_variants;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use wasmtime::component::{Component, ComponentType, Lift, Linker, Lower};
|
use wasmtime::component::{Component, ComponentType, Lift, Linker, Lower};
|
||||||
use wasmtime::Store;
|
use wasmtime::Store;
|
||||||
@@ -475,5 +476,41 @@ fn enum_derive() -> Result<()> {
|
|||||||
.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")
|
.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")
|
||||||
.is_err());
|
.is_err());
|
||||||
|
|
||||||
|
#[add_variants(257)]
|
||||||
|
#[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
|
||||||
|
#[component(enum)]
|
||||||
|
enum Many {}
|
||||||
|
|
||||||
|
let component = Component::new(
|
||||||
|
&engine,
|
||||||
|
make_echo_component(
|
||||||
|
&format!(
|
||||||
|
r#"(type $Foo (enum {}))"#,
|
||||||
|
(0..257)
|
||||||
|
.map(|index| format!(r#""V{}""#, index))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" ")
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
)?;
|
||||||
|
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
|
||||||
|
let func = instance.get_typed_func::<(Many,), Many, _>(&mut store, "echo")?;
|
||||||
|
|
||||||
|
for &input in &[Many::V0, Many::V1, Many::V254, Many::V255, Many::V256] {
|
||||||
|
let output = func.call_and_post_return(&mut store, (input,))?;
|
||||||
|
|
||||||
|
assert_eq!(input, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: The following case takes forever (i.e. I gave up after 30 minutes) to compile; we'll need to profile
|
||||||
|
// the compiler to find out why, which may point the way to a more efficient option. On the other hand, this
|
||||||
|
// may not be worth spending time on. Enums with over 2^16 variants are rare enough.
|
||||||
|
|
||||||
|
// #[add_variants(65537)]
|
||||||
|
// #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
|
||||||
|
// #[component(enum)]
|
||||||
|
// enum ManyMore {}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user