* support dynamic function calls in component model This addresses #4310, introducing a new `component::values::Val` type for representing component values dynamically, as well as `component::types::Type` for representing the corresponding interface types. It also adds a `call` method to `component::func::Func`, which takes a slice of `Val`s as parameters and returns a `Result<Val>` representing the result. Note that I've moved `post_return` and `call_raw` from `TypedFunc` to `Func` since there was nothing specific to `TypedFunc` about them, and I wanted to reuse them. The code in both is unchanged beyond the trivial tweaks to make them fit in their new home. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * order variants and match cases more consistently Signed-off-by: Joel Dice <joel.dice@fermyon.com> * implement lift for String, Box<str>, etc. This also removes the redundant `store` parameter from `Type::load`. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * implement code review feedback This fixes a few issues: - Bad offset calculation when lowering - Missing variant padding - Style issues regarding `types::Handle` - Missed opportunities to reuse `Lift` and `Lower` impls It also adds forwarding `Lift` impls for `Box<[T]>`, `Vec<T>`, etc. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * move `new_*` methods to specific `types` structs Per review feedback, I've moved `Type::new_record` to `Record::new_val` and added a `Type::unwrap_record` method; likewise for the other kinds of types. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * make tuple, option, and expected type comparisons recursive These types should compare as equal across component boundaries as long as their type parameters are equal. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * improve error diagnostic in `Type::check` We now distinguish between more failure cases to provide an informative error message. Signed-off-by: Joel Dice <joel.dice@fermyon.com> * address review feedback - Remove `WasmStr::to_str_from_memory` and `WasmList::get_from_memory` - add `try_new` methods to various `values` types - avoid using `ExactSizeIterator::len` where we can't trust it - fix over-constrained bounds on forwarded `ComponentType` impls Signed-off-by: Joel Dice <joel.dice@fermyon.com> * rearrange code per review feedback - Move functions from `types` to `values` module so we can make certain struct fields private - Rename `try_new` to just `new` Signed-off-by: Joel Dice <joel.dice@fermyon.com> * remove special-case equality test for tuples, options, and expecteds Instead, I've added a FIXME comment and will open an issue to do recursive structural equality testing. Signed-off-by: Joel Dice <joel.dice@fermyon.com>
1243 lines
40 KiB
Rust
1243 lines
40 KiB
Rust
use proc_macro2::{Literal, TokenStream, TokenTree};
|
|
use quote::{format_ident, quote};
|
|
use std::collections::HashSet;
|
|
use std::fmt;
|
|
use syn::parse::{Parse, ParseStream};
|
|
use syn::punctuated::Punctuated;
|
|
use syn::{braced, parse_macro_input, parse_quote, Data, DeriveInput, Error, Result, Token};
|
|
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
|
|
|
|
#[derive(Debug, Copy, Clone)]
|
|
enum VariantStyle {
|
|
Variant,
|
|
Enum,
|
|
Union,
|
|
}
|
|
|
|
impl fmt::Display for VariantStyle {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.write_str(match self {
|
|
Self::Variant => "variant",
|
|
Self::Enum => "enum",
|
|
Self::Union => "union",
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Copy, Clone)]
|
|
enum Style {
|
|
Record,
|
|
Variant(VariantStyle),
|
|
}
|
|
|
|
fn find_style(input: &DeriveInput) -> Result<Style> {
|
|
let mut style = None;
|
|
|
|
for attribute in &input.attrs {
|
|
if attribute.path.leading_colon.is_some() || attribute.path.segments.len() != 1 {
|
|
continue;
|
|
}
|
|
|
|
let ident = &attribute.path.segments[0].ident;
|
|
|
|
if "component" != &ident.to_string() {
|
|
continue;
|
|
}
|
|
|
|
let syntax_error = || {
|
|
Err(Error::new_spanned(
|
|
&attribute.tokens,
|
|
"expected `component(<style>)` syntax",
|
|
))
|
|
};
|
|
|
|
let style_string = if let [TokenTree::Group(group)] =
|
|
&attribute.tokens.clone().into_iter().collect::<Vec<_>>()[..]
|
|
{
|
|
if let [TokenTree::Ident(style)] = &group.stream().into_iter().collect::<Vec<_>>()[..] {
|
|
style.to_string()
|
|
} else {
|
|
return syntax_error();
|
|
}
|
|
} else {
|
|
return syntax_error();
|
|
};
|
|
|
|
if style.is_some() {
|
|
return Err(Error::new(ident.span(), "duplicate `component` attribute"));
|
|
}
|
|
|
|
style = Some(match style_string.as_ref() {
|
|
"record" => Style::Record,
|
|
"variant" => Style::Variant(VariantStyle::Variant),
|
|
"enum" => Style::Variant(VariantStyle::Enum),
|
|
"union" => Style::Variant(VariantStyle::Union),
|
|
"flags" => {
|
|
return Err(Error::new_spanned(
|
|
&attribute.tokens,
|
|
"`flags` not allowed here; \
|
|
use `wasmtime::component::flags!` macro to define `flags` types",
|
|
))
|
|
}
|
|
_ => {
|
|
return Err(Error::new_spanned(
|
|
&attribute.tokens,
|
|
"unrecognized component type keyword \
|
|
(expected `record`, `variant`, `enum`, or `union`)",
|
|
))
|
|
}
|
|
});
|
|
}
|
|
|
|
style.ok_or_else(|| Error::new_spanned(input, "missing `component` attribute"))
|
|
}
|
|
|
|
fn find_rename(attributes: &[syn::Attribute]) -> Result<Option<Literal>> {
|
|
let mut name = None;
|
|
|
|
for attribute in attributes {
|
|
if attribute.path.leading_colon.is_some() || attribute.path.segments.len() != 1 {
|
|
continue;
|
|
}
|
|
|
|
let ident = &attribute.path.segments[0].ident;
|
|
|
|
if "component" != &ident.to_string() {
|
|
continue;
|
|
}
|
|
|
|
let syntax_error = || {
|
|
Err(Error::new_spanned(
|
|
&attribute.tokens,
|
|
"expected `component(name = <name literal>)` syntax",
|
|
))
|
|
};
|
|
|
|
let name_literal = if let [TokenTree::Group(group)] =
|
|
&attribute.tokens.clone().into_iter().collect::<Vec<_>>()[..]
|
|
{
|
|
match &group.stream().into_iter().collect::<Vec<_>>()[..] {
|
|
[TokenTree::Ident(key), TokenTree::Punct(op), TokenTree::Literal(literal)]
|
|
if "name" == &key.to_string() && '=' == op.as_char() =>
|
|
{
|
|
literal.clone()
|
|
}
|
|
_ => return syntax_error(),
|
|
}
|
|
} else {
|
|
return syntax_error();
|
|
};
|
|
|
|
if name.is_some() {
|
|
return Err(Error::new(ident.span(), "duplicate field rename attribute"));
|
|
}
|
|
|
|
name = Some(name_literal);
|
|
}
|
|
|
|
Ok(name)
|
|
}
|
|
|
|
fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn::Generics {
|
|
let mut generics = generics.clone();
|
|
for param in &mut generics.params {
|
|
if let syn::GenericParam::Type(ref mut type_param) = *param {
|
|
type_param.bounds.push(bound.clone());
|
|
}
|
|
}
|
|
generics
|
|
}
|
|
|
|
struct VariantCase<'a> {
|
|
attrs: &'a [syn::Attribute],
|
|
ident: &'a syn::Ident,
|
|
ty: Option<&'a syn::Type>,
|
|
}
|
|
|
|
trait Expander {
|
|
fn expand_record(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
fields: &[&syn::Field],
|
|
) -> Result<TokenStream>;
|
|
|
|
fn expand_variant(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
discriminant_size: DiscriminantSize,
|
|
cases: &[VariantCase],
|
|
style: VariantStyle,
|
|
) -> Result<TokenStream>;
|
|
}
|
|
|
|
fn expand(expander: &dyn Expander, input: &DeriveInput) -> Result<TokenStream> {
|
|
match find_style(input)? {
|
|
Style::Record => expand_record(expander, input),
|
|
Style::Variant(style) => expand_variant(expander, input, style),
|
|
}
|
|
}
|
|
|
|
fn expand_record(expander: &dyn Expander, input: &DeriveInput) -> Result<TokenStream> {
|
|
let name = &input.ident;
|
|
|
|
let body = if let Data::Struct(body) = &input.data {
|
|
body
|
|
} else {
|
|
return Err(Error::new(
|
|
name.span(),
|
|
"`record` component types can only be derived for Rust `struct`s",
|
|
));
|
|
};
|
|
|
|
match &body.fields {
|
|
syn::Fields::Named(fields) => expander.expand_record(
|
|
&input.ident,
|
|
&input.generics,
|
|
&fields.named.iter().collect::<Vec<_>>(),
|
|
),
|
|
|
|
syn::Fields::Unnamed(_) | syn::Fields::Unit => Err(Error::new(
|
|
name.span(),
|
|
"`record` component types can only be derived for `struct`s with named fields",
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn expand_variant(
|
|
expander: &dyn Expander,
|
|
input: &DeriveInput,
|
|
style: VariantStyle,
|
|
) -> Result<TokenStream> {
|
|
let name = &input.ident;
|
|
|
|
let body = if let Data::Enum(body) = &input.data {
|
|
body
|
|
} else {
|
|
return Err(Error::new(
|
|
name.span(),
|
|
format!(
|
|
"`{}` component types can only be derived for Rust `enum`s",
|
|
style
|
|
),
|
|
));
|
|
};
|
|
|
|
if body.variants.is_empty() {
|
|
return Err(Error::new(
|
|
name.span(),
|
|
format!("`{}` component types can only be derived for Rust `enum`s with at least one variant", style),
|
|
));
|
|
}
|
|
|
|
let discriminant_size = DiscriminantSize::from_count(body.variants.len()).ok_or_else(|| {
|
|
Error::new(
|
|
input.ident.span(),
|
|
"`enum`s with more than 2^32 variants are not supported",
|
|
)
|
|
})?;
|
|
|
|
let cases = body
|
|
.variants
|
|
.iter()
|
|
.map(
|
|
|syn::Variant {
|
|
attrs,
|
|
ident,
|
|
fields,
|
|
..
|
|
}| {
|
|
Ok(VariantCase {
|
|
attrs,
|
|
ident,
|
|
ty: match fields {
|
|
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
|
Some(&fields.unnamed[0].ty)
|
|
}
|
|
syn::Fields::Unit => None,
|
|
_ => {
|
|
return Err(Error::new(
|
|
name.span(),
|
|
format!(
|
|
"`{}` component types can only be derived for Rust `enum`s \
|
|
containing variants with {}",
|
|
style,
|
|
match style {
|
|
VariantStyle::Variant => "at most one unnamed field each",
|
|
VariantStyle::Enum => "no fields",
|
|
VariantStyle::Union => "exactly one unnamed field each",
|
|
}
|
|
),
|
|
))
|
|
}
|
|
},
|
|
})
|
|
},
|
|
)
|
|
.collect::<Result<Vec<_>>>()?;
|
|
|
|
expander.expand_variant(
|
|
&input.ident,
|
|
&input.generics,
|
|
discriminant_size,
|
|
&cases,
|
|
style,
|
|
)
|
|
}
|
|
|
|
fn expand_record_for_component_type(
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
fields: &[&syn::Field],
|
|
typecheck: TokenStream,
|
|
typecheck_argument: TokenStream,
|
|
) -> Result<TokenStream> {
|
|
let internal = quote!(wasmtime::component::__internal);
|
|
|
|
let mut lower_generic_params = TokenStream::new();
|
|
let mut lower_generic_args = TokenStream::new();
|
|
let mut lower_field_declarations = TokenStream::new();
|
|
let mut sizes = TokenStream::new();
|
|
let mut unique_types = HashSet::new();
|
|
|
|
for (index, syn::Field { ident, ty, .. }) in fields.iter().enumerate() {
|
|
let generic = format_ident!("T{}", index);
|
|
|
|
lower_generic_params.extend(quote!(#generic: Copy,));
|
|
lower_generic_args.extend(quote!(<#ty as wasmtime::component::ComponentType>::Lower,));
|
|
|
|
lower_field_declarations.extend(quote!(#ident: #generic,));
|
|
|
|
sizes.extend(quote!(
|
|
size = #internal::align_to(size, <#ty as wasmtime::component::ComponentType>::ALIGN32);
|
|
size += <#ty as wasmtime::component::ComponentType>::SIZE32;
|
|
));
|
|
|
|
unique_types.insert(ty);
|
|
}
|
|
|
|
let alignments = unique_types
|
|
.into_iter()
|
|
.map(|ty| {
|
|
let align = quote!(<#ty as wasmtime::component::ComponentType>::ALIGN32);
|
|
quote!(if #align > align {
|
|
align = #align;
|
|
})
|
|
})
|
|
.collect::<TokenStream>();
|
|
|
|
let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::ComponentType));
|
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
|
let lower = format_ident!("Lower{}", name);
|
|
|
|
// You may wonder why we make the types of all the fields of the #lower struct generic. This is to work
|
|
// around the lack of [perfect derive support in
|
|
// rustc](https://smallcultfollowing.com/babysteps//blog/2022/04/12/implied-bounds-and-perfect-derive/#what-is-perfect-derive)
|
|
// as of this writing.
|
|
//
|
|
// If the struct we're deriving a `ComponentType` impl for has any generic parameters, then #lower needs
|
|
// generic parameters too. And if we just copy the parameters and bounds from the impl to #lower, then the
|
|
// `#[derive(Clone, Copy)]` will fail unless the original generics were declared with those bounds, which
|
|
// we don't want to require.
|
|
//
|
|
// Alternatively, we could just pass the `Lower` associated type of each generic type as arguments to
|
|
// #lower, but that would require distinguishing between generic and concrete types when generating
|
|
// #lower_field_declarations, which would require some form of symbol resolution. That doesn't seem worth
|
|
// the trouble.
|
|
|
|
let expanded = quote! {
|
|
#[doc(hidden)]
|
|
#[derive(Clone, Copy)]
|
|
#[repr(C)]
|
|
pub struct #lower <#lower_generic_params> {
|
|
#lower_field_declarations
|
|
}
|
|
|
|
unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause {
|
|
type Lower = #lower <#lower_generic_args>;
|
|
|
|
const SIZE32: usize = {
|
|
let mut size = 0;
|
|
#sizes
|
|
#internal::align_to(size, Self::ALIGN32)
|
|
};
|
|
|
|
const ALIGN32: u32 = {
|
|
let mut align = 1;
|
|
#alignments
|
|
align
|
|
};
|
|
|
|
#[inline]
|
|
fn typecheck(
|
|
ty: &#internal::InterfaceType,
|
|
types: &#internal::ComponentTypes,
|
|
) -> #internal::anyhow::Result<()> {
|
|
#internal::#typecheck(ty, types, &[#typecheck_argument])
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(quote!(const _: () = { #expanded };))
|
|
}
|
|
|
|
fn quote(size: DiscriminantSize, discriminant: usize) -> TokenStream {
|
|
match size {
|
|
DiscriminantSize::Size1 => {
|
|
let discriminant = u8::try_from(discriminant).unwrap();
|
|
quote!(#discriminant)
|
|
}
|
|
DiscriminantSize::Size2 => {
|
|
let discriminant = u16::try_from(discriminant).unwrap();
|
|
quote!(#discriminant)
|
|
}
|
|
DiscriminantSize::Size4 => {
|
|
let discriminant = u32::try_from(discriminant).unwrap();
|
|
quote!(#discriminant)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[proc_macro_derive(Lift, attributes(component))]
|
|
pub fn lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
|
expand(&LiftExpander, &parse_macro_input!(input as DeriveInput))
|
|
.unwrap_or_else(Error::into_compile_error)
|
|
.into()
|
|
}
|
|
|
|
struct LiftExpander;
|
|
|
|
impl Expander for LiftExpander {
|
|
fn expand_record(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
fields: &[&syn::Field],
|
|
) -> Result<TokenStream> {
|
|
let internal = quote!(wasmtime::component::__internal);
|
|
|
|
let mut lifts = TokenStream::new();
|
|
let mut loads = TokenStream::new();
|
|
|
|
for syn::Field { ident, ty, .. } in fields {
|
|
lifts.extend(quote!(#ident: <#ty as wasmtime::component::Lift>::lift(
|
|
store, options, &src.#ident
|
|
)?,));
|
|
|
|
loads.extend(quote!(#ident: <#ty as wasmtime::component::Lift>::load(
|
|
memory,
|
|
&bytes
|
|
[#internal::next_field::<#ty>(&mut offset)..]
|
|
[..<#ty as wasmtime::component::ComponentType>::SIZE32]
|
|
)?,));
|
|
}
|
|
|
|
let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lift));
|
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
|
|
|
let expanded = quote! {
|
|
unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause {
|
|
#[inline]
|
|
fn lift(
|
|
store: &#internal::StoreOpaque,
|
|
options: &#internal::Options,
|
|
src: &Self::Lower,
|
|
) -> #internal::anyhow::Result<Self> {
|
|
Ok(Self {
|
|
#lifts
|
|
})
|
|
}
|
|
|
|
#[inline]
|
|
fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result<Self> {
|
|
debug_assert!(
|
|
(bytes.as_ptr() as usize)
|
|
% (<Self as wasmtime::component::ComponentType>::ALIGN32 as usize)
|
|
== 0
|
|
);
|
|
let mut offset = 0;
|
|
Ok(Self {
|
|
#loads
|
|
})
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(expanded)
|
|
}
|
|
|
|
fn expand_variant(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
discriminant_size: DiscriminantSize,
|
|
cases: &[VariantCase],
|
|
_style: VariantStyle,
|
|
) -> Result<TokenStream> {
|
|
let internal = quote!(wasmtime::component::__internal);
|
|
|
|
let mut lifts = TokenStream::new();
|
|
let mut loads = TokenStream::new();
|
|
|
|
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
|
let index_u32 = u32::try_from(index).unwrap();
|
|
|
|
let index_quoted = quote(discriminant_size, index);
|
|
|
|
if let Some(ty) = ty {
|
|
lifts.extend(
|
|
quote!(#index_u32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
|
|
store, options, unsafe { &src.payload.#ident }
|
|
)?),),
|
|
);
|
|
|
|
loads.extend(
|
|
quote!(#index_quoted => Self::#ident(<#ty as wasmtime::component::Lift>::load(
|
|
memory, &payload[..<#ty as wasmtime::component::ComponentType>::SIZE32]
|
|
)?),),
|
|
);
|
|
} else {
|
|
lifts.extend(quote!(#index_u32 => Self::#ident,));
|
|
|
|
loads.extend(quote!(#index_quoted => Self::#ident,));
|
|
}
|
|
}
|
|
|
|
let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lift));
|
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
|
|
|
let from_bytes = match discriminant_size {
|
|
DiscriminantSize::Size1 => quote!(bytes[0]),
|
|
DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),
|
|
DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),
|
|
};
|
|
|
|
let payload_offset = usize::from(discriminant_size);
|
|
|
|
let expanded = quote! {
|
|
unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause {
|
|
#[inline]
|
|
fn lift(
|
|
store: &#internal::StoreOpaque,
|
|
options: &#internal::Options,
|
|
src: &Self::Lower,
|
|
) -> #internal::anyhow::Result<Self> {
|
|
Ok(match src.tag.get_u32() {
|
|
#lifts
|
|
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
|
})
|
|
}
|
|
|
|
#[inline]
|
|
fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result<Self> {
|
|
let align = <Self as wasmtime::component::ComponentType>::ALIGN32;
|
|
debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
|
|
let discrim = #from_bytes;
|
|
let payload = &bytes[#internal::align_to(#payload_offset, align)..];
|
|
Ok(match discrim {
|
|
#loads
|
|
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
|
|
})
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(expanded)
|
|
}
|
|
}
|
|
|
|
#[proc_macro_derive(Lower, attributes(component))]
|
|
pub fn lower(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
|
expand(&LowerExpander, &parse_macro_input!(input as DeriveInput))
|
|
.unwrap_or_else(Error::into_compile_error)
|
|
.into()
|
|
}
|
|
|
|
struct LowerExpander;
|
|
|
|
impl Expander for LowerExpander {
|
|
fn expand_record(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
fields: &[&syn::Field],
|
|
) -> Result<TokenStream> {
|
|
let internal = quote!(wasmtime::component::__internal);
|
|
|
|
let mut lowers = TokenStream::new();
|
|
let mut stores = TokenStream::new();
|
|
|
|
for syn::Field { ident, ty, .. } in fields {
|
|
lowers.extend(quote!(wasmtime::component::Lower::lower(
|
|
&self.#ident, store, options, #internal::map_maybe_uninit!(dst.#ident)
|
|
)?;));
|
|
|
|
stores.extend(quote!(wasmtime::component::Lower::store(
|
|
&self.#ident, memory, #internal::next_field::<#ty>(&mut offset)
|
|
)?;));
|
|
}
|
|
|
|
let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lower));
|
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
|
|
|
let expanded = quote! {
|
|
unsafe impl #impl_generics wasmtime::component::Lower for #name #ty_generics #where_clause {
|
|
#[inline]
|
|
fn lower<T>(
|
|
&self,
|
|
store: &mut wasmtime::StoreContextMut<T>,
|
|
options: &#internal::Options,
|
|
dst: &mut std::mem::MaybeUninit<Self::Lower>,
|
|
) -> #internal::anyhow::Result<()> {
|
|
#lowers
|
|
Ok(())
|
|
}
|
|
|
|
#[inline]
|
|
fn store<T>(
|
|
&self,
|
|
memory: &mut #internal::MemoryMut<'_, T>,
|
|
mut offset: usize
|
|
) -> #internal::anyhow::Result<()> {
|
|
debug_assert!(offset % (<Self as wasmtime::component::ComponentType>::ALIGN32 as usize) == 0);
|
|
#stores
|
|
Ok(())
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(expanded)
|
|
}
|
|
|
|
fn expand_variant(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
discriminant_size: DiscriminantSize,
|
|
cases: &[VariantCase],
|
|
_style: VariantStyle,
|
|
) -> Result<TokenStream> {
|
|
let internal = quote!(wasmtime::component::__internal);
|
|
|
|
let mut lowers = TokenStream::new();
|
|
let mut stores = TokenStream::new();
|
|
|
|
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
|
|
let index_u32 = u32::try_from(index).unwrap();
|
|
|
|
let index_quoted = quote(discriminant_size, index);
|
|
|
|
let discriminant_size = usize::from(discriminant_size);
|
|
|
|
let pattern;
|
|
let lower;
|
|
let store;
|
|
|
|
if ty.is_some() {
|
|
pattern = quote!(Self::#ident(value));
|
|
lower = quote!(value.lower(store, options, #internal::map_maybe_uninit!(dst.payload.#ident)));
|
|
store = quote!(value.store(
|
|
memory,
|
|
offset + #internal::align_to(
|
|
#discriminant_size,
|
|
<Self as wasmtime::component::ComponentType>::ALIGN32
|
|
)
|
|
));
|
|
} else {
|
|
pattern = quote!(Self::#ident);
|
|
lower = quote!(Ok(()));
|
|
store = quote!(Ok(()));
|
|
}
|
|
|
|
lowers.extend(quote!(#pattern => {
|
|
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_u32 as i32));
|
|
#lower
|
|
}));
|
|
|
|
stores.extend(quote!(#pattern => {
|
|
*memory.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes();
|
|
#store
|
|
}));
|
|
}
|
|
|
|
let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::Lower));
|
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
|
|
|
let expanded = quote! {
|
|
unsafe impl #impl_generics wasmtime::component::Lower for #name #ty_generics #where_clause {
|
|
#[inline]
|
|
fn lower<T>(
|
|
&self,
|
|
store: &mut wasmtime::StoreContextMut<T>,
|
|
options: &#internal::Options,
|
|
dst: &mut std::mem::MaybeUninit<Self::Lower>,
|
|
) -> #internal::anyhow::Result<()> {
|
|
// See comment in <Result<T, E> as Lower>::lower for why we zero out the payload here
|
|
unsafe {
|
|
#internal::map_maybe_uninit!(dst.payload)
|
|
.as_mut_ptr()
|
|
.write_bytes(0u8, 1);
|
|
}
|
|
|
|
match self {
|
|
#lowers
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn store<T>(
|
|
&self,
|
|
memory: &mut #internal::MemoryMut<'_, T>,
|
|
mut offset: usize
|
|
) -> #internal::anyhow::Result<()> {
|
|
debug_assert!(offset % (<Self as wasmtime::component::ComponentType>::ALIGN32 as usize) == 0);
|
|
match self {
|
|
#stores
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(expanded)
|
|
}
|
|
}
|
|
|
|
#[proc_macro_derive(ComponentType, attributes(component))]
|
|
pub fn component_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
|
expand(
|
|
&ComponentTypeExpander,
|
|
&parse_macro_input!(input as DeriveInput),
|
|
)
|
|
.unwrap_or_else(Error::into_compile_error)
|
|
.into()
|
|
}
|
|
|
|
struct ComponentTypeExpander;
|
|
|
|
impl Expander for ComponentTypeExpander {
|
|
fn expand_record(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
fields: &[&syn::Field],
|
|
) -> Result<TokenStream> {
|
|
expand_record_for_component_type(
|
|
name,
|
|
generics,
|
|
fields,
|
|
quote!(typecheck_record),
|
|
fields
|
|
.iter()
|
|
.map(
|
|
|syn::Field {
|
|
attrs, ident, ty, ..
|
|
}| {
|
|
let name = find_rename(attrs)?.unwrap_or_else(|| {
|
|
Literal::string(&ident.as_ref().unwrap().to_string())
|
|
});
|
|
|
|
Ok(quote!((#name, <#ty as wasmtime::component::ComponentType>::typecheck),))
|
|
},
|
|
)
|
|
.collect::<Result<_>>()?,
|
|
)
|
|
}
|
|
|
|
fn expand_variant(
|
|
&self,
|
|
name: &syn::Ident,
|
|
generics: &syn::Generics,
|
|
discriminant_size: DiscriminantSize,
|
|
cases: &[VariantCase],
|
|
style: VariantStyle,
|
|
) -> Result<TokenStream> {
|
|
let internal = quote!(wasmtime::component::__internal);
|
|
|
|
let mut case_names_and_checks = TokenStream::new();
|
|
let mut lower_payload_generic_params = TokenStream::new();
|
|
let mut lower_payload_generic_args = TokenStream::new();
|
|
let mut lower_payload_case_declarations = TokenStream::new();
|
|
let mut lower_generic_args = TokenStream::new();
|
|
let mut sizes = TokenStream::new();
|
|
let mut unique_types = HashSet::new();
|
|
|
|
for (index, VariantCase { attrs, ident, ty }) in cases.iter().enumerate() {
|
|
let rename = find_rename(attrs)?;
|
|
|
|
if let (Some(_), VariantStyle::Union) = (&rename, style) {
|
|
return Err(Error::new(
|
|
ident.span(),
|
|
"renaming `union` cases is not permitted; only the type is used",
|
|
));
|
|
}
|
|
|
|
let name = rename.unwrap_or_else(|| Literal::string(&ident.to_string()));
|
|
|
|
if let Some(ty) = ty {
|
|
sizes.extend({
|
|
let size = quote!(<#ty as wasmtime::component::ComponentType>::SIZE32);
|
|
quote!(if #size > size {
|
|
size = #size;
|
|
})
|
|
});
|
|
|
|
case_names_and_checks.extend(match style {
|
|
VariantStyle::Variant => {
|
|
quote!((#name, <#ty as wasmtime::component::ComponentType>::typecheck),)
|
|
}
|
|
VariantStyle::Union => {
|
|
quote!(<#ty as wasmtime::component::ComponentType>::typecheck,)
|
|
}
|
|
VariantStyle::Enum => {
|
|
return Err(Error::new(
|
|
ident.span(),
|
|
"payloads are not permitted for `enum` cases",
|
|
))
|
|
}
|
|
});
|
|
|
|
let generic = format_ident!("T{}", index);
|
|
|
|
lower_payload_generic_params.extend(quote!(#generic: Copy,));
|
|
lower_payload_generic_args.extend(quote!(#generic,));
|
|
lower_payload_case_declarations.extend(quote!(#ident: #generic,));
|
|
lower_generic_args
|
|
.extend(quote!(<#ty as wasmtime::component::ComponentType>::Lower,));
|
|
|
|
unique_types.insert(ty);
|
|
} else {
|
|
case_names_and_checks.extend(match style {
|
|
VariantStyle::Variant => {
|
|
quote!((#name, <() as wasmtime::component::ComponentType>::typecheck),)
|
|
}
|
|
VariantStyle::Union => {
|
|
quote!(<() as wasmtime::component::ComponentType>::typecheck,)
|
|
}
|
|
VariantStyle::Enum => quote!(#name,),
|
|
});
|
|
}
|
|
}
|
|
|
|
if lower_payload_case_declarations.is_empty() {
|
|
lower_payload_case_declarations.extend(quote!(_dummy: ()));
|
|
}
|
|
|
|
let alignments = unique_types
|
|
.into_iter()
|
|
.map(|ty| {
|
|
let align = quote!(<#ty as wasmtime::component::ComponentType>::ALIGN32);
|
|
quote!(if #align > align {
|
|
align = #align;
|
|
})
|
|
})
|
|
.collect::<TokenStream>();
|
|
|
|
let typecheck = match style {
|
|
VariantStyle::Variant => quote!(typecheck_variant),
|
|
VariantStyle::Union => quote!(typecheck_union),
|
|
VariantStyle::Enum => quote!(typecheck_enum),
|
|
};
|
|
|
|
let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::ComponentType));
|
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
|
let lower = format_ident!("Lower{}", name);
|
|
let lower_payload = format_ident!("LowerPayload{}", name);
|
|
let discriminant_size = u32::from(discriminant_size);
|
|
|
|
// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union
|
|
// generic. This is to work around a [normalization bug in
|
|
// rustc](https://github.com/rust-lang/rust/issues/90903) such that the compiler does not understand that
|
|
// e.g. `<i32 as ComponentType>::Lower` is `Copy` despite the bound specified in `ComponentType`'s
|
|
// definition.
|
|
//
|
|
// See also the comment in `Self::expand_record` above for another reason why we do this.
|
|
|
|
let expanded = quote! {
|
|
#[doc(hidden)]
|
|
#[derive(Clone, Copy)]
|
|
#[repr(C)]
|
|
pub struct #lower<#lower_payload_generic_params> {
|
|
tag: wasmtime::ValRaw,
|
|
payload: #lower_payload<#lower_payload_generic_args>
|
|
}
|
|
|
|
#[doc(hidden)]
|
|
#[allow(non_snake_case)]
|
|
#[derive(Clone, Copy)]
|
|
#[repr(C)]
|
|
union #lower_payload<#lower_payload_generic_params> {
|
|
#lower_payload_case_declarations
|
|
}
|
|
|
|
unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause {
|
|
type Lower = #lower<#lower_generic_args>;
|
|
|
|
#[inline]
|
|
fn typecheck(
|
|
ty: &#internal::InterfaceType,
|
|
types: &#internal::ComponentTypes,
|
|
) -> #internal::anyhow::Result<()> {
|
|
#internal::#typecheck(ty, types, &[#case_names_and_checks])
|
|
}
|
|
|
|
const SIZE32: usize = {
|
|
let mut size = 0;
|
|
#sizes
|
|
#internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size
|
|
};
|
|
|
|
const ALIGN32: u32 = {
|
|
let mut align = #discriminant_size;
|
|
#alignments
|
|
align
|
|
};
|
|
}
|
|
};
|
|
|
|
Ok(quote!(const _: () = { #expanded };))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Flag {
|
|
rename: Option<String>,
|
|
name: String,
|
|
}
|
|
|
|
impl Parse for Flag {
|
|
fn parse(input: ParseStream) -> Result<Self> {
|
|
let attributes = syn::Attribute::parse_outer(input)?;
|
|
|
|
let rename = find_rename(&attributes)?
|
|
.map(|literal| {
|
|
let s = literal.to_string();
|
|
|
|
s.strip_prefix('"')
|
|
.and_then(|s| s.strip_suffix('"'))
|
|
.map(|s| s.to_owned())
|
|
.ok_or_else(|| Error::new(literal.span(), "expected string literal"))
|
|
})
|
|
.transpose()?;
|
|
|
|
input.parse::<Token![const]>()?;
|
|
let name = input.parse::<syn::Ident>()?.to_string();
|
|
|
|
Ok(Self { rename, name })
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Flags {
|
|
name: String,
|
|
flags: Vec<Flag>,
|
|
}
|
|
|
|
impl Parse for Flags {
|
|
fn parse(input: ParseStream) -> Result<Self> {
|
|
let name = input.parse::<syn::Ident>()?.to_string();
|
|
|
|
let content;
|
|
braced!(content in input);
|
|
|
|
let flags = content
|
|
.parse_terminated::<_, Token![;]>(Flag::parse)?
|
|
.into_iter()
|
|
.collect();
|
|
|
|
Ok(Self { name, flags })
|
|
}
|
|
}
|
|
|
|
#[proc_macro]
|
|
pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
|
expand_flags(&parse_macro_input!(input as Flags))
|
|
.unwrap_or_else(Error::into_compile_error)
|
|
.into()
|
|
}
|
|
|
|
fn expand_flags(flags: &Flags) -> Result<TokenStream> {
|
|
let size = FlagsSize::from_count(flags.flags.len());
|
|
|
|
let ty;
|
|
let eq;
|
|
|
|
let count = flags.flags.len();
|
|
|
|
match size {
|
|
FlagsSize::Size1 => {
|
|
ty = quote!(u8);
|
|
|
|
eq = if count == 8 {
|
|
quote!(self.__inner0.eq(&rhs.__inner0))
|
|
} else {
|
|
let mask = !(0xFF_u8 << count);
|
|
|
|
quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask)))
|
|
};
|
|
}
|
|
FlagsSize::Size2 => {
|
|
ty = quote!(u16);
|
|
|
|
eq = if count == 16 {
|
|
quote!(self.__inner0.eq(&rhs.__inner0))
|
|
} else {
|
|
let mask = !(0xFFFF_u16 << count);
|
|
|
|
quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask)))
|
|
};
|
|
}
|
|
FlagsSize::Size4Plus(n) => {
|
|
ty = quote!(u32);
|
|
|
|
let comparisons = (0..(n - 1))
|
|
.map(|index| {
|
|
let field = format_ident!("__inner{}", index);
|
|
|
|
quote!(self.#field.eq(&rhs.#field) &&)
|
|
})
|
|
.collect::<TokenStream>();
|
|
|
|
let field = format_ident!("__inner{}", n - 1);
|
|
|
|
eq = if count % 32 == 0 {
|
|
quote!(#comparisons self.#field.eq(&rhs.#field))
|
|
} else {
|
|
let mask = !(0xFFFF_FFFF_u32 << (count % 32));
|
|
|
|
quote!(#comparisons (self.#field & #mask).eq(&(rhs.#field & #mask)))
|
|
}
|
|
}
|
|
}
|
|
|
|
let count;
|
|
let mut as_array;
|
|
let mut bitor;
|
|
let mut bitor_assign;
|
|
let mut bitand;
|
|
let mut bitand_assign;
|
|
let mut bitxor;
|
|
let mut bitxor_assign;
|
|
let mut not;
|
|
|
|
match size {
|
|
FlagsSize::Size1 | FlagsSize::Size2 => {
|
|
count = 1;
|
|
as_array = quote!([self.__inner0 as u32]);
|
|
bitor = quote!(Self {
|
|
__inner0: self.__inner0.bitor(rhs.__inner0)
|
|
});
|
|
bitor_assign = quote!(self.__inner0.bitor_assign(rhs.__inner0));
|
|
bitand = quote!(Self {
|
|
__inner0: self.__inner0.bitand(rhs.__inner0)
|
|
});
|
|
bitand_assign = quote!(self.__inner0.bitand_assign(rhs.__inner0));
|
|
bitxor = quote!(Self {
|
|
__inner0: self.__inner0.bitxor(rhs.__inner0)
|
|
});
|
|
bitxor_assign = quote!(self.__inner0.bitxor_assign(rhs.__inner0));
|
|
not = quote!(Self {
|
|
__inner0: self.__inner0.not()
|
|
});
|
|
}
|
|
FlagsSize::Size4Plus(n) => {
|
|
count = n;
|
|
as_array = TokenStream::new();
|
|
bitor = TokenStream::new();
|
|
bitor_assign = TokenStream::new();
|
|
bitand = TokenStream::new();
|
|
bitand_assign = TokenStream::new();
|
|
bitxor = TokenStream::new();
|
|
bitxor_assign = TokenStream::new();
|
|
not = TokenStream::new();
|
|
|
|
for index in 0..n {
|
|
let field = format_ident!("__inner{}", index);
|
|
|
|
as_array.extend(quote!(self.#field,));
|
|
bitor.extend(quote!(#field: self.#field.bitor(rhs.#field),));
|
|
bitor_assign.extend(quote!(self.#field.bitor_assign(rhs.#field);));
|
|
bitand.extend(quote!(#field: self.#field.bitand(rhs.#field),));
|
|
bitand_assign.extend(quote!(self.#field.bitand_assign(rhs.#field);));
|
|
bitxor.extend(quote!(#field: self.#field.bitxor(rhs.#field),));
|
|
bitxor_assign.extend(quote!(self.#field.bitxor_assign(rhs.#field);));
|
|
not.extend(quote!(#field: self.#field.not(),));
|
|
}
|
|
|
|
as_array = quote!([#as_array]);
|
|
bitor = quote!(Self { #bitor });
|
|
bitand = quote!(Self { #bitand });
|
|
bitxor = quote!(Self { #bitxor });
|
|
not = quote!(Self { #not });
|
|
}
|
|
};
|
|
|
|
let name = format_ident!("{}", flags.name);
|
|
|
|
let mut constants = TokenStream::new();
|
|
let mut rust_names = TokenStream::new();
|
|
let mut component_names = TokenStream::new();
|
|
|
|
for (index, Flag { name, rename }) in flags.flags.iter().enumerate() {
|
|
rust_names.extend(quote!(#name,));
|
|
|
|
let component_name = rename.as_ref().unwrap_or(name);
|
|
component_names.extend(quote!(#component_name,));
|
|
|
|
let fields = match size {
|
|
FlagsSize::Size1 => {
|
|
let init = 1_u8 << index;
|
|
quote!(__inner0: #init)
|
|
}
|
|
FlagsSize::Size2 => {
|
|
let init = 1_u16 << index;
|
|
quote!(__inner0: #init)
|
|
}
|
|
FlagsSize::Size4Plus(n) => (0..n)
|
|
.map(|i| {
|
|
let field = format_ident!("__inner{}", i);
|
|
|
|
let init = if index / 32 == i {
|
|
1_u32 << (index % 32)
|
|
} else {
|
|
0
|
|
};
|
|
|
|
quote!(#field: #init,)
|
|
})
|
|
.collect::<TokenStream>(),
|
|
};
|
|
|
|
let name = format_ident!("{}", name);
|
|
|
|
constants.extend(quote!(const #name: Self = Self { #fields };));
|
|
}
|
|
|
|
let generics = syn::Generics {
|
|
lt_token: None,
|
|
params: Punctuated::new(),
|
|
gt_token: None,
|
|
where_clause: None,
|
|
};
|
|
|
|
let fields = {
|
|
let ty = syn::parse2::<syn::Type>(ty.clone())?;
|
|
|
|
(0..count)
|
|
.map(|index| syn::Field {
|
|
attrs: Vec::new(),
|
|
vis: syn::Visibility::Inherited,
|
|
ident: Some(format_ident!("__inner{}", index)),
|
|
colon_token: None,
|
|
ty: ty.clone(),
|
|
})
|
|
.collect::<Vec<_>>()
|
|
};
|
|
|
|
let fields = fields.iter().collect::<Vec<_>>();
|
|
|
|
let component_type_impl = expand_record_for_component_type(
|
|
&name,
|
|
&generics,
|
|
&fields,
|
|
quote!(typecheck_flags),
|
|
component_names,
|
|
)?;
|
|
|
|
let lower_impl = LowerExpander.expand_record(&name, &generics, &fields)?;
|
|
|
|
let lift_impl = LiftExpander.expand_record(&name, &generics, &fields)?;
|
|
|
|
let internal = quote!(wasmtime::component::__internal);
|
|
|
|
let fields = fields
|
|
.iter()
|
|
.map(|syn::Field { ident, .. }| quote!(#[doc(hidden)] #ident: #ty,))
|
|
.collect::<TokenStream>();
|
|
|
|
let expanded = quote! {
|
|
#[derive(Copy, Clone, Default)]
|
|
struct #name { #fields }
|
|
|
|
impl #name {
|
|
#constants
|
|
|
|
fn as_array(&self) -> [u32; #count] {
|
|
#as_array
|
|
}
|
|
}
|
|
|
|
impl std::cmp::PartialEq for #name {
|
|
fn eq(&self, rhs: &#name) -> bool {
|
|
#eq
|
|
}
|
|
}
|
|
|
|
impl std::cmp::Eq for #name { }
|
|
|
|
impl std::fmt::Debug for #name {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
#internal::format_flags(&self.as_array(), &[#rust_names], f)
|
|
}
|
|
}
|
|
|
|
impl std::ops::BitOr for #name {
|
|
type Output = #name;
|
|
|
|
fn bitor(self, rhs: #name) -> #name {
|
|
#bitor
|
|
}
|
|
}
|
|
|
|
impl std::ops::BitOrAssign for #name {
|
|
fn bitor_assign(&mut self, rhs: #name) {
|
|
#bitor_assign
|
|
}
|
|
}
|
|
|
|
impl std::ops::BitAnd for #name {
|
|
type Output = #name;
|
|
|
|
fn bitand(self, rhs: #name) -> #name {
|
|
#bitand
|
|
}
|
|
}
|
|
|
|
impl std::ops::BitAndAssign for #name {
|
|
fn bitand_assign(&mut self, rhs: #name) {
|
|
#bitand_assign
|
|
}
|
|
}
|
|
|
|
impl std::ops::BitXor for #name {
|
|
type Output = #name;
|
|
|
|
fn bitxor(self, rhs: #name) -> #name {
|
|
#bitxor
|
|
}
|
|
}
|
|
|
|
impl std::ops::BitXorAssign for #name {
|
|
fn bitxor_assign(&mut self, rhs: #name) {
|
|
#bitxor_assign
|
|
}
|
|
}
|
|
|
|
impl std::ops::Not for #name {
|
|
type Output = #name;
|
|
|
|
fn not(self) -> #name {
|
|
#not
|
|
}
|
|
}
|
|
|
|
#component_type_impl
|
|
|
|
#lower_impl
|
|
|
|
#lift_impl
|
|
};
|
|
|
|
Ok(expanded)
|
|
}
|