support dynamic function calls in component model (#4442)

* support dynamic function calls in component model

This addresses #4310, introducing a new `component::values::Val` type for
representing component values dynamically, as well as `component::types::Type`
for representing the corresponding interface types. It also adds a `call` method
to `component::func::Func`, which takes a slice of `Val`s as parameters and
returns a `Result<Val>` representing the result.

Note that I've moved `post_return` and `call_raw` from `TypedFunc` to `Func`
since there was nothing specific to `TypedFunc` about them, and I wanted to
reuse them.  The code in both is unchanged beyond the trivial tweaks to make
them fit in their new home.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* order variants and match cases more consistently

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* implement lift for String, Box<str>, etc.

This also removes the redundant `store` parameter from `Type::load`.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* implement code review feedback

This fixes a few issues:

- Bad offset calculation when lowering
- Missing variant padding
- Style issues regarding `types::Handle`
- Missed opportunities to reuse `Lift` and `Lower` impls

It also adds forwarding `Lift` impls for `Box<[T]>`, `Vec<T>`, etc.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* move `new_*` methods to specific `types` structs

Per review feedback, I've moved `Type::new_record` to `Record::new_val` and
added a `Type::unwrap_record` method; likewise for the other kinds of types.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* make tuple, option, and expected type comparisons recursive

These types should compare as equal across component boundaries as long as their
type parameters are equal.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* improve error diagnostic in `Type::check`

We now distinguish between more failure cases to provide an informative error
message.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* address review feedback

- Remove `WasmStr::to_str_from_memory` and `WasmList::get_from_memory`
- add `try_new` methods to various `values` types
- avoid using `ExactSizeIterator::len` where we can't trust it
- fix over-constrained bounds on forwarded `ComponentType` impls

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* rearrange code per review feedback

- Move functions from `types` to `values` module so we can make certain struct fields private
- Rename `try_new` to just `new`

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* remove special-case equality test for tuples, options, and expecteds

Instead, I've added a FIXME comment and will open an issue to do recursive
structural equality testing.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>
This commit is contained in:
Joel Dice
2022-07-25 12:38:48 -06:00
committed by GitHub
parent ee7e4f4c6b
commit 7c67e620c4
16 changed files with 2796 additions and 453 deletions

View File

@@ -17,6 +17,7 @@ proc-macro = true
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0", features = ["extra-traits"] }
wasmtime-component-util = { path = "../component-util", version = "=0.40.0" }
[badges]
maintenance = { status = "actively-developed" }

View File

@@ -5,6 +5,7 @@ use std::fmt;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{braced, parse_macro_input, parse_quote, Data, DeriveInput, Error, Result, Token};
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
#[derive(Debug, Copy, Clone)]
enum VariantStyle {
@@ -147,64 +148,6 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn
generics
}
#[derive(Debug, Copy, Clone)]
enum DiscriminantSize {
Size1,
Size2,
Size4,
}
impl DiscriminantSize {
fn quote(self, discriminant: usize) -> TokenStream {
match self {
Self::Size1 => {
let discriminant = u8::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size2 => {
let discriminant = u16::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size4 => {
let discriminant = u32::try_from(discriminant).unwrap();
quote!(#discriminant)
}
}
}
}
impl From<DiscriminantSize> for u32 {
fn from(size: DiscriminantSize) -> u32 {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
impl From<DiscriminantSize> for usize {
fn from(size: DiscriminantSize) -> usize {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
fn discriminant_size(case_count: usize) -> Option<DiscriminantSize> {
if case_count <= 0xFF {
Some(DiscriminantSize::Size1)
} else if case_count <= 0xFFFF {
Some(DiscriminantSize::Size2)
} else if case_count <= 0xFFFF_FFFF {
Some(DiscriminantSize::Size4)
} else {
None
}
}
struct VariantCase<'a> {
attrs: &'a [syn::Attribute],
ident: &'a syn::Ident,
@@ -288,7 +231,7 @@ fn expand_variant(
));
}
let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| {
let discriminant_size = DiscriminantSize::from_count(body.variants.len()).ok_or_else(|| {
Error::new(
input.ident.span(),
"`enum`s with more than 2^32 variants are not supported",
@@ -417,7 +360,7 @@ fn expand_record_for_component_type(
const SIZE32: usize = {
let mut size = 0;
#sizes
size
#internal::align_to(size, Self::ALIGN32)
};
const ALIGN32: u32 = {
@@ -439,6 +382,23 @@ fn expand_record_for_component_type(
Ok(quote!(const _: () = { #expanded };))
}
fn quote(size: DiscriminantSize, discriminant: usize) -> TokenStream {
match size {
DiscriminantSize::Size1 => {
let discriminant = u8::try_from(discriminant).unwrap();
quote!(#discriminant)
}
DiscriminantSize::Size2 => {
let discriminant = u16::try_from(discriminant).unwrap();
quote!(#discriminant)
}
DiscriminantSize::Size4 => {
let discriminant = u32::try_from(discriminant).unwrap();
quote!(#discriminant)
}
}
}
#[proc_macro_derive(Lift, attributes(component))]
pub fn lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
expand(&LiftExpander, &parse_macro_input!(input as DeriveInput))
@@ -523,7 +483,7 @@ impl Expander for LiftExpander {
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u32 = u32::try_from(index).unwrap();
let index_quoted = discriminant_size.quote(index);
let index_quoted = quote(discriminant_size, index);
if let Some(ty) = ty {
lifts.extend(
@@ -666,7 +626,7 @@ impl Expander for LowerExpander {
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u32 = u32::try_from(index).unwrap();
let index_quoted = discriminant_size.quote(index);
let index_quoted = quote(discriminant_size, index);
let discriminant_size = usize::from(discriminant_size);
@@ -989,19 +949,6 @@ impl Parse for Flags {
}
}
enum FlagsSize {
/// Flags can fit in a u8
Size1,
/// Flags can fit in a u16
Size2,
/// Flags can fit in a specified number of u32 fields
Size4Plus(usize),
}
fn ceiling_divide(n: usize, d: usize) -> usize {
(n + d - 1) / d
}
#[proc_macro]
pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
expand_flags(&parse_macro_input!(input as Flags))
@@ -1010,13 +957,7 @@ pub fn flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
}
fn expand_flags(flags: &Flags) -> Result<TokenStream> {
let size = if flags.flags.len() <= 8 {
FlagsSize::Size1
} else if flags.flags.len() <= 16 {
FlagsSize::Size2
} else {
FlagsSize::Size4Plus(ceiling_divide(flags.flags.len(), 32))
};
let size = FlagsSize::from_count(flags.flags.len());
let ty;
let eq;

View File

@@ -0,0 +1,11 @@
[package]
name = "wasmtime-component-util"
version = "0.40.0"
authors = ["The Wasmtime Project Developers"]
description = "Utility types and functions to support the component model in Wasmtime"
license = "Apache-2.0 WITH LLVM-exception"
repository = "https://github.com/bytecodealliance/wasmtime"
documentation = "https://docs.rs/wasmtime-component-util/"
categories = ["wasm"]
keywords = ["webassembly", "wasm"]
edition = "2021"

View File

@@ -0,0 +1,75 @@
/// Represents the possible sizes in bytes of the discriminant of a variant type in the component model
#[derive(Debug, Copy, Clone)]
pub enum DiscriminantSize {
/// 8-bit discriminant
Size1,
/// 16-bit discriminant
Size2,
/// 32-bit discriminant
Size4,
}
impl DiscriminantSize {
/// Calculate the size of discriminant needed to represent a variant with the specified number of cases.
pub fn from_count(count: usize) -> Option<Self> {
if count <= 0xFF {
Some(Self::Size1)
} else if count <= 0xFFFF {
Some(Self::Size2)
} else if count <= 0xFFFF_FFFF {
Some(Self::Size4)
} else {
None
}
}
}
impl From<DiscriminantSize> for u32 {
/// Size of the discriminant as a `u32`
fn from(size: DiscriminantSize) -> u32 {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
impl From<DiscriminantSize> for usize {
/// Size of the discriminant as a `usize`
fn from(size: DiscriminantSize) -> usize {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
/// Represents the number of bytes required to store a flags value in the component model
pub enum FlagsSize {
/// Flags can fit in a u8
Size1,
/// Flags can fit in a u16
Size2,
/// Flags can fit in a specified number of u32 fields
Size4Plus(usize),
}
impl FlagsSize {
/// Calculate the size needed to represent a value with the specified number of flags.
pub fn from_count(count: usize) -> FlagsSize {
if count <= 8 {
FlagsSize::Size1
} else if count <= 16 {
FlagsSize::Size2
} else {
FlagsSize::Size4Plus(ceiling_divide(count, 32))
}
}
}
/// Divide `n` by `d`, rounding up in the case of a non-zero remainder.
fn ceiling_divide(n: usize, d: usize) -> usize {
(n + d - 1) / d
}

View File

@@ -20,6 +20,7 @@ wasmtime-cache = { path = "../cache", version = "=0.40.0", optional = true }
wasmtime-fiber = { path = "../fiber", version = "=0.40.0", optional = true }
wasmtime-cranelift = { path = "../cranelift", version = "=0.40.0", optional = true }
wasmtime-component-macro = { path = "../component-macro", version = "=0.40.0", optional = true }
wasmtime-component-util = { path = "../component-util", version = "=0.40.0", optional = true }
target-lexicon = { version = "0.12.0", default-features = false }
wasmparser = "0.87.0"
anyhow = "1.0.19"
@@ -115,4 +116,5 @@ component-model = [
"wasmtime-cranelift?/component-model",
"wasmtime-runtime/component-model",
"dep:wasmtime-component-macro",
"dep:wasmtime-component-util",
]

View File

@@ -1,8 +1,10 @@
use crate::component::instance::{Instance, InstanceData};
use crate::component::types::{SizeAndAlignment, Type};
use crate::component::values::Val;
use crate::store::{StoreOpaque, Stored};
use crate::{AsContext, ValRaw};
use anyhow::{Context, Result};
use std::mem::MaybeUninit;
use crate::{AsContext, AsContextMut, StoreContextMut, ValRaw};
use anyhow::{bail, Context, Result};
use std::mem::{self, MaybeUninit};
use std::ptr::NonNull;
use std::sync::Arc;
use wasmtime_environ::component::{
@@ -72,6 +74,12 @@ pub use self::host::*;
pub use self::options::*;
pub use self::typed::*;
#[repr(C)]
union ParamsAndResults<Params: Copy, Return: Copy> {
params: Params,
ret: Return,
}
/// A WebAssembly component function.
//
// FIXME: write more docs here
@@ -241,4 +249,346 @@ impl Func {
Ok(())
}
/// Get the parameter types for this function.
pub fn params(&self, store: impl AsContext) -> Box<[Type]> {
let data = &store.as_context()[self.0];
data.types[data.ty]
.params
.iter()
.map(|(_, ty)| Type::from(ty, &data.types))
.collect()
}
/// Invokes this function with the `params` given and returns the result.
///
/// The `params` here must match the type signature of this `Func`, or this will return an error. If a trap
/// occurs while executing this function, then an error will also be returned.
// TODO: say more -- most of the docs for `TypedFunc::call` apply here, too
pub fn call(&self, mut store: impl AsContextMut, args: &[Val]) -> Result<Val> {
let store = &mut store.as_context_mut();
let params;
let result;
{
let data = &store[self.0];
let ty = &data.types[data.ty];
if ty.params.len() != args.len() {
bail!(
"expected {} argument(s), got {}",
ty.params.len(),
args.len()
);
}
params = ty
.params
.iter()
.zip(args)
.map(|((_, ty), arg)| {
let ty = Type::from(ty, &data.types);
ty.check(arg).context("type mismatch with parameters")?;
Ok(ty)
})
.collect::<Result<Vec<_>>>()?;
result = Type::from(&ty.result, &data.types);
}
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>();
let result_count = result.flatten_count();
self.call_raw(
store,
args,
|store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_STACK_PARAMS]>| {
if param_count > MAX_STACK_PARAMS {
self.store_args(store, &options, &params, args, dst)
} else {
dst.write([ValRaw::u64(0); MAX_STACK_PARAMS]);
let dst = unsafe {
mem::transmute::<_, &mut [MaybeUninit<ValRaw>; MAX_STACK_PARAMS]>(dst)
};
args.iter()
.try_for_each(|arg| arg.lower(store, &options, &mut dst.iter_mut()))
}
},
|store, options, src: &[ValRaw; MAX_STACK_RESULTS]| {
if result_count > MAX_STACK_RESULTS {
Self::load_result(&Memory::new(store, &options), &result, &mut src.iter())
} else {
Val::lift(&result, store, &options, &mut src.iter())
}
},
)
}
/// Invokes the underlying wasm function, lowering arguments and lifting the
/// result.
///
/// The `lower` function and `lift` function provided here are what actually
/// do the lowering and lifting. The `LowerParams` and `LowerReturn` types
/// are what will be allocated on the stack for this function call. They
/// should be appropriately sized for the lowering/lifting operation
/// happening.
fn call_raw<T, Params: ?Sized, Return, LowerParams, LowerReturn>(
&self,
store: &mut StoreContextMut<'_, T>,
params: &Params,
lower: impl FnOnce(
&mut StoreContextMut<'_, T>,
&Options,
&Params,
&mut MaybeUninit<LowerParams>,
) -> Result<()>,
lift: impl FnOnce(&StoreOpaque, &Options, &LowerReturn) -> Result<Return>,
) -> Result<Return>
where
LowerParams: Copy,
LowerReturn: Copy,
{
let FuncData {
trampoline,
export,
options,
instance,
component_instance,
..
} = store.0[self.0];
let space = &mut MaybeUninit::<ParamsAndResults<LowerParams, LowerReturn>>::uninit();
// Double-check the size/alignemnt of `space`, just in case.
//
// Note that this alone is not enough to guarantee the validity of the
// `unsafe` block below, but it's definitely required. In any case LLVM
// should be able to trivially see through these assertions and remove
// them in release mode.
let val_size = mem::size_of::<ValRaw>();
let val_align = mem::align_of::<ValRaw>();
assert!(mem::size_of_val(space) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.params)) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.ret)) % val_size == 0);
assert!(mem::align_of_val(space) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// Test the "may enter" flag which is a "lock" on this instance.
// This is immediately set to `false` afterwards and note that
// there's no on-cleanup setting this flag back to true. That's an
// intentional design aspect where if anything goes wrong internally
// from this point on the instance is considered "poisoned" and can
// never be entered again. The only time this flag is set to `true`
// again is after post-return logic has completed successfully.
if !(*flags).may_enter() {
bail!("cannot reenter component instance");
}
(*flags).set_may_enter(false);
debug_assert!((*flags).may_leave());
(*flags).set_may_leave(false);
let result = lower(store, &options, params, map_maybe_uninit!(space.params));
(*flags).set_may_leave(true);
result?;
// This is unsafe as we are providing the guarantee that all the
// inputs are valid. The various pointers passed in for the function
// are all valid since they're coming from our store, and the
// `params_and_results` should have the correct layout for the core
// wasm function we're calling. Note that this latter point relies
// on the correctness of this module and `ComponentType`
// implementations, hence `ComponentType` being an `unsafe` trait.
crate::Func::call_unchecked_raw(
store,
export.anyfunc,
trampoline,
space.as_mut_ptr().cast(),
)?;
// Note that `.assume_init_ref()` here is unsafe but we're relying
// on the correctness of the structure of `LowerReturn` and the
// type-checking performed to acquire the `TypedFunc` to make this
// safe. It should be the case that `LowerReturn` is the exact
// representation of the return value when interpreted as
// `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return
// values).
let ret = map_maybe_uninit!(space.ret).assume_init_ref();
// Lift the result into the host while managing post-return state
// here as well.
//
// After a successful lift the return value of the function, which
// is currently required to be 0 or 1 values according to the
// canonical ABI, is saved within the `Store`'s `FuncData`. This'll
// later get used in post-return.
(*flags).set_needs_post_return(true);
let val = lift(store.0, &options, ret)?;
let ret_slice = cast_storage(ret);
let data = &mut store.0[self.0];
assert!(data.post_return_arg.is_none());
match ret_slice.len() {
0 => data.post_return_arg = Some(ValRaw::i32(0)),
1 => data.post_return_arg = Some(ret_slice[0]),
_ => unreachable!(),
}
return Ok(val);
}
unsafe fn cast_storage<T>(storage: &T) -> &[ValRaw] {
assert!(std::mem::size_of_val(storage) % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<ValRaw>());
std::slice::from_raw_parts(
(storage as *const T).cast(),
mem::size_of_val(storage) / mem::size_of::<ValRaw>(),
)
}
}
/// Invokes the `post-return` canonical ABI option, if specified, after a
/// [`Func::call`] has finished.
///
/// For some more information on when to use this function see the
/// documentation for post-return in the [`Func::call`] method.
/// Otherwise though this function is a required method call after a
/// [`Func::call`] completes successfully. After the embedder has
/// finished processing the return value then this function must be invoked.
///
/// # Errors
///
/// This function will return an error in the case of a WebAssembly trap
/// happening during the execution of the `post-return` function, if
/// specified.
///
/// # Panics
///
/// This function will panic if it's not called under the correct
/// conditions. This can only be called after a previous invocation of
/// [`Func::call`] completes successfully, and this function can only
/// be called for the same [`Func`] that was `call`'d.
///
/// If this function is called when [`Func::call`] was not previously
/// called, then it will panic. If a different [`Func`] for the same
/// component instance was invoked then this function will also panic
/// because the `post-return` needs to happen for the other function.
pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> {
let mut store = store.as_context_mut();
let data = &mut store.0[self.0];
let instance = data.instance;
let post_return = data.post_return;
let component_instance = data.component_instance;
let post_return_arg = data.post_return_arg.take();
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// First assert that the instance is in a "needs post return" state.
// This will ensure that the previous action on the instance was a
// function call above. This flag is only set after a component
// function returns so this also can't be called (as expected)
// during a host import for example.
//
// Note, though, that this assert is not sufficient because it just
// means some function on this instance needs its post-return
// called. We need a precise post-return for a particular function
// which is the second assert here (the `.expect`). That will assert
// that this function itself needs to have its post-return called.
//
// The theory at least is that these two asserts ensure component
// model semantics are upheld where the host properly calls
// `post_return` on the right function despite the call being a
// separate step in the API.
assert!(
(*flags).needs_post_return(),
"post_return can only be called after a function has previously been called",
);
let post_return_arg = post_return_arg.expect("calling post_return on wrong function");
// This is a sanity-check assert which shouldn't ever trip.
assert!(!(*flags).may_enter());
// Unset the "needs post return" flag now that post-return is being
// processed. This will cause future invocations of this method to
// panic, even if the function call below traps.
(*flags).set_needs_post_return(false);
// If the function actually had a `post-return` configured in its
// canonical options that's executed here.
//
// Note that if this traps (returns an error) this function
// intentionally leaves the instance in a "poisoned" state where it
// can no longer be entered because `may_enter` is `false`.
if let Some((func, trampoline)) = post_return {
crate::Func::call_unchecked_raw(
&mut store,
func.anyfunc,
trampoline,
&post_return_arg as *const ValRaw as *mut ValRaw,
)?;
}
// And finally if everything completed successfully then the "may
// enter" flag is set to `true` again here which enables further use
// of the component.
(*flags).set_may_enter(true);
}
Ok(())
}
fn store_args<T>(
&self,
store: &mut StoreContextMut<'_, T>,
options: &Options,
params: &[Type],
args: &[Val],
dst: &mut MaybeUninit<[ValRaw; MAX_STACK_PARAMS]>,
) -> Result<()> {
let mut size = 0;
let mut alignment = 1;
for ty in params {
alignment = alignment.max(ty.size_and_alignment().alignment);
ty.next_field(&mut size);
}
let mut memory = MemoryMut::new(store.as_context_mut(), options);
let ptr = memory.realloc(0, 0, alignment, size)?;
let mut offset = ptr;
for (ty, arg) in params.iter().zip(args) {
arg.store(&mut memory, ty.next_field(&mut offset))?;
}
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
Ok(())
}
fn load_result<'a>(
mem: &Memory,
ty: &Type,
src: &mut std::slice::Iter<'_, ValRaw>,
) -> Result<Val> {
let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
// FIXME: needs to read an i64 for memory64
let ptr = usize::try_from(src.next().unwrap().get_u32())?;
if ptr % usize::try_from(alignment)? != 0 {
bail!("return pointer not aligned");
}
let bytes = mem
.as_slice()
.get(ptr..)
.and_then(|b| b.get(..size))
.ok_or_else(|| anyhow::anyhow!("pointer out of bounds of memory"))?;
Val::load(ty, mem, bytes)
}
}

View File

@@ -213,7 +213,7 @@ impl<'a, T> MemoryMut<'a, T> {
/// Like `MemoryMut` but for a read-only version that's used during lifting.
pub struct Memory<'a> {
store: &'a StoreOpaque,
pub(crate) store: &'a StoreOpaque,
options: &'a Options,
}

View File

@@ -158,14 +158,14 @@ where
// count)
if Params::flatten_count() <= MAX_STACK_PARAMS {
if Return::flatten_count() <= MAX_STACK_RESULTS {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_stack_args,
Self::lift_stack_result,
)
} else {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_stack_args,
@@ -174,14 +174,14 @@ where
}
} else {
if Return::flatten_count() <= MAX_STACK_RESULTS {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_heap_args,
Self::lift_stack_result,
)
} else {
self.call_raw(
self.func.call_raw(
store,
&params,
Self::lower_heap_args,
@@ -280,228 +280,10 @@ where
Return::load(&memory, bytes)
}
/// Invokes the underlying wasm function, lowering arguments and lifting the
/// result.
///
/// The `lower` function and `lift` function provided here are what actually
/// do the lowering and lifting. The `LowerParams` and `LowerReturn` types
/// are what will be allocated on the stack for this function call. They
/// should be appropriately sized for the lowering/lifting operation
/// happening.
fn call_raw<T, LowerParams, LowerReturn>(
&self,
store: &mut StoreContextMut<'_, T>,
params: &Params,
lower: impl FnOnce(
&mut StoreContextMut<'_, T>,
&Options,
&Params,
&mut MaybeUninit<LowerParams>,
) -> Result<()>,
lift: impl FnOnce(&StoreOpaque, &Options, &LowerReturn) -> Result<Return>,
) -> Result<Return>
where
LowerParams: Copy,
LowerReturn: Copy,
{
let super::FuncData {
trampoline,
export,
options,
instance,
component_instance,
..
} = store.0[self.func.0];
let space = &mut MaybeUninit::<ParamsAndResults<LowerParams, LowerReturn>>::uninit();
// Double-check the size/alignemnt of `space`, just in case.
//
// Note that this alone is not enough to guarantee the validity of the
// `unsafe` block below, but it's definitely required. In any case LLVM
// should be able to trivially see through these assertions and remove
// them in release mode.
let val_size = mem::size_of::<ValRaw>();
let val_align = mem::align_of::<ValRaw>();
assert!(mem::size_of_val(space) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.params)) % val_size == 0);
assert!(mem::size_of_val(map_maybe_uninit!(space.ret)) % val_size == 0);
assert!(mem::align_of_val(space) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.params)) == val_align);
assert!(mem::align_of_val(map_maybe_uninit!(space.ret)) == val_align);
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// Test the "may enter" flag which is a "lock" on this instance.
// This is immediately set to `false` afterwards and note that
// there's no on-cleanup setting this flag back to true. That's an
// intentional design aspect where if anything goes wrong internally
// from this point on the instance is considered "poisoned" and can
// never be entered again. The only time this flag is set to `true`
// again is after post-return logic has completed successfully.
if !(*flags).may_enter() {
bail!("cannot reenter component instance");
}
(*flags).set_may_enter(false);
debug_assert!((*flags).may_leave());
(*flags).set_may_leave(false);
let result = lower(store, &options, params, map_maybe_uninit!(space.params));
(*flags).set_may_leave(true);
result?;
// This is unsafe as we are providing the guarantee that all the
// inputs are valid. The various pointers passed in for the function
// are all valid since they're coming from our store, and the
// `params_and_results` should have the correct layout for the core
// wasm function we're calling. Note that this latter point relies
// on the correctness of this module and `ComponentType`
// implementations, hence `ComponentType` being an `unsafe` trait.
crate::Func::call_unchecked_raw(
store,
export.anyfunc,
trampoline,
space.as_mut_ptr().cast(),
)?;
// Note that `.assume_init_ref()` here is unsafe but we're relying
// on the correctness of the structure of `LowerReturn` and the
// type-checking performed to acquire the `TypedFunc` to make this
// safe. It should be the case that `LowerReturn` is the exact
// representation of the return value when interpreted as
// `[ValRaw]`, and additionally they should have the correct types
// for the function we just called (which filled in the return
// values).
let ret = map_maybe_uninit!(space.ret).assume_init_ref();
// Lift the result into the host while managing post-return state
// here as well.
//
// After a successful lift the return value of the function, which
// is currently required to be 0 or 1 values according to the
// canonical ABI, is saved within the `Store`'s `FuncData`. This'll
// later get used in post-return.
(*flags).set_needs_post_return(true);
let val = lift(store.0, &options, ret)?;
let ret_slice = cast_storage(ret);
let data = &mut store.0[self.func.0];
assert!(data.post_return_arg.is_none());
match ret_slice.len() {
0 => data.post_return_arg = Some(ValRaw::i32(0)),
1 => data.post_return_arg = Some(ret_slice[0]),
_ => unreachable!(),
}
return Ok(val);
}
unsafe fn cast_storage<T>(storage: &T) -> &[ValRaw] {
assert!(std::mem::size_of_val(storage) % std::mem::size_of::<ValRaw>() == 0);
assert!(std::mem::align_of_val(storage) == std::mem::align_of::<ValRaw>());
std::slice::from_raw_parts(
(storage as *const T).cast(),
mem::size_of_val(storage) / mem::size_of::<ValRaw>(),
)
}
/// See [`Func::post_return`]
pub fn post_return(&self, store: impl AsContextMut) -> Result<()> {
self.func.post_return(store)
}
/// Invokes the `post-return` canonical ABI option, if specified, after a
/// [`TypedFunc::call`] has finished.
///
/// For some more information on when to use this function see the
/// documentation for post-return in the [`TypedFunc::call`] method.
/// Otherwise though this function is a required method call after a
/// [`TypedFunc::call`] completes successfully. After the embedder has
/// finished processing the return value then this function must be invoked.
///
/// # Errors
///
/// This function will return an error in the case of a WebAssembly trap
/// happening during the execution of the `post-return` function, if
/// specified.
///
/// # Panics
///
/// This function will panic if it's not called under the correct
/// conditions. This can only be called after a previous invocation of
/// [`TypedFunc::call`] completes successfully, and this function can only
/// be called for the same [`TypedFunc`] that was `call`'d.
///
/// If this function is called when [`TypedFunc::call`] was not previously
/// called, then it will panic. If a different [`TypedFunc`] for the same
/// component instance was invoked then this function will also panic
/// because the `post-return` needs to happen for the other function.
pub fn post_return(&self, mut store: impl AsContextMut) -> Result<()> {
let mut store = store.as_context_mut();
let data = &mut store.0[self.func.0];
let instance = data.instance;
let post_return = data.post_return;
let component_instance = data.component_instance;
let post_return_arg = data.post_return_arg.take();
let instance = store.0[instance.0].as_ref().unwrap().instance();
let flags = instance.flags(component_instance);
unsafe {
// First assert that the instance is in a "needs post return" state.
// This will ensure that the previous action on the instance was a
// function call above. This flag is only set after a component
// function returns so this also can't be called (as expected)
// during a host import for example.
//
// Note, though, that this assert is not sufficient because it just
// means some function on this instance needs its post-return
// called. We need a precise post-return for a particular function
// which is the second assert here (the `.expect`). That will assert
// that this function itself needs to have its post-return called.
//
// The theory at least is that these two asserts ensure component
// model semantics are upheld where the host properly calls
// `post_return` on the right function despite the call being a
// separate step in the API.
assert!(
(*flags).needs_post_return(),
"post_return can only be called after a function has previously been called",
);
let post_return_arg = post_return_arg.expect("calling post_return on wrong function");
// This is a sanity-check assert which shouldn't ever trip.
assert!(!(*flags).may_enter());
// Unset the "needs post return" flag now that post-return is being
// processed. This will cause future invocations of this method to
// panic, even if the function call below traps.
(*flags).set_needs_post_return(false);
// If the function actually had a `post-return` configured in its
// canonical options that's executed here.
//
// Note that if this traps (returns an error) this function
// intentionally leaves the instance in a "poisoned" state where it
// can no longer be entered because `may_enter` is `false`.
if let Some((func, trampoline)) = post_return {
crate::Func::call_unchecked_raw(
&mut store,
func.anyfunc,
trampoline,
&post_return_arg as *const ValRaw as *mut ValRaw,
)?;
}
// And finally if everything completed successfully then the "may
// enter" flag is set to `true` again here which enables further use
// of the component.
(*flags).set_may_enter(true);
}
Ok(())
}
}
#[repr(C)]
union ParamsAndResults<Params: Copy, Return: Copy> {
params: Params,
ret: Return,
}
/// A trait representing a static list of parameters that can be passed to a
@@ -567,11 +349,9 @@ pub unsafe trait ComponentParams: ComponentType {
// though, that correctness bugs in this trait implementation are highly likely
// to lead to security bugs, which again leads to the `unsafe` in the trait.
//
// Also note that this trait specifically is not sealed because we'll
// eventually have a proc macro that generates implementations of this trait
// for external types in a `#[derive]`-like fashion.
//
// FIXME: need to write a #[derive(ComponentType)]
// Also note that this trait specifically is not sealed because we have a proc
// macro that generates implementations of this trait for external types in a
// `#[derive]`-like fashion.
pub unsafe trait ComponentType {
/// Representation of the "lowered" form of this component value.
///
@@ -690,7 +470,7 @@ pub unsafe trait Lift: Sized + ComponentType {
// another type, used for wrappers in Rust like `&T`, `Box<T>`, etc. Note that
// these wrappers only implement lowering because lifting native Rust types
// cannot be done.
macro_rules! forward_impls {
macro_rules! forward_type_impls {
($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($(
unsafe impl <$($generics)*> ComponentType for $a {
type Lower = <$b as ComponentType>::Lower;
@@ -703,7 +483,20 @@ macro_rules! forward_impls {
<$b as ComponentType>::typecheck(ty, types)
}
}
)*)
}
forward_type_impls! {
(T: ComponentType + ?Sized) &'_ T => T,
(T: ComponentType + ?Sized) Box<T> => T,
(T: ComponentType + ?Sized) std::rc::Rc<T> => T,
(T: ComponentType + ?Sized) std::sync::Arc<T> => T,
() String => str,
(T: ComponentType) Vec<T> => [T],
}
macro_rules! forward_lowers {
($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($(
unsafe impl <$($generics)*> Lower for $a {
fn lower<U>(
&self,
@@ -721,7 +514,7 @@ macro_rules! forward_impls {
)*)
}
forward_impls! {
forward_lowers! {
(T: Lower + ?Sized) &'_ T => T,
(T: Lower + ?Sized) Box<T> => T,
(T: Lower + ?Sized) std::rc::Rc<T> => T,
@@ -730,6 +523,50 @@ forward_impls! {
(T: Lower) Vec<T> => [T],
}
macro_rules! forward_string_lifts {
($($a:ty,)*) => ($(
unsafe impl Lift for $a {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(<WasmStr as Lift>::lift(store, options, src)?.to_str_from_store(store)?.into())
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
Ok(<WasmStr as Lift>::load(memory, bytes)?.to_str_from_store(&memory.store)?.into())
}
}
)*)
}
forward_string_lifts! {
Box<str>,
std::rc::Rc<str>,
std::sync::Arc<str>,
String,
}
macro_rules! forward_list_lifts {
($($a:ty,)*) => ($(
unsafe impl <T: Lift> Lift for $a {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
let list = <WasmList::<T> as Lift>::lift(store, options, src)?;
(0..list.len).map(|index| list.get_from_store(store, index).unwrap()).collect()
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
let list = <WasmList::<T> as Lift>::load(memory, bytes)?;
(0..list.len).map(|index| list.get_from_store(&memory.store, index).unwrap()).collect()
}
}
)*)
}
forward_list_lifts! {
Box<[T]>,
std::rc::Rc<[T]>,
std::sync::Arc<[T]>,
Vec<T>,
}
// Macro to help generate `ComponentType` implementations for primitive types
// such as integers, char, bool, etc.
macro_rules! integers {
@@ -1092,10 +929,10 @@ impl WasmStr {
// method that returns `[u16]` after validating to avoid the utf16-to-utf8
// transcode.
pub fn to_str<'a, T: 'a>(&self, store: impl Into<StoreContext<'a, T>>) -> Result<Cow<'a, str>> {
self._to_str(store.into().0)
self.to_str_from_store(store.into().0)
}
fn _to_str<'a>(&self, store: &'a StoreOpaque) -> Result<Cow<'a, str>> {
fn to_str_from_store<'a>(&self, store: &'a StoreOpaque) -> Result<Cow<'a, str>> {
match self.options.string_encoding() {
StringEncoding::Utf8 => self.decode_utf8(store),
StringEncoding::Utf16 => self.decode_utf16(store),
@@ -1289,10 +1126,10 @@ impl<T: Lift> WasmList<T> {
// should we even expose a random access iteration API? In theory all
// consumers should be validating through the iterator.
pub fn get(&self, store: impl AsContext, index: usize) -> Option<Result<T>> {
self._get(store.as_context().0, index)
self.get_from_store(store.as_context().0, index)
}
fn _get(&self, store: &StoreOpaque, index: usize) -> Option<Result<T>> {
fn get_from_store(&self, store: &StoreOpaque, index: usize) -> Option<Result<T>> {
if index >= self.len {
return None;
}
@@ -1316,7 +1153,7 @@ impl<T: Lift> WasmList<T> {
store: impl Into<StoreContext<'a, U>>,
) -> impl ExactSizeIterator<Item = Result<T>> + 'a {
let store = store.into().0;
(0..self.len).map(move |i| self._get(store, i).unwrap())
(0..self.len).map(move |i| self.get_from_store(store, i).unwrap())
}
}

View File

@@ -9,6 +9,8 @@ mod instance;
mod linker;
mod matching;
mod store;
pub mod types;
mod values;
pub use self::component::Component;
pub use self::func::{
ComponentParams, ComponentType, Func, IntoComponentFunc, Lift, Lower, TypedFunc, WasmList,
@@ -16,6 +18,8 @@ pub use self::func::{
};
pub use self::instance::{ExportInstance, Exports, Instance, InstancePre};
pub use self::linker::{Linker, LinkerInstance};
pub use self::types::Type;
pub use self::values::Val;
pub use wasmtime_component_macro::{flags, ComponentType, Lift, Lower};
// These items are expected to be used by an eventual

View File

@@ -0,0 +1,664 @@
//! This module defines the `Type` type, representing the dynamic form of a component interface type.
use crate::component::func;
use crate::component::values::{self, Val};
use anyhow::{anyhow, Result};
use std::fmt;
use std::mem;
use std::ops::Deref;
use std::sync::Arc;
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
use wasmtime_environ::component::{
ComponentTypes, InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex,
TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex,
};
#[derive(Clone)]
struct Handle<T> {
index: T,
types: Arc<ComponentTypes>,
}
impl<T: fmt::Debug> fmt::Debug for Handle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Handle")
.field("index", &self.index)
.finish()
}
}
impl<T: PartialEq> PartialEq for Handle<T> {
fn eq(&self, other: &Self) -> bool {
// FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be
// equal unless they refer to the same declaration in the same component. It's a good shortcut for the
// common case, but we should also do a recursive structural equality test if the shortcut test fails.
self.index == other.index && Arc::ptr_eq(&self.types, &other.types)
}
}
impl<T: Eq> Eq for Handle<T> {}
/// A `list` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct List(Handle<TypeInterfaceIndex>);
impl List {
/// Instantiate this type with the specified `values`.
pub fn new_val(&self, values: Box<[Val]>) -> Result<Val> {
Ok(Val::List(values::List::new(self, values)?))
}
/// Retreive the element type of this `list`.
pub fn ty(&self) -> Type {
Type::from(&self.0.types[self.0.index], &self.0.types)
}
}
/// A field declaration belonging to a `record`
pub struct Field<'a> {
/// The name of the field
pub name: &'a str,
/// The type of the field
pub ty: Type,
}
/// A `record` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Record(Handle<TypeRecordIndex>);
impl Record {
/// Instantiate this type with the specified `values`.
pub fn new_val<'a>(&self, values: impl IntoIterator<Item = (&'a str, Val)>) -> Result<Val> {
Ok(Val::Record(values::Record::new(self, values)?))
}
/// Retrieve the fields of this `record` in declaration order.
pub fn fields(&self) -> impl ExactSizeIterator<Item = Field> {
self.0.types[self.0.index].fields.iter().map(|field| Field {
name: &field.name,
ty: Type::from(&field.ty, &self.0.types),
})
}
}
/// A `tuple` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Tuple(Handle<TypeTupleIndex>);
impl Tuple {
/// Instantiate this type ith the specified `values`.
pub fn new_val(&self, values: Box<[Val]>) -> Result<Val> {
Ok(Val::Tuple(values::Tuple::new(self, values)?))
}
/// Retrieve the types of the fields of this `tuple` in declaration order.
pub fn types(&self) -> impl ExactSizeIterator<Item = Type> + '_ {
self.0.types[self.0.index]
.types
.iter()
.map(|ty| Type::from(ty, &self.0.types))
}
}
/// A case declaration belonging to a `variant`
pub struct Case<'a> {
/// The name of the case
pub name: &'a str,
/// The type of the case
pub ty: Type,
}
/// A `variant` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Variant(Handle<TypeVariantIndex>);
impl Variant {
/// Instantiate this type with the specified case `name` and `value`.
pub fn new_val(&self, name: &str, value: Val) -> Result<Val> {
Ok(Val::Variant(values::Variant::new(self, name, value)?))
}
/// Retrieve the cases of this `variant` in declaration order.
pub fn cases(&self) -> impl ExactSizeIterator<Item = Case> {
self.0.types[self.0.index].cases.iter().map(|case| Case {
name: &case.name,
ty: Type::from(&case.ty, &self.0.types),
})
}
}
/// An `enum` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Enum(Handle<TypeEnumIndex>);
impl Enum {
/// Instantiate this type with the specified case `name`.
pub fn new_val(&self, name: &str) -> Result<Val> {
Ok(Val::Enum(values::Enum::new(self, name)?))
}
/// Retrieve the names of the cases of this `enum` in declaration order.
pub fn names(&self) -> impl ExactSizeIterator<Item = &str> {
self.0.types[self.0.index]
.names
.iter()
.map(|name| name.deref())
}
}
/// A `union` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Union(Handle<TypeUnionIndex>);
impl Union {
/// Instantiate this type with the specified `discriminant` and `value`.
pub fn new_val(&self, discriminant: u32, value: Val) -> Result<Val> {
Ok(Val::Union(values::Union::new(self, discriminant, value)?))
}
/// Retrieve the types of the cases of this `union` in declaration order.
pub fn types(&self) -> impl ExactSizeIterator<Item = Type> + '_ {
self.0.types[self.0.index]
.types
.iter()
.map(|ty| Type::from(ty, &self.0.types))
}
}
/// An `option` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Option(Handle<TypeInterfaceIndex>);
impl Option {
/// Instantiate this type with the specified `value`.
pub fn new_val(&self, value: std::option::Option<Val>) -> Result<Val> {
Ok(Val::Option(values::Option::new(self, value)?))
}
/// Retrieve the type parameter for this `option`.
pub fn ty(&self) -> Type {
Type::from(&self.0.types[self.0.index], &self.0.types)
}
}
/// An `expected` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Expected(Handle<TypeExpectedIndex>);
impl Expected {
/// Instantiate this type with the specified `value`.
pub fn new_val(&self, value: Result<Val, Val>) -> Result<Val> {
Ok(Val::Expected(values::Expected::new(self, value)?))
}
/// Retrieve the `ok` type parameter for this `option`.
pub fn ok(&self) -> Type {
Type::from(&self.0.types[self.0.index].ok, &self.0.types)
}
/// Retrieve the `err` type parameter for this `option`.
pub fn err(&self) -> Type {
Type::from(&self.0.types[self.0.index].err, &self.0.types)
}
}
/// A `flags` interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Flags(Handle<TypeFlagsIndex>);
impl Flags {
/// Instantiate this type with the specified flag `names`.
pub fn new_val(&self, names: &[&str]) -> Result<Val> {
Ok(Val::Flags(values::Flags::new(self, names)?))
}
/// Retrieve the names of the flags of this `flags` type in declaration order.
pub fn names(&self) -> impl ExactSizeIterator<Item = &str> {
self.0.types[self.0.index]
.names
.iter()
.map(|name| name.deref())
}
}
/// Represents the size and alignment requirements of the heap-serialized form of a type
pub(crate) struct SizeAndAlignment {
pub(crate) size: usize,
pub(crate) alignment: u32,
}
/// Represents a component model interface type
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum Type {
/// Unit
Unit,
/// Boolean
Bool,
/// Signed 8-bit integer
S8,
/// Unsigned 8-bit integer
U8,
/// Signed 16-bit integer
S16,
/// Unsigned 16-bit integer
U16,
/// Signed 32-bit integer
S32,
/// Unsigned 32-bit integer
U32,
/// Signed 64-bit integer
S64,
/// Unsigned 64-bit integer
U64,
/// 64-bit floating point value
Float32,
/// 64-bit floating point value
Float64,
/// 32-bit character
Char,
/// Character string
String,
/// List of values
List(List),
/// Record
Record(Record),
/// Tuple
Tuple(Tuple),
/// Variant
Variant(Variant),
/// Enum
Enum(Enum),
/// Union
Union(Union),
/// Option
Option(Option),
/// Expected
Expected(Expected),
/// Bit flags
Flags(Flags),
}
impl Type {
/// Retrieve the inner [`List`] of a [`Type::List`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::List`].
pub fn unwrap_list(&self) -> &List {
if let Type::List(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a list", self.desc())
}
}
/// Retrieve the inner [`Record`] of a [`Type::Record`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Record`].
pub fn unwrap_record(&self) -> &Record {
if let Type::Record(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a record", self.desc())
}
}
/// Retrieve the inner [`Tuple`] of a [`Type::Tuple`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Tuple`].
pub fn unwrap_tuple(&self) -> &Tuple {
if let Type::Tuple(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a tuple", self.desc())
}
}
/// Retrieve the inner [`Variant`] of a [`Type::Variant`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Variant`].
pub fn unwrap_variant(&self) -> &Variant {
if let Type::Variant(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a variant", self.desc())
}
}
/// Retrieve the inner [`Enum`] of a [`Type::Enum`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Enum`].
pub fn unwrap_enum(&self) -> &Enum {
if let Type::Enum(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a enum", self.desc())
}
}
/// Retrieve the inner [`Union`] of a [`Type::Union`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Union`].
pub fn unwrap_union(&self) -> &Union {
if let Type::Union(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a union", self.desc())
}
}
/// Retrieve the inner [`Option`] of a [`Type::Option`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Option`].
pub fn unwrap_option(&self) -> &Option {
if let Type::Option(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a option", self.desc())
}
}
/// Retrieve the inner [`Expected`] of a [`Type::Expected`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Expected`].
pub fn unwrap_expected(&self) -> &Expected {
if let Type::Expected(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a expected", self.desc())
}
}
/// Retrieve the inner [`Flags`] of a [`Type::Flags`].
///
/// # Panics
///
/// This will panic if `self` is not a [`Type::Flags`].
pub fn unwrap_flags(&self) -> &Flags {
if let Type::Flags(handle) = self {
&handle
} else {
panic!("attempted to unwrap a {} as a flags", self.desc())
}
}
pub(crate) fn check(&self, value: &Val) -> Result<()> {
let other = &value.ty();
if self == other {
Ok(())
} else if mem::discriminant(self) != mem::discriminant(other) {
Err(anyhow!(
"type mismatch: expected {}, got {}",
self.desc(),
other.desc()
))
} else {
Err(anyhow!(
"type mismatch for {}, possibly due to mixing distinct composite types",
self.desc()
))
}
}
/// Convert the specified `InterfaceType` to a `Type`.
pub(crate) fn from(ty: &InterfaceType, types: &Arc<ComponentTypes>) -> Self {
match ty {
InterfaceType::Unit => Type::Unit,
InterfaceType::Bool => Type::Bool,
InterfaceType::S8 => Type::S8,
InterfaceType::U8 => Type::U8,
InterfaceType::S16 => Type::S16,
InterfaceType::U16 => Type::U16,
InterfaceType::S32 => Type::S32,
InterfaceType::U32 => Type::U32,
InterfaceType::S64 => Type::S64,
InterfaceType::U64 => Type::U64,
InterfaceType::Float32 => Type::Float32,
InterfaceType::Float64 => Type::Float64,
InterfaceType::Char => Type::Char,
InterfaceType::String => Type::String,
InterfaceType::List(index) => Type::List(List(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Record(index) => Type::Record(Record(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Tuple(index) => Type::Tuple(Tuple(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Variant(index) => Type::Variant(Variant(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Enum(index) => Type::Enum(Enum(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Union(index) => Type::Union(Union(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Option(index) => Type::Option(Option(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Expected(index) => Type::Expected(Expected(Handle {
index: *index,
types: types.clone(),
})),
InterfaceType::Flags(index) => Type::Flags(Flags(Handle {
index: *index,
types: types.clone(),
})),
}
}
/// Return the number of stack slots needed to store values of this type in lowered form.
pub(crate) fn flatten_count(&self) -> usize {
match self {
Type::Unit => 0,
Type::Bool
| Type::S8
| Type::U8
| Type::S16
| Type::U16
| Type::S32
| Type::U32
| Type::S64
| Type::U64
| Type::Float32
| Type::Float64
| Type::Char
| Type::Enum(_) => 1,
Type::String | Type::List(_) => 2,
Type::Record(handle) => handle.fields().map(|field| field.ty.flatten_count()).sum(),
Type::Tuple(handle) => handle.types().map(|ty| ty.flatten_count()).sum(),
Type::Variant(handle) => {
1 + handle
.cases()
.map(|case| case.ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Union(handle) => {
1 + handle
.types()
.map(|ty| ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Option(handle) => 1 + handle.ty().flatten_count(),
Type::Expected(handle) => {
1 + handle
.ok()
.flatten_count()
.max(handle.err().flatten_count())
}
Type::Flags(handle) => values::u32_count_for_flag_count(handle.names().len()),
}
}
fn desc(&self) -> &'static str {
match self {
Type::Unit => "unit",
Type::Bool => "bool",
Type::S8 => "s8",
Type::U8 => "u8",
Type::S16 => "s16",
Type::U16 => "u16",
Type::S32 => "s32",
Type::U32 => "u32",
Type::S64 => "s64",
Type::U64 => "u64",
Type::Float32 => "float32",
Type::Float64 => "float64",
Type::Char => "char",
Type::String => "string",
Type::List(_) => "list",
Type::Record(_) => "record",
Type::Tuple(_) => "tuple",
Type::Variant(_) => "variant",
Type::Enum(_) => "enum",
Type::Union(_) => "union",
Type::Option(_) => "option",
Type::Expected(_) => "expected",
Type::Flags(_) => "flags",
}
}
/// Calculate the size and alignment requirements for the specified type.
pub(crate) fn size_and_alignment(&self) -> SizeAndAlignment {
match self {
Type::Unit => SizeAndAlignment {
size: 0,
alignment: 1,
},
Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment {
size: 1,
alignment: 1,
},
Type::S16 | Type::U16 => SizeAndAlignment {
size: 2,
alignment: 2,
},
Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment {
size: 4,
alignment: 4,
},
Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment {
size: 8,
alignment: 8,
},
Type::String | Type::List(_) => SizeAndAlignment {
size: 8,
alignment: 4,
},
Type::Record(handle) => {
record_size_and_alignment(handle.fields().map(|field| field.ty))
}
Type::Tuple(handle) => record_size_and_alignment(handle.types()),
Type::Variant(handle) => variant_size_and_alignment(handle.cases().map(|case| case.ty)),
Type::Enum(handle) => variant_size_and_alignment(handle.names().map(|_| Type::Unit)),
Type::Union(handle) => variant_size_and_alignment(handle.types()),
Type::Option(handle) => {
variant_size_and_alignment([Type::Unit, handle.ty()].into_iter())
}
Type::Expected(handle) => {
variant_size_and_alignment([handle.ok(), handle.err()].into_iter())
}
Type::Flags(handle) => match FlagsSize::from_count(handle.names().len()) {
FlagsSize::Size1 => SizeAndAlignment {
size: 1,
alignment: 1,
},
FlagsSize::Size2 => SizeAndAlignment {
size: 2,
alignment: 2,
},
FlagsSize::Size4Plus(n) => SizeAndAlignment {
size: n * 4,
alignment: 4,
},
},
}
}
/// Calculate the aligned offset of a field of this type, updating `offset` to point to just after that field.
pub(crate) fn next_field(&self, offset: &mut usize) -> usize {
let SizeAndAlignment { size, alignment } = self.size_and_alignment();
*offset = func::align_to(*offset, alignment);
let result = *offset;
*offset += size;
result
}
}
fn record_size_and_alignment(types: impl Iterator<Item = Type>) -> SizeAndAlignment {
let mut offset = 0;
let mut align = 1;
for ty in types {
let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
offset = func::align_to(offset, alignment) + size;
align = align.max(alignment);
}
SizeAndAlignment {
size: func::align_to(offset, align),
alignment: align,
}
}
fn variant_size_and_alignment(types: impl ExactSizeIterator<Item = Type>) -> SizeAndAlignment {
let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
let mut alignment = u32::from(discriminant_size);
let mut size = 0;
for ty in types {
let size_and_alignment = ty.size_and_alignment();
alignment = alignment.max(size_and_alignment.alignment);
size = size.max(size_and_alignment.size);
}
SizeAndAlignment {
size: func::align_to(usize::from(discriminant_size), alignment) + size,
alignment,
}
}

View File

@@ -0,0 +1,908 @@
use crate::component::func::{self, Lift, Lower, Memory, MemoryMut, Options};
use crate::component::types::{self, SizeAndAlignment, Type};
use crate::store::StoreOpaque;
use crate::{AsContextMut, StoreContextMut, ValRaw};
use anyhow::{anyhow, bail, Context, Error, Result};
use std::collections::HashMap;
use std::iter;
use std::mem::MaybeUninit;
use std::ops::Deref;
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct List {
ty: types::List,
values: Box<[Val]>,
}
impl List {
/// Instantiate the specified type with the specified `values`.
pub fn new(ty: &types::List, values: Box<[Val]>) -> Result<Self> {
let element_type = ty.ty();
for (index, value) in values.iter().enumerate() {
element_type
.check(value)
.with_context(|| format!("type mismatch for element {index} of list"))?;
}
Ok(Self {
ty: ty.clone(),
values,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Record {
ty: types::Record,
values: Box<[Val]>,
}
impl Record {
/// Instantiate the specified type with the specified `values`.
pub fn new<'a>(
ty: &types::Record,
values: impl IntoIterator<Item = (&'a str, Val)>,
) -> Result<Self> {
let mut fields = ty.fields();
let expected_len = fields.len();
let mut iter = values.into_iter();
let mut values = Vec::with_capacity(expected_len);
loop {
match (fields.next(), iter.next()) {
(Some(field), Some((name, value))) => {
if name == field.name {
field
.ty
.check(&value)
.with_context(|| format!("type mismatch for field {name} of record"))?;
values.push(value);
} else {
bail!("field name mismatch: expected {}; got {name}", field.name)
}
}
(None, Some((_, value))) => values.push(value),
_ => break,
}
}
if values.len() != expected_len {
bail!("expected {} value(s); got {}", expected_len, values.len());
}
Ok(Self {
ty: ty.clone(),
values: values.into(),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Tuple {
ty: types::Tuple,
values: Box<[Val]>,
}
impl Tuple {
/// Instantiate the specified type ith the specified `values`.
pub fn new(ty: &types::Tuple, values: Box<[Val]>) -> Result<Self> {
if values.len() != ty.types().len() {
bail!(
"expected {} value(s); got {}",
ty.types().len(),
values.len()
);
}
for (index, (value, ty)) in values.iter().zip(ty.types()).enumerate() {
ty.check(value)
.with_context(|| format!("type mismatch for field {index} of tuple"))?;
}
Ok(Self {
ty: ty.clone(),
values,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Variant {
ty: types::Variant,
discriminant: u32,
value: Box<Val>,
}
impl Variant {
/// Instantiate the specified type with the specified case `name` and `value`.
pub fn new(ty: &types::Variant, name: &str, value: Val) -> Result<Self> {
let (discriminant, case_type) = ty
.cases()
.enumerate()
.find_map(|(index, case)| {
if case.name == name {
Some((index, case.ty))
} else {
None
}
})
.ok_or_else(|| anyhow!("unknown variant case: {name}"))?;
case_type
.check(&value)
.with_context(|| format!("type mismatch for case {name} of variant"))?;
Ok(Self {
ty: ty.clone(),
discriminant: u32::try_from(discriminant)?,
value: Box::new(value),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Enum {
ty: types::Enum,
discriminant: u32,
}
impl Enum {
/// Instantiate the specified type with the specified case `name`.
pub fn new(ty: &types::Enum, name: &str) -> Result<Self> {
let discriminant = u32::try_from(
ty.names()
.position(|n| n == name)
.ok_or_else(|| anyhow!("unknown enum case: {name}"))?,
)?;
Ok(Self {
ty: ty.clone(),
discriminant,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Union {
ty: types::Union,
discriminant: u32,
value: Box<Val>,
}
impl Union {
/// Instantiate the specified type with the specified `discriminant` and `value`.
pub fn new(ty: &types::Union, discriminant: u32, value: Val) -> Result<Self> {
if let Some(case_ty) = ty.types().nth(usize::try_from(discriminant)?) {
case_ty
.check(&value)
.with_context(|| format!("type mismatch for case {discriminant} of union"))?;
Ok(Self {
ty: ty.clone(),
discriminant,
value: Box::new(value),
})
} else {
Err(anyhow!(
"discriminant {discriminant} out of range: [0,{})",
ty.types().len()
))
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Option {
ty: types::Option,
discriminant: u32,
value: Box<Val>,
}
impl Option {
/// Instantiate the specified type with the specified `value`.
pub fn new(ty: &types::Option, value: std::option::Option<Val>) -> Result<Self> {
let value = value
.map(|value| {
ty.ty().check(&value).context("type mismatch for option")?;
Ok::<_, Error>(value)
})
.transpose()?;
Ok(Self {
ty: ty.clone(),
discriminant: if value.is_none() { 0 } else { 1 },
value: Box::new(value.unwrap_or(Val::Unit)),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Expected {
ty: types::Expected,
discriminant: u32,
value: Box<Val>,
}
impl Expected {
/// Instantiate the specified type with the specified `value`.
pub fn new(ty: &types::Expected, value: Result<Val, Val>) -> Result<Self> {
Ok(Self {
ty: ty.clone(),
discriminant: if value.is_ok() { 0 } else { 1 },
value: Box::new(match value {
Ok(value) => {
ty.ok()
.check(&value)
.context("type mismatch for ok case of expected")?;
value
}
Err(value) => {
ty.err()
.check(&value)
.context("type mismatch for err case of expected")?;
value
}
}),
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Flags {
ty: types::Flags,
count: u32,
value: Box<[u32]>,
}
impl Flags {
/// Instantiate the specified type with the specified flag `names`.
pub fn new(ty: &types::Flags, names: &[&str]) -> Result<Self> {
let map = ty
.names()
.enumerate()
.map(|(index, name)| (name, index))
.collect::<HashMap<_, _>>();
let mut values = vec![0_u32; u32_count_for_flag_count(ty.names().len())];
for name in names {
let index = map
.get(name)
.ok_or_else(|| anyhow!("unknown flag: {name}"))?;
values[index / 32] |= 1 << (index % 32);
}
Ok(Self {
ty: ty.clone(),
count: u32::try_from(map.len())?,
value: values.into(),
})
}
}
/// Represents possible runtime values which a component function can either consume or produce
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Val {
/// Unit
Unit,
/// Boolean
Bool(bool),
/// Signed 8-bit integer
S8(i8),
/// Unsigned 8-bit integer
U8(u8),
/// Signed 16-bit integer
S16(i16),
/// Unsigned 16-bit integer
U16(u16),
/// Signed 32-bit integer
S32(i32),
/// Unsigned 32-bit integer
U32(u32),
/// Signed 64-bit integer
S64(i64),
/// Unsigned 64-bit integer
U64(u64),
/// 32-bit floating point value
Float32(u32),
/// 64-bit floating point value
Float64(u64),
/// 32-bit character
Char(char),
/// Character string
String(Box<str>),
/// List of values
List(List),
/// Record
Record(Record),
/// Tuple
Tuple(Tuple),
/// Variant
Variant(Variant),
/// Enum
Enum(Enum),
/// Union
Union(Union),
/// Option
Option(Option),
/// Expected
Expected(Expected),
/// Bit flags
Flags(Flags),
}
impl Val {
/// Retrieve the [`Type`] of this value.
pub fn ty(&self) -> Type {
match self {
Val::Unit => Type::Unit,
Val::Bool(_) => Type::Bool,
Val::S8(_) => Type::S8,
Val::U8(_) => Type::U8,
Val::S16(_) => Type::S16,
Val::U16(_) => Type::U16,
Val::S32(_) => Type::S32,
Val::U32(_) => Type::U32,
Val::S64(_) => Type::S64,
Val::U64(_) => Type::U64,
Val::Float32(_) => Type::Float32,
Val::Float64(_) => Type::Float64,
Val::Char(_) => Type::Char,
Val::String(_) => Type::String,
Val::List(List { ty, .. }) => Type::List(ty.clone()),
Val::Record(Record { ty, .. }) => Type::Record(ty.clone()),
Val::Tuple(Tuple { ty, .. }) => Type::Tuple(ty.clone()),
Val::Variant(Variant { ty, .. }) => Type::Variant(ty.clone()),
Val::Enum(Enum { ty, .. }) => Type::Enum(ty.clone()),
Val::Union(Union { ty, .. }) => Type::Union(ty.clone()),
Val::Option(Option { ty, .. }) => Type::Option(ty.clone()),
Val::Expected(Expected { ty, .. }) => Type::Expected(ty.clone()),
Val::Flags(Flags { ty, .. }) => Type::Flags(ty.clone()),
}
}
/// Deserialize a value of this type from core Wasm stack values.
pub(crate) fn lift<'a>(
ty: &Type,
store: &StoreOpaque,
options: &Options,
src: &mut std::slice::Iter<'_, ValRaw>,
) -> Result<Val> {
Ok(match ty {
Type::Unit => Val::Unit,
Type::Bool => Val::Bool(bool::lift(store, options, next(src))?),
Type::S8 => Val::S8(i8::lift(store, options, next(src))?),
Type::U8 => Val::U8(u8::lift(store, options, next(src))?),
Type::S16 => Val::S16(i16::lift(store, options, next(src))?),
Type::U16 => Val::U16(u16::lift(store, options, next(src))?),
Type::S32 => Val::S32(i32::lift(store, options, next(src))?),
Type::U32 => Val::U32(u32::lift(store, options, next(src))?),
Type::S64 => Val::S64(i64::lift(store, options, next(src))?),
Type::U64 => Val::U64(u64::lift(store, options, next(src))?),
Type::Float32 => Val::Float32(u32::lift(store, options, next(src))?),
Type::Float64 => Val::Float64(u64::lift(store, options, next(src))?),
Type::Char => Val::Char(char::lift(store, options, next(src))?),
Type::String => {
Val::String(Box::<str>::lift(store, options, &[*next(src), *next(src)])?)
}
Type::List(handle) => {
// FIXME: needs memory64 treatment
let ptr = u32::lift(store, options, next(src))? as usize;
let len = u32::lift(store, options, next(src))? as usize;
load_list(handle, &Memory::new(store, options), ptr, len)?
}
Type::Record(handle) => Val::Record(Record {
ty: handle.clone(),
values: handle
.fields()
.map(|field| Self::lift(&field.ty, store, options, src))
.collect::<Result<_>>()?,
}),
Type::Tuple(handle) => Val::Tuple(Tuple {
ty: handle.clone(),
values: handle
.types()
.map(|ty| Self::lift(&ty, store, options, src))
.collect::<Result<_>>()?,
}),
Type::Variant(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
handle.cases().map(|case| case.ty),
store,
options,
src,
)?;
Val::Variant(Variant {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Enum(handle) => {
let (discriminant, _) = lift_variant(
ty.flatten_count(),
handle.names().map(|_| Type::Unit),
store,
options,
src,
)?;
Val::Enum(Enum {
ty: handle.clone(),
discriminant,
})
}
Type::Union(handle) => {
let (discriminant, value) =
lift_variant(ty.flatten_count(), handle.types(), store, options, src)?;
Val::Union(Union {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Option(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
[Type::Unit, handle.ty()].into_iter(),
store,
options,
src,
)?;
Val::Option(Option {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Expected(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
[handle.ok(), handle.err()].into_iter(),
store,
options,
src,
)?;
Val::Expected(Expected {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Flags(handle) => {
let count = u32::try_from(handle.names().len()).unwrap();
assert!(count <= 32);
let value = iter::once(u32::lift(store, options, next(src))?).collect();
Val::Flags(Flags {
ty: handle.clone(),
count,
value,
})
}
})
}
/// Deserialize a value of this type from the heap.
pub(crate) fn load(ty: &Type, mem: &Memory, bytes: &[u8]) -> Result<Val> {
Ok(match ty {
Type::Unit => Val::Unit,
Type::Bool => Val::Bool(bool::load(mem, bytes)?),
Type::S8 => Val::S8(i8::load(mem, bytes)?),
Type::U8 => Val::U8(u8::load(mem, bytes)?),
Type::S16 => Val::S16(i16::load(mem, bytes)?),
Type::U16 => Val::U16(u16::load(mem, bytes)?),
Type::S32 => Val::S32(i32::load(mem, bytes)?),
Type::U32 => Val::U32(u32::load(mem, bytes)?),
Type::S64 => Val::S64(i64::load(mem, bytes)?),
Type::U64 => Val::U64(u64::load(mem, bytes)?),
Type::Float32 => Val::Float32(u32::load(mem, bytes)?),
Type::Float64 => Val::Float64(u64::load(mem, bytes)?),
Type::Char => Val::Char(char::load(mem, bytes)?),
Type::String => Val::String(Box::<str>::load(mem, bytes)?),
Type::List(handle) => {
// FIXME: needs memory64 treatment
let ptr = u32::from_le_bytes(bytes[..4].try_into().unwrap()) as usize;
let len = u32::from_le_bytes(bytes[4..].try_into().unwrap()) as usize;
load_list(handle, mem, ptr, len)?
}
Type::Record(handle) => Val::Record(Record {
ty: handle.clone(),
values: load_record(handle.fields().map(|field| field.ty), mem, bytes)?,
}),
Type::Tuple(handle) => Val::Tuple(Tuple {
ty: handle.clone(),
values: load_record(handle.types(), mem, bytes)?,
}),
Type::Variant(handle) => {
let (discriminant, value) =
load_variant(ty, handle.cases().map(|case| case.ty), mem, bytes)?;
Val::Variant(Variant {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Enum(handle) => {
let (discriminant, _) =
load_variant(ty, handle.names().map(|_| Type::Unit), mem, bytes)?;
Val::Enum(Enum {
ty: handle.clone(),
discriminant,
})
}
Type::Union(handle) => {
let (discriminant, value) = load_variant(ty, handle.types(), mem, bytes)?;
Val::Union(Union {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Option(handle) => {
let (discriminant, value) =
load_variant(ty, [Type::Unit, handle.ty()].into_iter(), mem, bytes)?;
Val::Option(Option {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Expected(handle) => {
let (discriminant, value) =
load_variant(ty, [handle.ok(), handle.err()].into_iter(), mem, bytes)?;
Val::Expected(Expected {
ty: handle.clone(),
discriminant,
value: Box::new(value),
})
}
Type::Flags(handle) => Val::Flags(Flags {
ty: handle.clone(),
count: u32::try_from(handle.names().len())?,
value: match FlagsSize::from_count(handle.names().len()) {
FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(),
FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(),
FlagsSize::Size4Plus(n) => (0..n)
.map(|index| u32::load(mem, &bytes[index * 4..][..4]))
.collect::<Result<_>>()?,
},
}),
})
}
/// Serialize this value as core Wasm stack values.
pub(crate) fn lower<T>(
&self,
store: &mut StoreContextMut<T>,
options: &Options,
dst: &mut std::slice::IterMut<'_, MaybeUninit<ValRaw>>,
) -> Result<()> {
match self {
Val::Unit => (),
Val::Bool(value) => value.lower(store, options, next_mut(dst))?,
Val::S8(value) => value.lower(store, options, next_mut(dst))?,
Val::U8(value) => value.lower(store, options, next_mut(dst))?,
Val::S16(value) => value.lower(store, options, next_mut(dst))?,
Val::U16(value) => value.lower(store, options, next_mut(dst))?,
Val::S32(value) => value.lower(store, options, next_mut(dst))?,
Val::U32(value) => value.lower(store, options, next_mut(dst))?,
Val::S64(value) => value.lower(store, options, next_mut(dst))?,
Val::U64(value) => value.lower(store, options, next_mut(dst))?,
Val::Float32(value) => value.lower(store, options, next_mut(dst))?,
Val::Float64(value) => value.lower(store, options, next_mut(dst))?,
Val::Char(value) => value.lower(store, options, next_mut(dst))?,
Val::String(value) => {
let my_dst = &mut MaybeUninit::<[ValRaw; 2]>::uninit();
value.lower(store, options, my_dst)?;
let my_dst = unsafe { my_dst.assume_init() };
next_mut(dst).write(my_dst[0]);
next_mut(dst).write(my_dst[1]);
}
Val::List(List { values, ty }) => {
let (ptr, len) = lower_list(
&ty.ty(),
&mut MemoryMut::new(store.as_context_mut(), options),
values,
)?;
next_mut(dst).write(ValRaw::i64(ptr as i64));
next_mut(dst).write(ValRaw::i64(len as i64));
}
Val::Record(Record { values, .. }) | Val::Tuple(Tuple { values, .. }) => {
for value in values.deref() {
value.lower(store, options, dst)?;
}
}
Val::Variant(Variant {
discriminant,
value,
..
})
| Val::Union(Union {
discriminant,
value,
..
})
| Val::Option(Option {
discriminant,
value,
..
})
| Val::Expected(Expected {
discriminant,
value,
..
}) => {
next_mut(dst).write(ValRaw::u32(*discriminant));
value.lower(store, options, dst)?;
for _ in (1 + value.ty().flatten_count())..self.ty().flatten_count() {
next_mut(dst).write(ValRaw::u32(0));
}
}
Val::Enum(Enum { discriminant, .. }) => {
next_mut(dst).write(ValRaw::u32(*discriminant));
}
Val::Flags(Flags { value, .. }) => {
for value in value.deref() {
next_mut(dst).write(ValRaw::u32(*value));
}
}
}
Ok(())
}
/// Serialize this value to the heap at the specified memory location.
pub(crate) fn store<T>(&self, mem: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> {
debug_assert!(offset % usize::try_from(self.ty().size_and_alignment().alignment)? == 0);
match self {
Val::Unit => (),
Val::Bool(value) => value.store(mem, offset)?,
Val::S8(value) => value.store(mem, offset)?,
Val::U8(value) => value.store(mem, offset)?,
Val::S16(value) => value.store(mem, offset)?,
Val::U16(value) => value.store(mem, offset)?,
Val::S32(value) => value.store(mem, offset)?,
Val::U32(value) => value.store(mem, offset)?,
Val::S64(value) => value.store(mem, offset)?,
Val::U64(value) => value.store(mem, offset)?,
Val::Float32(value) => value.store(mem, offset)?,
Val::Float64(value) => value.store(mem, offset)?,
Val::Char(value) => value.store(mem, offset)?,
Val::String(value) => value.store(mem, offset)?,
Val::List(List { values, ty }) => {
let (ptr, len) = lower_list(&ty.ty(), mem, values)?;
// FIXME: needs memory64 handling
*mem.get(offset + 0) = (ptr as i32).to_le_bytes();
*mem.get(offset + 4) = (len as i32).to_le_bytes();
}
Val::Record(Record { values, .. }) | Val::Tuple(Tuple { values, .. }) => {
let mut offset = offset;
for value in values.deref() {
value.store(mem, value.ty().next_field(&mut offset))?;
}
}
Val::Variant(Variant {
discriminant,
value,
ty,
}) => self.store_variant(*discriminant, value, ty.cases().len(), mem, offset)?,
Val::Enum(Enum { discriminant, ty }) => {
self.store_variant(*discriminant, &Val::Unit, ty.names().len(), mem, offset)?
}
Val::Union(Union {
discriminant,
value,
ty,
}) => self.store_variant(*discriminant, value, ty.types().len(), mem, offset)?,
Val::Option(Option {
discriminant,
value,
..
})
| Val::Expected(Expected {
discriminant,
value,
..
}) => self.store_variant(*discriminant, value, 2, mem, offset)?,
Val::Flags(Flags { count, value, .. }) => {
match FlagsSize::from_count(*count as usize) {
FlagsSize::Size1 => u8::try_from(value[0]).unwrap().store(mem, offset)?,
FlagsSize::Size2 => u16::try_from(value[0]).unwrap().store(mem, offset)?,
FlagsSize::Size4Plus(_) => {
let mut offset = offset;
for value in value.deref() {
value.store(mem, offset)?;
offset += 4;
}
}
}
}
}
Ok(())
}
fn store_variant<T>(
&self,
discriminant: u32,
value: &Val,
case_count: usize,
mem: &mut MemoryMut<'_, T>,
offset: usize,
) -> Result<()> {
let discriminant_size = DiscriminantSize::from_count(case_count).unwrap();
match discriminant_size {
DiscriminantSize::Size1 => u8::try_from(discriminant).unwrap().store(mem, offset)?,
DiscriminantSize::Size2 => u16::try_from(discriminant).unwrap().store(mem, offset)?,
DiscriminantSize::Size4 => (discriminant).store(mem, offset)?,
}
value.store(
mem,
offset
+ func::align_to(
discriminant_size.into(),
self.ty().size_and_alignment().alignment,
),
)
}
}
fn load_list(handle: &types::List, mem: &Memory, ptr: usize, len: usize) -> Result<Val> {
let element_type = handle.ty();
let SizeAndAlignment {
size: element_size,
alignment: element_alignment,
} = element_type.size_and_alignment();
match len
.checked_mul(element_size)
.and_then(|len| ptr.checked_add(len))
{
Some(n) if n <= mem.as_slice().len() => {}
_ => bail!("list pointer/length out of bounds of memory"),
}
if ptr % usize::try_from(element_alignment)? != 0 {
bail!("list pointer is not aligned")
}
Ok(Val::List(List {
ty: handle.clone(),
values: (0..len)
.map(|index| {
Val::load(
&element_type,
mem,
&mem.as_slice()[ptr + (index * element_size)..][..element_size],
)
})
.collect::<Result<_>>()?,
}))
}
fn load_record(
types: impl Iterator<Item = Type>,
mem: &Memory,
bytes: &[u8],
) -> Result<Box<[Val]>> {
let mut offset = 0;
types
.map(|ty| {
Val::load(
&ty,
mem,
&bytes[ty.next_field(&mut offset)..][..ty.size_and_alignment().size],
)
})
.collect()
}
fn load_variant(
ty: &Type,
mut types: impl ExactSizeIterator<Item = Type>,
mem: &Memory,
bytes: &[u8],
) -> Result<(u32, Val)> {
let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
let discriminant = match discriminant_size {
DiscriminantSize::Size1 => u8::load(mem, &bytes[..1])? as u32,
DiscriminantSize::Size2 => u16::load(mem, &bytes[..2])? as u32,
DiscriminantSize::Size4 => u32::load(mem, &bytes[..4])?,
};
let case_ty = types.nth(discriminant as usize).ok_or_else(|| {
anyhow!(
"discriminant {} out of range [0..{})",
discriminant,
types.len()
)
})?;
let value = Val::load(
&case_ty,
mem,
&bytes[func::align_to(
usize::from(discriminant_size),
ty.size_and_alignment().alignment,
)..][..case_ty.size_and_alignment().size],
)?;
Ok((discriminant, value))
}
fn lift_variant<'a>(
flatten_count: usize,
mut types: impl ExactSizeIterator<Item = Type>,
store: &StoreOpaque,
options: &Options,
src: &mut std::slice::Iter<'_, ValRaw>,
) -> Result<(u32, Val)> {
let len = types.len();
let discriminant = next(src).get_u32();
let ty = types
.nth(discriminant as usize)
.ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?;
let value = Val::lift(&ty, store, options, src)?;
for _ in (1 + ty.flatten_count())..flatten_count {
next(src);
}
Ok((discriminant, value))
}
/// Lower a list with the specified element type and values.
fn lower_list<T>(
element_type: &Type,
mem: &mut MemoryMut<'_, T>,
items: &[Val],
) -> Result<(usize, usize)> {
let SizeAndAlignment {
size: element_size,
alignment: element_alignment,
} = element_type.size_and_alignment();
let size = items
.len()
.checked_mul(element_size)
.ok_or_else(|| anyhow::anyhow!("size overflow copying a list"))?;
let ptr = mem.realloc(0, 0, element_alignment, size)?;
let mut element_ptr = ptr;
for item in items {
item.store(mem, element_ptr)?;
element_ptr += element_size;
}
Ok((ptr, items.len()))
}
/// Calculate the size of a u32 array needed to represent the specified number of bit flags.
///
/// Note that this will always return at least 1, even if the `count` parameter is zero.
pub(crate) fn u32_count_for_flag_count(count: usize) -> usize {
match FlagsSize::from_count(count) {
FlagsSize::Size1 | FlagsSize::Size2 => 1,
FlagsSize::Size4Plus(n) => n,
}
}
fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw {
src.next().unwrap()
}
fn next_mut<'a>(
dst: &mut std::slice::IterMut<'a, MaybeUninit<ValRaw>>,
) -> &'a mut MaybeUninit<ValRaw> {
dst.next().unwrap()
}