support dynamic function calls in component model (#4442)

* support dynamic function calls in component model

This addresses #4310, introducing a new `component::values::Val` type for
representing component values dynamically, as well as `component::types::Type`
for representing the corresponding interface types. It also adds a `call` method
to `component::func::Func`, which takes a slice of `Val`s as parameters and
returns a `Result<Val>` representing the result.

Note that I've moved `post_return` and `call_raw` from `TypedFunc` to `Func`
since there was nothing specific to `TypedFunc` about them, and I wanted to
reuse them.  The code in both is unchanged beyond the trivial tweaks to make
them fit in their new home.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* order variants and match cases more consistently

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* implement lift for String, Box<str>, etc.

This also removes the redundant `store` parameter from `Type::load`.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* implement code review feedback

This fixes a few issues:

- Bad offset calculation when lowering
- Missing variant padding
- Style issues regarding `types::Handle`
- Missed opportunities to reuse `Lift` and `Lower` impls

It also adds forwarding `Lift` impls for `Box<[T]>`, `Vec<T>`, etc.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* move `new_*` methods to specific `types` structs

Per review feedback, I've moved `Type::new_record` to `Record::new_val` and
added a `Type::unwrap_record` method; likewise for the other kinds of types.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* make tuple, option, and expected type comparisons recursive

These types should compare as equal across component boundaries as long as their
type parameters are equal.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* improve error diagnostic in `Type::check`

We now distinguish between more failure cases to provide an informative error
message.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* address review feedback

- Remove `WasmStr::to_str_from_memory` and `WasmList::get_from_memory`
- add `try_new` methods to various `values` types
- avoid using `ExactSizeIterator::len` where we can't trust it
- fix over-constrained bounds on forwarded `ComponentType` impls

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* rearrange code per review feedback

- Move functions from `types` to `values` module so we can make certain struct fields private
- Rename `try_new` to just `new`

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* remove special-case equality test for tuples, options, and expecteds

Instead, I've added a FIXME comment and will open an issue to do recursive
structural equality testing.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>
This commit is contained in:
Joel Dice
2022-07-25 12:38:48 -06:00
committed by GitHub
parent ee7e4f4c6b
commit 7c67e620c4
16 changed files with 2796 additions and 453 deletions

6
Cargo.lock generated
View File

@@ -3333,6 +3333,7 @@ dependencies = [
"wasmparser",
"wasmtime-cache",
"wasmtime-component-macro",
"wasmtime-component-util",
"wasmtime-cranelift",
"wasmtime-environ",
"wasmtime-fiber",
@@ -3472,8 +3473,13 @@ dependencies = [
"proc-macro2",
"quote",
"syn",
"wasmtime-component-util",
]
[[package]]
name = "wasmtime-component-util"
version = "0.40.0"
[[package]]
name = "wasmtime-cranelift"
version = "0.40.0"

View File

@@ -17,6 +17,7 @@ proc-macro = true
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0", features = ["extra-traits"] }
wasmtime-component-util = { path = "../component-util", version = "=0.40.0" }
[badges]
maintenance = { status = "actively-developed" }

View File

@@ -5,6 +5,7 @@ use std::fmt;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{braced, parse_macro_input, parse_quote, Data, DeriveInput, Error, Result, Token};
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
#[derive(Debug, Copy, Clone)]
enum VariantStyle {
@@ -147,64 +148,6 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn
generics
}
#[derive(Debug, Copy, Clone)]
enum DiscriminantSize {
Size1,
Size2,
Size4,
}
impl DiscriminantSize {
fn quote(self, discriminant: usize) -> TokenStream {
match self {
Self::Size1 => {
let discriminant = u8::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size2 => {
let discriminant = u16::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size4 => {
let discriminant = u32::try_from(discriminant).unwrap();
quote!(#discriminant)
}
}
}
}
impl From<DiscriminantSize> for u32 {
fn from(size: DiscriminantSize) -> u32 {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
impl From<DiscriminantSize> for usize {
fn from(size: DiscriminantSize) -> usize {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
fn discriminant_size(case_count: usize) -> Option<DiscriminantSize> {
if case_count <= 0xFF {
Some(DiscriminantSize::Size1)
} else if case_count <= 0xFFFF {
Some(DiscriminantSize::Size2)
} else if case_count <= 0xFFFF_FFFF {
Some(DiscriminantSize::Size4)
} else {
None
}
}
struct VariantCase<'a> {
attrs: &'a [syn::Attribute],
ident: &'a syn::Ident,
@@ -288,7 +231,7 @@ fn expand_variant(
));
}
let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| {
let discriminant_size = DiscriminantSize::from_count(body.variants.len()).ok_or_else(|| {
Error::new(
input.ident.span(),
"`enum`s with more than 2^32 variants are not supported",
@@ -417,7 +360,7 @@ fn expand_record_for_component_type(
const SIZE32: usize = {
let mut size = 0;
#sizes
size
#internal::align_to(size, Self::ALIGN32)
};
const ALIGN32: u32 = {
@@ -439,6 +382,23 @@ fn expand_record_for_component_type(
Ok(quote!(const _: () = { #expanded };))
}
fn quote(size: DiscriminantSize, discriminant: usize) -> TokenStream {
match size {
DiscriminantSize::Size1 => {
let discriminant = u8::try_from(discriminant).unwrap();
quote!(#discriminant)
}
DiscriminantSize::Size2 => {
let discriminant = u16::try_from(discriminant).unwrap();
quote!(#discriminant)
}
DiscriminantSize::Size4 => {
let discriminant = u32::try_from(discriminant).unwrap();
quote!(#discriminant)
}
}
}
#[proc_macro_derive(Lift, attributes(component))]
pub fn lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
expand(&LiftExpander, &parse_macro_input!(input as DeriveInput))
@@ -523,7 +483,7 @@ impl Expander for LiftExpander {
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u32 = u32::try_from(index).unwrap();
let index_quoted = discriminant_size.quote(index);
let index_quoted = quote(discriminant_size, index);
if let Some(ty) = ty {
lifts.extend(
@@ -666,7 +626,7 @@ impl Expander for LowerExpander {
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u32 = u32::try_from(index).unwrap();
let index_quoted = discriminant_size.quote(index);
let index_quoted = quote(discriminant_size, index);
let discriminant_size = usize::from(discriminant_size);
@@ -989,19 +949,6 @@ impl Parse for Flags {
}
}
enum FlagsSize {
/// Flags can fit in a u8
Size1,
/// Flags can fit in a u16
Size2,
/// Flags can fit in a specified number of u32 fields
Size4Plus(usize),
}
fn ceiling_divide(n: usize, d: usize) -> usize {
(n + d - 1) / d
}
#[proc_macro]
pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
expand_flags(&parse_macro_input!(input as Flags))
@@ -1010,13 +957,7 @@ pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
}
fn expand_flags(flags: &Flags) -> Result<TokenStream> {
let size = if flags.flags.len() <= 8 {
FlagsSize::Size1
} else if flags.flags.len() <= 16 {
FlagsSize::Size2
} else {
FlagsSize::Size4Plus(ceiling_divide(flags.flags.len(), 32))
};
let size = FlagsSize::from_count(flags.flags.len());
let ty;
let eq;

View File

@@ -0,0 +1,11 @@
[package]
name = "wasmtime-component-util"
version = "0.40.0"
authors = ["The Wasmtime Project Developers"]
description = "Utility types and functions to support the component model in Wasmtime"
license = "Apache-2.0 WITH LLVM-exception"
repository = "https://github.com/bytecodealliance/wasmtime"
documentation = "https://docs.rs/wasmtime-component-util/"
categories = ["wasm"]
keywords = ["webassembly", "wasm"]
edition = "2021"

View File

@@ -0,0 +1,75 @@
/// Represents the possible sizes in bytes of the discriminant of a variant type in the component model
#[derive(Debug, Copy, Clone)]
pub enum DiscriminantSize {
/// 8-bit discriminant
Size1,
/// 16-bit discriminant
Size2,
/// 32-bit discriminant
Size4,
}
impl DiscriminantSize {
/// Calculate the size of discriminant needed to represent a variant with the specified number of cases.
pub fn from_count(count: usize) -> Option<Self> {
if count <= 0xFF {
Some(Self::Size1)
} else if count <= 0xFFFF {
Some(Self::Size2)
} else if count <= 0xFFFF_FFFF {
Some(Self::Size4)
} else {
None
}
}
}
impl From<DiscriminantSize> for u32 {
/// Size of the discriminant as a `u32`
fn from(size: DiscriminantSize) -> u32 {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
impl From<DiscriminantSize> for usize {
/// Size of the discriminant as a `usize`
fn from(size: DiscriminantSize) -> usize {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
/// Represents the number of bytes required to store a flags value in the component model
pub enum FlagsSize {
/// Flags can fit in a u8
Size1,
/// Flags can fit in a u16
Size2,
/// Flags can fit in a specified number of u32 fields
Size4Plus(usize),
}
impl FlagsSize {
/// Calculate the size needed to represent a value with the specified number of flags.
pub fn from_count(count: usize) -> FlagsSize {
if count <= 8 {
FlagsSize::Size1
} else if count <= 16 {
FlagsSize::Size2
} else {
FlagsSize::Size4Plus(ceiling_divide(count, 32))
}
}
}
/// Divide `n` by `d`, rounding up in the case of a non-zero remainder.
fn ceiling_divide(n: usize, d: usize) -> usize {
(n + d - 1) / d
}

View File

@@ -20,6 +20,7 @@ wasmtime-cache = { path = "../cache", version = "=0.40.0", optional = true }
wasmtime-fiber = { path = "../fiber", version = "=0.40.0", optional = true }
wasmtime-cranelift = { path = "../cranelift", version = "=0.40.0", optional = true }
wasmtime-component-macro = { path = "../component-macro", version = "=0.40.0", optional = true }
wasmtime-component-util = { path = "../component-util", version = "=0.40.0", optional = true }
target-lexicon = { version = "0.12.0", default-features = false }
wasmparser = "0.87.0"
anyhow = "1.0.19"
@@ -115,4 +116,5 @@ component-model = [
"wasmtime-cranelift?/component-model",
"wasmtime-runtime/component-model",
"dep:wasmtime-component-macro",
"dep:wasmtime-component-util",
]

View File

@@ -1,8 +1,10 @@
use crate::component::instance::{Instance, InstanceData};
use crate::component::types::{SizeAndAlignment, Type};
use crate::component::values::Val;
use crate::store::{StoreOpaque, Stored};
use crate::{AsContext, ValRaw};
use anyhow::{Context, Result};
use std::mem::MaybeUninit;
use crate::{AsContext, AsContextMut, StoreContextMut, ValRaw};
use anyhow::{bail, Context, Result};
use std::mem::{self, MaybeUninit};
use std::ptr::NonNull;
use std::sync::Arc;
use wasmtime_environ::component::{
@@ -72,6 +74,12 @@ pub use self::host::*;
pub use self::options::*;
pub use self::typed::*;
#[repr(C)]
union ParamsAndResults<Params: Copy, Return: Copy> {
params: Params,
ret: Return,
}
/// A WebAssembly component function.
//
// FIXME: write more docs here
@@ -241,4 +249,346 @@ impl Func {
Ok(())
}
/// Get the parameter types for this function.
pub fn params(&self, store: impl AsContext) -> Box<[Type]> {
let data = &store.as_context()[self.0];
data.types[data.ty]
.params
.iter()
.map(|(_, ty)| Type::from(ty, &data.types))
.collect()
}
/// Invokes this function with the `params` given and returns the result.
///
/// The `params` here must match the type signature of this `Func`, or this will return an error. If a trap
/// occurs while executing this function, then an error will also be returned.
// TODO: say more -- most of the docs for `TypedFunc::call` apply here, too
pub fn call(&self, mut store: impl AsContextMut, args: &[Val]) -> Result<Val> {
let store = &mut store.as_context_mut();
let params;
let result;
{
let data = &store[self.0];
let ty = &data.types[data.ty];
if ty.params.len() != args.len() {
bail!(
"expected {} argument(s), got {}",
ty.params.len(),
args.len()
);
}
params = ty
.params
.iter()
.zip(args)
.map(|((_, ty), arg)| {
let ty = Type::from(ty, &data.types);
ty.check(arg).context("type mismatch with parameters")?;
Ok(ty)
})
.collect::<Result<Vec<_>>>()?;
result = Type::from(&ty.result, &data.types);
}
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>();
let result_count = result.flatten_count();
self.call_raw(
store,
args,
|store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_STACK_PARAMS]>| {
if param_count > MAX_STACK_PARAMS {
self.store_args(store, &options, &params, args, dst)
} else {
dst.write([ValRaw::u64(0); MAX_STACK_PARAMS]);
let dst = unsafe {
mem::transmute::<_, &mut [MaybeUninit<ValRaw>; MAX_STACK_PARAMS]>(dst)
};
args.iter()
.try_for_each(|arg| arg.lower(store, &options, &mut dst.iter_mut()))
}
},
|store, options, src: &[ValRaw; MAX_STACK_RESULTS]| {
if result_count > MAX_STACK_RESULTS {
Self::load_result(&Memory::new(store, &options), &result, &mut src.iter())
} else {
Val::lift(&result, store, &options, &mut src.iter())
}
},
)
}
/// Invokes the underlying wasm function, lowering arguments and lifting the
/// result.
///
/// The `lower` function and `lift` function provided here are what actually
/// do the lowering and lifting. The `LowerParams` and `LowerReturn` types
/// are what will be allocated on the stack for this function call. They
/// should be appropriately sized for the lowering/lifting operation
/// happening.
fn call_raw<T, Params: ?Sized, Return, LowerParams, LowerReturn>(
&self,
store: &mut StoreContextMut<'_, T>,
params: &Params,
lower: impl FnOnce(
&mut StoreContextMut<'_, T>,
&Options,
&Params,
&mut MaybeUninit<LowerParams>,
) -> Result<()>,
lift: impl FnOnce(&StoreOpaque, &Options, &LowerReturn) -> Result<Return>,
) -> Result<Return>
where
LowerParams: Copy,
LowerReturn: Copy,
{
let FuncData {
trampoline,
export,
options,
instance,
component_instance,
..
} = store.0[self.0];
let space = &mut MaybeUninit::<ParamsAndResults<LowerParams, LowerReturn>>::uninit();
// Double-check the size/alignemnt of `space`, just in case.
//
// Note that this alone is not enough to guarantee the validity of the
// `unsafe` block below, but it's definitely required. In any case LLVM
// should be able to trivially see through these assertions and remove
// them in release mode.
let val_size = mem::size_of::<ValRaw>();
let val_align = mem::align_of::<ValRaw>();
assert!(mem::size_of_val(space) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.params)) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.ret)) % val_size == 0);
assert!(mem::align_of_val(space) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// Test the "may enter" flag which is a "lock" on this instance.
// This is immediately set to `false` afterwards and note that
// there's no on-cleanup setting this flag back to true. That's an
// intentional design aspect where if anything goes wrong internally
// from this point on the instance is considered "poisoned" and can
// never be entered again. The only time this flag is set to `true`
// again is after post-return logic has completed successfully.
if !(*flags).may_enter() {
bail!("cannot reenter component instance");
}
(*flags).set_may_enter(false);
debug_assert!((*flags).may_leave());
(*flags).set_may_leave(false);
let result = lower(store, &options, params, map_maybe_uninit!(space.params));
(*flags).set_may_leave(true);
result?;
// This is unsafe as we are providing the guarantee that all the
// inputs are valid. The various pointers passed in for the function
// are all valid since they're coming from our store, and the
// `params_and_results` should have the correct layout for the core
// wasm function we're calling. Note that this latter point relies
// on the correctness of this module and `ComponentType`
// implementations, hence `ComponentType` being an `unsafe` trait.
crate::Func::call_unchecked_raw(
store,
export.anyfunc,
trampoline,
space.as_mut_ptr().cast(),
)?;
// Note that `.assume_init_ref()` here is unsafe but we're relying
// on the correctness of the structure of `LowerReturn` and the
// type-checking performed to acquire the `TypedFunc` to make this
// safe. It should be the case that `LowerReturn` is the exact
// representation of the return value when interpreted as
// `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return
// values).
let ret = map_maybe_uninit!(space.ret).assume_init_ref();
// Lift the result into the host while managing post-return state
// here as well.
//
// After a successful lift the return value of the function, which
// is currently required to be 0 or 1 values according to the
// canonical ABI, is saved within the `Store`'s `FuncData`. This'll
// later get used in post-return.
(*flags).set_needs_post_return(true);
let val = lift(store.0, &options, ret)?;
let ret_slice = cast_storage(ret);
let data = &mut store.0[self.0];
assert!(data.post_return_arg.is_none());
match ret_slice.len() {
0 => data.post_return_arg = Some(ValRaw::i32(0)),
1 => data.post_return_arg = Some(ret_slice[0]),
_ => unreachable!(),
}
return Ok(val);
}
unsafe fn cast_storage<T>(storage: &T) -> &[ValRaw] {
assert!(std::mem::size_of_val(storage) % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<ValRaw>());
std::slice::from_raw_parts(
(storage as *const T).cast(),
mem::size_of_val(storage) / mem::size_of::<ValRaw>(),
)
}
}
/// Invokes the `post-return` canonical ABI option, if specified, after a
/// [`Func::call`] has finished.
///
/// For some more information on when to use this function see the
/// documentation for post-return in the [`Func::call`] method.
/// Otherwise though this function is a required method call after a
/// [`Func::call`] completes successfully. After the embedder has
/// finished processing the return value then this function must be invoked.
///
/// # Errors
///
/// This function will return an error in the case of a WebAssembly trap
/// happening during the execution of the `post-return` function, if
/// specified.
///
/// # Panics
///
/// This function will panic if it's not called under the correct
/// conditions. This can only be called after a previous invocation of
/// [`Func::call`] completes successfully, and this function can only
/// be called for the same [`Func`] that was `call`'d.
///
/// If this function is called when [`Func::call`] was not previously
/// called, then it will panic. If a different [`Func`] for the same
/// component instance was invoked then this function will also panic
/// because the `post-return` needs to happen for the other function.
pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> {
let mut store = store.as_context_mut();
let data = &mut store.0[self.0];
let instance = data.instance;
let post_return = data.post_return;
let component_instance = data.component_instance;
let post_return_arg = data.post_return_arg.take();
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// First assert that the instance is in a "needs post return" state.
// This will ensure that the previous action on the instance was a
// function call above. This flag is only set after a component
// function returns so this also can't be called (as expected)
// during a host import for example.
//
// Note, though, that this assert is not sufficient because it just
// means some function on this instance needs its post-return
// called. We need a precise post-return for a particular function
// which is the second assert here (the `.expect`). That will assert
// that this function itself needs to have its post-return called.
//
// The theory at least is that these two asserts ensure component
// model semantics are upheld where the host properly calls
// `post_return` on the right function despite the call being a
// separate step in the API.
assert!(
(*flags).needs_post_return(),
"post_return can only be called after a function has previously been called",
);
let post_return_arg = post_return_arg.expect("calling post_return on wrong function");
// This is a sanity-check assert which shouldn't ever trip.
assert!(!(*flags).may_enter());
// Unset the "needs post return" flag now that post-return is being
// processed. This will cause future invocations of this method to
// panic, even if the function call below traps.
(*flags).set_needs_post_return(false);
// If the function actually had a `post-return` configured in its
// canonical options that's executed here.
//
// Note that if this traps (returns an error) this function
// intentionally leaves the instance in a "poisoned" state where it
// can no longer be entered because `may_enter` is `false`.
if let Some((func, trampoline)) = post_return {
crate::Func::call_unchecked_raw(
&mut store,
func.anyfunc,
trampoline,
&post_return_arg as *const ValRaw as *mut ValRaw,
)?;
}
// And finally if everything completed successfully then the "may
// enter" flag is set to `true` again here which enables further use
// of the component.
(*flags).set_may_enter(true);
}
Ok(())
}
fn store_args<T>(
&self,
store: &mut StoreContextMut<'_, T>,
options: &Options,
params: &[Type],
args: &[Val],
dst: &mut MaybeUninit<[ValRaw; MAX_STACK_PARAMS]>,
) -> Result<()> {
let mut size = 0;
let mut alignment = 1;
for ty in params {
alignment = alignment.max(ty.size_and_alignment().alignment);
ty.next_field(&mut size);
}
let mut memory = MemoryMut::new(store.as_context_mut(), options);
let ptr = memory.realloc(0, 0, alignment, size)?;
let mut offset = ptr;
for (ty, arg) in params.iter().zip(args) {
arg.store(&mut memory, ty.next_field(&mut offset))?;
}
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
Ok(())
}
fn load_result<'a>(
mem: &Memory,
ty: &Type,
src: &mut std::slice::Iter<'_, ValRaw>,
) -> Result<Val> {
let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
// FIXME: needs to read an i64 for memory64
let ptr = usize::try_from(src.next().unwrap().get_u32())?;
if ptr % usize::try_from(alignment)? != 0 {
bail!("return pointer not aligned");
}
let bytes = mem
.as_slice()
.get(ptr..)
.and_then(|b| b.get(..size))
.ok_or_else(|| anyhow::anyhow!("pointer out of bounds of memory"))?;
Val::load(ty, mem, bytes)
}
}

View File

@@ -213,7 +213,7 @@ impl<'a, T> MemoryMut<'a, T> {
/// Like `MemoryMut` but for a read-only version that's used during lifting.
pub struct Memory<'a> {
store: &'a StoreOpaque,
pub(crate) store: &'a StoreOpaque,
options: &'a Options,
}

View File

@@ -158,14 +158,14 @@ where
// count)
if Params::flatten_count() <= MAX_STACK_PARAMS {
if Return::flatten_count() <= MAX_STACK_RESULTS {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_stack_args,
Self::lift_stack_result,
)
} else {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_stack_args,
@@ -174,14 +174,14 @@ where
}
} else {
if Return::flatten_count() <= MAX_STACK_RESULTS {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_heap_args,
Self::lift_stack_result,
)
} else {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_heap_args,
@@ -280,228 +280,10 @@ where
Return::load(&memory, bytes)
}
/// Invokes the underlying wasm function, lowering arguments and lifting the
/// result.
///
/// The `lower` function and `lift` function provided here are what actually
/// do the lowering and lifting. The `LowerParams` and `LowerReturn` types
/// are what will be allocated on the stack for this function call. They
/// should be appropriately sized for the lowering/lifting operation
/// happening.
fn call_raw<T, LowerParams, LowerReturn>(
&self,
store: &mut StoreContextMut<'_, T>,
params: &Params,
lower: impl FnOnce(
&mut StoreContextMut<'_, T>,
&Options,
&Params,
&mut MaybeUninit<LowerParams>,
) -> Result<()>,
lift: impl FnOnce(&StoreOpaque, &Options, &LowerReturn) -> Result<Return>,
) -> Result<Return>
where
LowerParams: Copy,
LowerReturn: Copy,
{
let super::FuncData {
trampoline,
export,
options,
instance,
component_instance,
..
} = store.0[self.func.0];
let space = &mut MaybeUninit::<ParamsAndResults<LowerParams, LowerReturn>>::uninit();
// Double-check the size/alignemnt of `space`, just in case.
//
// Note that this alone is not enough to guarantee the validity of the
// `unsafe` block below, but it's definitely required. In any case LLVM
// should be able to trivially see through these assertions and remove
// them in release mode.
let val_size = mem::size_of::<ValRaw>();
let val_align = mem::align_of::<ValRaw>();
assert!(mem::size_of_val(space) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.params)) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.ret)) % val_size == 0);
assert!(mem::align_of_val(space) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// Test the "may enter" flag which is a "lock" on this instance.
// This is immediately set to `false` afterwards and note that
// there's no on-cleanup setting this flag back to true. That's an
// intentional design aspect where if anything goes wrong internally
// from this point on the instance is considered "poisoned" and can
// never be entered again. The only time this flag is set to `true`
// again is after post-return logic has completed successfully.
if !(*flags).may_enter() {
bail!("cannot reenter component instance");
}
(*flags).set_may_enter(false);
debug_assert!((*flags).may_leave());
(*flags).set_may_leave(false);
let result = lower(store, &options, params, map_maybe_uninit!(space.params));
(*flags).set_may_leave(true);
result?;
// This is unsafe as we are providing the guarantee that all the
// inputs are valid. The various pointers passed in for the function
// are all valid since they're coming from our store, and the
// `params_and_results` should have the correct layout for the core
// wasm function we're calling. Note that this latter point relies
// on the correctness of this module and `ComponentType`
// implementations, hence `ComponentType` being an `unsafe` trait.
crate::Func::call_unchecked_raw(
store,
export.anyfunc,
trampoline,
space.as_mut_ptr().cast(),
)?;
// Note that `.assume_init_ref()` here is unsafe but we're relying
// on the correctness of the structure of `LowerReturn` and the
// type-checking performed to acquire the `TypedFunc` to make this
// safe. It should be the case that `LowerReturn` is the exact
// representation of the return value when interpreted as
// `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return
// values).
let ret = map_maybe_uninit!(space.ret).assume_init_ref();
// Lift the result into the host while managing post-return state
// here as well.
//
// After a successful lift the return value of the function, which
// is currently required to be 0 or 1 values according to the
// canonical ABI, is saved within the `Store`'s `FuncData`. This'll
// later get used in post-return.
(*flags).set_needs_post_return(true);
let val = lift(store.0, &options, ret)?;
let ret_slice = cast_storage(ret);
let data = &mut store.0[self.func.0];
assert!(data.post_return_arg.is_none());
match ret_slice.len() {
0 => data.post_return_arg = Some(ValRaw::i32(0)),
1 => data.post_return_arg = Some(ret_slice[0]),
_ => unreachable!(),
}
return Ok(val);
}
unsafe fn cast_storage<T>(storage: &T) -> &[ValRaw] {
assert!(std::mem::size_of_val(storage) % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<ValRaw>());
std::slice::from_raw_parts(
(storage as *const T).cast(),
mem::size_of_val(storage) / mem::size_of::<ValRaw>(),
)
}
/// See [`Func::post_return`]
pub fn post_return(&self, store: impl AsContextMut) -> Result<()> {
self.func.post_return(store)
}
/// Invokes the `post-return` canonical ABI option, if specified, after a
/// [`TypedFunc::call`] has finished.
///
/// For some more information on when to use this function see the
/// documentation for post-return in the [`TypedFunc::call`] method.
/// Otherwise though this function is a required method call after a
/// [`TypedFunc::call`] completes successfully. After the embedder has
/// finished processing the return value then this function must be invoked.
///
/// # Errors
///
/// This function will return an error in the case of a WebAssembly trap
/// happening during the execution of the `post-return` function, if
/// specified.
///
/// # Panics
///
/// This function will panic if it's not called under the correct
/// conditions. This can only be called after a previous invocation of
/// [`TypedFunc::call`] completes successfully, and this function can only
/// be called for the same [`TypedFunc`] that was `call`'d.
///
/// If this function is called when [`TypedFunc::call`] was not previously
/// called, then it will panic. If a different [`TypedFunc`] for the same
/// component instance was invoked then this function will also panic
/// because the `post-return` needs to happen for the other function.
pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> {
let mut store = store.as_context_mut();
let data = &mut store.0[self.func.0];
let instance = data.instance;
let post_return = data.post_return;
let component_instance = data.component_instance;
let post_return_arg = data.post_return_arg.take();
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// First assert that the instance is in a "needs post return" state.
// This will ensure that the previous action on the instance was a
// function call above. This flag is only set after a component
// function returns so this also can't be called (as expected)
// during a host import for example.
//
// Note, though, that this assert is not sufficient because it just
// means some function on this instance needs its post-return
// called. We need a precise post-return for a particular function
// which is the second assert here (the `.expect`). That will assert
// that this function itself needs to have its post-return called.
//
// The theory at least is that these two asserts ensure component
// model semantics are upheld where the host properly calls
// `post_return` on the right function despite the call being a
// separate step in the API.
assert!(
(*flags).needs_post_return(),
"post_return can only be called after a function has previously been called",
);
let post_return_arg = post_return_arg.expect("calling post_return on wrong function");
// This is a sanity-check assert which shouldn't ever trip.
assert!(!(*flags).may_enter());
// Unset the "needs post return" flag now that post-return is being
// processed. This will cause future invocations of this method to
// panic, even if the function call below traps.
(*flags).set_needs_post_return(false);
// If the function actually had a `post-return` configured in its
// canonical options that's executed here.
//
// Note that if this traps (returns an error) this function
// intentionally leaves the instance in a "poisoned" state where it
// can no longer be entered because `may_enter` is `false`.
if let Some((func, trampoline)) = post_return {
crate::Func::call_unchecked_raw(
&mut store,
func.anyfunc,
trampoline,
&post_return_arg as *const ValRaw as *mut ValRaw,
)?;
}
// And finally if everything completed successfully then the "may
// enter" flag is set to `true` again here which enables further use
// of the component.
(*flags).set_may_enter(true);
}
Ok(())
}
}
#[repr(C)]
union ParamsAndResults<Params: Copy, Return: Copy> {
params: Params,
ret: Return,
}
/// A trait representing a static list of parameters that can be passed to a
@@ -567,11 +349,9 @@ pub unsafe trait ComponentParams: ComponentType {
// though, that correctness bugs in this trait implementation are highly likely
// to lead to security bugs, which again leads to the `unsafe` in the trait.
//
// Also note that this trait specifically is not sealed because we'll
// eventually have a proc macro that generates implementations of this trait
// for external types in a `#[derive]`-like fashion.
//
// FIXME: need to write a #[derive(ComponentType)]
// Also note that this trait specifically is not sealed because we have a proc
// macro that generates implementations of this trait for external types in a
// `#[derive]`-like fashion.
pub unsafe trait ComponentType {
/// Representation of the "lowered" form of this component value.
///
@@ -690,7 +470,7 @@ pub unsafe trait Lift: Sized + ComponentType {
// another type, used for wrappers in Rust like `&T`, `Box<T>`, etc. Note that
// these wrappers only implement lowering because lifting native Rust types
// cannot be done.
macro_rules! forward_impls {
macro_rules! forward_type_impls {
($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($(
unsafe impl <$($generics)*> ComponentType for $a {
type Lower = <$b as ComponentType>::Lower;
@@ -703,7 +483,20 @@ macro_rules! forward_impls {
<$b as ComponentType>::typecheck(ty, types)
}
}
)*)
}
forward_type_impls! {
(T: ComponentType + ?Sized) &'_ T => T,
(T: ComponentType + ?Sized) Box<T> => T,
(T: ComponentType + ?Sized) std::rc::Rc<T> => T,
(T: ComponentType + ?Sized) std::sync::Arc<T> => T,
() String => str,
(T: ComponentType) Vec<T> => [T],
}
macro_rules! forward_lowers {
($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($(
unsafe impl <$($generics)*> Lower for $a {
fn lower<U>(
&self,
@@ -721,7 +514,7 @@ macro_rules! forward_impls {
)*)
}
forward_impls! {
forward_lowers! {
(T: Lower + ?Sized) &'_ T => T,
(T: Lower + ?Sized) Box<T> => T,
(T: Lower + ?Sized) std::rc::Rc<T> => T,
@@ -730,6 +523,50 @@ forward_impls! {
(T: Lower) Vec<T> => [T],
}
macro_rules! forward_string_lifts {
($($a:ty,)*) => ($(
unsafe impl Lift for $a {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(<WasmStr as Lift>::lift(store, options, src)?.to_str_from_store(store)?.into())
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
Ok(<WasmStr as Lift>::load(memory, bytes)?.to_str_from_store(&memory.store)?.into())
}
}
)*)
}
forward_string_lifts! {
Box<str>,
std::rc::Rc<str>,
std::sync::Arc<str>,
String,
}
macro_rules! forward_list_lifts {
($($a:ty,)*) => ($(
unsafe impl <T: Lift> Lift for $a {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
let list = <WasmList::<T> as Lift>::lift(store, options, src)?;
(0..list.len).map(|index| list.get_from_store(store, index).unwrap()).collect()
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
let list = <WasmList::<T> as Lift>::load(memory, bytes)?;
(0..list.len).map(|index| list.get_from_store(&memory.store, index).unwrap()).collect()
}
}
)*)
}
forward_list_lifts! {
Box<[T]>,
std::rc::Rc<[T]>,
std::sync::Arc<[T]>,
Vec<T>,
}
// Macro to help generate `ComponentType` implementations for primitive types
// such as integers, char, bool, etc.
macro_rules! integers {
@@ -1092,10 +929,10 @@ impl WasmStr {
// method that returns `[u16]` after validating to avoid the utf16-to-utf8
// transcode.
pub fn to_str<'a, T: 'a>(&self, store: impl Into<StoreContext<'a, T>>) -> Result<Cow<'a, str>> {
self._to_str(store.into().0)
self.to_str_from_store(store.into().0)
}
fn _to_str<'a>(&self, store: &'a StoreOpaque) -> Result<Cow<'a, str>> {
fn to_str_from_store<'a>(&self, store: &'a StoreOpaque) -> Result<Cow<'a, str>> {
match self.options.string_encoding() {
StringEncoding::Utf8 => self.decode_utf8(store),
StringEncoding::Utf16 => self.decode_utf16(store),
@@ -1289,10 +1126,10 @@ impl<T: Lift> WasmList<T> {
// should we even expose a random access iteration API? In theory all
// consumers should be validating through the iterator.
pub fn get(&self, store: impl AsContext, index: usize) -> Option<Result<T>> {
self._get(store.as_context().0, index)
self.get_from_store(store.as_context().0, index)
}
fn _get(&self, store: &StoreOpaque, index: usize) -> Option<Result<T>> {
fn get_from_store(&self, store: &StoreOpaque, index: usize) -> Option<Result<T>> {
if index >= self.len {
return None;
}
@@ -1316,7 +1153,7 @@ impl<T: Lift> WasmList<T> {
store: impl Into<StoreContext<'a, U>>,
) -> impl ExactSizeIterator<Item = Result<T>> + 'a {
let store = store.into().0;
(0..self.len).map(move |i| self._get(store, i).unwrap())
(0..self.len).map(move |i| self.get_from_store(store, i).unwrap())
}
}

View File

@@ -9,6 +9,8 @@ mod instance;
mod linker;
mod matching;
mod store;
pub mod types;
mod values;
pub use self::component::Component;
pub use self::func::{
ComponentParams, ComponentType, Func, IntoComponentFunc, Lift, Lower, TypedFunc, WasmList,
@@ -16,6 +18,8 @@ pub use self::func::{
};
pub use self::instance::{ExportInstance, Exports, Instance, InstancePre};
pub use self::linker::{Linker, LinkerInstance};
pub use self::types::Type;
pub use self::values::Val;
pub use wasmtime_component_macro::{flags, ComponentType, Lift, Lower};
// These items are expected to be used by an eventual

View File

@@ -0,0 +1,664 @@
//! This module defines the `Type` type, representing the dynamic form of a component interface type.
use crate::component::func;
use crate::component::values::{self, Val};
use anyhow::{anyhow, Result};
use std::fmt;
use std::mem;
use std::ops::Deref;
use std::sync::Arc;
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
use wasmtime_environ::component::{
ComponentTypes, InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex,
TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex,
};
#[derive(Clone)]
struct Handle<T> {
index: T,
types: Arc<ComponentTypes>,
}
impl<T: fmt::Debug> fmt::Debug for Handle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Handle")
.field("index", &self.index)
.finish()
}
}
impl<T: PartialEq> PartialEq for Handle<T> {
fn eq(&self, other: &Self) -> bool {
// FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be
// equal unless they refer to the same declaration in the same component. It's a good shortcut for the
// common case, but we should also do a recursive structural equality test if the shortcut test fails.
self.index == other.index && Arc::ptr_eq(&self.types, &other.types)
}
}
impl<T: Eq> Eq for Handle<T> {}
/// A `list` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct List(Handle<TypeInterfaceIndex>);
impl List {
/// Instantiate this type with the specified `values`.
pub fn new_val(&self, values: Box<[Val]>) -> Result<Val> {
Ok(Val::List(values::List::new(self, values)?))
}
/// Retreive the element type of this `list`.
pub fn ty(&self) -> Type {
Type::from(&self.0.types[self.0.index], &self.0.types)
}
}
/// A field declaration belonging to a `record`
pub struct Field<'a> {
/// The name of the field
pub name: &'a str,
/// The type of the field
pub ty: Type,
}
/// A `record` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Record(Handle<TypeRecordIndex>);
impl Record {
/// Instantiate this type with the specified `values`.
pub fn new_val<'a>(&self, values: impl IntoIterator<Item = (&'a str, Val)>) -> Result<Val> {
Ok(Val::Record(values::Record::new(self, values)?))
}
/// Retrieve the fields of this `record` in declaration order.
pub fn fields(&self) -> impl ExactSizeIterator<Item = Field> {
self.0.types[self.0.index].fields.iter().map(|field| Field {
name: &field.name,
ty: Type::from(&field.ty, &self.0.types),
})
}
}
/// A `tuple` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Tuple(Handle<TypeTupleIndex>);
impl Tuple {
/// Instantiate this type ith the specified `values`.
pub fn new_val(&self, values: Box<[Val]>) -> Result<Val> {
Ok(Val::Tuple(values::Tuple::new(self, values)?))
}
/// Retrieve the types of the fields of this `tuple` in declaration order.
pub fn types(&self) -> impl ExactSizeIterator<Item = Type> + '_ {
self.0.types[self.0.index]
.types
.iter()
.map(|ty| Type::from(ty, &self.0.types))
}
}
/// A case declaration belonging to a `variant`
pub struct Case<'a> {
/// The name of the case
pub name: &'a str,
/// The type of the case
pub ty: Type,
}
/// A `variant` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Variant(Handle<TypeVariantIndex>);
impl Variant {
/// Instantiate this type with the specified case `name` and `value`.
pub fn new_val(&self, name: &str, value: Val) -> Result<Val> {
Ok(Val::Variant(values::Variant::new(self, name, value)?))
}
/// Retrieve the cases of this `variant` in declaration order.
pub fn cases(&self) -> impl ExactSizeIterator<Item = Case> {
self.0.types[self.0.index].cases.iter().map(|case| Case {
name: &case.name,
ty: Type::from(&case.ty, &self.0.types),
})
}
}
/// An `enum` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Enum(Handle<TypeEnumIndex>);
impl Enum {
/// Instantiate this type with the specified case `name`.
pub fn new_val(&self, name: &str) -> Result<Val> {
Ok(Val::Enum(values::Enum::new(self, name)?))
}
/// Retrieve the names of the cases of this `enum` in declaration order.
pub fn names(&self) -> impl ExactSizeIterator<Item = &str> {
self.0.types[self.0.index]
.names
.iter()
.map(|name| name.deref())
}
}
/// A `union` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Union(Handle<TypeUnionIndex>);
impl Union {
/// Instantiate this type with the specified `discriminant` and `value`.
pub fn new_val(&self, discriminant: u32, value: Val) -> Result<Val> {
Ok(Val::Union(values::Union::new(self, discriminant, value)?))
}
/// Retrieve the types of the cases of this `union` in declaration order.
pub fn types(&self) -> impl ExactSizeIterator<Item = Type> + '_ {
self.0.types[self.0.index]
.types
.iter()
.map(|ty| Type::from(ty, &self.0.types))
}
}
/// An `option` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Option(Handle<TypeInterfaceIndex>);
impl Option {
/// Instantiate this type with the specified `value`.
pub fn new_val(&self, value: std::option::Option<Val>) -> Result<Val> {
Ok(Val::Option(values::Option::new(self, value)?))
}
/// Retrieve the type parameter for this `option`.
pub fn ty(&self) -> Type {
Type::from(&self.0.types[self.0.index], &self.0.types)
}
}
/// An `expected` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Expected(Handle<TypeExpectedIndex>);
impl Expected {
/// Instantiate this type with the specified `value`.
pub fn new_val(&self, value: Result<Val, Val>) -> Result<Val> {
Ok(Val::Expected(values::Expected::new(self, value)?))
}
/// Retrieve the `ok` type parameter for this `option`.
pub fn ok(&self) -> Type {
Type::from(&self.0.types[self.0.index].ok, &self.0.types)
}
/// Retrieve the `err` type parameter for this `option`.
pub fn err(&self) -> Type {
Type::from(&self.0.types[self.0.index].err, &self.0.types)
}
}
/// A `flags` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Flags(Handle<TypeFlagsIndex>);
impl Flags {
/// Instantiate this type with the specified flag `names`.
pub fn new_val(&self, names: &[&str]) -> Result<Val> {
Ok(Val::Flags(values::Flags::new(self, names)?))
}
/// Retrieve the names of the flags of this `flags` type in declaration order.
pub fn names(&self) -> impl ExactSizeIterator<Item = &str> {
self.0.types[self.0.index]
.names
.iter()
.map(|name| name.deref())
}
}
/// Represents the size and alignment requirements of the heap-serialized form of a type
pub(crate) struct SizeAndAlignment {
pub(crate) size: usize,
pub(crate) alignment: u32,
}
/// Represents a component model interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum Type {
/// Unit
Unit,
/// Boolean
Bool,
/// Signed 8-bit integer
S8,
/// Unsigned 8-bit integer
U8,
/// Signed 16-bit integer
S16,
/// Unsigned 16-bit integer
U16,
/// Signed 32-bit integer
S32,
/// Unsigned 32-bit integer
U32,
/// Signed 64-bit integer
S64,
/// Unsigned 64-bit integer
U64,
/// 64-bit floating point value
Float32,
/// 64-bit floating point value
Float64,
/// 32-bit character
Char,
/// Character string
String,
/// List of values
List(List),
/// Record
Record(Record),
/// Tuple
Tuple(Tuple),
/// Variant
Variant(Variant),
/// Enum
Enum(Enum),
/// Union
Union(Union),
/// Option
Option(Option),
/// Expected
Expected(Expected),
/// Bit flags
Flags(Flags),
}
impl Type {
/// Retrieve the inner [`List`] of a [`Type::List`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::List`].
pub fn unwrap_list(&self) -> &List {
if let Type::List(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a list", self.desc())
}
}
/// Retrieve the inner [`Record`] of a [`Type::Record`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Record`].
pub fn unwrap_record(&self) -> &Record {
if let Type::Record(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a record", self.desc())
}
}
/// Retrieve the inner [`Tuple`] of a [`Type::Tuple`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Tuple`].
pub fn unwrap_tuple(&self) -> &Tuple {
if let Type::Tuple(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a tuple", self.desc())
}
}
/// Retrieve the inner [`Variant`] of a [`Type::Variant`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Variant`].
pub fn unwrap_variant(&self) -> &Variant {
if let Type::Variant(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a variant", self.desc())
}
}
/// Retrieve the inner [`Enum`] of a [`Type::Enum`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Enum`].
pub fn unwrap_enum(&self) -> &Enum {
if let Type::Enum(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a enum", self.desc())
}
}
/// Retrieve the inner [`Union`] of a [`Type::Union`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Union`].
pub fn unwrap_union(&self) -> &Union {
if let Type::Union(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a union", self.desc())
}
}
/// Retrieve the inner [`Option`] of a [`Type::Option`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Option`].
pub fn unwrap_option(&self) -> &Option {
if let Type::Option(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a option", self.desc())
}
}
/// Retrieve the inner [`Expected`] of a [`Type::Expected`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Expected`].
pub fn unwrap_expected(&self) -> &Expected {
if let Type::Expected(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a expected", self.desc())
}
}
/// Retrieve the inner [`Flags`] of a [`Type::Flags`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Flags`].
pub fn unwrap_flags(&self) -> &Flags {
if let Type::Flags(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a flags", self.desc())
}
}
pub(crate) fn check(&self, value: &Val) -> Result<()> {
let other = &value.ty();
if self == other {
Ok(())
} else if mem::discriminant(self) != mem::discriminant(other) {
Err(anyhow!(
"type mismatch: expected {}, got {}",
self.desc(),
other.desc()
))
} else {
Err(anyhow!(
"type mismatch for {}, possibly due to mixing distinct composite types",
self.desc()
))
}
}
/// Convert the specified `InterfaceType` to a `Type`.
pub(crate) fn from(ty: &InterfaceType, types: &Arc<ComponentTypes>) -> Self {
match ty {
InterfaceType::Unit => Type::Unit,
InterfaceType::Bool => Type::Bool,
InterfaceType::S8 => Type::S8,
InterfaceType::U8 => Type::U8,
InterfaceType::S16 => Type::S16,
InterfaceType::U16 => Type::U16,
InterfaceType::S32 => Type::S32,
InterfaceType::U32 => Type::U32,
InterfaceType::S64 => Type::S64,
InterfaceType::U64 => Type::U64,
InterfaceType::Float32 => Type::Float32,
InterfaceType::Float64 => Type::Float64,
InterfaceType::Char => Type::Char,
InterfaceType::String => Type::String,
InterfaceType::List(index) => Type::List(List(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Record(index) => Type::Record(Record(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Tuple(index) => Type::Tuple(Tuple(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Variant(index) => Type::Variant(Variant(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Enum(index) => Type::Enum(Enum(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Union(index) => Type::Union(Union(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Option(index) => Type::Option(Option(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Expected(index) => Type::Expected(Expected(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Flags(index) => Type::Flags(Flags(Handle {
index: *index,
types: types.clone(),
})),
}
}
/// Return the number of stack slots needed to store values of this type in lowered form.
pub(crate) fn flatten_count(&self) -> usize {
match self {
Type::Unit => 0,
Type::Bool
| Type::S8
| Type::U8
| Type::S16
| Type::U16
| Type::S32
| Type::U32
| Type::S64
| Type::U64
| Type::Float32
| Type::Float64
| Type::Char
| Type::Enum(_) => 1,
Type::String | Type::List(_) => 2,
Type::Record(handle) => handle.fields().map(|field| field.ty.flatten_count()).sum(),
Type::Tuple(handle) => handle.types().map(|ty| ty.flatten_count()).sum(),
Type::Variant(handle) => {
1 + handle
.cases()
.map(|case| case.ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Union(handle) => {
1 + handle
.types()
.map(|ty| ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Option(handle) => 1 + handle.ty().flatten_count(),
Type::Expected(handle) => {
1 + handle
.ok()
.flatten_count()
.max(handle.err().flatten_count())
}
Type::Flags(handle) => values::u32_count_for_flag_count(handle.names().len()),
}
}
fn desc(&self) -> &'static str {
match self {
Type::Unit => "unit",
Type::Bool => "bool",
Type::S8 => "s8",
Type::U8 => "u8",
Type::S16 => "s16",
Type::U16 => "u16",
Type::S32 => "s32",
Type::U32 => "u32",
Type::S64 => "s64",
Type::U64 => "u64",
Type::Float32 => "float32",
Type::Float64 => "float64",
Type::Char => "char",
Type::String => "string",
Type::List(_) => "list",
Type::Record(_) => "record",
Type::Tuple(_) => "tuple",
Type::Variant(_) => "variant",
Type::Enum(_) => "enum",
Type::Union(_) => "union",
Type::Option(_) => "option",
Type::Expected(_) => "expected",
Type::Flags(_) => "flags",
}
}
/// Calculate the size and alignment requirements for the specified type.
pub(crate) fn size_and_alignment(&self) -> SizeAndAlignment {
match self {
Type::Unit => SizeAndAlignment {
size: 0,
alignment: 1,
},
Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment {
size: 1,
alignment: 1,
},
Type::S16 | Type::U16 => SizeAndAlignment {
size: 2,
alignment: 2,
},
Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment {
size: 4,
alignment: 4,
},
Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment {
size: 8,
alignment: 8,
},
Type::String | Type::List(_) => SizeAndAlignment {
size: 8,
alignment: 4,
},
Type::Record(handle) => {
record_size_and_alignment(handle.fields().map(|field| field.ty))
}
Type::Tuple(handle) => record_size_and_alignment(handle.types()),
Type::Variant(handle) => variant_size_and_alignment(handle.cases().map(|case| case.ty)),
Type::Enum(handle) => variant_size_and_alignment(handle.names().map(|_| Type::Unit)),
Type::Union(handle) => variant_size_and_alignment(handle.types()),
Type::Option(handle) => {
variant_size_and_alignment([Type::Unit, handle.ty()].into_iter())
}
Type::Expected(handle) => {
variant_size_and_alignment([handle.ok(), handle.err()].into_iter())
}
Type::Flags(handle) => match FlagsSize::from_count(handle.names().len()) {
FlagsSize::Size1 => SizeAndAlignment {
size: 1,
alignment: 1,
},
FlagsSize::Size2 => SizeAndAlignment {
size: 2,
alignment: 2,
},
FlagsSize::Size4Plus(n) => SizeAndAlignment {
size: n * 4,
alignment: 4,
},
},
}
}
/// Calculate the aligned offset of a field of this type, updating `offset` to point to just after that field.
pub(crate) fn next_field(&self, offset: &mut usize) -> usize {
let SizeAndAlignment { size, alignment } = self.size_and_alignment();
*offset = func::align_to(*offset, alignment);
let result = *offset;
*offset += size;
result
}
}
fn record_size_and_alignment(types: impl Iterator<Item = Type>) -> SizeAndAlignment {
let mut offset = 0;
let mut align = 1;
for ty in types {
let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
offset = func::align_to(offset, alignment) + size;
align = align.max(alignment);
}
SizeAndAlignment {
size: func::align_to(offset, align),
alignment: align,
}
}
fn variant_size_and_alignment(types: impl ExactSizeIterator<Item = Type>) -> SizeAndAlignment {
let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
let mut alignment = u32::from(discriminant_size);
let mut size = 0;
for ty in types {
let size_and_alignment = ty.size_and_alignment();
alignment = alignment.max(size_and_alignment.alignment);
size = size.max(size_and_alignment.size);
}
SizeAndAlignment {
size: func::align_to(usize::from(discriminant_size), alignment) + size,
alignment,
}
}

View File

@@ -0,0 +1,908 @@
use crate::component::func::{self, Lift, Lower, Memory, MemoryMut, Options};
use crate::component::types::{self, SizeAndAlignment, Type};
use crate::store::StoreOpaque;
use crate::{AsContextMut, StoreContextMut, ValRaw};
use anyhow::{anyhow, bail, Context, Error, Result};
use std::collections::HashMap;
use std::iter;
use std::mem::MaybeUninit;
use std::ops::Deref;
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct List {
ty: types::List,
values: Box<[Val]>,
}
impl List {
/// Instantiate the specified type with the specified `values`.
pub fn new(ty: &types::List, values: Box<[Val]>) -> Result<Self> {
let element_type = ty.ty();
for (index, value) in values.iter().enumerate() {
element_type
.check(value)
.with_context(|| format!("type mismatch for element {index} of list"))?;
}
Ok(Self {
ty: ty.clone(),
values,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Record {
ty: types::Record,
values: Box<[Val]>,
}
impl Record {
/// Instantiate the specified type with the specified `values`.
pub fn new<'a>(
ty: &types::Record,
values: impl IntoIterator<Item = (&'a str, Val)>,
) -> Result<Self> {
let mut fields = ty.fields();
let expected_len = fields.len();
let mut iter = values.into_iter();
let mut values = Vec::with_capacity(expected_len);
loop {
match (fields.next(), iter.next()) {
(Some(field), Some((name, value))) => {
if name == field.name {
field
.ty
.check(&value)
.with_context(|| format!("type mismatch for field {name} of record"))?;
values.push(value);
} else {
bail!("field name mismatch: expected {}; got {name}", field.name)
}
}
(None, Some((_, value))) => values.push(value),
_ => break,
}
}
if values.len() != expected_len {
bail!("expected {} value(s); got {}", expected_len, values.len());
}
Ok(Self {
ty: ty.clone(),
values: values.into(),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Tuple {
ty: types::Tuple,
values: Box<[Val]>,
}
impl Tuple {
/// Instantiate the specified type ith the specified `values`.
pub fn new(ty: &types::Tuple, values: Box<[Val]>) -> Result<Self> {
if values.len() != ty.types().len() {
bail!(
"expected {} value(s); got {}",
ty.types().len(),
values.len()
);
}
for (index, (value, ty)) in values.iter().zip(ty.types()).enumerate() {
ty.check(value)
.with_context(|| format!("type mismatch for field {index} of tuple"))?;
}
Ok(Self {
ty: ty.clone(),
values,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Variant {
ty: types::Variant,
discriminant: u32,
value: Box<Val>,
}
impl Variant {
/// Instantiate the specified type with the specified case `name` and `value`.
pub fn new(ty: &types::Variant, name: &str, value: Val) -> Result<Self> {
let (discriminant, case_type) = ty
.cases()
.enumerate()
.find_map(|(index, case)| {
if case.name == name {
Some((index, case.ty))
} else {
None
}
})
.ok_or_else(|| anyhow!("unknown variant case: {name}"))?;
case_type
.check(&value)
.with_context(|| format!("type mismatch for case {name} of variant"))?;
Ok(Self {
ty: ty.clone(),
discriminant: u32::try_from(discriminant)?,
value: Box::new(value),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Enum {
ty: types::Enum,
discriminant: u32,
}
impl Enum {
/// Instantiate the specified type with the specified case `name`.
pub fn new(ty: &types::Enum, name: &str) -> Result<Self> {
let discriminant = u32::try_from(
ty.names()
.position(|n| n == name)
.ok_or_else(|| anyhow!("unknown enum case: {name}"))?,
)?;
Ok(Self {
ty: ty.clone(),
discriminant,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Union {
ty: types::Union,
discriminant: u32,
value: Box<Val>,
}
impl Union {
/// Instantiate the specified type with the specified `discriminant` and `value`.
pub fn new(ty: &types::Union, discriminant: u32, value: Val) -> Result<Self> {
if let Some(case_ty) = ty.types().nth(usize::try_from(discriminant)?) {
case_ty
.check(&value)
.with_context(|| format!("type mismatch for case {discriminant} of union"))?;
Ok(Self {
ty: ty.clone(),
discriminant,
value: Box::new(value),
})
} else {
Err(anyhow!(
"discriminant {discriminant} out of range: [0,{})",
ty.types().len()
))
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Option {
ty: types::Option,
discriminant: u32,
value: Box<Val>,
}
impl Option {
/// Instantiate the specified type with the specified `value`.
pub fn new(ty: &types::Option, value: std::option::Option<Val>) -> Result<Self> {
let value = value
.map(|value| {
ty.ty().check(&value).context("type mismatch for option")?;
Ok::<_, Error>(value)
})
.transpose()?;
Ok(Self {
ty: ty.clone(),
discriminant: if value.is_none() { 0 } else { 1 },
value: Box::new(value.unwrap_or(Val::Unit)),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Expected {
ty: types::Expected,
discriminant: u32,
value: Box<Val>,
}
impl Expected {
/// Instantiate the specified type with the specified `value`.
pub fn new(ty: &types::Expected, value: Result<Val, Val>) -> Result<Self> {
Ok(Self {
ty: ty.clone(),
discriminant: if value.is_ok() { 0 } else { 1 },
value: Box::new(match value {
Ok(value) => {
ty.ok()
.check(&value)
.context("type mismatch for ok case of expected")?;
value
}
Err(value) => {
ty.err()
.check(&value)
.context("type mismatch for err case of expected")?;
value
}
}),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Flags {
ty: types::Flags,
count: u32,
value: Box<[u32]>,
}
impl Flags {
/// Instantiate the specified type with the specified flag `names`.
pub fn new(ty: &types::Flags, names: &[&str]) -> Result<Self> {
let map = ty
.names()
.enumerate()
.map(|(index, name)| (name, index))
.collect::<HashMap<_, _>>();
let mut values = vec![0_u32; u32_count_for_flag_count(ty.names().len())];
for name in names {
let index = map
.get(name)
.ok_or_else(|| anyhow!("unknown flag: {name}"))?;
values[index / 32] |= 1 << (index % 32);
}
Ok(Self {
ty: ty.clone(),
count: u32::try_from(map.len())?,
value: values.into(),
})
}
}
/// Represents possible runtime values which a component function can either consume or produce
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Val {
/// Unit
Unit,
/// Boolean
Bool(bool),
/// Signed 8-bit integer
S8(i8),
/// Unsigned 8-bit integer
U8(u8),
/// Signed 16-bit integer
S16(i16),
/// Unsigned 16-bit integer
U16(u16),
/// Signed 32-bit integer
S32(i32),
/// Unsigned 32-bit integer
U32(u32),
/// Signed 64-bit integer
S64(i64),
/// Unsigned 64-bit integer
U64(u64),
/// 32-bit floating point value
Float32(u32),
/// 64-bit floating point value
Float64(u64),
/// 32-bit character
Char(char),
/// Character string
String(Box<str>),
/// List of values
List(List),
/// Record
Record(Record),
/// Tuple
Tuple(Tuple),
/// Variant
Variant(Variant),
/// Enum
Enum(Enum),
/// Union
Union(Union),
/// Option
Option(Option),
/// Expected
Expected(Expected),
/// Bit flags
Flags(Flags),
}
impl Val {
/// Retrieve the [`Type`] of this value.
pub fn ty(&self) -> Type {
match self {
Val::Unit => Type::Unit,
Val::Bool(_) => Type::Bool,
Val::S8(_) => Type::S8,
Val::U8(_) => Type::U8,
Val::S16(_) => Type::S16,
Val::U16(_) => Type::U16,
Val::S32(_) => Type::S32,
Val::U32(_) => Type::U32,
Val::S64(_) => Type::S64,
Val::U64(_) => Type::U64,
Val::Float32(_) => Type::Float32,
Val::Float64(_) => Type::Float64,
Val::Char(_) => Type::Char,
Val::String(_) => Type::String,
Val::List(List { ty, .. }) => Type::List(ty.clone()),
Val::Record(Record { ty, .. }) => Type::Record(ty.clone()),
Val::Tuple(Tuple { ty, .. }) => Type::Tuple(ty.clone()),
Val::Variant(Variant { ty, .. }) => Type::Variant(ty.clone()),
Val::Enum(Enum { ty, .. }) => Type::Enum(ty.clone()),
Val::Union(Union { ty, .. }) => Type::Union(ty.clone()),
Val::Option(Option { ty, .. }) => Type::Option(ty.clone()),
Val::Expected(Expected { ty, .. }) => Type::Expected(ty.clone()),
Val::Flags(Flags { ty, .. }) => Type::Flags(ty.clone()),
}
}
/// Deserialize a value of this type from core Wasm stack values.
pub(crate) fn lift<'a>(
ty: &Type,
store: &StoreOpaque,
options: &Options,
src: &mut std::slice::Iter<'_, ValRaw>,
) -> Result<Val> {
Ok(match ty {
Type::Unit => Val::Unit,
Type::Bool => Val::Bool(bool::lift(store, options, next(src))?),
Type::S8 => Val::S8(i8::lift(store, options, next(src))?),
Type::U8 => Val::U8(u8::lift(store, options, next(src))?),
Type::S16 => Val::S16(i16::lift(store, options, next(src))?),
Type::U16 => Val::U16(u16::lift(store, options, next(src))?),
Type::S32 => Val::S32(i32::lift(store, options, next(src))?),
Type::U32 => Val::U32(u32::lift(store, options, next(src))?),
Type::S64 => Val::S64(i64::lift(store, options, next(src))?),
Type::U64 => Val::U64(u64::lift(store, options, next(src))?),
Type::Float32 => Val::Float32(u32::lift(store, options, next(src))?),
Type::Float64 => Val::Float64(u64::lift(store, options, next(src))?),
Type::Char => Val::Char(char::lift(store, options, next(src))?),
Type::String => {
Val::String(Box::<str>::lift(store, options, &[*next(src), *next(src)])?)
}
Type::List(handle) => {
// FIXME: needs memory64 treatment
let ptr = u32::lift(store, options, next(src))? as usize;
let len = u32::lift(store, options, next(src))? as usize;
load_list(handle, &Memory::new(store, options), ptr, len)?
}
Type::Record(handle) => Val::Record(Record {
ty: handle.clone(),
values: handle
.fields()
.map(|field| Self::lift(&field.ty, store, options, src))
.collect::<Result<_>>()?,
}),
Type::Tuple(handle) => Val::Tuple(Tuple {
ty: handle.clone(),
values: handle
.types()
.map(|ty| Self::lift(&ty, store, options, src))
.collect::<Result<_>>()?,
}),
Type::Variant(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
handle.cases().map(|case| case.ty),
store,
options,
src,
)?;
Val::Variant(Variant {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Enum(handle) => {
let (discriminant, _) = lift_variant(
ty.flatten_count(),
handle.names().map(|_| Type::Unit),
store,
options,
src,
)?;
Val::Enum(Enum {
ty: handle.clone(),
discriminant,
})
}
Type::Union(handle) => {
let (discriminant, value) =
lift_variant(ty.flatten_count(), handle.types(), store, options, src)?;
Val::Union(Union {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Option(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
[Type::Unit, handle.ty()].into_iter(),
store,
options,
src,
)?;
Val::Option(Option {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Expected(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
[handle.ok(), handle.err()].into_iter(),
store,
options,
src,
)?;
Val::Expected(Expected {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Flags(handle) => {
let count = u32::try_from(handle.names().len()).unwrap();
assert!(count <= 32);
let value = iter::once(u32::lift(store, options, next(src))?).collect();
Val::Flags(Flags {
ty: handle.clone(),
count,
value,
})
}
})
}
/// Deserialize a value of this type from the heap.
pub(crate) fn load(ty: &Type, mem: &Memory, bytes: &[u8]) -> Result<Val> {
Ok(match ty {
Type::Unit => Val::Unit,
Type::Bool => Val::Bool(bool::load(mem, bytes)?),
Type::S8 => Val::S8(i8::load(mem, bytes)?),
Type::U8 => Val::U8(u8::load(mem, bytes)?),
Type::S16 => Val::S16(i16::load(mem, bytes)?),
Type::U16 => Val::U16(u16::load(mem, bytes)?),
Type::S32 => Val::S32(i32::load(mem, bytes)?),
Type::U32 => Val::U32(u32::load(mem, bytes)?),
Type::S64 => Val::S64(i64::load(mem, bytes)?),
Type::U64 => Val::U64(u64::load(mem, bytes)?),
Type::Float32 => Val::Float32(u32::load(mem, bytes)?),
Type::Float64 => Val::Float64(u64::load(mem, bytes)?),
Type::Char => Val::Char(char::load(mem, bytes)?),
Type::String => Val::String(Box::<str>::load(mem, bytes)?),
Type::List(handle) => {
// FIXME: needs memory64 treatment
let ptr = u32::from_le_bytes(bytes[..4].try_into().unwrap()) as usize;
let len = u32::from_le_bytes(bytes[4..].try_into().unwrap()) as usize;
load_list(handle, mem, ptr, len)?
}
Type::Record(handle) => Val::Record(Record {
ty: handle.clone(),
values: load_record(handle.fields().map(|field| field.ty), mem, bytes)?,
}),
Type::Tuple(handle) => Val::Tuple(Tuple {
ty: handle.clone(),
values: load_record(handle.types(), mem, bytes)?,
}),
Type::Variant(handle) => {
let (discriminant, value) =
load_variant(ty, handle.cases().map(|case| case.ty), mem, bytes)?;
Val::Variant(Variant {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Enum(handle) => {
let (discriminant, _) =
load_variant(ty, handle.names().map(|_| Type::Unit), mem, bytes)?;
Val::Enum(Enum {
ty: handle.clone(),
discriminant,
})
}
Type::Union(handle) => {
let (discriminant, value) = load_variant(ty, handle.types(), mem, bytes)?;
Val::Union(Union {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Option(handle) => {
let (discriminant, value) =
load_variant(ty, [Type::Unit, handle.ty()].into_iter(), mem, bytes)?;
Val::Option(Option {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Expected(handle) => {
let (discriminant, value) =
load_variant(ty, [handle.ok(), handle.err()].into_iter(), mem, bytes)?;
Val::Expected(Expected {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Flags(handle) => Val::Flags(Flags {
ty: handle.clone(),
count: u32::try_from(handle.names().len())?,
value: match FlagsSize::from_count(handle.names().len()) {
FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(),
FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(),
FlagsSize::Size4Plus(n) => (0..n)
.map(|index| u32::load(mem, &bytes[index * 4..][..4]))
.collect::<Result<_>>()?,
},
}),
})
}
/// Serialize this value as core Wasm stack values.
pub(crate) fn lower<T>(
&self,
store: &mut StoreContextMut<T>,
options: &Options,
dst: &mut std::slice::IterMut<'_, MaybeUninit<ValRaw>>,
) -> Result<()> {
match self {
Val::Unit => (),
Val::Bool(value) => value.lower(store, options, next_mut(dst))?,
Val::S8(value) => value.lower(store, options, next_mut(dst))?,
Val::U8(value) => value.lower(store, options, next_mut(dst))?,
Val::S16(value) => value.lower(store, options, next_mut(dst))?,
Val::U16(value) => value.lower(store, options, next_mut(dst))?,
Val::S32(value) => value.lower(store, options, next_mut(dst))?,
Val::U32(value) => value.lower(store, options, next_mut(dst))?,
Val::S64(value) => value.lower(store, options, next_mut(dst))?,
Val::U64(value) => value.lower(store, options, next_mut(dst))?,
Val::Float32(value) => value.lower(store, options, next_mut(dst))?,
Val::Float64(value) => value.lower(store, options, next_mut(dst))?,
Val::Char(value) => value.lower(store, options, next_mut(dst))?,
Val::String(value) => {
let my_dst = &mut MaybeUninit::<[ValRaw; 2]>::uninit();
value.lower(store, options, my_dst)?;
let my_dst = unsafe { my_dst.assume_init() };
next_mut(dst).write(my_dst[0]);
next_mut(dst).write(my_dst[1]);
}
Val::List(List { values, ty }) => {
let (ptr, len) = lower_list(
&ty.ty(),
&mut MemoryMut::new(store.as_context_mut(), options),
values,
)?;
next_mut(dst).write(ValRaw::i64(ptr as i64));
next_mut(dst).write(ValRaw::i64(len as i64));
}
Val::Record(Record { values, .. }) | Val::Tuple(Tuple { values, .. }) => {
for value in values.deref() {
value.lower(store, options, dst)?;
}
}
Val::Variant(Variant {
discriminant,
value,
..
})
| Val::Union(Union {
discriminant,
value,
..
})
| Val::Option(Option {
discriminant,
value,
..
})
| Val::Expected(Expected {
discriminant,
value,
..
}) => {
next_mut(dst).write(ValRaw::u32(*discriminant));
value.lower(store, options, dst)?;
for _ in (1 + value.ty().flatten_count())..self.ty().flatten_count() {
next_mut(dst).write(ValRaw::u32(0));
}
}
Val::Enum(Enum { discriminant, .. }) => {
next_mut(dst).write(ValRaw::u32(*discriminant));
}
Val::Flags(Flags { value, .. }) => {
for value in value.deref() {
next_mut(dst).write(ValRaw::u32(*value));
}
}
}
Ok(())
}
/// Serialize this value to the heap at the specified memory location.
pub(crate) fn store<T>(&self, mem: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> {
debug_assert!(offset % usize::try_from(self.ty().size_and_alignment().alignment)? == 0);
match self {
Val::Unit => (),
Val::Bool(value) => value.store(mem, offset)?,
Val::S8(value) => value.store(mem, offset)?,
Val::U8(value) => value.store(mem, offset)?,
Val::S16(value) => value.store(mem, offset)?,
Val::U16(value) => value.store(mem, offset)?,
Val::S32(value) => value.store(mem, offset)?,
Val::U32(value) => value.store(mem, offset)?,
Val::S64(value) => value.store(mem, offset)?,
Val::U64(value) => value.store(mem, offset)?,
Val::Float32(value) => value.store(mem, offset)?,
Val::Float64(value) => value.store(mem, offset)?,
Val::Char(value) => value.store(mem, offset)?,
Val::String(value) => value.store(mem, offset)?,
Val::List(List { values, ty }) => {
let (ptr, len) = lower_list(&ty.ty(), mem, values)?;
// FIXME: needs memory64 handling
*mem.get(offset + 0) = (ptr as i32).to_le_bytes();
*mem.get(offset + 4) = (len as i32).to_le_bytes();
}
Val::Record(Record { values, .. }) | Val::Tuple(Tuple { values, .. }) => {
let mut offset = offset;
for value in values.deref() {
value.store(mem, value.ty().next_field(&mut offset))?;
}
}
Val::Variant(Variant {
discriminant,
value,
ty,
}) => self.store_variant(*discriminant, value, ty.cases().len(), mem, offset)?,
Val::Enum(Enum { discriminant, ty }) => {
self.store_variant(*discriminant, &Val::Unit, ty.names().len(), mem, offset)?
}
Val::Union(Union {
discriminant,
value,
ty,
}) => self.store_variant(*discriminant, value, ty.types().len(), mem, offset)?,
Val::Option(Option {
discriminant,
value,
..
})
| Val::Expected(Expected {
discriminant,
value,
..
}) => self.store_variant(*discriminant, value, 2, mem, offset)?,
Val::Flags(Flags { count, value, .. }) => {
match FlagsSize::from_count(*count as usize) {
FlagsSize::Size1 => u8::try_from(value[0]).unwrap().store(mem, offset)?,
FlagsSize::Size2 => u16::try_from(value[0]).unwrap().store(mem, offset)?,
FlagsSize::Size4Plus(_) => {
let mut offset = offset;
for value in value.deref() {
value.store(mem, offset)?;
offset += 4;
}
}
}
}
}
Ok(())
}
fn store_variant<T>(
&self,
discriminant: u32,
value: &Val,
case_count: usize,
mem: &mut MemoryMut<'_, T>,
offset: usize,
) -> Result<()> {
let discriminant_size = DiscriminantSize::from_count(case_count).unwrap();
match discriminant_size {
DiscriminantSize::Size1 => u8::try_from(discriminant).unwrap().store(mem, offset)?,
DiscriminantSize::Size2 => u16::try_from(discriminant).unwrap().store(mem, offset)?,
DiscriminantSize::Size4 => (discriminant).store(mem, offset)?,
}
value.store(
mem,
offset
+ func::align_to(
discriminant_size.into(),
self.ty().size_and_alignment().alignment,
),
)
}
}
fn load_list(handle: &types::List, mem: &Memory, ptr: usize, len: usize) -> Result<Val> {
let element_type = handle.ty();
let SizeAndAlignment {
size: element_size,
alignment: element_alignment,
} = element_type.size_and_alignment();
match len
.checked_mul(element_size)
.and_then(|len| ptr.checked_add(len))
{
Some(n) if n <= mem.as_slice().len() => {}
_ => bail!("list pointer/length out of bounds of memory"),
}
if ptr % usize::try_from(element_alignment)? != 0 {
bail!("list pointer is not aligned")
}
Ok(Val::List(List {
ty: handle.clone(),
values: (0..len)
.map(|index| {
Val::load(
&element_type,
mem,
&mem.as_slice()[ptr + (index * element_size)..][..element_size],
)
})
.collect::<Result<_>>()?,
}))
}
fn load_record(
types: impl Iterator<Item = Type>,
mem: &Memory,
bytes: &[u8],
) -> Result<Box<[Val]>> {
let mut offset = 0;
types
.map(|ty| {
Val::load(
&ty,
mem,
&bytes[ty.next_field(&mut offset)..][..ty.size_and_alignment().size],
)
})
.collect()
}
fn load_variant(
ty: &Type,
mut types: impl ExactSizeIterator<Item = Type>,
mem: &Memory,
bytes: &[u8],
) -> Result<(u32, Val)> {
let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
let discriminant = match discriminant_size {
DiscriminantSize::Size1 => u8::load(mem, &bytes[..1])? as u32,
DiscriminantSize::Size2 => u16::load(mem, &bytes[..2])? as u32,
DiscriminantSize::Size4 => u32::load(mem, &bytes[..4])?,
};
let case_ty = types.nth(discriminant as usize).ok_or_else(|| {
anyhow!(
"discriminant {} out of range [0..{})",
discriminant,
types.len()
)
})?;
let value = Val::load(
&case_ty,
mem,
&bytes[func::align_to(
usize::from(discriminant_size),
ty.size_and_alignment().alignment,
)..][..case_ty.size_and_alignment().size],
)?;
Ok((discriminant, value))
}
fn lift_variant<'a>(
flatten_count: usize,
mut types: impl ExactSizeIterator<Item = Type>,
store: &StoreOpaque,
options: &Options,
src: &mut std::slice::Iter<'_, ValRaw>,
) -> Result<(u32, Val)> {
let len = types.len();
let discriminant = next(src).get_u32();
let ty = types
.nth(discriminant as usize)
.ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?;
let value = Val::lift(&ty, store, options, src)?;
for _ in (1 + ty.flatten_count())..flatten_count {
next(src);
}
Ok((discriminant, value))
}
/// Lower a list with the specified element type and values.
fn lower_list<T>(
element_type: &Type,
mem: &mut MemoryMut<'_, T>,
items: &[Val],
) -> Result<(usize, usize)> {
let SizeAndAlignment {
size: element_size,
alignment: element_alignment,
} = element_type.size_and_alignment();
let size = items
.len()
.checked_mul(element_size)
.ok_or_else(|| anyhow::anyhow!("size overflow copying a list"))?;
let ptr = mem.realloc(0, 0, element_alignment, size)?;
let mut element_ptr = ptr;
for item in items {
item.store(mem, element_ptr)?;
element_ptr += element_size;
}
Ok((ptr, items.len()))
}
/// Calculate the size of a u32 array needed to represent the specified number of bit flags.
///
/// Note that this will always return at least 1, even if the `count` parameter is zero.
pub(crate) fn u32_count_for_flag_count(count: usize) -> usize {
match FlagsSize::from_count(count) {
FlagsSize::Size1 | FlagsSize::Size2 => 1,
FlagsSize::Size4Plus(n) => n,
}
}
fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw {
src.next().unwrap()
}
fn next_mut<'a>(
dst: &mut std::slice::IterMut<'a, MaybeUninit<ValRaw>>,
) -> &'a mut MaybeUninit<ValRaw> {
dst.next().unwrap()
}

View File

@@ -41,6 +41,7 @@ const CRATES_TO_PUBLISH: &[&str] = &[
"wiggle-macro",
// wasmtime
"wasmtime-asm-macros",
"wasmtime-component-util",
"wasmtime-component-macro",
"wasmtime-jit-debug",
"wasmtime-fiber",

View File

@@ -1,7 +1,10 @@
use anyhow::Result;
use std::fmt::Write;
use std::iter;
use wasmtime::component::{Component, ComponentParams, Lift, Lower, TypedFunc};
use wasmtime::{AsContextMut, Config, Engine};
mod dynamic;
mod func;
mod import;
mod instance;
@@ -148,3 +151,128 @@ fn components_importing_modules() -> Result<()> {
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum Type {
S8,
U8,
S16,
U16,
I32,
I64,
F32,
F64,
}
impl Type {
fn store(&self) -> &'static str {
match self {
Self::S8 | Self::U8 => "store8",
Self::S16 | Self::U16 => "store16",
Self::I32 | Self::F32 | Self::I64 | Self::F64 => "store",
}
}
fn primitive(&self) -> &'static str {
match self {
Self::S8 | Self::U8 | Self::S16 | Self::U16 | Self::I32 => "i32",
Self::I64 => "i64",
Self::F32 => "f32",
Self::F64 => "f64",
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
struct Param(Type, Option<usize>);
fn make_echo_component(type_definition: &str, type_size: u32) -> String {
let mut offset = 0;
make_echo_component_with_params(
type_definition,
&iter::repeat(Type::I32)
.map(|ty| {
let param = Param(ty, Some(offset));
offset += 4;
param
})
.take(usize::try_from(type_size).unwrap() / 4)
.collect::<Vec<_>>(),
)
}
fn make_echo_component_with_params(type_definition: &str, params: &[Param]) -> String {
let func = if params.len() == 1 || params.len() > 16 {
let primitive = if params.len() == 1 {
params[0].0.primitive()
} else {
"i32"
};
format!(
r#"
(func (export "echo") (param {primitive}) (result {primitive})
local.get 0
)"#,
)
} else {
let mut param_string = String::new();
let mut store = String::new();
let mut size = 8;
for (index, Param(ty, offset)) in params.iter().enumerate() {
let primitive = ty.primitive();
write!(&mut param_string, " {primitive}").unwrap();
if let Some(offset) = offset {
write!(
&mut store,
"({primitive}.{} offset={offset} (local.get $base) (local.get {index}))",
ty.store(),
)
.unwrap();
size = size.max(offset + 8);
}
}
format!(
r#"
(func (export "echo") (param{param_string}) (result i32)
(local $base i32)
(local.set $base
(call $realloc
(i32.const 0)
(i32.const 0)
(i32.const 4)
(i32.const {size})))
{store}
local.get $base
)"#
)
};
format!(
r#"
(component
(core module $m
{func}
(memory (export "memory") 1)
{REALLOC_AND_FREE}
)
(core instance $i (instantiate $m))
(type $Foo {type_definition})
(func (export "echo") (param $Foo) (result $Foo)
(canon lift
(core func $i "echo")
(memory $i "memory")
(realloc (func $i "realloc"))
)
)
)"#
)
}

View File

@@ -0,0 +1,511 @@
use super::{make_echo_component, make_echo_component_with_params, Param, Type};
use anyhow::Result;
use wasmtime::component::{self, Component, Func, Linker, Val};
use wasmtime::{AsContextMut, Store};
trait FuncExt {
fn call_and_post_return(&self, store: impl AsContextMut, args: &[Val]) -> Result<Val>;
}
impl FuncExt for Func {
fn call_and_post_return(&self, mut store: impl AsContextMut, args: &[Val]) -> Result<Val> {
let result = self.call(&mut store, args)?;
self.post_return(&mut store)?;
Ok(result)
}
}
#[test]
fn primitives() -> Result<()> {
let engine = super::engine();
let mut store = Store::new(&engine, ());
for (input, ty, param) in [
(Val::Bool(true), "bool", Param(Type::U8, Some(0))),
(Val::S8(-42), "s8", Param(Type::S8, Some(0))),
(Val::U8(42), "u8", Param(Type::U8, Some(0))),
(Val::S16(-4242), "s16", Param(Type::S16, Some(0))),
(Val::U16(4242), "u16", Param(Type::U16, Some(0))),
(Val::S32(-314159265), "s32", Param(Type::I32, Some(0))),
(Val::U32(314159265), "u32", Param(Type::I32, Some(0))),
(Val::S64(-31415926535897), "s64", Param(Type::I64, Some(0))),
(Val::U64(31415926535897), "u64", Param(Type::I64, Some(0))),
(
Val::Float32(3.14159265_f32.to_bits()),
"float32",
Param(Type::F32, Some(0)),
),
(
Val::Float64(3.14159265_f64.to_bits()),
"float64",
Param(Type::F64, Some(0)),
),
(Val::Char('🦀'), "char", Param(Type::I32, Some(0))),
] {
let component = Component::new(&engine, make_echo_component_with_params(ty, &[param]))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
}
// Sad path: type mismatch
let component = Component::new(
&engine,
make_echo_component_with_params("float64", &[Param(Type::F64, Some(0))]),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let err = func
.call_and_post_return(&mut store, &[Val::U64(42)])
.unwrap_err();
assert!(err.to_string().contains("type mismatch"), "{err}");
// Sad path: arity mismatch (too many)
let err = func
.call_and_post_return(
&mut store,
&[
Val::Float64(3.14159265_f64.to_bits()),
Val::Float64(3.14159265_f64.to_bits()),
],
)
.unwrap_err();
assert!(
err.to_string().contains("expected 1 argument(s), got 2"),
"{err}"
);
// Sad path: arity mismatch (too few)
let err = func.call_and_post_return(&mut store, &[]).unwrap_err();
assert!(
err.to_string().contains("expected 1 argument(s), got 0"),
"{err}"
);
Ok(())
}
#[test]
fn strings() -> Result<()> {
let engine = super::engine();
let mut store = Store::new(&engine, ());
let component = Component::new(&engine, make_echo_component("string", 8))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let input = Val::String(Box::from("hello, component!"));
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
Ok(())
}
#[test]
fn lists() -> Result<()> {
let engine = super::engine();
let mut store = Store::new(&engine, ());
let component = Component::new(&engine, make_echo_component("(list u32)", 8))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let ty = &func.params(&store)[0];
let input = ty.unwrap_list().new_val(Box::new([
Val::U32(32343),
Val::U32(79023439),
Val::U32(2084037802),
]))?;
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
// Sad path: type mismatch
let err = ty
.unwrap_list()
.new_val(Box::new([
Val::U32(32343),
Val::U32(79023439),
Val::Float32(3.14159265_f32.to_bits()),
]))
.unwrap_err();
assert!(err.to_string().contains("type mismatch"), "{err}");
Ok(())
}
#[test]
fn records() -> Result<()> {
let engine = super::engine();
let mut store = Store::new(&engine, ());
let component = Component::new(
&engine,
make_echo_component_with_params(
r#"(record (field "A" u32) (field "B" float64) (field "C" (record (field "D" bool) (field "E" u32))))"#,
&[
Param(Type::I32, Some(0)),
Param(Type::F64, Some(8)),
Param(Type::U8, Some(16)),
Param(Type::I32, Some(20)),
],
),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let ty = &func.params(&store)[0];
let inner_type = &ty.unwrap_record().fields().nth(2).unwrap().ty;
let input = ty.unwrap_record().new_val([
("A", Val::U32(32343)),
("B", Val::Float64(3.14159265_f64.to_bits())),
(
"C",
inner_type
.unwrap_record()
.new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?,
),
])?;
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
// Sad path: type mismatch
let err = ty
.unwrap_record()
.new_val([
("A", Val::S32(32343)),
("B", Val::Float64(3.14159265_f64.to_bits())),
(
"C",
inner_type
.unwrap_record()
.new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?,
),
])
.unwrap_err();
assert!(err.to_string().contains("type mismatch"), "{err}");
// Sad path: too many fields
let err = ty
.unwrap_record()
.new_val([
("A", Val::U32(32343)),
("B", Val::Float64(3.14159265_f64.to_bits())),
(
"C",
inner_type
.unwrap_record()
.new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?,
),
("F", Val::Unit),
])
.unwrap_err();
assert!(
err.to_string().contains("expected 3 value(s); got 4"),
"{err}"
);
// Sad path: too few fields
let err = ty
.unwrap_record()
.new_val([
("A", Val::U32(32343)),
("B", Val::Float64(3.14159265_f64.to_bits())),
])
.unwrap_err();
assert!(
err.to_string().contains("expected 3 value(s); got 2"),
"{err}"
);
Ok(())
}
#[test]
fn variants() -> Result<()> {
let engine = super::engine();
let mut store = Store::new(&engine, ());
let component = Component::new(
&engine,
make_echo_component_with_params(
r#"(variant (case "A" u32) (case "B" float64) (case "C" (record (field "D" bool) (field "E" u32))))"#,
&[
Param(Type::U8, Some(0)),
Param(Type::I64, Some(8)),
Param(Type::I32, None),
],
),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let ty = &func.params(&store)[0];
let input = ty
.unwrap_variant()
.new_val("B", Val::Float64(3.14159265_f64.to_bits()))?;
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
// Do it again, this time using case "C"
let component = Component::new(
&engine,
dbg!(make_echo_component_with_params(
r#"(variant (case "A" u32) (case "B" float64) (case "C" (record (field "D" bool) (field "E" u32))))"#,
&[
Param(Type::U8, Some(0)),
Param(Type::I64, Some(8)),
Param(Type::I32, Some(12)),
],
)),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let ty = &func.params(&store)[0];
let c_type = &ty.unwrap_variant().cases().nth(2).unwrap().ty;
let input = ty.unwrap_variant().new_val(
"C",
c_type
.unwrap_record()
.new_val([("D", Val::Bool(true)), ("E", Val::U32(314159265))])?,
)?;
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
// Sad path: type mismatch
let err = ty
.unwrap_variant()
.new_val("B", Val::U64(314159265))
.unwrap_err();
assert!(err.to_string().contains("type mismatch"), "{err}");
// Sad path: unknown case
let err = ty
.unwrap_variant()
.new_val("D", Val::U64(314159265))
.unwrap_err();
assert!(err.to_string().contains("unknown variant case"), "{err}");
// Make sure we lift variants which have cases of different sizes with the correct alignment
let component = Component::new(
&engine,
make_echo_component_with_params(
r#"
(record
(field "A" (variant
(case "A" u32)
(case "B" float64)
(case "C" (record (field "D" bool) (field "E" u32)))))
(field "B" u32))"#,
&[
Param(Type::U8, Some(0)),
Param(Type::I64, Some(8)),
Param(Type::I32, None),
Param(Type::I32, Some(16)),
],
),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let ty = &func.params(&store)[0];
let a_type = &ty.unwrap_record().fields().nth(0).unwrap().ty;
let input = ty.unwrap_record().new_val([
(
"A",
a_type.unwrap_variant().new_val("A", Val::U32(314159265))?,
),
("B", Val::U32(628318530)),
])?;
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
Ok(())
}
#[test]
fn flags() -> Result<()> {
let engine = super::engine();
let mut store = Store::new(&engine, ());
let component = Component::new(
&engine,
make_echo_component_with_params(
r#"(flags "A" "B" "C" "D" "E")"#,
&[Param(Type::U8, Some(0))],
),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let ty = &func.params(&store)[0];
let input = ty.unwrap_flags().new_val(&["B", "D"])?;
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
// Sad path: unknown flags
let err = ty.unwrap_flags().new_val(&["B", "D", "F"]).unwrap_err();
assert!(err.to_string().contains("unknown flag"), "{err}");
Ok(())
}
#[test]
fn everything() -> Result<()> {
// This serves to test both nested types and storing parameters on the heap (i.e. exceeding `MAX_STACK_PARAMS`)
let engine = super::engine();
let mut store = Store::new(&engine, ());
let component = Component::new(
&engine,
make_echo_component_with_params(
r#"
(record
(field "A" u32)
(field "B" (enum "1" "2"))
(field "C" (record (field "D" bool) (field "E" u32)))
(field "F" (list (flags "G" "H" "I")))
(field "J" (variant
(case "K" u32)
(case "L" float64)
(case "M" (record (field "N" bool) (field "O" u32)))))
(field "P" s8)
(field "Q" s16)
(field "R" s32)
(field "S" s64)
(field "T" float32)
(field "U" float64)
(field "V" string)
(field "W" char)
(field "X" unit)
(field "Y" (tuple u32 u32))
(field "Z" (union u32 float64))
(field "AA" (option u32))
(field "BB" (expected string string))
)"#,
&[
Param(Type::I32, Some(0)),
Param(Type::U8, Some(4)),
Param(Type::U8, Some(5)),
Param(Type::I32, Some(8)),
Param(Type::I32, Some(12)),
Param(Type::I32, Some(16)),
Param(Type::U8, Some(20)),
Param(Type::I64, Some(28)),
Param(Type::I32, Some(32)),
Param(Type::S8, Some(36)),
Param(Type::S16, Some(38)),
Param(Type::I32, Some(40)),
Param(Type::I64, Some(48)),
Param(Type::F32, Some(56)),
Param(Type::F64, Some(64)),
Param(Type::I32, Some(72)),
Param(Type::I32, Some(76)),
Param(Type::I32, Some(80)),
Param(Type::I32, Some(84)),
Param(Type::I32, Some(88)),
Param(Type::I64, Some(96)),
Param(Type::U8, Some(104)),
Param(Type::I32, Some(108)),
Param(Type::U8, Some(112)),
Param(Type::I32, Some(116)),
Param(Type::I32, Some(120)),
],
),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_func(&mut store, "echo").unwrap();
let ty = &func.params(&store)[0];
let types = ty
.unwrap_record()
.fields()
.map(|field| field.ty)
.collect::<Box<[component::Type]>>();
let (b_type, c_type, f_type, j_type, y_type, z_type, aa_type, bb_type) = (
&types[1], &types[2], &types[3], &types[4], &types[14], &types[15], &types[16], &types[17],
);
let f_element_type = &f_type.unwrap_list().ty();
let input = ty.unwrap_record().new_val([
("A", Val::U32(32343)),
("B", b_type.unwrap_enum().new_val("2")?),
(
"C",
c_type
.unwrap_record()
.new_val([("D", Val::Bool(false)), ("E", Val::U32(2084037802))])?,
),
(
"F",
f_type.unwrap_list().new_val(Box::new([f_element_type
.unwrap_flags()
.new_val(&["G", "I"])?]))?,
),
(
"J",
j_type
.unwrap_variant()
.new_val("L", Val::Float64(3.14159265_f64.to_bits()))?,
),
("P", Val::S8(42)),
("Q", Val::S16(4242)),
("R", Val::S32(42424242)),
("S", Val::S64(424242424242424242)),
("T", Val::Float32(3.14159265_f32.to_bits())),
("U", Val::Float64(3.14159265_f64.to_bits())),
("V", Val::String(Box::from("wow, nice types"))),
("W", Val::Char('🦀')),
("X", Val::Unit),
(
"Y",
y_type
.unwrap_tuple()
.new_val(Box::new([Val::U32(42), Val::U32(24)]))?,
),
(
"Z",
z_type
.unwrap_union()
.new_val(1, Val::Float64(3.14159265_f64.to_bits()))?,
),
(
"AA",
aa_type.unwrap_option().new_val(Some(Val::U32(314159265)))?,
),
(
"BB",
bb_type
.unwrap_expected()
.new_val(Ok(Val::String(Box::from("no problem"))))?,
),
])?;
let output = func.call_and_post_return(&mut store, &[input.clone()])?;
assert_eq!(input, output);
Ok(())
}

View File

@@ -1,75 +1,9 @@
use super::TypedFuncExt;
use super::{make_echo_component, TypedFuncExt};
use anyhow::Result;
use component_macro_test::{add_variants, flags_test};
use std::fmt::Write;
use wasmtime::component::{Component, ComponentType, Lift, Linker, Lower};
use wasmtime::Store;
fn make_echo_component(type_definition: &str, type_size: u32) -> String {
if type_size <= 4 {
format!(
r#"
(component
(core module $m
(func (export "echo") (param i32) (result i32)
local.get 0
)
(memory (export "memory") 1)
)
(core instance $i (instantiate $m))
{}
(func (export "echo") (param $Foo) (result $Foo)
(canon lift (core func $i "echo") (memory $i "memory"))
)
)"#,
type_definition
)
} else {
let mut params = String::new();
let mut store = String::new();
for index in 0..(type_size / 4) {
params.push_str(" i32");
write!(
&mut store,
"(i32.store offset={} (local.get $base) (local.get {}))",
index * 4,
index,
)
.unwrap();
}
format!(
r#"
(component
(core module $m
(func (export "echo") (param{}) (result i32)
(local $base i32)
(local.set $base (i32.const 0))
{}
local.get $base
)
(memory (export "memory") 1)
)
(core instance $i (instantiate $m))
{}
(func (export "echo") (param $Foo) (result $Foo)
(canon lift (core func $i "echo") (memory $i "memory"))
)
)"#,
params, store, type_definition
)
}
}
#[test]
fn record_derive() -> Result<()> {
#[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
@@ -87,10 +21,7 @@ fn record_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" u32)))"#,
8,
),
make_echo_component(r#"(record (field "foo-bar-baz" s32) (field "b" u32))"#, 8),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -105,7 +36,7 @@ fn record_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (record (field "foo-bar-baz" s32)))"#, 4),
make_echo_component(r#"(record (field "foo-bar-baz" s32))"#, 4),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -118,7 +49,7 @@ fn record_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" u32) (field "c" u32)))"#,
r#"(record (field "foo-bar-baz" s32) (field "b" u32) (field "c" u32))"#,
12,
),
)?;
@@ -132,7 +63,7 @@ fn record_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (record (field "a" s32) (field "b" u32)))"#, 8),
make_echo_component(r#"(record (field "a" s32) (field "b" u32))"#, 8),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -144,10 +75,7 @@ fn record_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" s32)))"#,
8,
),
make_echo_component(r#"(record (field "foo-bar-baz" s32) (field "b" s32))"#, 8),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -172,10 +100,7 @@ fn record_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (record (field "foo-bar-baz" s32) (field "b" u32)))"#,
8,
),
make_echo_component(r#"(record (field "foo-bar-baz" s32) (field "b" u32))"#, 8),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -203,10 +128,7 @@ fn union_derive() -> Result<()> {
// Happy path: component type matches case count and types
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (union s32 u32 s32))"#, 8),
)?;
let component = Component::new(&engine, make_echo_component("(union s32 u32 s32)", 8))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")?;
@@ -218,10 +140,7 @@ fn union_derive() -> Result<()> {
// Sad path: case count mismatch (too few)
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (union s32 u32))"#, 8),
)?;
let component = Component::new(&engine, make_echo_component("(union s32 u32)", 8))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
assert!(instance
@@ -232,7 +151,7 @@ fn union_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (union s32 u32 s32 s32))"#, 8),
make_echo_component(r#"(union s32 u32 s32 s32)"#, 8),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -246,10 +165,7 @@ fn union_derive() -> Result<()> {
// Sad path: case type mismatch
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (union s32 s32 s32))"#, 8),
)?;
let component = Component::new(&engine, make_echo_component("(union s32 s32 s32)", 8))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
assert!(instance
@@ -266,10 +182,7 @@ fn union_derive() -> Result<()> {
C(C),
}
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (union s32 u32 s32))"#, 8),
)?;
let component = Component::new(&engine, make_echo_component("(union s32 u32 s32)", 8))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_typed_func::<(Generic<i32, u32, i32>,), Generic<i32, u32, i32>, _>(
&mut store, "echo",
@@ -307,7 +220,7 @@ fn variant_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit)))"#,
r#"(variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit))"#,
8,
),
)?;
@@ -324,10 +237,7 @@ fn variant_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32)))"#,
8,
),
make_echo_component(r#"(variant (case "foo-bar-baz" s32) (case "B" u32))"#, 8),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -340,7 +250,7 @@ fn variant_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit) (case "D" u32)))"#,
r#"(variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit) (case "D" u32))"#,
8,
),
)?;
@@ -355,7 +265,7 @@ fn variant_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (variant (case "A" s32) (case "B" u32) (case "C" unit)))"#,
r#"(variant (case "A" s32) (case "B" u32) (case "C" unit))"#,
8,
),
)?;
@@ -370,7 +280,7 @@ fn variant_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" s32) (case "C" unit)))"#,
r#"(variant (case "foo-bar-baz" s32) (case "B" s32) (case "C" unit))"#,
8,
),
)?;
@@ -394,7 +304,7 @@ fn variant_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(
r#"(type $Foo (variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit)))"#,
r#"(variant (case "foo-bar-baz" s32) (case "B" u32) (case "C" unit))"#,
8,
),
)?;
@@ -429,7 +339,7 @@ fn enum_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (enum "foo-bar-baz" "B" "C"))"#, 4),
make_echo_component(r#"(enum "foo-bar-baz" "B" "C")"#, 4),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")?;
@@ -444,7 +354,7 @@ fn enum_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (enum "foo-bar-baz" "B"))"#, 4),
make_echo_component(r#"(enum "foo-bar-baz" "B")"#, 4),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -456,7 +366,7 @@ fn enum_derive() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (enum "foo-bar-baz" "B" "C" "D"))"#, 4),
make_echo_component(r#"(enum "foo-bar-baz" "B" "C" "D")"#, 4),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -466,10 +376,7 @@ fn enum_derive() -> Result<()> {
// Sad path: case name mismatch
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (enum "A" "B" "C"))"#, 4),
)?;
let component = Component::new(&engine, make_echo_component(r#"(enum "A" "B" "C")"#, 4))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
assert!(instance
@@ -487,7 +394,7 @@ fn enum_derive() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (enum {}))"#,
"(enum {})",
(0..257)
.map(|index| format!(r#""V{}""#, index))
.collect::<Vec<_>>()
@@ -542,7 +449,7 @@ fn flags() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B" "C"))"#, 4),
make_echo_component(r#"(flags "foo-bar-baz" "B" "C")"#, 4),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")?;
@@ -568,7 +475,7 @@ fn flags() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B"))"#, 4),
make_echo_component(r#"(flags "foo-bar-baz" "B")"#, 4),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -580,7 +487,7 @@ fn flags() -> Result<()> {
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (flags "foo-bar-baz" "B" "C" "D"))"#, 4),
make_echo_component(r#"(flags "foo-bar-baz" "B" "C" "D")"#, 4),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
@@ -590,10 +497,7 @@ fn flags() -> Result<()> {
// Sad path: flag name mismatch
let component = Component::new(
&engine,
make_echo_component(r#"(type $Foo (flags "A" "B" "C"))"#, 4),
)?;
let component = Component::new(&engine, make_echo_component(r#"(flags "A" "B" "C")"#, 4))?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
assert!(instance
@@ -633,7 +537,7 @@ fn flags() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (flags {}))"#,
r#"(flags {})"#,
(0..8)
.map(|index| format!(r#""F{}""#, index))
.collect::<Vec<_>>()
@@ -682,7 +586,7 @@ fn flags() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (flags {}))"#,
"(flags {})",
(0..9)
.map(|index| format!(r#""F{}""#, index))
.collect::<Vec<_>>()
@@ -730,7 +634,7 @@ fn flags() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (flags {}))"#,
r#"(flags {})"#,
(0..16)
.map(|index| format!(r#""F{}""#, index))
.collect::<Vec<_>>()
@@ -769,7 +673,7 @@ fn flags() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (flags {}))"#,
"(flags {})",
(0..17)
.map(|index| format!(r#""F{}""#, index))
.collect::<Vec<_>>()
@@ -817,7 +721,7 @@ fn flags() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (flags {}))"#,
r#"(flags {})"#,
(0..32)
.map(|index| format!(r#""F{}""#, index))
.collect::<Vec<_>>()
@@ -856,7 +760,7 @@ fn flags() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (flags {}))"#,
"(flags {})",
(0..33)
.map(|index| format!(r#""F{}""#, index))
.collect::<Vec<_>>()
@@ -889,7 +793,7 @@ fn flags() -> Result<()> {
&engine,
make_echo_component(
&format!(
r#"(type $Foo (flags {}))"#,
"(flags {})",
(0..65)
.map(|index| format!(r#""F{}""#, index))
.collect::<Vec<_>>()